跳转至

F - 多彩的线段

基本信息

题目出处2023 山东省大学生程序设计竞赛
队伍通过率5/276 (1.8%)

题解

错误做法

首先分析一种常见的错误做法。一些读者可能会将所有线段按右端点排序,然后维护 \(f(i)\) 表示只考虑前 \(i\) 条线段,且第 \(i\) 条线段必选有几种方法。枚举上一条选择的线段 \(j\),则递推式为

\[ f(i) = 1 + \sum\limits_{1 \le j < i, c_i = c_j} f(j) + \sum\limits_{1 \le j < i, c_i \ne c_j, r_j < l_i} f(j) \]

关于该做法的反例,考虑一条蓝色线段 \([2, 3]\) 以及两条红色线段 \([4, 5]\)\([1, 6]\),则 \(f(2) = 2\)(选择 \([4, 5]\) 时,\([2, 3]\) 可选可不选),\(f(3) = 1 + f(2) = 3\)。然而 \(f(3)\) 正确的值应该是 \(2\)(选择 \([1, 6]\) 时,\([4, 5]\) 可选可不选,\([2, 3]\) 一定不能选),这是因为 \(f(2)\) 包含的两种选法对于 \(f(3)\) 并非全部有效。

正确但复杂度较高的做法

首先仍然将所有线段按右端点排序。根据上一节的分析可知,我们不能将每条线段单独考虑,而是要考虑颜色相同的一整组线段。我们把所选线段分成若干连续段,每一段内的线段颜色都相同。

\(f(i, c)\) 表示只考虑前 \(i\) 条线段,且第 \(i\) 条线段必选,且第 \(i\) 条线段的颜色为 \(c\) 的方案数。转移时,我们枚举上一段的终点 \(j\)\(j\) 的颜色必须和 \(i\) 不同),则转移方程为

\[ f(i, c) = \sum\limits_{j = 0}^{p_i} f(j, 1 - c) \times 2^{g(i - 1, r_j + 1, c)} \]

其中 \(p_i\) 表示满足 \(r_{p_i} < l_i\) 的最大下标。\(g(i - 1, r_j + 1, c)\) 表示前 \((i - 1)\) 条线段中,左端点大于等于 \((r_j + 1)\),且颜色为 \(c\) 的线段数量。形象地理解,就是位于上一段终点右边,且颜色和当前线段一样的其它线段随便选。

方便起见,我们令 \(r_0 = 0\)\(f(0, 0) = f(0, 1) = 1\),这样就解决了初值的问题。答案就是 \(1 + \sum\limits_{i = 1}^n f(i, c_i)\)

优化复杂度

直接计算这个 dp 的复杂度是 \(\mathcal{O}(n^2)\) 的,考虑优化。

\(h(i, j, c) = f(j, c) \times 2^{g(i, r_j + 1, 1 - c)}\),则 dp 方程可以改写为

\[ f(i, c) = \sum\limits_{j = 0}^{p_i} h(i - 1, j, 1 - c) \]

注意到 \(g\) 有以下递推关系

\[ \begin{cases} g(i, r_j + 1, c) = g(i - 1, r_j + 1, c) + 1 & \text{if } 1 \le j \le p_i \text{ and } c = c_i \\ g(i, r_j + 1, c) = g(i - 1, r_j + 1, c) & \text{otherwise} \end{cases} \]

\[ \begin{cases} h(i, j, c) = 2h(i - 1, j, c) & \text{if } 1 \le j \le p_i \text{ and } c = 1 - c_i \\ h(i, j, c) = h(i - 1, j, c) & \text{otherwise} \end{cases} \]

形象地理解,就是处理完第 \(i\) 条线段之后,对于所有上一段的终点,可以任意选择的线段数量又增加了一条。

因此我们使用线段树维护 \(h(i, j, c)\) 的值。需要支持前缀乘 \(2\),单点赋值,以及前缀求和。复杂度 \(\mathcal{O}(n\log n)\)

参考代码

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
#include <bits/stdc++.h>
#define MAXN ((int) 1e5)
#define MOD 998244353
using namespace std;

int n;
long long ans;

