点分治

点分治

点分治常用于解决树上路径统计问题。它的基本思路是:选定某个点 uu,将路径划分为经过 uu 和不经过 uu 两类。对于前一种,我们直接处理。对于后一种,我们递归,在子树中处理。

首先要解决的是递归的层数,考虑一条链,如果从链的一端开始,每次向子结点递归,那么递归层数是 O(n)O(n) 的。解决方法是每次选取重心进行递归:即先处理整棵树的重心,然后递归处理每棵子树的重心。这样每次子树大小至少减少一半,所以递归层数是 O(logn)O(\log n) 级别的。

然后我们要解决的问题就是:选定某个根,统计经过它的路径数量。这个因题而异,没有统一的方法。一种比较常见的思路是逐棵子树统计,并合并答案。我们以一道例题来具体说明。


例 P3806 【模板】点分治1

题目链接:https://www.luogu.com.cn/problem/P3806

题意:

给定一棵有 nn 个点的树。mm 次询问,每次询问树上是否存在距离为 kk 的点对。1n104,1m100,1k1071\leq n\leq 10^4,1\leq m\leq 100,1\leq k\leq 10^7

题解:

我们考虑对于给定的根,如何解决。显然,我们可以遍历每棵子树,求出当前子树内的点到根的距离,对于每个距离 dd,枚举询问,判断前面的子树中是否有距离为 kdk-d 的点,若存在,则这两个点距离为 kk 且经过根,符合。最后把所有点的距离保存下来,处理下一棵子树。

这样单个点的时间复杂度为 O(nm)O(nm),总的时间复杂度为 O(nmlogn)O(nm\log n)

然后看一下具体实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
mx[rt=0]=inf,sum=n;
getsiz(1,0);getsiz(rt,0);

void getsiz(int u,int fa)
{
siz[u]=1,mx[u]=0;
for(auto i:g[u])
{
int v=i.v;
if(v==fa||vis[v]) continue;
getsiz(v,u);
siz[u]+=siz[v];
mx[u]=max(mx[u],siz[v]);
}
mx[u]=max(mx[u],sum-siz[u]);
if(mx[u]<mx[rt]) rt=u;
}

这里 sum 表示当前处理的子树大小,getsiz 函数用于求出当前子树的重心,做两遍的目的是求出以重心为根时每棵子树的大小,这样分治处理的时候可以更新sumvis数组的作用是标记已经处理过的点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
dfz(rt,0);

void getdis(int u,int fa)
{
if(d[u]<=1e7) vec.push_back(d[u]);
for(auto i:g[u])
{
int v=i.v,w=i.w;
if(v==fa||vis[v]) continue;
d[v]=d[u]+w;
getdis(v,u);
}
}
void dfz(int u,int fa)
{
exist[0]=1,vis[u]=1;
for(auto i:g[u])
{
int v=i.v,w=i.w;
if(v==fa||vis[v]) continue;
d[v]=w;
getdis(v,u);
for(auto x:vec)
{
for(int j=1;j<=m;j++)
if(q[j]>=x) ans[j]|=exist[q[j]-x];

}
for(auto x:vec)
{
q1.push(x);exist[x]=1;
}
vec.clear();
}
while(!q1.empty())
{
int x=q1.front();q1.pop();
exist[x]=0;
}
for(auto i:g[u])
{
int v=i.v;
if(v==fa||vis[v]) continue;
mx[rt=0]=inf,sum=siz[v];
getsiz(v,u);getsiz(rt,0);dfz(rt,0);
}
}

getdist函数的作用是计算每棵子树内的点到根结点的距离,并保存在vec中。在点分治的过程中,我们标记已经处理过的点,遍历每棵子树,得到vec。然后根据exist(记录前面子树中已经出现距离)更新答案。最后将vec中的距离写入exist中,直到所有子树都遍历完成。

当该点处理完毕后,exist数组需要清空,所以我们可以用一个队列q1来记录处理该点时,一共有哪些距离被加入。(直接暴力清空会TLE)

最后,我们递归处理每一棵子树。

