COCI 2020/2021 Svjetlo
题目大意
求最短的树上路径(可以重复经过点或边)长度使得经过每个点的次数满足给定的奇偶性。树的大小为
N
N
N。
N
≤
500000
Nle 500000
N≤500000
题解
路径是可以重复的,简单的树形DP可能难以处理,考虑路径的拼接。设
f
i
,
j
,
k
f_{i,j,k}
fi,j,k表示第
i
i
i个点的子树内(除了自己)的奇偶性已经满足,且子树内(包括自己)的路径端点数有
j
j
j个,第
i
i
i个点的奇偶性为
k
k
k的最短路径长度,其中
j
∈
{
0
,
1
,
2
}
,
k
∈
{
0
,
1
}
jin{0,1,2},kin{0,1}
j∈{0,1,2},k∈{0,1}。转移的时候有很多种情况,但它们都是类似的,端点个数(状态第二维)的转移有:1、儿子子树内均为
0
0
0个端点 –> 自己子树
0
0
0个端点2、儿子子树内均为
0
0
0个端点 + 自己作为某一个端点 –> 自己子树内
1
1
1个端点3、儿子子树内均为
0
0
0个端点 + 自己作为两个端点 –> 自己子树内
2
2
2个端点4、一个儿子子树内
1
1
1个端点 + 其他儿子子树内均为
0
0
0个端点 –> 自己子树内
1
1
1个端点5、一个儿子子树内
1
1
1个端点 + 其他儿子子树内均为
0
0
0个端点 + 自己作为某个端点 –> 自己子树内
2
2
2个端点6、一个儿子子树内
2
2
2个端点 + 其他儿子子树内均为
0
0
0个端点 –> 自己子树内
2
2
2个端点7、两个儿子子树内各
1
1
1个端点 + 其他儿子子树内均为
0
0
0个端点 –> 自己子树内
2
2
2个端点第二维的
j
j
j可以理解为是伸出了多少个“头”,然后每个子树相连拼接上,再用剩下的“头”继续往上转移。
0
0
0和
2
2
2都是两个“头”,
1
1
1是一个“头”。儿子之间合并的时候要注意答案所求的是路径点的个数,所以不能把每个儿子所有“2”都延长
2
2
2的长度到父亲,不然儿子之间相接时会算重,而应该少延长一个”头“,最后更新答案时再只加多
1
1
1。如何保证儿子子树内的奇偶性都满足条件?如果从儿子节点尚不满足奇偶性的点转移时,需要多加上
2
2
2的长度,表示到了父亲再往下到儿子走一个来回。至于第三维是转移到当前节点的
0
0
0还是
1
1
1,需要看转移上来的偶儿子(指
j
j
j为偶数的儿子)个数的奇偶性。还要注意,根节点剩下的两个“头”会相连,不仅答案会减
1
1
1,而且对他而言奇偶性还会再多变一次。这样就做完了吗?写到后面可能会很容易忽略的是,这样写会默认每个节点都至少被经过一次,而其实并不然,所以根节点设为任意一个奇偶性条件为
1
1
1的点,同时在枚举儿子转移时,若整个子树都已经满足了就直接跳过。
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define N 500010
int last[N], nxt[N * 2], to[N * 2], len = 0;
int f[N][3][2], a[N], s[N];
void add(int x, int y) {
to[++len] = y;
nxt[len] = last[x];
last[x] = len;
}
void dfs(int k, int fa) {
int s0 = 0, s1 = 1e9, s2 = 1e9, s3 = 1e9, s4 = 1e9, s5 = 1e9, s6 = 1e9, s7 = 1e9;
int t0, t1;
if(!a[k]) s[k]++;
for(int i = last[k]; i; i = nxt[i]) if(to[i] != fa) {
int x = to[i];
dfs(x, k);
if(!s[to[i]]) continue;
s[k] += s[to[i]];
t0 = min(s7 + f[x][0][1] + 1, s6 + f[x][0][0] + 3), t1 = min(s6 + f[x][0][1] + 1, s7 + f[x][0][0] + 3);
s6 = t0, s7 = t1;
t0 = min(s2 + f[x][1][1], s3 + f[x][1][0] + 2), t1 = min(s3 + f[x][1][1], s2 + f[x][1][0] + 2);
s6 = min(s6, t0), s7 = min(s7, t1);
t0 = min(s5 + f[x][0][1] + 1, s4 + f[x][0][0] + 3), t1 = min(s4 + f[x][0][1] + 1, s5 + f[x][0][0] + 3);
s4 = t0, s5 = t1;
t0 = min(s1 + f[x][2][1] + 1, s0 + f[x][2][0] + 3), t1 = min(s0 + f[x][2][1] + 1, s1 + f[x][2][0] + 3);
s4 = min(s4, t0), s5 = min(s5, t1);
t0 = min(s3 + f[x][0][1] + 1, s2 + f[x][0][0] + 3), t1 = min(s2 + f[x][0][1] + 1, s3 + f[x][0][0] + 3);
s2 = t0, s3 = t1;
t0 = min(s0 + f[x][1][1], s1 + f[x][1][0] + 2), t1 = min(s1 + f[x][1][1], s0 + f[x][1][0] + 2);
s2 = min(s2, t0), s3 = min(s3, t1);
t0 = min(s1 + f[x][0][1] + 1, s0 + f[x][0][0] + 3), t1 = min(s0 + f[x][0][1] + 1, s1 + f[x][0][0] + 3);
s0 = t0, s1 = t1;
}
f[k][0][a[k]] = s1 + 1;
f[k][0][a[k] ^ 1] = s0 + 1;
f[k][1][a[k]] = min(s3, s1) + 1;
f[k][1][a[k] ^ 1] = min(s2, s0) + 1;
f[k][2][a[k]] = min(min(s0 + 2, s5 + 1), min(s6 + 2, s2 + 2));
f[k][2][a[k] ^ 1] = min(min(s1 + 2, s4 + 1), min(s7 + 2, s3 + 2));
}
int main() {
int n, i, x, y;
scanf("%d
", &n);
for(i = 1; i <= n; i++) {
a[i] = getchar() - '0';
}
for(i = 1; i < n; i++) {
scanf("%d%d", &x, &y);
add(x, y), add(y, x);
}
for(i = 1; i <= n; i++) if(!a[i]) break;
dfs(i, 0);
printf("%d
", f[i][2][0] - 1);
fclose(stdin);
fclose(stdout);
return 0;
}