struct Seg {
    int l, r, c;
} A[MAXN + 10];

long long sumo[2][MAXN * 4 + 10], lazy[2][MAXN * 4 + 10];

// 线段树初始化
void build(int tid, int id, int l, int r) {
    if (l != r) {
        int nxt = id << 1, mid = (l + r) >> 1;
        build(tid, nxt, l, mid); build(tid, nxt | 1, mid + 1, r);
    }
    sumo[tid][id] = 0; lazy[tid][id] = 1;
}

// 线段树 lazy 标志下推
void down(int tid, int id) {
    long long &p = lazy[tid][id];
    if (p == 1) return;
    int nxt = id << 1;
    sumo[tid][nxt] = sumo[tid][nxt] * p % MOD;
    lazy[tid][nxt] = lazy[tid][nxt] * p % MOD;
    sumo[tid][nxt | 1] = sumo[tid][nxt | 1] * p % MOD;
    lazy[tid][nxt | 1] = lazy[tid][nxt | 1] * p % MOD;
    p = 1;
}

// 线段树单点加法
void add(int tid, int id, int l, int r, int pos, long long val) {
    if (l == r) sumo[tid][id] = (sumo[tid][id] + val) % MOD;
    else {
        down(tid, id);
        int nxt = id << 1, mid = (l + r) >> 1;
        if (pos <= mid) add(tid, nxt, l, mid, pos, val);
        else add(tid, nxt | 1, mid + 1, r, pos, val);
        sumo[tid][id] = (sumo[tid][nxt] + sumo[tid][nxt | 1]) % MOD;
    }
}

// 线段树前缀 * 2
void mul2(int tid, int id, int l, int r, int pos) {
    if (r <= pos) {
        sumo[tid][id] = sumo[tid][id] * 2 % MOD;
        lazy[tid][id] = lazy[tid][id] * 2 % MOD;
    } else {
        down(tid, id);
        int nxt = id << 1, mid = (l + r) >> 1;
        mul2(tid, nxt, l, mid, pos);
        if (pos > mid) mul2(tid, nxt | 1, mid + 1, r, pos);
        sumo[tid][id] = (sumo[tid][nxt] + sumo[tid][nxt | 1]) % MOD;
    }
}

// 线段树询问前缀和
long long query(int tid, int id, int l, int r, int pos) {
    if (r <= pos) return sumo[tid][id];
    else {
        down(tid, id);
        int nxt = id << 1, mid = (l + r) >> 1;
        return (
            query(tid, nxt, l, mid, pos) +
            (pos > mid ? query(tid, nxt | 1, mid + 1, r, pos) : 0)
        ) % MOD;
    }
}

void solve() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) scanf("%d%d%d", &A[i].l, &A[i].r, &A[i].c);
    // 所有区间按右端点排序
    sort(A + 1, A + n + 1, [](Seg &a, Seg &b) {
        return a.r < b.r;
    });
    A[0].l = A[0].r = 0;

    // dp 初值:f(0, 0) = f(0, 1) = 1
    for (int k = 0; k < 2; k++) build(k, 1, 0, n), add(k, 1, 0, n, 0, 1);
    ans = 1;
    for (int i = 1; i <= n; i++) {
        // 二分找到右端点小于当前线段左端点的、最右边的线段
        int head = 0, tail = i - 1;
        while (head < tail) {
            int mid = (head + tail + 1) >> 1;
            if (A[mid].r < A[i].l) head = mid;
            else tail = mid - 1;
        }
        // 算出以当前线段为结尾的方案数
        long long val = query(A[i].c ^ 1, 1, 0, n, head);
        // h(i, i, c_i) = f(i, c_i)
        add(A[i].c, 1, 0, n, i, val);
        // h(i, j, 1 - c_i) = 2h(i - 1, j, 1 - c_i) if 1 <= j <= p_i
        mul2(A[i].c ^ 1, 1, 0, n, head);
        ans = (ans + val) % MOD;
    }
    printf("%lld\n", ans);
}

int main() {
    int tcase; scanf("%d", &tcase);
    while (tcase--) solve();
    return 0;
}