附上完整代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include <bits/stdc++.h>
using namespace std;
const int maxn=10005;
const int maxm=1e7+5;
const int inf=0x3f3f3f3f;
struct edge
{
int v,w;
};
vector<edge> g[maxn];
int rt,sum,siz[maxn],mx[maxn];
int n,m,q[105],d[maxn];
bool ans[105],vis[maxn],exist[maxm];
queue<int> q1;
vector<int> vec;
void getsiz(int u,int fa)
{
siz[u]=1,mx[u]=0;
for(auto i:g[u])
{
int v=i.v;
if(v==fa||vis[v]) continue;
getsiz(v,u);
siz[u]+=siz[v];
mx[u]=max(mx[u],siz[v]);
}
mx[u]=max(mx[u],sum-siz[u]);
if(mx[u]<mx[rt]) rt=u;
}
void getdis(int u,int fa)
{
if(d[u]<=1e7) vec.push_back(d[u]);
for(auto i:g[u])
{
int v=i.v,w=i.w;
if(v==fa||vis[v]) continue;
d[v]=d[u]+w;
getdis(v,u);
}
}
void dfz(int u,int fa)
{
exist[0]=1,vis[u]=1;
for(auto i:g[u])
{
int v=i.v,w=i.w;
if(v==fa||vis[v]) continue;
d[v]=w;
getdis(v,u);
for(auto x:vec)
{
for(int j=1;j<=m;j++)
if(q[j]>=x) ans[j]|=exist[q[j]-x];

}
for(auto x:vec)
{
q1.push(x);exist[x]=1;
}
vec.clear();
}
while(!q1.empty())
{
int x=q1.front();q1.pop();
exist[x]=0;
}
for(auto i:g[u])
{
int v=i.v;
if(v==fa||vis[v]) continue;
mx[rt=0]=inf,sum=siz[v];
getsiz(v,u);getsiz(rt,0);dfz(rt,0);
}
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
cin>>n>>m;
for(int i=1;i<n;i++)
{
int u,v,w;cin>>u>>v>>w;
g[u].push_back({v,w}),g[v].push_back({u,w});
}
for(int i=1;i<=m;i++) cin>>q[i];
mx[rt=0]=inf,sum=n;
getsiz(1,0);getsiz(rt,0);dfz(rt,0);
for(int i=1;i<=m;i++)
if(ans[i]) cout<<"AYE"<<'\n';
else cout<<"NAY"<<'\n';
return 0;
}

以下是一些练习题:

P4178 Tree

题意:

给出一棵 nn 个结点的树,求距离小于等于 kk 的点对数量。n2×104,k2×104n\leq 2\times 10^4,k\leq 2\times 10^4

题解:

本题和上一题的区别在于由等于 kk 变成了小于等于 kk。可以想到一个很直接的处理方式:用权值线段树来记录出现过的点。枚举当前子树,对于该子树内每个点与根的距离 dd,它对答案的贡献为距离小于等于 kdk-d 的点个数和。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#include <bits/stdc++.h>
using namespace std;
const int maxn=40005;
const int inf=0x3f3f3f3f;
struct edge
{
int v,w;
};
vector<edge> g[maxn];
int rt,sum,siz[maxn],mx[maxn];
int n,m,ans,d[maxn];
bool vis[maxn];
queue<int> q1;
vector<int> vec;
int tr[4*maxn],tag[4*maxn];
void pushdown(int l,int r,int p)
{
int m=l+((r-l)>>1);
tr[p*2]+=tag[p]*(m-l+1),tr[p*2+1]+=tag[p]*(r-m);
tag[p*2]+=tag[p],tag[p*2+1]+=tag[p];
tag[p]=0;
}
void update(int ul,int ur,int k,int l,int r,int p)
{
if(ul<=l&&ur>=r)
{
tr[p]+=(r-l+1)*k,tag[p]+=k;
return ;
}
int m=l+((r-l)>>1);
if(tag[p]&&l!=r) pushdown(l,r,p);
if(ul<=m) update(ul,ur,k,l,m,p*2);
if(ur>m) update(ul,ur,k,m+1,r,p*2+1);
tr[p]=tr[p*2]+tr[p*2+1];
}
int query(int ql,int qr,int l,int r,int p)
{
if(ql<=l&&qr>=r) return tr[p];
int m=l+((r-l)>>1);
if(tag[p]) pushdown(l,r,p);
int sum=0;
if(ql<=m) sum+=query(ql,qr,l,m,p*2);
if(qr>m) sum+=query(ql,qr,m+1,r,p*2+1);
return sum;
}
void getsiz(int u,int fa)
{
siz[u]=1,mx[u]=0;
for(auto i:g[u])
{
int v=i.v;
if(v==fa||vis[v]) continue;
getsiz(v,u);
siz[u]+=siz[v];
mx[u]=max(mx[u],siz[v]);
}
mx[u]=max(mx[u],sum-siz[u]);
if(mx[u]<mx[rt]) rt=u;
}
void getdis(int u,int fa)
{
if(d[u]<=m) vec.push_back(d[u]);
for(auto i:g[u])
{
int v=i.v,w=i.w;
if(v==fa||vis[v]) continue;
d[v]=d[u]+w;
getdis(v,u);
}
}
void dfz(int u,int fa)
{
update(0,0,1,0,m,1),vis[u]=1;
for(auto i:g[u])
{
int v=i.v,w=i.w;
if(v==fa||vis[v]) continue;
d[v]=w;
getdis(v,u);
for(auto x:vec)
{
if(m-x>=0) ans+=query(0,m-x,0,m,1);
}
for(auto x:vec)
{
q1.push(x);update(x,x,1,0,m,1);
}
vec.clear();
}
update(0,0,-1,0,m,1);
while(!q1.empty())
{
int x=q1.front();q1.pop();
update(x,x,-1,0,m,1);
}
for(auto i:g[u])
{
int v=i.v;
if(v==fa||vis[v]) continue;
mx[rt=0]=inf,sum=siz[v];
getsiz(v,u);getsiz(rt,0);dfz(rt,0);
}
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
cin>>n;
for(int i=1;i<n;i++)
{
int u,v,w;cin>>u>>v>>w;
g[u].push_back({v,w}),g[v].push_back({u,w});
}
cin>>m;
mx[rt=0]=inf,sum=n;
getsiz(1,0);getsiz(rt,0);dfz(rt,0);
cout<<ans<<'\n';
return 0;
}

