Liveddd's Blog

愛にできることはまだあるかい

QOJ#2065. Cyclic Distance

DP 趣题,运用到函数的凸性。以及 mydcwfy 提供了的贪心解法。

考虑选点后每条边的贡献。容易发现对于子树 $i$ 中选择 $j$ 个点,存在一种顺序使得父亲边被贡献 $2\times \min(j,k-j)$ 次,调整法容易证明。于是我们不需要考虑顺序, 考虑子树中点集即可。设计 $f(i,j)$ 表示子树 $i$ 中选取 $j$ 个点,容易得到 $\mathcal O(nk)$ 的朴素树上背包的解法。

考虑研究 $f(i,j)$ 的性质,不妨猜测 $f(i,x)$ 具有凸性。当前点是否选中相当于平移加取 $\max$,而背包的转移是 $(\max,+)$ 卷积,相当于求凸包的闵可夫斯基和,得到的还是一个凸包。所以 $f(i,x)$ 具有凸性质。

于是我们考虑用维护 $f(i,x)$ 的差分数组,求闵可夫斯基和就直接进行启发式合并。因为贡献是 $\min(j,k-j)$ 的样子,我们需要根据 $j$ 按照 $\frac{k}{2}$ 分类,甚至需要讨论 $k\equiv 1\pmod 2$ 的情况,中间可能需要多维护一个值。而这部分的贡献是区间修改,直接在平衡树上打 $\text{tag}$ 就行。

但是实际上我们直接用堆就可以维护了。我们将函数分为两段 $j\le \lfloor \frac{k}{2}\rfloor,j> \lceil \frac{k}{2}\rceil$,这两段用堆来维护。再一个变量表示 $j=\lceil \frac{k}{2}\rceil$ 的答案。根据凸性质,$f(i,x)$ 差分数组单减,每次我们取出堆顶元素进行更新。同样加上贡献对于两个堆打上 $\text{tag}$ 即可。

两种写法时间复杂度均为 $\mathcal O(n\log^2 n)$。

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#define ll long long
using namespace std;
const int N=2e5+10;
template<class T>
inline void read(T &x)
{
x=0;int f=0;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
if(f)x=~x+1;
}
template<class T,class ...T1>
inline void read(T &x,T1 &...x1)
{
read(x),read(x1...);
}

int T,n,k;
int tot,head[N],ver[N<<1],e[N<<1],ne[N<<1];

struct Heap
{
priority_queue<ll,vector<ll>,greater<ll> >h1,h2;
ll v,tag1,tag2;
Heap() : v(-1),tag1(0),tag2(0){}
int size(){return h1.size()+h2.size()+(v!=-1);}
void add(ll val)
{
if(h1.size()<k/2)return h1.push(val-tag1);
if(val>h1.top()+tag1)
{
ll ne=h1.top()+tag1;
return h1.pop(),h1.push(val-tag1),add(ne);
}
if((k&1)&&v==-1)return v=val,void();
if((k&1)&&val>v)
{
ll ne=v;
return v=val,add(ne);
}
if(h2.size()<k/2)return h2.push(val-tag2);
if(val>h2.top()+tag2)h2.pop(),h2.push(val-tag2);
}
}f[N];

inline void add(int u,int v,int w)
{
ver[++tot]=v;
e[tot]=w;
ne[tot]=head[u];
head[u]=tot;
}
inline void init()
{
tot=0;
memset(head,0,sizeof(head));
}
inline void merge(Heap &x,Heap &y)
{
if(x.size()<y.size())swap(x,y);
while(!y.h1.empty())x.add(y.h1.top()+y.tag1),y.h1.pop();
while(!y.h2.empty())x.add(y.h2.top()+y.tag2),y.h2.pop();
if(~y.v)x.add(y.v),y.v=-1;
}
void dfs(int x,int fa)
{
f[x].add(0);
for(int i=head[x];i;i=ne[i])
{
int y=ver[i];
if(y==fa)continue;
dfs(y,x);
f[y].tag1+=e[i]<<1,f[y].tag2-=e[i]<<1;
merge(f[x],f[y]);
}
}
inline void solve()
{
init();
read(n,k);
for(int i=1,u,v,w;i<n;i++)read(u,v,w),add(u,v,w),add(v,u,w);
dfs(1,0);
ll ans=0;
while(!f[1].h1.empty())ans+=f[1].h1.top(),f[1].h1.pop();
while(!f[1].h2.empty())ans+=f[1].h2.top(),f[1].h2.pop();
if((k&1)&&(~f[1].v))ans+=f[1].v,f[1].v=-1;
printf("%lld\n",ans);
}
int main()
{
// freopen("2.in","r",stdin);
read(T);
// T=1;
while(T--)solve();
return 0;
}

还有 mydcwfy 的贪心解法,此处简述。我们每次贪心选取贡献最大的点。具体来讲,未选的点都被视为在根节点,并且维护在该点到根节点上链的贡献之和。过程中需要维护点到根节点的各条边上的贡献系数是 $1$ 还是 $-1$。具体以 $\lfloor \frac{k}{2}\rfloor$ 作为分界,该边被贡献 $\le \lfloor \frac{k}{2}\rfloor$ 次,那么贡献还是 $1$,超过就将其变为 $-1$。这个可以用树链剖分维护。正确性可以调整法证明,但是给人感觉就是很对。

总的时间复杂度为 $\mathcal O(n\log^2 n)$。贴一份巨佬 mydcwfy 的代码qwq。

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#pragma GCC optimize(2)
#include <iostream>
#include <cstdio>

