COCI 2020/2021 Svjetlo

题目大意

求最短的树上路径(可以重复经过点或边)长度使得经过每个点的次数满足给定的奇偶性。树的大小为

N

N

N

N

500000

Nle 500000

N500000

题解

路径是可以重复的,简单的树形DP可能难以处理,考虑路径的拼接。设

f

i

,

j

,

k

f_{i,j,k}

fi,j,k表示第

i

i

i个点的子树内(除了自己)的奇偶性已经满足,且子树内(包括自己)的路径端点数有

j

j

j个,第

i

i

i个点的奇偶性为

k

k

k的最短路径长度,其中

j

{

0

,

1

,

2

}

k

{

0

,

1

}

jin{0,1,2},kin{0,1}

j{0,1,2}k{0,1}。转移的时候有很多种情况,但它们都是类似的,端点个数(状态第二维)的转移有:1、儿子子树内均为

0

0

0个端点 –> 自己子树

0

0

0个端点2、儿子子树内均为

0

0

0个端点 + 自己作为某一个端点 –> 自己子树内

1

1

1个端点3、儿子子树内均为

0

0

0个端点 + 自己作为两个端点 –> 自己子树内

2

2

2个端点4、一个儿子子树内

1

1

1个端点 + 其他儿子子树内均为

0

0

0个端点 –> 自己子树内

1

1

1个端点5、一个儿子子树内

1

1

1个端点 + 其他儿子子树内均为

0

0

0个端点 + 自己作为某个端点 –> 自己子树内

2

2

2个端点6、一个儿子子树内

2

2

2个端点 + 其他儿子子树内均为

0

0

0个端点 –> 自己子树内

2

2

2个端点7、两个儿子子树内各

1

1

1个端点 + 其他儿子子树内均为

0

0

0个端点 –> 自己子树内

2

2

2个端点第二维的

j

j

j可以理解为是伸出了多少个“头”,然后每个子树相连拼接上,再用剩下的“头”继续往上转移。

0

0

0

2

2

2都是两个“头”,

1

1

1是一个“头”。儿子之间合并的时候要注意答案所求的是路径点的个数,所以不能把每个儿子所有“2”都延长

2

2

2的长度到父亲,不然儿子之间相接时会算重,而应该少延长一个”头“,最后更新答案时再只加多

1

1

1。如何保证儿子子树内的奇偶性都满足条件?如果从儿子节点尚不满足奇偶性的点转移时,需要多加上

2

2

2的长度,表示到了父亲再往下到儿子走一个来回。至于第三维是转移到当前节点的

0

0

0还是

1

1

1,需要看转移上来的偶儿子(指

j

j

j为偶数的儿子)个数的奇偶性。还要注意,根节点剩下的两个“头”会相连,不仅答案会减

1

1

1,而且对他而言奇偶性还会再多变一次。这样就做完了吗?写到后面可能会很容易忽略的是,这样写会默认每个节点都至少被经过一次,而其实并不然,所以根节点设为任意一个奇偶性条件为

1

1

1的点,同时在枚举儿子转移时,若整个子树都已经满足了就直接跳过。

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define N 500010
int last[N], nxt[N * 2], to[N * 2], len = 0;
int f[N][3][2], a[N], s[N];
void add(int x, int y) {
	to[++len] = y;
	nxt[len] = last[x];
	last[x] = len;
}
void dfs(int k, int fa) {
	int s0 = 0, s1 = 1e9, s2 = 1e9, s3 = 1e9, s4 = 1e9, s5 = 1e9, s6 = 1e9, s7 = 1e9; 
	int t0, t1;
	if(!a[k]) s[k]++;
	for(int i = last[k]; i; i = nxt[i]) if(to[i] != fa) {
		int x = to[i];
		dfs(x, k);
		if(!s[to[i]]) continue;
		s[k] += s[to[i]];
		t0 = min(s7 + f[x][0][1] + 1, s6 + f[x][0][0] + 3), t1 = min(s6 + f[x][0][1] + 1, s7 + f[x][0][0] + 3);
		s6 = t0, s7 = t1;
		t0 = min(s2 + f[x][1][1], s3 + f[x][1][0] + 2), t1 = min(s3 + f[x][1][1], s2 + f[x][1][0] + 2);
		s6 = min(s6, t0), s7 = min(s7, t1);
		
		t0 = min(s5 + f[x][0][1] + 1, s4 + f[x][0][0] + 3), t1 = min(s4 + f[x][0][1] + 1, s5 + f[x][0][0] + 3);
		s4 = t0, s5 = t1;
		t0 = min(s1 + f[x][2][1] + 1, s0 + f[x][2][0] + 3), t1 = min(s0 + f[x][2][1] + 1, s1 + f[x][2][0] + 3);
		s4 = min(s4, t0), s5 = min(s5, t1);
		
		t0 = min(s3 + f[x][0][1] + 1, s2 + f[x][0][0] + 3), t1 = min(s2 + f[x][0][1] + 1, s3 + f[x][0][0] + 3);
		s2 = t0, s3 = t1;
		t0 = min(s0 + f[x][1][1], s1 + f[x][1][0] + 2), t1 = min(s1 + f[x][1][1], s0 + f[x][1][0] + 2);
		s2 = min(s2, t0), s3 = min(s3, t1);
		
		t0 = min(s1 + f[x][0][1] + 1, s0 + f[x][0][0] + 3), t1 = min(s0 + f[x][0][1] + 1, s1 + f[x][0][0] + 3);
		s0 = t0, s1 = t1;
	
	}
	f[k][0][a[k]] = s1 + 1;
	f[k][0][a[k] ^ 1] = s0 + 1;
	f[k][1][a[k]] = min(s3, s1) + 1;
	f[k][1][a[k] ^ 1] = min(s2, s0) + 1;
	f[k][2][a[k]] = min(min(s0 + 2, s5 + 1), min(s6 + 2, s2 + 2));
	f[k][2][a[k] ^ 1] = min(min(s1 + 2, s4 + 1), min(s7 + 2, s3 + 2));
}
int main() {
	int n, i, x, y;
	scanf("%d
", &n);
	for(i = 1; i <= n; i++) {
		a[i] = getchar() - '0';
	}
	for(i = 1; i < n; i++) {
		scanf("%d%d", &x, &y);
		add(x, y), add(y, x);
	}
	for(i = 1; i <= n; i++) if(!a[i]) break;
	dfs(i, 0);
	printf("%d
", f[i][2][0] - 1);
	fclose(stdin);
	fclose(stdout);
	return 0;
}