P2634 [国家集训队] 聪聪可可

题意:

给出一个 nn 个结点的树,任意选取树上两点,求两点间距离能被 33 整除的概率。n2×104n\leq 2\times 10^4

题解:

只需要维护模 330,1,20,1,2 的路径数量即可。

注意:由于子树依次加入,所以模板中求出的点对是无序的,而这里是有序的。同时,还要加上两个点重合的情况。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=20005;
const int inf=0x3f3f3f3f;
struct edge
{
int v,w;
};
vector<edge> g[maxn];
int rt,sum,siz[maxn],mx[maxn];
int n,m,q[105],d[maxn],exist[3];
ll ans;
bool vis[maxn];
vector<int> vec;
void getsiz(int u,int fa)
{
siz[u]=1,mx[u]=0;
for(auto i:g[u])
{
int v=i.v;
if(v==fa||vis[v]) continue;
getsiz(v,u);
siz[u]+=siz[v];
mx[u]=max(mx[u],siz[v]);
}
mx[u]=max(mx[u],sum-siz[u]);
if(mx[u]<mx[rt]) rt=u;
}
void getdis(int u,int fa)
{
vec.push_back(d[u]%3);
for(auto i:g[u])
{
int v=i.v,w=i.w;
if(v==fa||vis[v]) continue;
d[v]=d[u]+w;
getdis(v,u);
}
}
void dfz(int u,int fa)
{
exist[0]++,vis[u]=1;
for(auto i:g[u])
{
int v=i.v,w=i.w;
if(v==fa||vis[v]) continue;
d[v]=w;
getdis(v,u);
for(auto x:vec)
{
ans+=exist[(3-x)%3];
}
for(auto x:vec)
{
exist[x]++;
}
vec.clear();
}
exist[0]=exist[1]=exist[2]=0;
for(auto i:g[u])
{
int v=i.v;
if(v==fa||vis[v]) continue;
mx[rt=0]=inf,sum=siz[v];
getsiz(v,u);getsiz(rt,0);dfz(rt,0);
}
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
cin>>n;
for(int i=1;i<n;i++)
{
int u,v,w;cin>>u>>v>>w;
g[u].push_back({v,w}),g[v].push_back({u,w});
}
mx[rt=0]=inf,sum=n;
getsiz(1,0);getsiz(rt,0);dfz(rt,0);
ans=ans*2+n;
ll g=__gcd(ans,1ll*n*n);
cout<<ans/g<<'/'<<1ll*n*n/g<<'\n';
return 0;
}

