Liveddd's Blog

愛にできることはまだあるかい

P2664 树上游戏

有比较不错的树上差分 $\mathcal O(n)$ 做法,也可以使用点分治 $\mathcal O(n\log n)$ 解决。

点分治做法不再赘述,着重梳理一下线性树上差分。

考虑求出 $t_{x,i}$ 表示从 $x$ 出发不包括颜色 $i$ 的路径条数。考虑当前对颜色 $i$ 求,简单做法是去掉颜色为 $i$ 的 点及其连边,得到若干联通块,对连通块中点的 $t_{x,i}$ 加连通块大小。但我们显然不能直接这样,因为 $t_{x,i}$ 状态是 $\mathcal O(n^2)$,并且每做一次都需要遍历整棵树。

尝试一次遍历求出 $t’_{x}=\sum t_{x,i}$,考虑计算每个点删去之后对该点颜色 $c_x$ 的答案的影响。设 $x$ 子节点为 $y_i$,子树 $y_i$ 变得互相独立,且 $y_i$ 包含的连通块要加上其大小 $s_i$。连通块与子树 $y_i$ 中的点 $z$ 满足 $c_z=c_x$ 有关,因为其也形如对子树的贡献,容易考虑树上差分。对于 $z$,我们对于每种颜色开一个栈维护,每次计算 $x$ 之后遍历的节点 $z$ 并弹出即可。 对于 $s_i$,我们直接维护当前已经计算过的连通块大小 $\operatorname{colsize}(c_x)$,并且在递归子树 $y_i$ 前记录,递归前减去递归即可得到子树 $y_i$ 中连通块大小,即 $s_i=\operatorname{size}(y)-(\operatorname{colsize}(c_x)-\operatorname{colsize}’(c_x))$。显然每个节点只会进出栈一次。最终 $t’_{x}$ 为当前节点 $x$ 到根节点的差分数组之和。时间复杂度为 $\mathcal O(n)$。

代码理解后很好写,注意最后单独处理根节点。

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include<iostream>
#include<cstdio>
#include<stack>
using namespace std;
using ll=long long;
const int N=1e5+10;
template <class T>
inline void read(T &x)
{
x=0;bool f=0;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
if(f)x=-x;
}
template <class T,class ...T1>
inline void read(T &x,T1 &...x1)
{
read(x),read(x1...);
}
int n,m,k;
int c[N];
int tot,head[N],ver[N<<1],ne[N<<1];
int dfn[N],si[N],cols[N];
ll s[N];
stack<int>sta[N];
bool vis[N];
inline void add(int u,int v)
{
ver[++tot]=v;
ne[tot]=head[u];
head[u]=tot;
}
void dfs(int x,int fa)
{
si[x]=1,dfn[x]=++*dfn;
int c=::c[x];
for(int i=head[x];i;i=ne[i])
{
int y=ver[i];
if(y==fa)continue;
int pre=cols[c];
dfs(y,x);
si[x]+=si[y];
int add=si[y]-cols[c]+pre,z;
cols[c]+=add,s[y]+=add;
while(!sta[c].empty()&&dfn[x]<dfn[z=sta[c].top()])
s[z]-=add,sta[c].pop();
}
sta[c].push(x);
cols[c]++;
}
void calc(int x,int fa)
{
s[x]+=s[fa];
for(int i=head[x];i;i=ne[i])if(ver[i]!=fa)calc(ver[i],x);
}
int main()
{
read(n);
for(int i=1;i<=n;i++)read(c[i]),m=max(m,c[i]),vis[c[i]]=1;
for(int i=1,u,v;i<n;i++)read(u,v),add(u,v),add(v,u);
dfs(1,0);
for(int i=1;i<=m;i++)
if(vis[i])
{
k++;
int add=n-cols[i];
s[1]+=add;
while(!sta[i].empty())s[sta[i].top()]-=add,sta[i].pop();
}
calc(1,0);
for(int i=1;i<=n;i++)printf("%lld\n",1ll*n*k-s[i]);
return 0;
}