跳转至

D - 红黑树

基本信息

题目出处2023 ICPC 亚洲区域赛南京站
队伍通过率19/342 (5.6%)

题解

\(f(u, x)\) 表示以 \(u\) 为根的子树是完美的,且从 \(u\) 到任意后代叶子节点的路径上都有 \(x\) 个黑色点需要的最少修改次数。有朴素的 dp 方程:

\[ f(u, x) = \min\limits_{i \in [0, 1]} (g(u, i) + \sum\limits_{v \in \text{son}(u)} f(v, x - i)) \]

其中 \(g(u, 0/1)\) 是让节点 \(u\) 变红/黑的代价。节点 \(u\) 的答案就是 \(\min\limits_x f(u, x)\)

可以归纳证明 \(f(u, x)\) 是关于 \(x\) 的凸序列:

  • \(g(u, 0/1)\) 是凸序列,因为这个序列只有两个点。
  • 根据归纳假设,\(f(v, x - i)\) 是凸序列。由于凸序列的和还是凸序列,因此 \(\sum\limits_{v \in \text{son}(u)} f(v, x - i)\) 也是凸序列。
  • 凸序列的 \((\min, +)\) 卷积也还是凸序列,因此 \(f(u, x)\) 是凸序列。

凸序列常用单调的差分数组进行维护。我们维护 \(h(u) = \{f(u, 1) - f(u, 0), f(u, 2) - f(u, 1), \cdots\}\),这个序列是单调递增的,因此 \(\min\limits_x f(u, x) = f(u, 0) + \sum\limits_{h(u, x) < 0} h(u, x)\),我们还要顺便维护差分数组 \(h\) 的负值之和。接下来我们看看差分数组 \(h\) 如何加快上述 dp 方程的计算。

首先是 \(f\) 求和的部分。原数组求和,差分数组也是求和。注意到 \(x\) 的取值范围是 \(u\) 所有子节点的最小深度 \(d\) 加一,因此我们只要暴力地把所有子节点长度为 \((d + 1)\) 的前缀加起来即可。这样做的复杂度是多少呢?大家可能知道,如果每个点是计算是把其它链合并到最长的链上,那么复杂度是线性的(对树进行长链剖分,每个点只会在长链的顶端被合并一次),而本题甚至是把其它链合并到最短的链上,因此复杂度肯定不会高于线性。

接下来考虑与 \(g(u, i)\) 进行 \((\min, +)\) 卷积。同样考虑 \(g\) 的差分,注意到 \(g(u, 1) - g(u, 0) = \pm 1\),因此我们还要支持往差分数组里插入一个 \(1\)\(-1\),并维持差分数组的单调性。

如果直接使用 set 维护差分数组,复杂度是 \(\mathcal{O}(n\log n)\) 的。这里注意到我们每次插入的数都是固定的 \(1\)\(-1\),因此可以考虑这样的数据结构:维护两个 vector,一个 vector 保存所有负数,一个 vector 保存所有正数,再开一个变量记录有几个 \(0\)。这样插入 \(1\) 就往正数 vector 的开头插,插入 \(-1\) 就往负数 vector 的末尾插。这样复杂度仍然是 \(\mathcal{O}(n)\) 的。

参考代码

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include <bits/stdc++.h>
#define MAXN ((int) 1e5)
using namespace std;

int n, ans[MAXN + 10];
char s[MAXN + 10];

vector<int> e[MAXN + 10];

// 维护差分数组的数据结构
struct Magic {
    // neg 保存所有负数的差分值,pos 保存所有正数的差分值
    vector<int> neg, pos;
    // zero 表示有几个差分值是 0,negSm = sum(neg)
    int zero, negSm;

    Magic(): zero(0), negSm(0) {};

    // 用差分数组 vec 初始化数据结构
    Magic(vector<int> &vec): zero(0), negSm(0) {
        for (int x : vec) {
            if (x < 0) neg.push_back(x), negSm += x;
            else if (x == 0) zero++;
            else pos.push_back(x);
        }
        reverse(pos.begin(), pos.end());
    }

    int size() {
        return neg.size() + pos.size() + zero;
    }

    // 取出差分数组中下标为 idx 的元素
    int at(int idx) {
        if (idx < neg.size()) return neg[idx];
        else if (idx < neg.size() + zero) return 0;
        else {
            idx -= neg.size() + zero;
            return pos[pos.size() - 1 - idx];
        }
    }

    // 往差分数组里插入 1 或 -1
    void insert(int x) {
        assert(x == 1 || x == -1);
        if (x == 1) pos.push_back(1);
        else neg.push_back(-1), negSm--;
    }
};

typedef pair<int, Magic> pim;

pim dfs(int sn) {
    int v = 0;
    Magic magic;

    for (int fn : e[sn]) {
        pim tmp = dfs(fn);
        v += tmp.first;
        if (magic.size() == 0) magic = move(tmp.second);
        else {
            int sz = min(magic.size(), tmp.second.size());
            // 只保留两个差分数组较短的前缀
            vector<int> vec(sz);
            for (int i = 0; i < sz; i++) vec[i] = magic.at(i) + tmp.second.at(i);
            magic = Magic(vec);
        }
    }

    // 根据 sn 原来的颜色,往差分数组里插入 1 或 -1
    v += s[sn] - '0';
    if (s[sn] == '0') magic.insert(1);
    else magic.insert(-1);

    ans[sn] = v + magic.negSm;
    return pim(v, move(magic));
}

void solve() {
    scanf("%d%s", &n, s + 1);
    for (int i = 1; i <= n; i++) e[i].clear();
    for (int i = 2; i <= n; i++) {
        int x; scanf("%d", &x);
        e[x].push_back(i);
    }
    dfs(1);
    for (int i = 1; i <= n; i++) printf("%d%c", ans[i], "\n "[i < n]);
}

int main() {
    int tcase; scanf("%d", &tcase);
    while (tcase--) solve();
    return 0;
}