Liveddd's Blog

愛にできることはまだあるかい

LOJ#3340. 「NOI2020」命运

线段树合并优化 DP。

注意到对于所有 $(u,v)\in \mathcal Q$ 的 $v$ 只需要深度最深的 $u$ 满足即可。考虑设计状态 $f_{x,i}$ 表示以 $x$ 为根的子树中未满足条件的 $u$ 深度最大为 $i$ 的方案数。考虑 $(x,y)$ 为连向子树的边,选取该条边为 $1$ 的贡献是 $\sum_{j=0}^{dep_x} f_{x,i}\times f_{y,j}$,不选取的贡献是 $\sum_{j=0}^{i} f_{x,i}\times f_{y,j}+\sum_{j=0}^{i-1} f_{x,j}\times f_{y,i}$。用前缀和优化,得到总的转移方程:

这样做到 $\mathcal O(n^2)$,对深度离散化可以做到 $\mathcal O(n\min{n,m})$,建虚树可以做到 $\mathcal O((\min{n,m})^2)$。

观察式子,先不管 $sum_y(dep_x)$,剩下的项全都和下标 $i$ 有关,求区间和。而且真正有用的状态数是 $\mathcal O(m)$ 的,于是考虑整体 DP。整个过程用线段树合并来维护,合并过程中维护 $sum_x,sum_y$,先合并左子树,再合并右子树。$sum_y(dep_x)$ 在线段树上查一下就行了。时间复杂度 $\mathcal O(n \log n)$。

全都和下标 $i$ 有关这一点提醒我们也可以用启发式合并,但是感觉比较难写,是 $\mathcal O(n\log^2 n)$,有被卡的风险。

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#include<iostream>
#include<cstdio>
#include<vector>
#define ll long long
using namespace std;
const int N=5e5+10,mod=998244353;

template<class T>
inline void read(T &x)
{
x=0;int f=0;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
if(f)x=~x+1;
}
template<class T,class ...T1>
inline void read(T &x,T1 &...x1)
{
read(x),read(x1...);
}

struct Node
{
int lc,rc;
ll sum,tag;
}tr[N*80];
int n,m;
int tot,head[N],ver[N<<1],ne[N<<1];
int rt[N],d[N];
int cnt;
vector<int>vec[N];

inline void add(int u,int v)
{
ver[++tot]=v;
ne[tot]=head[u];
head[u]=tot;
}
inline void pushup(int x)
{
tr[x].sum=(tr[tr[x].lc].sum+tr[tr[x].rc].sum)%mod;
}
inline void update(int x,ll k)
{
tr[x].sum=tr[x].sum*k%mod;
tr[x].tag=tr[x].tag*k%mod;
}
inline void pushdown(int x)
{
if(tr[x].tag==1)return ;
if(tr[x].lc)update(tr[x].lc,tr[x].tag);
if(tr[x].rc)update(tr[x].rc,tr[x].tag);
tr[x].tag=1;
}
void insert(int &x,int l,int r,int pos)
{
x=++cnt;
tr[x].sum=tr[x].tag=1;
if(l==r)
return ;
int mid=l+r>>1;
if(pos<=mid)insert(tr[x].lc,l,mid,pos);
else insert(tr[x].rc,mid+1,r,pos);
}
ll query(int x,int l,int r,int pos)
{
if(!x||r<=pos)return tr[x].sum;
int mid=l+r>>1;
pushdown(x);
if(pos<=mid)return query(tr[x].lc,l,mid,pos);
return (tr[tr[x].lc].sum+query(tr[x].rc,mid+1,r,pos))%mod;
}
int merge(int x,int y,int l,int r,ll &sumx,ll &sumy)
{
if(!x&&!y)return 0;
if(!x)
{
sumy=(sumy+tr[y].sum)%mod;
update(y,sumx);
return y;
}
if(!y)
{
sumx=(sumx+tr[x].sum)%mod;
update(x,sumy);
return x;
}
if(l==r)
{
sumx=(sumx+tr[x].sum)%mod;
sumy=(sumy+tr[y].sum)%mod;
tr[x].sum=(tr[x].sum*sumy%mod+tr[y].sum*((sumx+mod-tr[x].sum)%mod)%mod)%mod;
return x;
}
pushdown(x),pushdown(y);
int mid=l+r>>1;
tr[x].lc=merge(tr[x].lc,tr[y].lc,l,mid,sumx,sumy);
tr[x].rc=merge(tr[x].rc,tr[y].rc,mid+1,r,sumx,sumy);
pushup(x);
return x;
}

void dfs(int x,int fa)
{
int res=0;
for(auto y:vec[x])
res=max(res,d[y]);
insert(rt[x],0,n-1,res);
for(int i=head[x];i;i=ne[i])
{
int y=ver[i];
if(y==fa)continue;
d[y]=d[x]+1;
dfs(y,x);
ll sumx=0,sumy=query(rt[y],0,n-1,d[x]);
rt[x]=merge(rt[x],rt[y],0,n-1,sumx,sumy);
}
}
int main()
{
freopen("destiny.in","r",stdin);
freopen("destiny.out","w",stdout);
read(n);
for(int i=1;i<n;i++)
{
int u,v;
read(u,v);
add(u,v),add(v,u);
}
read(m);
for(int i=1;i<=m;i++)
{
int u,v;
read(u,v);
vec[v].push_back(u);
}
d[1]=1;
dfs(1,0);
printf("%lld\n",query(rt[1],0,n-1,0));
return 0;
}