原题:https://fjnuacm.top/d/junior/p/512?tid=633d6550d2fe705a3c4684c7之所以来写这个题解,是因为思路真的太清晰啦((
题意
给定一段由 \(0, 1, 2\) 组成的二叉树序列 \(S\),序列由下面三种元素构成:
- \(0\):表示该树没有节点;
- \(1 S_1\):表示该数有一个节点,\(S_1\) 为其子树的二叉树序列;
- \(2 S_1 S_2\):表示该树有两个子节点,\(S_1\) 和 \(S_2\) 分别表示其两个子树的二叉树序列。
根据上述序列建树,并标上红蓝绿三种颜色,相邻颜色不能重复,子节点颜色不能重复,求出这棵树中绿色节点的最大和最小数量。
思路
我们考虑下面的两个问题。
如何建树
根据题给条件,在给出根节点后,后面将会有一段数字作为根节点的子树,而其子树又可向右找到他的子树,以此类推。
但我们要如何确定下一个节点从哪里开始呢?显然,在上一个子节点遍历完后,下一个下标即为另一个子节点的开始下标。
于是乎,我们可以记录一下当前的下标在哪个位置,然后.....
这里也有两种写法。
递归写法
我们只需向下一次调用传递当前下标的位置,并返回处理结束后的下标位置(也可以开一个全局变量存储下标,效果是一样的)即可。
对应代码
void buildTree(int father){
if(curIndex == inputTree.size()) return;
cnt[father] = inputTree[curIndex ++] - '0';
for(int i=0;i<cnt[father];i++) {
nodes[father][i] = ++tot;
buildTree(tot);
}
}
STL写法
我们不妨这么想,在找到根节点后,我们需要寻找它的两个子节点,而因为需要建树,我们需要知道这两个子节点对应的父节点是什么。
所以,我们可以使用一个数据结构存储这个根节点,在嵌套寻找的时候能正确获取上一个根节点,并能在两个子节点处理完后移除这个根节点。
这个数据结构满足一个特点:先进先出。
没错,就是栈结构。
对应代码
void buildTree() {
stack<pair<int, int>> root; //index, sum
int cur = 0;
root.push(pair<int, int>(++tot, inputTree[0] - '0'));
while (!root.empty()) {
pair<int, int> father = root.top();
root.pop();
int now = inputTree[++cur] - '0';
father.second--;
nodes[father.first][cnt[father.first]++] = ++tot; //建树
if (father.second > 0) root.push(father);
if (now > 0) root.push(pair<int, int>(tot, now));
}
}
如何dp
在说之前,先吐槽一句我的代码,它看起来好蠢
我们不妨用状态机的写法,开一个二维 \(dp\) 数组第一维为下标,第二维为当前节点的一个状态。
显然,作为一个节点,他有三种状态——红蓝绿。
初始状态
将叶节点的所有状态赋值 \(1\).
状态转移
首先,如果一个父节点要成为绿色,那么他的子节点一定是红色、蓝色,或者蓝色、红色。当然如果只有一个子节点,那么这个子节点就是蓝色或者红色。
所以,对于一个父节点,对于一种颜色,它总会有两种取法,而又因为两种取法不影响父节点的颜色,所以 \(dp\) 的最大值就是两种情况的最大值,最小值同理。
这是最直接的思路,而按照这么写,代码会很冗长。
对应代码
void dfs(int root) {
for (int i = 0; i < cnt[root]; i++) dfs(nodes[root][i]);
if (cnt[root] == 0) {
dpMax[root][1] = 1;
dpMin[root][1] = 1;
} else if (cnt[root] == 1) {
dpMax[root][0] = max(dpMax[nodes[root][0]][1], dpMax[nodes[root][0]][2]);
dpMax[root][1] = max(dpMax[nodes[root][0]][0], dpMax[nodes[root][0]][2]) + 1;
dpMax[root][2] = max(dpMax[nodes[root][0]][0], dpMax[nodes[root][0]][1]);
dpMin[root][0] = min(dpMin[nodes[root][0]][1], dpMin[nodes[root][0]][2]);
dpMin[root][1] = min(dpMin[nodes[root][0]][0], dpMin[nodes[root][0]][2]) + 1;
dpMin[root][2] = min(dpMin[nodes[root][0]][0], dpMin[nodes[root][0]][1]);
} else if (cnt[root] == 2) {
dpMax[root][0] = max(dpMax[nodes[root][0]][1] + dpMax[nodes[root][1]][2],
dpMax[nodes[root][1]][1] + dpMax[nodes[root][0]][2]);
dpMax[root][1] = max(dpMax[nodes[root][0]][0] + dpMax[nodes[root][1]][2],
dpMax[nodes[root][1]][0] + dpMax[nodes[root][0]][2]) + 1;
dpMax[root][2] = max(dpMax[nodes[root][0]][0] + dpMax[nodes[root][1]][1],
dpMax[nodes[root][1]][0] + dpMax[nodes[root][0]][1]);
dpMin[root][0] = min(dpMin[nodes[root][0]][1] + dpMin[nodes[root][1]][2],
dpMin[nodes[root][1]][1] + dpMin[nodes[root][0]][2]);
dpMin[root][1] = min(dpMin[nodes[root][0]][0] + dpMin[nodes[root][1]][2],
dpMin[nodes[root][1]][0] + dpMin[nodes[root][0]][2]) + 1;
dpMin[root][2] = min(dpMin[nodes[root][0]][0] + dpMin[nodes[root][1]][1],
dpMin[nodes[root][1]][0] + dpMin[nodes[root][0]][1]);
}
}
是不是很蠢,我看着就想笑
最终结果
根节点分别为红蓝绿时,所记录下来的最大值和最小值即为答案。
对应AC代码 (递归)
#include <bits/stdc++.h>
using namespace std;
//你问我啥用cpp写,因为Java栈溢出了
int tot;
int nodes[500010][2], dpMin[500010][3], dpMax[500010][3]; //0是红,1是绿,2是蓝,dp的值是绿色点的个数
int cnt[500010];
string inputTree;
int curIndex = 0;
void buildTree(int father){
if(curIndex == inputTree.size()) return;
cnt[father] = inputTree[curIndex ++] - '0';
for(int i=0;i<cnt[father];i++) {
nodes[father][i] = ++tot;
buildTree(tot);
}
}
void dfs(int root) { //好蠢
for (int i = 0; i < cnt[root]; i++) dfs(nodes[root][i]);
if (cnt[root] == 0) { //断子绝孙
dpMax[root][1] = 1;
dpMin[root][1] = 1;
} else if (cnt[root] == 1) { //一个节点
dpMax[root][0] = max(dpMax[nodes[root][0]][1], dpMax[nodes[root][0]][2]);
dpMax[root][1] = max(dpMax[nodes[root][0]][0], dpMax[nodes[root][0]][2]) + 1;
dpMax[root][2] = max(dpMax[nodes[root][0]][0], dpMax[nodes[root][0]][1]);
dpMin[root][0] = min(dpMin[nodes[root][0]][1], dpMin[nodes[root][0]][2]);
dpMin[root][1] = min(dpMin[nodes[root][0]][0], dpMin[nodes[root][0]][2]) + 1;
dpMin[root][2] = min(dpMin[nodes[root][0]][0], dpMin[nodes[root][0]][1]);
} else if (cnt[root] == 2) {
dpMax[root][0] = max(dpMax[nodes[root][0]][1] + dpMax[nodes[root][1]][2],
dpMax[nodes[root][1]][1] + dpMax[nodes[root][0]][2]);
dpMax[root][1] = max(dpMax[nodes[root][0]][0] + dpMax[nodes[root][1]][2],
dpMax[nodes[root][1]][0] + dpMax[nodes[root][0]][2]) + 1;
dpMax[root][2] = max(dpMax[nodes[root][0]][0] + dpMax[nodes[root][1]][1],
dpMax[nodes[root][1]][0] + dpMax[nodes[root][0]][1]);
dpMin[root][0] = min(dpMin[nodes[root][0]][1] + dpMin[nodes[root][1]][2],
dpMin[nodes[root][1]][1] + dpMin[nodes[root][0]][2]);
dpMin[root][1] = min(dpMin[nodes[root][0]][0] + dpMin[nodes[root][1]][2],
dpMin[nodes[root][1]][0] + dpMin[nodes[root][0]][2]) + 1;
dpMin[root][2] = min(dpMin[nodes[root][0]][0] + dpMin[nodes[root][1]][1],
dpMin[nodes[root][1]][0] + dpMin[nodes[root][0]][1]);
}
}
int main() {
cin >> inputTree;
buildTree(++ tot);
dfs(1);
cout << max(dpMax[1][0], max(dpMax[1][1], dpMax[1][2])) << " " << min(dpMin[1][0], min(dpMin[1][1], dpMin[1][2]));
}
对应AC代码 (STL)
#include <bits/stdc++.h>
using namespace std;
//你问我啥用cpp写,因为Java栈溢出了
int tot;
int nodes[500010][2], dpMin[500010][3], dpMax[500010][3]; //0是红,1是绿,2是蓝,dp的值是绿色点的个数
int cnt[500010];
string inputTree;
void buildTree() {
stack<pair<int, int>> root; //index, sum
int cur = 0;
root.push(pair<int, int>(++tot, inputTree[0] - '0'));
while (!root.empty()) {
pair<int, int> father = root.top();
root.pop();
int now = inputTree[++cur] - '0';
father.second--;
nodes[father.first][cnt[father.first]++] = ++tot; //建树
if (father.second > 0) root.push(father);
if (now > 0) root.push(pair<int, int>(tot, now));
}
}
void dfs(int root) { //好蠢
for (int i = 0; i < cnt[root]; i++) dfs(nodes[root][i]);
if (cnt[root] == 0) { //断子绝孙
dpMax[root][1] = 1;
dpMin[root][1] = 1;
} else if (cnt[root] == 1) { //一个节点
dpMax[root][0] = max(dpMax[nodes[root][0]][1], dpMax[nodes[root][0]][2]);
dpMax[root][1] = max(dpMax[nodes[root][0]][0], dpMax[nodes[root][0]][2]) + 1;
dpMax[root][2] = max(dpMax[nodes[root][0]][0], dpMax[nodes[root][0]][1]);
dpMin[root][0] = min(dpMin[nodes[root][0]][1], dpMin[nodes[root][0]][2]);
dpMin[root][1] = min(dpMin[nodes[root][0]][0], dpMin[nodes[root][0]][2]) + 1;
dpMin[root][2] = min(dpMin[nodes[root][0]][0], dpMin[nodes[root][0]][1]);
} else if (cnt[root] == 2) {
dpMax[root][0] = max(dpMax[nodes[root][0]][1] + dpMax[nodes[root][1]][2],
dpMax[nodes[root][1]][1] + dpMax[nodes[root][0]][2]);
dpMax[root][1] = max(dpMax[nodes[root][0]][0] + dpMax[nodes[root][1]][2],
dpMax[nodes[root][1]][0] + dpMax[nodes[root][0]][2]) + 1;
dpMax[root][2] = max(dpMax[nodes[root][0]][0] + dpMax[nodes[root][1]][1],
dpMax[nodes[root][1]][0] + dpMax[nodes[root][0]][1]);
dpMin[root][0] = min(dpMin[nodes[root][0]][1] + dpMin[nodes[root][1]][2],
dpMin[nodes[root][1]][1] + dpMin[nodes[root][0]][2]);
dpMin[root][1] = min(dpMin[nodes[root][0]][0] + dpMin[nodes[root][1]][2],
dpMin[nodes[root][1]][0] + dpMin[nodes[root][0]][2]) + 1;
dpMin[root][2] = min(dpMin[nodes[root][0]][0] + dpMin[nodes[root][1]][1],
dpMin[nodes[root][1]][0] + dpMin[nodes[root][0]][1]);
}
}
int main() {
cin >> inputTree;
buildTree();
dfs(1);
cout << max(dpMax[1][0], max(dpMax[1][1], dpMax[1][2])) << " " << min(dpMin[1][0], min(dpMin[1][1], dpMin[1][2]));
}
其实递归是写本题解的时候想到的,而有趣的是它反而是最优解。