Liveddd's Blog

愛にできることはまだあるかい

CF Nasty Donchik

数据结构好题。

题目链接

如何快速判断数集相等?考虑处理出每个位置的数在前面和后面第一次出现的位置 $pre_i,ne_i$。对于三元组 $(i,j,k)$ 合法的充要条件是: $\displaystyle \min_{t=j+1}^k pre_t\ge i,\max_{t=i}^j ne_t\le k$。考虑固定 $j,k$,第一个条件限制的是 $i$ 的上界,第二个条件限制的是 $i$ 的下界,二者之间的 $i$ 都是满足条件的。我们考虑从左到右枚举 $k$,维护前面所有 $j$,对应的 $i$ 的上界与下界,以及它们之和,最后相减即可得到答案。

我们需要每个位置 $j$ 的数是否在后面 $(j,k]$ 中出现过,这样的位置我们称之为“合法”,我们只对合法位置计算答案。第一个条件是容易的,区间取 min 即可。考虑第二个条件,相当于求 $j$ 向前最大的合法连续段(这样的段才满足 $\max_{t=i}^j ne_t\le k$)。加入位置 $k$ 会使位置 $pre_k$ 的点变为合法,用向前连续段的答案更新向后连续段的答案即可,线段树二分和区间赋值操作完成,也可以直接区间取 min。

jls 线段树维护一下即可,时间复杂度 $\mathcal O(n\log n)$。

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#include<iostream>
#include<cstdio>
#define ll long long
using namespace std;
const int N=2e5+10;
template<class T>
inline void read(T &x)
{
x=0;int f=0;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
if(f)x=~x+1;
}
template<class T,class ...T1>
inline void read(T &x,T1 &...x1)
{
read(x),read(x1...);
}
int n;
int a[N];
int pre[N],suf[N];
int p[N];

