- 「动态规划3」树型动态规划
树型dp - 三色二叉树 题解
- 2022-12-2 16:06:00 @
之所以来写这个题解,是因为思路真的太清晰啦((
题面
给出符合题目约定的一串数字,建树并标上红蓝绿三种颜色,相邻颜色不能重复,子节点颜色不能重复,求出这棵树中绿色节点的最大和最小数量。
>>我们把这道题分成两个问题来讨论>>
Q1 如何建树
根据题给条件,在给出根节点后,后面将会有一段数字作为根节点的子树,而其子树又可向右找到他的子树,以此类推。
但我们要如何确定下一个节点从哪里开始呢?显然,在上一个子节点遍历完后,下一个下标即为另一个子节点的开始下标。
于是乎,我们可以记录一下当前的下标在哪个位置,然后.....
这里也有两种写法。
1.最容易想到的是递归写法,我们只需向下一次调用传递当前下标的位置,并返回处理结束后的下标位置(也可以开一个全局变量存储下标,效果是一样的)即可。
下面给出使用递归的建树部分cpp代码,其中inputTree
为输入的那串数字。
void buildTree(int father){
if(curIndex == inputTree.size()) return;
cnt[father] = inputTree[curIndex ++] - '0';
for(int i=0;i<cnt[father];i++) {
nodes[father][i] = ++tot;
buildTree(tot);
}
}
2.STL解法
我们不妨这么想,在找到根节点后,我们需要寻找它的两个子节点,而因为需要建树,我们需要知道这两个子节点对应的父节点是什么。
所以,我们可以使用一个数据结构存储这个根节点,在嵌套寻找的时候能正确获取上一个根节点,并能在两个子节点处理完后移除这个根节点。
这个数据结构满足一个特点:先进先出
下面给出使用Stack的建树部分cpp代码,其中inputTree
为输入的那串数字。
void buildTree() {
stack<pair<int, int>> root; //index, sum
int cur = 0;
root.push(pair<int, int>(++tot, inputTree[0] - '0'));
while (!root.empty()) {
pair<int, int> father = root.top();
root.pop();
int now = inputTree[++cur] - '0';
father.second--;
nodes[father.first][cnt[father.first]++] = ++tot; //建树
if (father.second > 0) root.push(father);
if (now > 0) root.push(pair<int, int>(tot, now));
}
}
Q2 如何dp
在说之前,先吐槽一句我的代码,它看起来好蠢
树形dp的一种类型,第一维为下标,第二维为当前节点的一个状态。
显然,作为一个节点,他有三种状态——红蓝绿。
初始状态 我们将无子节点的节点的所有状态赋值1
.
状态转移
1.首先,如果一个父节点要成为绿色,那么他的子节点一定是红色、蓝色,或者蓝色、红色。当然如果只有一个子节点,那么这个子节点就是蓝色或者红色。
2.所以,对于一个父节点,对于一种颜色,它总会有两种取法,而又因为两种取法不影响父节点的颜色,所以dp的最大值就是两种情况的最大值,最小值同理。
这是最直接的思路,而按照这么写,代码会很冗长((
本段问题对应cpp代码
void dfs(int root) {
for (int i = 0; i < cnt[root]; i++) dfs(nodes[root][i]);
if (cnt[root] == 0) {
dpMax[root][1] = 1;
dpMin[root][1] = 1;
} else if (cnt[root] == 1) {
dpMax[root][0] = max(dpMax[nodes[root][0]][1], dpMax[nodes[root][0]][2]);
dpMax[root][1] = max(dpMax[nodes[root][0]][0], dpMax[nodes[root][0]][2]) + 1;
dpMax[root][2] = max(dpMax[nodes[root][0]][0], dpMax[nodes[root][0]][1]);
dpMin[root][0] = min(dpMin[nodes[root][0]][1], dpMin[nodes[root][0]][2]);
dpMin[root][1] = min(dpMin[nodes[root][0]][0], dpMin[nodes[root][0]][2]) + 1;
dpMin[root][2] = min(dpMin[nodes[root][0]][0], dpMin[nodes[root][0]][1]);
} else if (cnt[root] == 2) {
dpMax[root][0] = max(dpMax[nodes[root][0]][1] + dpMax[nodes[root][1]][2],
dpMax[nodes[root][1]][1] + dpMax[nodes[root][0]][2]);
dpMax[root][1] = max(dpMax[nodes[root][0]][0] + dpMax[nodes[root][1]][2],
dpMax[nodes[root][1]][0] + dpMax[nodes[root][0]][2]) + 1;
dpMax[root][2] = max(dpMax[nodes[root][0]][0] + dpMax[nodes[root][1]][1],
dpMax[nodes[root][1]][0] + dpMax[nodes[root][0]][1]);
dpMin[root][0] = min(dpMin[nodes[root][0]][1] + dpMin[nodes[root][1]][2],
dpMin[nodes[root][1]][1] + dpMin[nodes[root][0]][2]);
dpMin[root][1] = min(dpMin[nodes[root][0]][0] + dpMin[nodes[root][1]][2],
dpMin[nodes[root][1]][0] + dpMin[nodes[root][0]][2]) + 1;
dpMin[root][2] = min(dpMin[nodes[root][0]][0] + dpMin[nodes[root][1]][1],
dpMin[nodes[root][1]][0] + dpMin[nodes[root][0]][1]);
}
}
是不是很蠢,我看着就想笑
对了,结果你应该会算吧? 根节点分别为红蓝绿中的最大值和最小值即为答案。
ok,解毕。
AC(cpp)代码
1.递归写法
#include <bits/stdc++.h>
using namespace std;
//你问我啥用cpp写,因为Java栈溢出了
int tot;
int nodes[500010][2], dpMin[500010][3], dpMax[500010][3]; //0是红,1是绿,2是蓝,dp的值是绿色点的个数
int cnt[500010];
string inputTree;
int curIndex = 0;
void buildTree(int father){
if(curIndex == inputTree.size()) return;
cnt[father] = inputTree[curIndex ++] - '0';
for(int i=0;i<cnt[father];i++) {
nodes[father][i] = ++tot;
buildTree(tot);
}
}
void dfs(int root) { //好蠢
for (int i = 0; i < cnt[root]; i++) dfs(nodes[root][i]);
if (cnt[root] == 0) { //断子绝孙
dpMax[root][1] = 1;
dpMin[root][1] = 1;
} else if (cnt[root] == 1) { //一个节点
dpMax[root][0] = max(dpMax[nodes[root][0]][1], dpMax[nodes[root][0]][2]);
dpMax[root][1] = max(dpMax[nodes[root][0]][0], dpMax[nodes[root][0]][2]) + 1;
dpMax[root][2] = max(dpMax[nodes[root][0]][0], dpMax[nodes[root][0]][1]);
dpMin[root][0] = min(dpMin[nodes[root][0]][1], dpMin[nodes[root][0]][2]);
dpMin[root][1] = min(dpMin[nodes[root][0]][0], dpMin[nodes[root][0]][2]) + 1;
dpMin[root][2] = min(dpMin[nodes[root][0]][0], dpMin[nodes[root][0]][1]);
} else if (cnt[root] == 2) {
dpMax[root][0] = max(dpMax[nodes[root][0]][1] + dpMax[nodes[root][1]][2],
dpMax[nodes[root][1]][1] + dpMax[nodes[root][0]][2]);
dpMax[root][1] = max(dpMax[nodes[root][0]][0] + dpMax[nodes[root][1]][2],
dpMax[nodes[root][1]][0] + dpMax[nodes[root][0]][2]) + 1;
dpMax[root][2] = max(dpMax[nodes[root][0]][0] + dpMax[nodes[root][1]][1],
dpMax[nodes[root][1]][0] + dpMax[nodes[root][0]][1]);
dpMin[root][0] = min(dpMin[nodes[root][0]][1] + dpMin[nodes[root][1]][2],
dpMin[nodes[root][1]][1] + dpMin[nodes[root][0]][2]);
dpMin[root][1] = min(dpMin[nodes[root][0]][0] + dpMin[nodes[root][1]][2],
dpMin[nodes[root][1]][0] + dpMin[nodes[root][0]][2]) + 1;
dpMin[root][2] = min(dpMin[nodes[root][0]][0] + dpMin[nodes[root][1]][1],
dpMin[nodes[root][1]][0] + dpMin[nodes[root][0]][1]);
}
}
int main() {
cin >> inputTree;
buildTree(++ tot);
dfs(1);
cout << max(dpMax[1][0], max(dpMax[1][1], dpMax[1][2])) << " " << min(dpMin[1][0], min(dpMin[1][1], dpMin[1][2]));
}
2.STL写法
#include <bits/stdc++.h>
using namespace std;
//你问我啥用cpp写,因为Java栈溢出了
int tot;
int nodes[500010][2], dpMin[500010][3], dpMax[500010][3]; //0是红,1是绿,2是蓝,dp的值是绿色点的个数
int cnt[500010];
string inputTree;
void buildTree() {
stack<pair<int, int>> root; //index, sum
int cur = 0;
root.push(pair<int, int>(++tot, inputTree[0] - '0'));
while (!root.empty()) {
pair<int, int> father = root.top();
root.pop();
int now = inputTree[++cur] - '0';
father.second--;
nodes[father.first][cnt[father.first]++] = ++tot; //建树
if (father.second > 0) root.push(father);
if (now > 0) root.push(pair<int, int>(tot, now));
}
}
void dfs(int root) { //好蠢
for (int i = 0; i < cnt[root]; i++) dfs(nodes[root][i]);
if (cnt[root] == 0) { //断子绝孙
dpMax[root][1] = 1;
dpMin[root][1] = 1;
} else if (cnt[root] == 1) { //一个节点
dpMax[root][0] = max(dpMax[nodes[root][0]][1], dpMax[nodes[root][0]][2]);
dpMax[root][1] = max(dpMax[nodes[root][0]][0], dpMax[nodes[root][0]][2]) + 1;
dpMax[root][2] = max(dpMax[nodes[root][0]][0], dpMax[nodes[root][0]][1]);
dpMin[root][0] = min(dpMin[nodes[root][0]][1], dpMin[nodes[root][0]][2]);
dpMin[root][1] = min(dpMin[nodes[root][0]][0], dpMin[nodes[root][0]][2]) + 1;
dpMin[root][2] = min(dpMin[nodes[root][0]][0], dpMin[nodes[root][0]][1]);
} else if (cnt[root] == 2) {
dpMax[root][0] = max(dpMax[nodes[root][0]][1] + dpMax[nodes[root][1]][2],
dpMax[nodes[root][1]][1] + dpMax[nodes[root][0]][2]);
dpMax[root][1] = max(dpMax[nodes[root][0]][0] + dpMax[nodes[root][1]][2],
dpMax[nodes[root][1]][0] + dpMax[nodes[root][0]][2]) + 1;
dpMax[root][2] = max(dpMax[nodes[root][0]][0] + dpMax[nodes[root][1]][1],
dpMax[nodes[root][1]][0] + dpMax[nodes[root][0]][1]);
dpMin[root][0] = min(dpMin[nodes[root][0]][1] + dpMin[nodes[root][1]][2],
dpMin[nodes[root][1]][1] + dpMin[nodes[root][0]][2]);
dpMin[root][1] = min(dpMin[nodes[root][0]][0] + dpMin[nodes[root][1]][2],
dpMin[nodes[root][1]][0] + dpMin[nodes[root][0]][2]) + 1;
dpMin[root][2] = min(dpMin[nodes[root][0]][0] + dpMin[nodes[root][1]][1],
dpMin[nodes[root][1]][0] + dpMin[nodes[root][0]][1]);
}
}
int main() {
cin >> inputTree;
buildTree();
dfs(1);
cout << max(dpMax[1][0], max(dpMax[1][1], dpMax[1][2])) << " " << min(dpMin[1][0], min(dpMin[1][1], dpMin[1][2]));
}
其实递归是写本题解的时候想到的,而有趣的是它反而是最优解。
欢迎大佬们指点。