using LL = long long;

const int N = 2e5 + 10;
const LL INF = 1e17;
int h[N], e[N << 1], ne[N << 1], w[N << 1], idx;
int n, k, added, delst, fw[N], fa[N];
int dfn[N], sz[N], nw[N], top[N], dep[N], son[N];
bool vis[N];

struct Sgt1 {
struct Node {
int l, r, id;
LL lt, mx;
} tr[N << 2];

void pushup(int x)
{
if (tr[x << 1].mx > tr[x << 1 | 1].mx)
tr[x].mx = tr[x << 1].mx, tr[x].id = tr[x << 1].id;
else
tr[x].mx = tr[x << 1 | 1].mx, tr[x].id = tr[x << 1 | 1].id;
}

void build(int x, int l, int r)
{
tr[x] = {l, r, l, 0, 0};
if (l == r) return;
int mid = (l + r) >> 1;
build(x << 1, l, mid), build(x << 1 | 1, mid + 1, r);
}

void update(int x, LL c) { tr[x].mx += c, tr[x].lt += c; }

void pushdown(int x) {
if (!tr[x].lt) return;
update(x << 1, tr[x].lt), update(x << 1 | 1, tr[x].lt);
tr[x].lt = 0;
}

void modify(int x, int l, int r, LL c)
{
if (tr[x].l > r || tr[x].r < l) return;
if (tr[x].l >= l && tr[x].r <= r) return update(x, c);
pushdown(x);
modify(x << 1, l, r, c), modify(x << 1 | 1, l, r, c);
pushup(x);
}
} seg1;

struct Sgt2 {
struct Node {
int l, r, mx, mn, lt;
} tr[N << 2];

void pushup(int x) {
tr[x].mx = std::max(tr[x << 1].mx, tr[x << 1 | 1].mx);
tr[x].mn = std::min(tr[x << 1].mn, tr[x << 1 | 1].mn);
}

void build(int x, int l, int r)
{
tr[x] = {l, r, 0, 0, 0};
if (l == r) return;
int mid = (l + r) >> 1;
build(x << 1, l, mid), build(x << 1 | 1, mid + 1, r);
}

void update(int x, int c) { tr[x].lt += c, tr[x].mx += c, tr[x].mn += c; }

void pushdown(int x) {
if (!tr[x].lt) return;
update(x << 1, tr[x].lt), update(x << 1 | 1, tr[x].lt);
tr[x].lt = 0;
}

void modify(int x, int l, int r, int c)
{
if (tr[x].l > r || tr[x].r < l) return;
if (tr[x].l >= l && tr[x].r <= r && (tr[x].mn >= delst || tr[x].mx + c <= added))
return update(x, c);
if (tr[x].l == tr[x].r) {
int to = 0;
if (tr[x].mn <= added) to --;
if (tr[x].mn + c <= added) to ++;
else if (tr[x].mn + c >= delst) to --;
update(x, c);
seg1.modify(1, tr[x].l, tr[x].l + sz[nw[tr[x].l]] - 1, to * fw[nw[tr[x].l]]);
return;
}
pushdown(x);
modify(x << 1, l, r, c), modify(x << 1 | 1, l, r, c);
pushup(x);
}
} seg2;

template <class T>
inline bool chkmax(T &x, T y) { return x < y ? x = y, 1 : 0; }

void add(int a, int b, int c)
{
e[idx] = b, ne[idx] = h[a], w[idx] = c, h[a] = idx ++;
}

void dfs1(int x, int fa = 0)
{
sz[x] = 1, ::fa[x] = fa, dep[x] = dep[fa] + 1;
for (int i = h[x], v; ~i; i = ne[i])
{
if ((v = e[i]) == fa) continue;
fw[v] = w[i], dfs1(v, x);
if (sz[v] > sz[son[x]]) son[x] = v;
sz[x] += sz[v];
}
}

void dfs2(int x, int tp = 1)
{
top[x] = tp, nw[dfn[x] = ++ *dfn] = x;
if (!son[x]) return;
dfs2(son[x], tp);
for (int i = h[x], v; ~i; i = ne[i])
if ((v = e[i]) != fa[x] && v != son[x]) dfs2(v, v);
}

void update(int x)
{
while (x) seg2.modify(1, dfn[top[x]], dfn[x], 1), x = fa[top[x]];
}

int main()
{
std::cin >> n >> k;
for (int i = 1; i <= n; ++ i) h[i] = -1;
for (int i = 1, u, v, w; i < n; ++ i)
{
scanf("%d %d %d", &u, &v, &w);
add(u, v, w), add(v, u, w);
}
dfs1(1), dfs2(1);
added = k / 2 - 1, delst = (k + 1) / 2;
seg1.build(1, 1, n), seg2.build(1, 1, n);
for (int i = 1; i <= n; ++ i)
seg1.modify(1, dfn[i], dfn[i] + sz[i] - 1, fw[i]);
/*for (int i = 1; i <= n; ++ i) printf("%d ", sz[i]);
puts("");*/
LL res = 0;
while (k --) {
int cur = nw[seg1.tr[1].id];
res += seg1.tr[1].mx;
// std::cout << seg1.tr[1].id << ' ' << seg1.tr[1].mx << '\n';
seg1.modify(1, dfn[cur], dfn[cur], -INF), update(cur);
}
std::cout << res * 2 << '\n';
return 0;
}