原题:https://fjnuacm.top/d/junior/p/512?tid=633d6550d2fe705a3c4684c7

之所以来写这个题解,是因为思路真的太清晰啦((

题意

给定一段由 \(0, 1, 2\) 组成的二叉树序列 \(S\),序列由下面三种元素构成:

  1. \(0\):表示该树没有节点;
  2. \(1 S_1\):表示该数有一个节点,\(S_1\) 为其子树的二叉树序列;
  3. \(2 S_1 S_2\):表示该树有两个子节点,\(S_1\)\(S_2\) 分别表示其两个子树的二叉树序列。

根据上述序列建树,并标上红蓝绿三种颜色,相邻颜色不能重复,子节点颜色不能重复,求出这棵树中绿色节点的最大最小数量。

思路

我们考虑下面的两个问题。

如何建树

根据题给条件,在给出根节点后,后面将会有一段数字作为根节点的子树,而其子树又可向右找到他的子树,以此类推。

但我们要如何确定下一个节点从哪里开始呢?显然,在上一个子节点遍历完后,下一个下标即为另一个子节点的开始下标。

于是乎,我们可以记录一下当前的下标在哪个位置,然后.....

这里也有两种写法。

递归写法

我们只需向下一次调用传递当前下标的位置,并返回处理结束后的下标位置(也可以开一个全局变量存储下标,效果是一样的)即可。

对应代码

void buildTree(int father){
    if(curIndex == inputTree.size()) return;
    cnt[father] = inputTree[curIndex ++] - '0';
    for(int i=0;i<cnt[father];i++) {
        nodes[father][i] = ++tot;
        buildTree(tot);
    }
}

STL写法

我们不妨这么想,在找到根节点后,我们需要寻找它的两个子节点,而因为需要建树,我们需要知道这两个子节点对应的父节点是什么。

所以,我们可以使用一个数据结构存储这个根节点,在嵌套寻找的时候能正确获取上一个根节点,并能在两个子节点处理完后移除这个根节点。

这个数据结构满足一个特点:先进先出。

没错,就是栈结构。

对应代码

void buildTree() {
    stack<pair<int, int>> root; //index, sum
    int cur = 0;
    root.push(pair<int, int>(++tot, inputTree[0] - '0'));
    while (!root.empty()) {
        pair<int, int> father = root.top();
        root.pop();
        int now = inputTree[++cur] - '0';
        father.second--;
        nodes[father.first][cnt[father.first]++] = ++tot; //建树
        if (father.second > 0) root.push(father);
        if (now > 0) root.push(pair<int, int>(tot, now));
    }
}

如何dp

在说之前,先吐槽一句我的代码,它看起来好蠢

我们不妨用状态机的写法,开一个二维 \(dp\) 数组第一维为下标,第二维为当前节点的一个状态。

显然,作为一个节点,他有三种状态——红蓝绿。

初始状态

将叶节点的所有状态赋值 \(1\).

状态转移

  1. 首先,如果一个父节点要成为绿色,那么他的子节点一定是红色、蓝色,或者蓝色、红色。当然如果只有一个子节点,那么这个子节点就是蓝色或者红色。

  2. 所以,对于一个父节点,对于一种颜色,它总会有两种取法,而又因为两种取法不影响父节点的颜色,所以 \(dp\) 的最大值就是两种情况的最大值,最小值同理。

这是最直接的思路,而按照这么写,代码会很冗长。

对应代码

void dfs(int root) {
    for (int i = 0; i < cnt[root]; i++) dfs(nodes[root][i]);
    if (cnt[root] == 0) {
        dpMax[root][1] = 1;
        dpMin[root][1] = 1;
    } else if (cnt[root] == 1) {
        dpMax[root][0] = max(dpMax[nodes[root][0]][1], dpMax[nodes[root][0]][2]);
        dpMax[root][1] = max(dpMax[nodes[root][0]][0], dpMax[nodes[root][0]][2]) + 1;
        dpMax[root][2] = max(dpMax[nodes[root][0]][0], dpMax[nodes[root][0]][1]);
        dpMin[root][0] = min(dpMin[nodes[root][0]][1], dpMin[nodes[root][0]][2]);
        dpMin[root][1] = min(dpMin[nodes[root][0]][0], dpMin[nodes[root][0]][2]) + 1;
        dpMin[root][2] = min(dpMin[nodes[root][0]][0], dpMin[nodes[root][0]][1]);
    } else if (cnt[root] == 2) {
        dpMax[root][0] = max(dpMax[nodes[root][0]][1] + dpMax[nodes[root][1]][2],
                             dpMax[nodes[root][1]][1] + dpMax[nodes[root][0]][2]);
        dpMax[root][1] = max(dpMax[nodes[root][0]][0] + dpMax[nodes[root][1]][2],
                             dpMax[nodes[root][1]][0] + dpMax[nodes[root][0]][2]) + 1;
        dpMax[root][2] = max(dpMax[nodes[root][0]][0] + dpMax[nodes[root][1]][1],
                             dpMax[nodes[root][1]][0] + dpMax[nodes[root][0]][1]);
        dpMin[root][0] = min(dpMin[nodes[root][0]][1] + dpMin[nodes[root][1]][2],
                             dpMin[nodes[root][1]][1] + dpMin[nodes[root][0]][2]);
        dpMin[root][1] = min(dpMin[nodes[root][0]][0] + dpMin[nodes[root][1]][2],
                             dpMin[nodes[root][1]][0] + dpMin[nodes[root][0]][2]) + 1;
        dpMin[root][2] = min(dpMin[nodes[root][0]][0] + dpMin[nodes[root][1]][1],
                             dpMin[nodes[root][1]][0] + dpMin[nodes[root][0]][1]);
    }
}

是不是很蠢,我看着就想笑

最终结果

根节点分别为红蓝绿时,所记录下来的最大值和最小值即为答案。

对应AC代码 (递归)

#include <bits/stdc++.h>

using namespace std;

//你问我啥用cpp写,因为Java栈溢出了
int tot;
int nodes[500010][2], dpMin[500010][3], dpMax[500010][3]; //0是红,1是绿,2是蓝,dp的值是绿色点的个数
int cnt[500010];
string inputTree;
int curIndex = 0;

void buildTree(int father){
    if(curIndex == inputTree.size()) return;
    cnt[father] = inputTree[curIndex ++] - '0';
    for(int i=0;i<cnt[father];i++) {
        nodes[father][i] = ++tot;
        buildTree(tot);
    }
}

void dfs(int root) { //好蠢
    for (int i = 0; i < cnt[root]; i++) dfs(nodes[root][i]);
    if (cnt[root] == 0) { //断子绝孙
        dpMax[root][1] = 1;
        dpMin[root][1] = 1;
    } else if (cnt[root] == 1) { //一个节点
        dpMax[root][0] = max(dpMax[nodes[root][0]][1], dpMax[nodes[root][0]][2]);
        dpMax[root][1] = max(dpMax[nodes[root][0]][0], dpMax[nodes[root][0]][2]) + 1;
        dpMax[root][2] = max(dpMax[nodes[root][0]][0], dpMax[nodes[root][0]][1]);
        dpMin[root][0] = min(dpMin[nodes[root][0]][1], dpMin[nodes[root][0]][2]);
        dpMin[root][1] = min(dpMin[nodes[root][0]][0], dpMin[nodes[root][0]][2]) + 1;
        dpMin[root][2] = min(dpMin[nodes[root][0]][0], dpMin[nodes[root][0]][1]);
    } else if (cnt[root] == 2) {
        dpMax[root][0] = max(dpMax[nodes[root][0]][1] + dpMax[nodes[root][1]][2],
                             dpMax[nodes[root][1]][1] + dpMax[nodes[root][0]][2]);
        dpMax[root][1] = max(dpMax[nodes[root][0]][0] + dpMax[nodes[root][1]][2],
                             dpMax[nodes[root][1]][0] + dpMax[nodes[root][0]][2]) + 1;
        dpMax[root][2] = max(dpMax[nodes[root][0]][0] + dpMax[nodes[root][1]][1],
                             dpMax[nodes[root][1]][0] + dpMax[nodes[root][0]][1]);
        dpMin[root][0] = min(dpMin[nodes[root][0]][1] + dpMin[nodes[root][1]][2],
                             dpMin[nodes[root][1]][1] + dpMin[nodes[root][0]][2]);
        dpMin[root][1] = min(dpMin[nodes[root][0]][0] + dpMin[nodes[root][1]][2],
                             dpMin[nodes[root][1]][0] + dpMin[nodes[root][0]][2]) + 1;
        dpMin[root][2] = min(dpMin[nodes[root][0]][0] + dpMin[nodes[root][1]][1],
                             dpMin[nodes[root][1]][0] + dpMin[nodes[root][0]][1]);
    }
}

int main() {
    cin >> inputTree;
    buildTree(++ tot);
    dfs(1);
    cout << max(dpMax[1][0], max(dpMax[1][1], dpMax[1][2])) << " " << min(dpMin[1][0], min(dpMin[1][1], dpMin[1][2]));
}

对应AC代码 (STL)

#include <bits/stdc++.h>

using namespace std;

//你问我啥用cpp写,因为Java栈溢出了
int tot;
int nodes[500010][2], dpMin[500010][3], dpMax[500010][3]; //0是红,1是绿,2是蓝,dp的值是绿色点的个数
int cnt[500010];
string inputTree;

void buildTree() {
    stack<pair<int, int>> root; //index, sum
    int cur = 0;
    root.push(pair<int, int>(++tot, inputTree[0] - '0'));
    while (!root.empty()) {
        pair<int, int> father = root.top();
        root.pop();
        int now = inputTree[++cur] - '0';
        father.second--;
        nodes[father.first][cnt[father.first]++] = ++tot; //建树
        if (father.second > 0) root.push(father);
        if (now > 0) root.push(pair<int, int>(tot, now));
    }
}

void dfs(int root) { //好蠢
    for (int i = 0; i < cnt[root]; i++) dfs(nodes[root][i]);
    if (cnt[root] == 0) { //断子绝孙
        dpMax[root][1] = 1;
        dpMin[root][1] = 1;
    } else if (cnt[root] == 1) { //一个节点
        dpMax[root][0] = max(dpMax[nodes[root][0]][1], dpMax[nodes[root][0]][2]);
        dpMax[root][1] = max(dpMax[nodes[root][0]][0], dpMax[nodes[root][0]][2]) + 1;
        dpMax[root][2] = max(dpMax[nodes[root][0]][0], dpMax[nodes[root][0]][1]);
        dpMin[root][0] = min(dpMin[nodes[root][0]][1], dpMin[nodes[root][0]][2]);
        dpMin[root][1] = min(dpMin[nodes[root][0]][0], dpMin[nodes[root][0]][2]) + 1;
        dpMin[root][2] = min(dpMin[nodes[root][0]][0], dpMin[nodes[root][0]][1]);
    } else if (cnt[root] == 2) {
        dpMax[root][0] = max(dpMax[nodes[root][0]][1] + dpMax[nodes[root][1]][2],
                             dpMax[nodes[root][1]][1] + dpMax[nodes[root][0]][2]);
        dpMax[root][1] = max(dpMax[nodes[root][0]][0] + dpMax[nodes[root][1]][2],
                             dpMax[nodes[root][1]][0] + dpMax[nodes[root][0]][2]) + 1;
        dpMax[root][2] = max(dpMax[nodes[root][0]][0] + dpMax[nodes[root][1]][1],
                             dpMax[nodes[root][1]][0] + dpMax[nodes[root][0]][1]);
        dpMin[root][0] = min(dpMin[nodes[root][0]][1] + dpMin[nodes[root][1]][2],
                             dpMin[nodes[root][1]][1] + dpMin[nodes[root][0]][2]);
        dpMin[root][1] = min(dpMin[nodes[root][0]][0] + dpMin[nodes[root][1]][2],
                             dpMin[nodes[root][1]][0] + dpMin[nodes[root][0]][2]) + 1;
        dpMin[root][2] = min(dpMin[nodes[root][0]][0] + dpMin[nodes[root][1]][1],
                             dpMin[nodes[root][1]][0] + dpMin[nodes[root][0]][1]);
    }
}

int main() {
    cin >> inputTree;
    buildTree();
    dfs(1);
    cout << max(dpMax[1][0], max(dpMax[1][1], dpMax[1][2])) << " " << min(dpMin[1][0], min(dpMin[1][1], dpMin[1][2]));
}

其实递归是写本题解的时候想到的,而有趣的是它反而是最优解。