CF1101D GCD Counting

题意:

给出一棵有 nn 个点的树,每个点有点权 aia_i。求一条最长的路径长度,使得路径上所有点点权的 gcd\gcd 大于 11n,ai2×105n,a_i\leq 2\times 10^5

题解:

对于每个点,我们枚举它的质因子 pp,分别处理出点权 gcd\gcd 等于 pp 的最长路径长度,这是容易维护的。而 2×1052\times 10^5 范围内的整数最多只有 66 个质因子,时间复杂度为 O(nlogn)O(n\log n)

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#include <bits/stdc++.h>
using namespace std;
const int maxn=2e5+5;
const int inf=0x3f3f3f3f;
int a[maxn];
vector<int> g[maxn];
int cnt,pri[maxn],mn[maxn];
int rt,sum,siz[maxn],mx[maxn],d[maxn],exist[maxn];
bool vis0[maxn],vis[maxn];
int p,ans;
queue<int> q1;
vector<int> vec;
vector<int> fac[maxn];
void getsiz(int u,int fa)
{
siz[u]=1,mx[u]=0;
for(auto v:g[u])
{
if(v==fa||vis[v]) continue;
getsiz(v,u);
siz[u]+=siz[v];
mx[u]=max(mx[u],siz[v]);
}
mx[u]=max(mx[u],sum-siz[u]);
if(mx[u]<mx[rt]) rt=u;
}
void getdis(int u,int fa)
{
if(d[u]) vec.push_back(d[u]);
for(auto v:g[u])
{
if(v==fa||vis[v]) continue;
if(a[v]%p==0&&d[u])
{
d[v]=d[u]+1;
getdis(v,u);
}
}
}
void init()
{
for(int i=2;i<maxn;i++)
{
if(!vis0[i]) pri[++cnt]=i,mn[i]=i;
for(int j=1;j<=cnt&&i*pri[j]<maxn;j++)
{
vis0[i*pri[j]]=1,mn[i*pri[j]]=pri[j];
if(i%pri[j]==0) break;
}
}
for(int i=2;i<maxn;i++)
{
int j=i;
while(j>1)
{
fac[i].push_back(mn[j]);
j/=mn[j];
}
int len=unique(fac[i].begin(),fac[i].end())-fac[i].begin();
fac[i].resize(len);
}
}
void dfz(int u,int fa)
{
vis[u]=1;
for(auto i:fac[a[u]])
{
p=i;
for(auto v:g[u])
{
if(v==fa||vis[v]) continue;
if(a[v]%p==0) d[v]=1;
else d[v]=0;
getdis(v,u);
for(auto x:vec)
{
ans=max(ans,exist[p]+x+1);
}
for(auto x:vec)
{
exist[p]=max(x,exist[p]);
}
vec.clear();
}
q1.push(p);
}
while(!q1.empty())
{
int x=q1.front();q1.pop();
exist[x]=0;
}
for(auto v:g[u])
{
if(v==fa||vis[v]) continue;
mx[rt=0]=inf,sum=siz[v];
getsiz(v,u);getsiz(rt,0);dfz(rt,0);
}
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
init();
int n;cin>>n;
for(int i=1;i<=n;i++)
{
cin>>a[i];
if(a[i]>1) ans=1;
}
for(int i=1;i<n;i++)
{
int u,v;cin>>u>>v;
g[u].push_back(v),g[v].push_back(u);
}
mx[rt=0]=inf,sum=n;
getsiz(1,0);getsiz(rt,0);dfz(rt,0);
cout<<ans<<'\n';
return 0;
}

*P2664 树上游戏

题意:

给出一棵有 nn 个点的树,树的每个节点有个颜色。定义 s(i,j)s(i,j)iijj 的颜色数量。以及sumi=j=1ns(i,j)sum_i=\sum_{j=1}^n s(i, j)。求出所有的 sumisum_i1n,ci1051\leq n,c_i\leq 10^5

题解:

(感觉这题理解得还不是很透彻)

考虑固定的分治中心 ii。记 cntjcnt_j 表示以 ii 为一个端点且含有颜色 jj 的路径数量,则 sumi=cntjsum_i=\sum cnt_jcntjcnt_j 很容易处理,进行dfs,每遇到一个新的颜色就加上它的子树大小。并且它对子树信息的合并很有帮助。

