Liveddd's Blog

愛にできることはまだあるかい

[ABC242Ex] Random Painting

很好的 min-max 容斥。

之前 ZR 考了原题,不会。于是就去学了 min-max 容斥,现在重新来看这道题。

设每个位置被覆盖的时间为 ti,我们要求的是:

E(maxiSti)

直接 min-max 容斥,得到:

E(maxiSti)=TS(1)|T|+1E(miniTti)

考虑求如何求 E(miniSti),设有 f(S) 个区间能够覆盖 S 中任意一个位置,容易得到 E(miniSti)=mf(S),于是考虑用 DP 求出 f(S)=i(1)|S|+1

设计状态 fi,j 表示上一个加的位置为 i ,有 j 个区间覆盖 S1,2,,i。转移时枚举新加的位置 p,与 [i,p) 有交的区间都可以覆盖当前点集 S。因为集合 S 中只多选了一个点,转移时直接将容斥系数乘 1 即可。时间复杂度 O(n2m)

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include<iostream>
#include<cstdio>
using namespace std;
const int N=5e2+10,mod=998244353;
template<class T>
inline void read(T &x)
{
x=0;int f=0;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
if(f)x=~x+1;
}
template<class T,class ...T1>
inline void read(T &x,T1 &...x1)
{
read(x),read(x1...);
}
int n,m;
int cnt[N][N],f[N][N];
inline int adj(int x){return (x>=mod)?x-mod:x;}
inline int qpow(int x,int k=mod-2)
{
int res=1;
while(k)
{
if(k&1)res=1ll*res*x%mod;
x=1ll*x*x%mod;
k>>=1;
}
return res;
}
int main()
{
read(n,m);
for(int i=1;i<=m;i++)
{
int l,r;
read(l,r);
for(int j=0;j<l;j++)
cnt[j][l]++,cnt[j][r+1]--;
}
for(int i=0;i<=n;i++)
for(int j=i+1;j<=n;j++)
cnt[i][j]+=cnt[i][j-1];
f[0][0]=mod-1;
int res=0;
for(int i=0;i<=n;i++)
{
for(int j=0;j<=m;j++)
{
if(!f[i][j])continue;
res=adj(res+1ll*m*qpow(j)%mod*f[i][j]%mod);
for(int k=i+1;k<=n;k++)
f[k][j+cnt[i][k]]=adj(f[k][j+cnt[i][k]]-f[i][j]+mod);
}
}
printf("%d\n",res);
return 0;
}

数据加强到 n,m3000 呢?原来我们统计的是能覆盖的区间。方便起见,我们考虑统计不覆盖点集中任意点的的区间 f(S)=i(1)|S|+1(此处的朴素转移方程也是类似的。利用原来的定义也可以,只是稍显麻烦)。

设多项式 fi(x),其中 [xj]fi(x),表示只考虑 [1,i],其中有 j 个区间不覆盖点集中任意点的方案数。我们对于每个区间 [l,r] 枚举其右端点,分别考虑。发现未加入该区间时有转移 fr(x)=(1)×1i<rfi(x),而加入一个新的区间,就会有 1il1,从 [xj]fi 转移到 [xj+1]fi 上,即为 fi(x)=fi(x)×x。用线段树维护这个转移,具体可看看代码。时间复杂度 O(nmlogn)

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#include<iostream>
#include<cstdio>
#include<vector>
using namespace std;
const int N=3e3+10,mod=998244353;
template<class T>
inline void read(T &x)
{
x=0;int f=0;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
if(f)x=~x+1;
}
template<class T,class ...T1>
inline void read(T &x,T1 &...x1)
{
read(x),read(x1...);
}

int n,m;
inline int adj(int x){return x>=mod?x-mod:x;}
struct poly
{
int a[N];
poly operator+(const poly &t)const
{
poly res;
for(int i=0;i<=m;i++)
res.a[i]=adj(a[i]+t.a[i]);
return res;
}
};
struct Node
{
int l,r;
poly sum;
int tag;
}tr[N<<2];
vector<int>v[N];
inline void pushup(int x)
{
tr[x].sum=tr[x<<1].sum+tr[x<<1|1].sum;
}
inline void update(int x,int k)
{
for(int i=m;i>=k;i--)tr[x].sum.a[i]=tr[x].sum.a[i-k];
for(int i=0;i<k;i++)tr[x].sum.a[i]=0;
tr[x].tag+=k;
}
inline void pushdown(int x)
{
if(!tr[x].tag)return ;
update(x<<1,tr[x].tag),update(x<<1|1,tr[x].tag);
tr[x].tag=0;
}
void build(int x,int l,int r)
{
tr[x].l=l,tr[x].r=r;
if(l==r)return ;
int mid=l+r>>1;
build(x<<1,l,mid),build(x<<1|1,mid+1,r);
pushup(x);
}
void add(int x,int l,int r,int k)
{
if(tr[x].l>=l&&tr[x].r<=r)
{
update(x,k);
return ;
}
pushdown(x);
int mid=tr[x].l+tr[x].r>>1;
if(l<=mid)add(x<<1,l,r,k);
if(r>mid)add(x<<1|1,l,r,k);
pushup(x);
}
void modify(int x,int pos,const poly &k)
{
if(tr[x].l==tr[x].r)return tr[x].sum=k,void();
pushdown(x);
int mid=tr[x].l+tr[x].r>>1;
if(pos<=mid)modify(x<<1,pos,k);
else modify(x<<1|1,pos,k);
pushup(x);
}
inline int qpow(int x,int k=mod-2)
{
int res=1;
while(k)
{
if(k&1)res=1ll*res*x%mod;
x=1ll*x*x%mod;
k>>=1;
}
return res;
}
int main()
{
read(n,m);
for(int i=1;i<=m;i++)
{
int l,r;
read(l,r);
v[r].push_back(l);
}
build(1,0,n);
modify(1,0,(poly){mod-1});
for(int i=1;i<=n;i++)
{
poly res=tr[1].sum;
for(auto j:v[i])add(1,0,j-1,1);
for(int j=0;j<=m;j++)res.a[j]=mod-res.a[j];
modify(1,i,res);
}
poly res=tr[1].sum;
int ans=0;
for(int i=0;i<=m;i++)
ans=adj(ans+1ll*res.a[i]*qpow(m-i)%mod);
printf("%d\n",1ll*ans*m%mod);
return 0;
}

Gitalk 加载中 ...