Liveddd's Blog

愛にできることはまだあるかい

CF1540B Tree Array

有点意思的树上期望问题。

直接做的话发现给当前已经选了的连通块定位是比较困难的,但是我们的答案和这个息息相关。于是我们尝试直接钦定一个点为选择的第一个点,并且以这个点为根进行计算,最终将答案之和除以 $n$ 即可。

我们枚举两个点,逆序对与它们加入连通块的顺序有关。连通块必然是从 $\text{LCA}(x,y)$ 处向 $x,y$ 延申。不妨设 $x<y$,那么答案就是 $\text{LCA}(x,y)$ 先到达 $y$ 的期望。这个只与 $x,y$ 分别到 $\text{LCA}(x,y)$ 的距离有关,可以 $\mathcal O(n^2)$ 预处理。总的时间复杂度为 $\mathcal O(n^3\log n)$。也可以预处理 $\text{LCA}(x,y)$,或者用 $\mathcal O(n\log n)-\mathcal O(1)$ 的求 $\text{LCA}(x,y)$ 的方法去掉 $\log$。下面是前一种实现。

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int N=3e3+10,K=8,mod=1e9+7,inv=(mod+1)/2;
template<class T>
inline void read(T &x)
{
x=0;int f=0;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
if(f)x=~x+1;
}
template<class T,class ...T1>
inline void read(T &x,T1 &...x1)
{
read(x),read(x1...);
}
int n,m;
int tot,head[N],ver[N<<1],ne[N<<1];
int d[N],fa[N][K+5];
int f[N][N];
inline int adj(int x){return (x>=mod)?x-mod:x;}
inline int qpow(int x,int k=mod-2)
{
int res=1;
while(k)
{
if(k&1)res=1ll*res*x%mod;
x=1ll*x*x%mod;
k>>=1;
}
return res;
}

inline void add(int u,int v)
{
ver[++tot]=v;
ne[tot]=head[u];
head[u]=tot;
}
void dfs(int x)
{
d[x]=d[fa[x][0]]+1;
for(int i=head[x];i;i=ne[i])
{
int y=ver[i];
if(y==fa[x][0])continue;
fa[y][0]=x;
for(int j=1;j<=K;j++)fa[y][j]=fa[fa[y][j-1]][j-1];
dfs(y);
}
}
inline int LCA(int x,int y)
{
if(d[x]<d[y])swap(x,y);
for(int i=K;~i;i--)
if(d[fa[x][i]]>=d[y])
x=fa[x][i];
if(x==y)return x;
for(int i=K;~i;i--)
if(fa[x][i]!=fa[y][i])
x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
inline int calc(int x)
{
memset(fa[x],0,sizeof(fa[x]));
dfs(x);
int ans=0;
for(int i=1;i<n;i++)
for(int j=i+1;j<=n;j++)
{
int lca=LCA(i,j);
ans=adj(ans+f[d[i]-d[lca]][d[j]-d[lca]]);
}
return ans;
}
int main()
{
read(n);
for(int i=1,u,v;i<n;i++)read(u,v),add(u,v),add(v,u);
for(int i=1;i<=n;i++)f[i][0]=1;
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
f[i][j]=1ll*adj(f[i-1][j]+f[i][j-1])*inv%mod;
int ans=0;
for(int i=1;i<=n;i++)ans=adj(ans+calc(i));
printf("%d\n",1ll*ans*qpow(n)%mod);
return 0;
}