Liveddd's Blog

愛にできることはまだあるかい

LOJ#3701. 「联合省选 2022」填树

做完之后感觉有些套路,但还是挺不错的一题。

先考虑暴力一点的 DP 做法。考虑直接钦定左端点 $l$,要求链上必须存在一个点权值为 $l$,所有点权值满足 $v\in [l,l+K]$。这样还是有点难做,尝试放宽限制然后容斥,直接钦定 $l,r=l+K$,要求链上的点都满足 $v\in [l,l+K]$,发现与 $v\in [l+1,l+K+1]$ 算重的部分为 $v\in [l+1,l+K]$。于是再算一遍 $l,r=l+K-1$ 的答案并减去即可。这样容易做到 $\mathcal O(nV)$,其中 $V$ 为点值值域。

观察数据范围,我们显然需要优化掉时间复杂度中的 $V$。注意到转移形如当前函数加上 乘上一个转移的另一函数。我们设 $f(x)$ 表示左端点为 $x$ 的函数对应的答案。那么随着 $x$ 的增大,当且仅当 $[x,x+K]$ 碰到区间的左右端点时,其函数会发生变化。那么答案函数这是一个有 $\mathcal O(n)$ 段的分段函数,其中每一段都是一个普通多项式。观察到第一问对应的函数多项式次数不超过 $n$,第二问是第一问的前缀和,次数不超过 $n+1$。那么问题变得简单了,我们只需要对于分段函数 $f(x)$ 的每一段,求出其 $n+2$ 个点值,再使用拉格朗日插值插出答案即可。时间复杂度为 $\mathcal O(n^3)$。

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<vector>

using namespace std;
using ll=long long;
const int N=200+10,mod=1e9+7;
template<class T>
inline void read(T &x)
{
x=0;bool f=0;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
if(f)x=~x+1;
}
template<class T,class ...T1>
inline void read(T &x,T1 &...x1)
{
read(x),read(x1...);
}
int n,k;
int cl,cr,l[N],r[N];
int tot,head[N],ver[N<<1],ne[N<<1];
inline void add(int u,int v)
{
ver[++tot]=v;
ne[tot]=head[u];
head[u]=tot;
}
inline void adj(int &x){x+=x>>31&mod;}
inline int qpow(int x,int k=mod-2)
{
int res=1;
while(k)
{
if(k&1)res=1ll*res*x%mod;
x=1ll*x*x%mod;
k>>=1;
}
return res;
}
struct Node
{
int cnt,sum;
Node():cnt(0),sum(0){}
Node(int c,int s):cnt(c),sum(s){}
void operator+=(const Node &t){adj(cnt+=t.cnt-mod),adj(sum+=t.sum-mod);}
void operator-=(const Node &t){adj(cnt-=t.cnt),adj(sum-=t.sum);}
Node operator*(const Node &t){return Node(1ll*cnt*t.cnt%mod,(1ll*cnt*t.sum+1ll*t.cnt*sum)%mod);}

}f[N],all;

int fac[N],ifac[N],pre[N],suf[N];
inline void init(int n)
{
fac[0]=1;
for(int i=1;i<=n;i++)fac[i]=1ll*fac[i-1]*i%mod;
ifac[n]=qpow(fac[n]);
for(int i=n;i;i--)ifac[i-1]=1ll*ifac[i]*i%mod;
}
inline int lagrange(int *y,int n,int x)
{
pre[0]=suf[n+1]=1;
for(int i=1;i<=n;i++)pre[i]=1ll*pre[i-1]*(x-i)%mod;
for(int i=n;i;i--)suf[i]=1ll*suf[i+1]*(x-i)%mod;
int res=0;
for(int i=1;i<=n;i++)
adj(res+=((n-i)&1?mod-1ll:1ll)*pre[i-1]%mod*suf[i+1]%mod
*ifac[i-1]%mod*ifac[n-i]%mod*y[i]%mod-mod);
return res;
}
inline ll S(int x){return 1ll*x*(x+1)>>1;}
void dfs(int x,int fa=0)
{
int il=max(cl,l[x]),ir=min(cr,r[x]);
Node res;
if(il<=ir)res=Node(ir-il+1,(S(ir)-S(il-1))%mod);
f[x]=res,all+=res;
for(int i=head[x];i;i=ne[i])
{
int y=ver[i];
if(y==fa)continue;
dfs(y,x);
all+=f[x]*f[y];
f[x]+=f[y]*res;
}
}

inline Node calc(int l,int r)
{
cl=l,cr=r,all=Node();
dfs(1);
return all;
}
inline Node solve(int k)
{
static int cnt[N],sum[N];
vector<int>vec;
for(int i=1;i<=n;i++)
{
vec.push_back(l[i]-k),vec.push_back(l[i]+1);
vec.push_back(r[i]-k),vec.push_back(r[i]+1);
}
sort(vec.begin(),vec.end());
vec.erase(unique(vec.begin(),vec.end()),vec.end());
Node ans;
int si=vec.size();
for(int i=0;i<si-1;i++)
{
int cl=vec[i],cr=vec[i+1]-1;
if(cr-cl<=n+2)
{
for(int v=cl;v<=cr;v++)ans+=calc(v,v+k);
continue;
}
for(int v=0;v<n+2;v++)
{
Node res=calc(cl+v,cl+v+k);
adj(cnt[v+1]=cnt[v]+res.cnt-mod);
adj(sum[v+1]=sum[v]+res.sum-mod);
}
ans+=Node(lagrange(cnt,n+2,cr-cl+1),lagrange(sum,n+2,cr-cl+1));
}
return ans;
}

int main()
{
freopen("tree.in","r",stdin);
freopen("tree.out","w",stdout);
init(N-1);
read(n,k);
for(int i=1;i<=n;i++)read(l[i],r[i]);
for(int i=1,u,v;i<n;i++)read(u,v),add(u,v),add(v,u);
Node ans=solve(k);
ans-=solve(k-1);
printf("%d\n%d\n",ans.cnt,ans.sum);
return 0;
}