struct Node
{
int l,r;
ll sum;
int si,pre,suf,fir,se,cnt,tag;
};
struct Seg
{
Node tr[N<<2];
void pushup(int x)
{
tr[x].sum=tr[x<<1].sum+tr[x<<1|1].sum;
tr[x].si=tr[x<<1].si&tr[x<<1|1].si;
tr[x].pre=tr[x<<1].pre;
if(tr[x<<1].si)tr[x].pre+=tr[x<<1|1].pre;
tr[x].suf=tr[x<<1|1].suf;
if(tr[x<<1|1].si)tr[x].suf+=tr[x<<1].suf;

if(tr[x<<1].fir==tr[x<<1|1].fir)
{
tr[x].fir=tr[x<<1].fir;
tr[x].cnt=tr[x<<1].cnt+tr[x<<1|1].cnt;
tr[x].se=max(tr[x<<1].se,tr[x<<1|1].se);
}
else if(tr[x<<1].fir>tr[x<<1|1].fir)
{
tr[x].fir=tr[x<<1].fir;
tr[x].cnt=tr[x<<1].cnt;
tr[x].se=max(tr[x<<1].se,tr[x<<1|1].fir);
}
else
{
tr[x].fir=tr[x<<1|1].fir;
tr[x].cnt=tr[x<<1|1].cnt;
tr[x].se=max(tr[x<<1].fir,tr[x<<1|1].se);
}
}
void update(int x,int k)
{
if(tr[x].fir<=k)return ;
tr[x].sum+=1ll*(k-tr[x].fir)*tr[x].cnt;
tr[x].fir=tr[x].tag=k;
}
void pushdown(int x)
{
if(!~tr[x].tag)return ;
update(x<<1,tr[x].tag);
update(x<<1|1,tr[x].tag);
tr[x].tag=-1;
}
void build(int x,int l,int r)
{
tr[x].l=l,tr[x].r=r,tr[x].tag=-1;
if(l==r)return tr[x].sum=tr[x].si=tr[x].pre=tr[x].suf=0,tr[x].fir=l,tr[x].se=-1,tr[x].cnt=1,void();
int mid=l+r>>1;
build(x<<1,l,mid),build(x<<1|1,mid+1,r);
pushup(x);
}
void modify_min(int x,int l,int r,int k)
{
if(tr[x].fir<=k)return ;
if(l<=tr[x].l&&r>=tr[x].r&&tr[x].se<k)return update(x,k);
pushdown(x);
int mid=tr[x].l+tr[x].r>>1;
if(l<=mid)modify_min(x<<1,l,r,k);
if(r>mid)modify_min(x<<1|1,l,r,k);
pushup(x);
}
void modify(int x,int pos)
{
if(tr[x].l==tr[x].r)return tr[x].pre=tr[x].suf=tr[x].si=1,tr[x].sum=tr[x].fir,void();
pushdown(x);
int mid=tr[x].l+tr[x].r>>1;
if(pos<=mid)modify(x<<1,pos);
else modify(x<<1|1,pos);
pushup(x);
// cout<<x<<" : "<<tr[x].l<<" "<<tr[x].r<<" "<<mid<<"\n";
// cout<<tr[x<<1].suf<<" "<<tr[x<<1|1].pre<<"\n";
}
ll query(int x,int l,int r)
{
if(l<=tr[x].l&&r>=tr[x].r)return tr[x].sum;
pushdown(x);
int mid=tr[x].l+tr[x].r>>1;
ll res=0;
if(l<=mid)res+=query(x<<1,l,r);
if(r>mid)res+=query(x<<1|1,l,r);
return res;
}
int get_pre(int x,int pos)
{
if(tr[x].l==tr[x].r)return pos;
pushdown(x);
int mid=tr[x].l+tr[x].r>>1;
// cout<<x<<" : "<<tr[x].l<<" "<<tr[x].r<<" "<<mid<<"\n";
// cout<<tr[x<<1].suf<<" "<<tr[x<<1|1].pre<<"\n";
if(pos<=mid)return get_pre(x<<1,pos);
if(mid+tr[x<<1|1].pre>=pos)return mid-tr[x<<1].suf;
return get_pre(x<<1|1,pos);
}
int get_ne(int x,int pos)
{
if(tr[x].l==tr[x].r)return pos;
pushdown(x);
int mid=tr[x].l+tr[x].r>>1;
// cout<<x<<" : "<<tr[x].l<<" "<<tr[x].r<<" "<<mid<<"\n";
// cout<<tr[x<<1].suf<<" "<<tr[x<<1|1].pre<<"\n";
if(pos>mid)return get_ne(x<<1|1,pos);
if(mid-tr[x<<1].suf+1<=pos)return mid+tr[x<<1|1].pre+1;
return get_ne(x<<1,pos);
}
}t1,t2;
int main()
{
// freopen("2.in","r",stdin);
// freopen("1.out","w",stdout);
read(n);
for(int i=1;i<=n;i++)read(a[i]),pre[i]=p[a[i]],p[a[i]]=i;
// for(int i=1;i<=n;i++)printf("%d ",a[i]);
// puts("");
// for(int i=1;i<=n;i++)printf("%d ",pre[i]);
// puts("");
t1.build(1,0,n+1);
t2.build(1,0,n+1);
ll ans=0;
for(int i=1;i<=n;i++)
{
if(pre[i])
{
// cout<<"get : "<<i<<"\n";
int x=t1.get_pre(1,pre[i]-1);
int y=t1.get_ne(1,pre[i]+1)-1;
t1.modify_min(1,pre[i],y,x);
// cout<<i<<" : "<<x<<" "<<y<<"\n";
// cout<<"end_get : \n";

// cout<<"modify: "<<pre[i]<<"\n";
t1.modify(1,pre[i]);
t2.modify(1,pre[i]);
// cout<<"end_modify : \n";
}
t1.modify_min(1,1,i-1,pre[i]);
t2.modify_min(1,1,i-1,pre[i]);
ans+=t2.tr[1].sum-t1.tr[1].sum;
}
printf("%lld\n",ans);
return 0;
}