对于固定的分治中心 ii,我们只需要关心经过它的路径。可以分为两类,一类是以分治中心为端点的路径,一类是端点位于分治中心的两棵子树内的路径。

对于第一类,我们在遍历子树的时候就可以顺便统计。对于第二类,我们再分为两部分。设当前处理的点为 vv,对于 (i,v)(i,v) 路径上出现的颜色 jj,它对 vv 的贡献为其它子树的大小,对于没有出现的颜色 jj,它对 vv 的贡献为其它子树的 cntjcnt_j 之和。

在实际实现的时候,我们可以考虑正反序各遍历一遍子树,计算每个点与它前面子树所有点的路径和每个点与它后面子树所有点的路径。同时,对于路径上出现的每一种颜色 jj,还需要扣除掉其他子树内的 cntjcnt_j ,因为它不能重复计算。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=1e5+5;
const int inf=0x3f3f3f3f;
int c[maxn];
vector<int> g[maxn];
int rt,sum,siz[maxn],mx[maxn];
int cnt[maxn],v[maxn];
bool vis[maxn];
ll tot,ans[maxn];
void getsiz(int u,int fa)
{
siz[u]=1,mx[u]=0;
for(auto v:g[u])
{
if(v==fa||vis[v]) continue;
getsiz(v,u);
siz[u]+=siz[v];
mx[u]=max(mx[u],siz[v]);
}
mx[u]=max(mx[u],sum-siz[u]);
if(mx[u]<mx[rt]) rt=u;
}
void getdis(int u,int fa,int now)
{
siz[u]=1;
if(!v[c[u]])
{
tot-=cnt[c[u]];
now++;
}
v[c[u]]++;
ans[u]+=tot+now*siz[rt];
for(auto v:g[u])
{
if(v==fa||vis[v]) continue;
getdis(v,u,now);
siz[u]+=siz[v];
}
v[c[u]]--;
if(!v[c[u]]) tot+=cnt[c[u]];
}
void getcnt(int u,int fa)
{
if(!v[c[u]])
{
cnt[c[u]]+=siz[u];
tot+=siz[u];
}
v[c[u]]++;
for(auto v:g[u])
{
if(v==fa||vis[v]) continue;
getcnt(v,u);
}
v[c[u]]--;
}
void clear(int u,int fa,int now)
{
if(!v[c[u]]) now++;
v[c[u]]++;
ans[u]-=now;
ans[rt]+=now;
for(auto v:g[u])
{
if(v==fa||vis[v]) continue;
clear(v,u,now);
}
v[c[u]]--;
cnt[c[u]]=0;
}
void clear2(int u,int fa)
{
cnt[c[u]]=0;
for(auto v:g[u])
{
if(v==fa||vis[v]) continue;
clear2(v,u);
}
}
void dfz(int u,int fa)
{
vis[u]=1;ans[u]++;
siz[u]=tot=cnt[c[u]]=1;
v[c[u]]++;
for(auto v:g[u])
{
if(v==fa||vis[v]) continue;
getdis(v,u,0);
getcnt(v,u);
siz[u]+=siz[v];
cnt[c[u]]+=siz[v];
tot+=siz[v];
}
clear2(u,0);
siz[u]=tot=cnt[c[u]]=1;
reverse(g[u].begin(),g[u].end());
for(auto v:g[u])
{
if(v==fa||vis[v]) continue;
getdis(v,u,0);
getcnt(v,u);
siz[u]+=siz[v];
cnt[c[u]]+=siz[v];
tot+=siz[v];
}
v[c[u]]--;
clear(u,0,0);
for(auto v:g[u])
{
if(v==fa||vis[v]) continue;
mx[rt=0]=inf,sum=siz[v];
getsiz(v,u);dfz(rt,0);
}
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n;cin>>n;
for(int i=1;i<=n;i++) cin>>c[i];
for(int i=1;i<n;i++)
{
int u,v;cin>>u>>v;
g[u].push_back(v),g[v].push_back(u);
}
mx[rt=0]=inf,sum=n;
getsiz(1,0);dfz(rt,0);
for(int i=1;i<=n;i++) cout<<ans[i]<<'\n';
return 0;
}

点分治
https://je3ter.github.io/2023/06/29/ACM/点分治/
作者
Je3ter
发布于
2023年6月29日
许可协议