学到了树上启发式合并这个东西,其实本质就是暴力,果然应了那句:世界万物皆暴力。
这是处理树上问题强有力的工具。
首先我们有三个函数 dfs2,clear,insert。dfs2用于处理以该节点为根节点的子树的答案,并且保留其信息。clear为清空该子树信息,insert为插入(合并)该子树信息。
大致做法:
1、先处理出轻重儿子。
2、然后dfs2(u,fa)时,先把u的轻儿子dfs2了,然后clear它,以防止对其他子树影响。
3、dfs2重儿子,并且不用clear。
4、insert其他轻儿子。
5、加入该节点(u)信息。
6、统计答案。
按照一篇博客来刷例题(https://blog.csdn.net/pb122401/article/details/84648993)
cf375D:https://vjudge.net/problem/CodeForces-375D
1 #include<iostream>
2 #include<cstdio>
3 #include<vector>
4 using namespace std;
5 typedef long long ll;
6 const int N = 1e5+9;
7 vector<int> G[N];
8 int col[N],hson[N],siz[N],num[N],sum[N];
9 struct Query{
10 int val,ans;
11 }q[N];
12 vector<int> ask[N];
13 int id,mx;
14 void dfs1(int u,int fa){
15 // cerr<<u<<' '<<fa<<endl;
16 siz[u] = 1; hson[u] = 0;
17 for(int i = 0; i < G[u].size();++i){
18 int v = G[u][i];
19 if(v == fa) continue;
20 dfs1(v,u);
21 hson[u] = siz[hson[u]] > siz[v] ? hson[u] : v;
22 siz[u] += siz[v];
23 }
24 }
25 void clear(int u,int fa){
26 --num[col[u]];
27 --sum[num[col[u]] + 1];
28 for(int i = 0;i<G[u].size();++i){
29 int v = G[u][i];
30 if(v==fa) continue;
31 clear(v,u);
32 }
33 }
34 void insert(int u,int fa){
35 ++num[col[u]];
36 ++sum[ num[col[u]] ];
37 for(int i=0;i<G[u].size();++i){
38 int v = G[u][i];
39 if(v==fa) continue;
40 insert(v,u);
41 }
42 }
43 void dfs2(int u,int fa){
44 // cerr<<u<<' '<<fa<<endl;
45 for(int i = 0;i<G[u].size();++i){
46 int v = G[u][i];
47 if(v==fa || v==hson[u]) continue;
48 dfs2(v,u);
49 clear(v,u);
50 }
51 if(hson[u]) dfs2(hson[u],u);
52 for(int i = 0;i<G[u].size();++i){
53 int v = G[u][i];
54 if(v==fa || v==hson[u]) continue;
55 insert(v,u);
56 }
57 ++num[col[u]];
58 ++sum[num[col[u]]];
59 for(int i =0;i<ask[u].size();++i){
60 int id = ask[u][i];
61 q[id].ans = sum[q[id].val];
62 // cerr<<'a'<<sum[q[id].val]<<endl;
63 }
64 }
65 int main(){
66 int n,m;
67 scanf("%d %d",&n,&m);
68 for(int i = 1;i<=n;++i) scanf("%d",col+i);
69 for(int i = 1;i<n;++i){
70 int u,v; scanf("%d %d",&u,&v);
71 G[u].push_back(v);
72 G[v].push_back(u);
73 // cerr<<'a'<<u<<' '<<v<<endl;
74 }
75 for(int i = 1;i<=m;++i){
76 int u,v; scanf("%d %d",&u,&v);
77 ask[u].push_back(i);
78 q[i].val = v;
79 }
80 dfs1(1,0);
81 // cerr<<'a'<<endl;
82 dfs2(1,0);
83 for(int i = 1;i<=m;++i) printf("%d\n",q[i].ans);
84 return 0;
85 }
知识兔View Codecf 600E:https://vjudge.net/problem/CodeForces-600E
1 #include<iostream>
2 #include<cstdio>
3 #include<vector>
4 using namespace std;
5 typedef long long ll;
6 const int N = 1e5+9;
7 vector<int> G[N];
8 int col[N],hson[N],siz[N],num[N];
9 ll sum[N];
10 int id,mx;
11 void dfs1(int u,int fa){
12 siz[u] = 1; hson[u] = 0;
13 for(int i = 0; i < G[u].size();++i){
14 int v = G[u][i];
15 if(v == fa) continue;
16 dfs1(v,u);
17 hson[u] = siz[hson[u]] > siz[v] ? hson[u] : v;
18 siz[u] += siz[v];
19 }
20 }
21 void clear(int u,int fa){
22 --num[col[u]];
23 for(int i = 0;i<G[u].size();++i){
24 int v = G[u][i];
25 if(v==fa) continue;
26 clear(v,u);
27 }
28 }
29 void insert(int u,int fa,int p){
30 // cerr<<u<<' '<<fa<<' '<<p<<' '<<sum[p]<<endl;
31 ++num[col[u]];
32 if(num[col[u]] >= mx){
33 if(num[col[u]] >mx) sum[p] = 0,id = col[u],mx=num[col[u]];
34 sum[p] += col[u];
35 }
36 // if(fa==1) cerr<<u<<' '<<sum[1]<<' '<<id<<' '<<mx<<'a'<<endl;
37 for(int i=0;i<G[u].size();++i){
38 int v = G[u][i];
39 if(v==fa) continue;
40 insert(v,u,p);
41 }
42 }
43 void dfs2(int u,int fa){
44 // cerr<<u<<' '<<fa<<' '<<id<<endl;
45 for(int i = 0;i<G[u].size();++i){
46 int v = G[u][i];
47 if(v==fa || v==hson[u]) continue;
48 dfs2(v,u);
49 clear(v,u);
50 id = 0;
51 mx=0;
52 }
53 if(hson[u]){
54 dfs2(hson[u],u),sum[u] = sum[hson[u]];
55 // if(u==2) cerr<<id<<' '<<hson[u]<<endl;
56 // cerr<<u<<'u'<<endl;
57 }
58 // if(u==1) cerr<<u<<' '<<sum[u]<<' '<<id<<' '<<mx<<endl;
59 for(int i = 0;i<G[u].size();++i){
60 int v = G[u][i];
61 if(v==fa || v==hson[u]) continue;
62 insert(v,u,u);
63 // if(u==1) cerr<<v<<' '<<sum[u]<<' '<<id<<' '<<mx<<endl;
64 }
65 // if(u==1) cerr<<num[1]<<' '<<num[2]<<num[3]<<endl;
66 // if(u==1) cerr<<u<<' '<<sum[u]<<' '<<id<<' '<<mx<<endl;
67 ++num[col[u]];
68 // if(u==1) cerr<<num[col[u]]<<endl;
69 if(num[col[u]] >= mx){
70 if(num[col[u]] >mx) sum[u] = 0,id = col[u],mx=num[col[u]];
71 sum[u] += col[u];
72 }
73 }
74 int main(){
75 int n;
76 scanf("%d",&n);
77 for(int i = 1;i<=n;++i) scanf("%d",col+i);
78 for(int i = 1;i<n;++i){
79 int u,v; scanf("%d %d",&u,&v);
80 G[u].push_back(v);
81 G[v].push_back(u);
82 }
83 dfs1(1,0);
84 dfs2(1,0);
85 for(int i = 1;i<n;++i) printf("%lld ",sum[i]);
86 printf("%lld",sum[n]);
87 return 0;
88 }
知识兔View Code