采蘑菇
Time Limit: 20 Sec Memory Limit: 256 MBDescription
Input
Output
Sample Input
5 1 2 3 2 3 1 2 1 3 2 4 2 5
Sample Output
10 9 12 9 11
HINT
Main idea
询问从以每个点为起始点时,各条路径上的颜色种类的和。
Solution
我们看到题目,立马想到了O(n^2)的做法,然后从这个做法研究一下本质,我们确定了可以以点分治作为框架。
我们先用点分治来确定一个center(重心)。然后计算跟这个center有关的路径。设现在要统计的是经过center,对x提供贡献的路径。
我们先记录一个记录Sum[x]表示1~i-1子树中 颜色x 第一次出现的位置的那个点 的子树和,然后我们就利用这个Sum来解题。
我们显然可以分两种情况来讨论:
(1)统计center->x出现颜色的贡献:
显然,这时候,对于center->x这一段,直接像O(n^2)做法那样记录一个color表示到目前为止出现的颜色个数,然后加一下即可。再记录一个record表示当前可有的贡献和,一旦出现过一个颜色,那么这个颜色在1~i-1子树上出现第一次以下的点,对于x就不再提供贡献了,record减去Sum[这个颜色],然后这样深搜往下计算即可。(2)统计center->x没出现过的颜色的贡献:
显然,对于center->x上没出现过的颜色,直接往下深搜,一开始为record为(All - Sum[center]),一旦出现了一个颜色,record则减去这个Sum。同样表示不再提供贡献即可。我们这样做就可以求出每个子树前缀对于其的贡献了,倒着再做一边即可求出全部的贡献。统计x的时候,顺便统计一下center。可以满足效率,成功AC这道题。
Code
1 #include2 #include 3 #include 4 #include 5 #include 6 #include 7 using namespace std; 8 9 const int ONE = 600005; 10 const int INF = 214783640; 11 const int MOD = 1e9+7; 12 13 int n,x,y; 14 int Val[ONE]; 15 int next[ONE],first[ONE],go[ONE],tot; 16 int vis[ONE]; 17 int Ans[ONE],Sum[ONE]; 18 int All; 19 20 21 int get() 22 { 23 int res,Q=1; char c; 24 while( (c=getchar())<48 || c>57) 25 if(c=='-')Q=-1; 26 if(Q) res=c-48; 27 while((c=getchar())>=48 && c<=57) 28 res=res*10+c-48; 29 return res*Q; 30 } 31 32 void Add(int u,int v) 33 { 34 next[++tot]=first[u]; first[u]=tot; go[tot]=v; 35 next[++tot]=first[v]; first[v]=tot; go[tot]=u; 36 } 37 38 namespace Point 39 { 40 int center; 41 int Stack[ONE],top; 42 int total,Max,center_vis[ONE]; 43 int num,V[ONE]; 44 45 struct power 46 { 47 int size,maxx; 48 }S[ONE]; 49 50 void Getsize(int u,int father) 51 { 52 S[u].size=1; 53 S[u].maxx=0; 54 for(int e=first[u];e;e=next[e]) 55 { 56 int v=go[e]; 57 if(v==father || center_vis[v]) continue; 58 Getsize(v,u); 59 S[u].size += S[v].size; 60 S[u].maxx = max(S[u].maxx,S[v].size); 61 } 62 } 63 64 void Getcenter(int u,int father,int total) 65 { 66 S[u].maxx = max(S[u].maxx,total-S[u].size); 67 if(S[u].maxx < Max) 68 { 69 Max = S[u].maxx; 70 center = u; 71 } 72 73 for(int e=first[u];e;e=next[e]) 74 { 75 int v=go[e]; 76 if(v==father || center_vis[v]) continue; 77 Getcenter(v,u,total); 78 } 79 } 80 81 void Ad_sum(int u,int father) 82 { 83 if(!vis[Val[u]]) 84 { 85 Stack[++top] = Val[u]; 86 All += S[u].size; Sum[Val[u]] += S[u].size; 87 } 88 vis[Val[u]]++; 89 for(int e=first[u];e;e=next[e]) 90 { 91 int v=go[e]; 92 if(v==father || center_vis[v]) continue; 93 Ad_sum(v,u); 94 } 95 vis[Val[u]]--; 96 } 97 98 void Calc_in(int u,int father,int center,int Size,int f_time,int record) 99 {100 if(!vis[Val[u]]) f_time++, record += Size, record -= Sum[Val[u]];101 Ans[u] += record; Ans[center]+=f_time;102 Ans[u] += f_time; vis[Val[u]] ++;103 for(int e=first[u];e;e=next[e])104 {105 int v=go[e];106 if(v==father || center_vis[v]) continue;107 Calc_in(v,u,center,Size,f_time,record);108 }109 vis[Val[u]] --;110 }111 112 void Calc_not(int u,int father,int record)113 {114 if(!vis[Val[u]]) record -= Sum[ Val[u] ];115 Ans[u] += record; vis[Val[u]] ++;116 for(int e=first[u];e;e=next[e])117 {118 int v=go[e];119 if(v==father || center_vis[v]) continue;120 Calc_not(v,u,record);121 }122 vis[Val[u]] --;123 }124 125 void Dfs(int u)126 {127 Max = n;128 Getsize(u,0);129 Getcenter(u,0,S[u].size);130 Getsize(center,0);131 center_vis[center] = 1;132 133 int num=0; for(int e=first[center];e;e=next[e]) if(!center_vis[go[e]]) V[++num]=go[e];134 135 for(int i=1;i<=num;i++)136 {137 int v=V[i];138 int Size = S[center].size - S[v].size - 1;139 vis[Val[center]] = 1;140 Calc_in(v,center,center, Size,1,All - Sum[Val[center]] + Size);141 vis[Val[center]] = 0;142 Ad_sum(v,center);143 }144 while(top) Sum[Stack[top--]]=0; All=0;145 146 for(int i=num;i>=1;i--)147 {148 int v=V[i];149 vis[Val[center]] = 1;150 Calc_not(v,center, All-Sum[Val[center]]);151 vis[Val[center]] = 0;152 Ad_sum(v,center);153 }154 155 while(top) Sum[Stack[top--]]=0; All=0;156 for(int e=first[center];e;e=next[e])157 {158 int v=go[e];159 if(center_vis[v]) continue;160 Dfs(v);161 }162 }163 164 }165 166 int main()167 { 168 n=get();169 for(int i=1;i<=n;i++) Val[i]=get();170 171 for(int i=1;i< n;i++)172 {173 x=get(); y=get();174 Add(x,y);175 }176 177 Point:: Dfs(1);178 for(int i=1;i<=n;i++)179 printf("%d\n",Ans[i]+1);180 }