线段树

「观前提醒」

「文章仅供学习和参考,如有问题请在评论区提出」


引入


线段树(Segment Tree)是算法竞赛中常用的用来维护区间信息的数据结构。

线段树可以在 \(O(logN)\) 的时间复杂度内实现单点修改、区间修改、区间查询等操作。能够用来维护很多类型的信息,包括但不仅限于求区间和、求区间最值、求区间最大子段和、求区间最大公约数等。


基本原理


这里先给出线段树的结构图,

从结构图我们可以看到,线段树是一颗平衡二叉树,而且每个节点都对应了一个区间值。

然后我们发现,对于编号为 \(p\) 的节点,它的左右儿子节点编号分别为 \(2p\)\(2p + 1\) 。同理,我们也能够推出编号为 \(p\) 的节点的父节点编号为 $\left \lfloor \frac{p}{2} \right \rfloor $ 。

现在我们把图稍微改变一下,

我们发现,对于任意一对左右节点的区间范围,是通过其父节点的区间区间折半推导出来的。

\(tr_{p}\) 所维护的区间是 \([l, r]\) ,然后取 $mid = \left \lfloor \frac{l + r}{2} \right \rfloor $ ,那么 \(tr_{p}\) 的左儿子节点 \(tr_{2p}\) 表示区间就是 \([l, mid]\),右儿子节点 \(tr_{2p + 1}\) 表示的区间是 \([mid + 1, r]\)

这里我们注意到,对于编号为 \(p\) 的节点,它的左右节点所表示的区间长度不一定会相等。但是这样并不会影响线段树的整体结构,因为每个节点都会不断往下细分,直到区间长度为 \(1\)

所以这里只需要满足,一对左右儿子节点所表示的区间的加和就是其父亲节点所表示的区间就行,即 \(tr_{p} = tr_{2p} + tr_{2p + 1}\) (这里指的是区间相加)。


建树


根据上面所理解的线段树的存储逻辑,我们需要一层一层地向下折半拆分,直到区间长度为 \(1\) 时再对其进行赋值。那么我们就可以考虑用递归来实现这个过程。

而且又因为每个节点维护的是某一段的区间值,对于所要维护的区间值(区间和、区间最值等)不同,建树的逻辑也会有些许的不同,但是总体思路是相同的。那么我们这里就以求区间加和问题来举例。

我们每次分别进行递归遍历自己的左右节点,直到区间长度为 \(1\) 时,对自身进行赋值。然后每个节点再根据已经更新好后的左右节点值,对自身进行赋值。

这里我们可以专门写一个 pushup() 函数来进行父节点的更新,因为越往后看就会发现更新父节点是个重复操作,没有必要重复写。

这样的话,我们就需要分层遍历每个节点,所以建树的时间复杂度就是 \(O(N)\)

代码(以求区间和为例)

const int N = 1e5 + 10;	// 根据序列数据范围而定

// 对于存储每个节点需要 2n 的空间,但每次会 ×2 来获取子节点,所以一般还要再多开 2n 的空间来防止越界
struct Node {	// 每个节点所存储的信息
    int l, r;	// 当前节点维护的区间 [l, r]
    int sum;		// 当前节点所维护的区间值
} tr[N * 4];	// tr[] 存储线段树数据

int a[N];	// a[] 存储原数组数据

// 计算 u 节点区间值并更新
void pushup(int u) {
    // 根据已经更新好的左右节点值来更新自身
    tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}

// 建树函数
// u 是节点编号,[l, r] 是此时编号为 p 的节点所要代表的区间
void build(int u, int l, int r) {
    if (l == r) tr[u] = {l, r, a[l]};	// 此时区间长度为 1,进行赋值
	else {
        tr[u] = {l, r};	// 更新当前节点区间
        int mid = l + r >> 1;	// 位运算,相当于 (l + r) / 2
        
        // 分别向下递归左右节点
        build(u << 1, l, mid);			// u << 1,位运算,相当于 u * 2
        build(u << 1 | 1, mid + 1, r);	// u << 1 | 1, 位运算,相当于 u * 2 + 1
		
        // 计算当前节点区间值并更新
        pushup(u);
    }
}

// 建树

int n;	// 序列长度

build(1, 1, n);

区间查询


若我们想要查询区间 \([l, r]\) 内每个数的加和,我们只需要从根节点开始,用每个节点所维护的区间 \([a, b]\) 和要查询的区间 \([l, r]\) 进行对比。那么就会有三种情况:(我们取 $mid = \left \lfloor \frac{a + b}{2} \right \rfloor $ )

  1. 如果 \(l \le a \le b \le r\) ,直接返回当前节点值。

  2. 如果 \(l \le mid\) , 还需继续往下访问左节点。

  3. 如果 \(mid + 1\le r\) ,还需继续往下访问右节。

然后就是递归处理这三种情况,最终返回所要查询的区间值。

因为一个区间 \([l, r]\) ,最多可以将其拆成 \(logN\) 个极大的区间,所以每次查询的时间复杂度为 \(O(logN)\)

代码实现 (以求区间和为例)

struct Node {
    int l, r;	// 区间 [l, r]
    int sum;	// 区间值
} tr[N * 4];

// 查询操作, u 是节点编号, [l, r] 是要查询的区间
void query(int u, int l, int r) {
    if (tr[u].l >= l && tr[u].r <= r) // 当前节点区间被 [l, r] 完全包含
        return tr[u].sum;	// 返回当前节点值
    
    int mid = tr[u].l + tr[u].r >> 1;
    int sum = 0;

    if (l <= mid) sum += query(u << 1, l, r);	// 遍历左节点
    if (r >= mid + 1) sum += query(u << 1 | 1, l, r);	// 遍历右节点

    return sum;	// 返回加和的值
}

单点修改


对于线段的单点修改逻辑很简单,就是不断地向下拆分区间,直到区间长度为 \(1\) 的叶子节点。然后对其进行修改,同时在修改完回溯的时候,再用 pushup() 更新父节点的值。

这里的向下遍历的思路和区间查询差不多,就是判断要修改的位置是在左节点还是右节点。

这样的话,我们最多就会遍历 \(logN\) 层,所以单点修改的时间复杂度为 \(O(logN)\)

代码实现 (以求区间和为例)

struct Node {
    int l, r;	// 区间 [l, r]
    int sum;	// 区间加和值
} tr[N * 4];

// 更新节点 u 的值
void pushup(int u) {
    tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}

// 修改操作, u 是节点编号, x 是要修改的位置, v 是要修改成的值
void modify(int u, int x, int v) {
    if (tr[u].l == x && tr[u].r == x)	// 查询到节点
        tr[u].sum = v;	// 修改节点值
    else {
        int mid = tr[u].l + tr[u].r >> 1;
        // 遍历左右节点
        if (x <= mid) modify(u << 1, x, v);
        if (x >= mid + 1) modify(u << 1 | 1, x, v);
        
        pushup(u);	// 修改后, 在回溯时对节点值进行更新
    }
}

区间修改 + 懒惰标记


对于区间修改,我们不可能遍历每个节点,然后更新 \([l, r]\) 区间里的每个节点值,因为这样的时间复杂度是 \(O(N)\) 的。而这种时间复杂度对于线段树来说是不可取的。

那么我们应该怎么进行修改呢?接下来我们就需要用到懒惰标记

懒惰标记,就是对节点信息的延迟更新,即用多少,就更新多少的方式,以此来降低修改操作所耗费的时间。每次执行修改操作的时候,我们通过打标记的方式表明该节点对应的区间在某一次操作中被修改,但是不更新该节点的子节点信息。然后在下一次访问该节点时才会下放标记并且更新子节点。

同区间查询相同,对于一个要修改的区间 \([l, r]\) ,最多可以将其拆成 \(logN\) 个极大的区间,所以每次区间修改的时间复杂度就被降为了 \(O(logN)\)

这里我们也可以写一个 pushdown() 函数来进行子节点数值的更新和标记下放。

因为一旦访问到被打上标记的节点,就需要执行 pushdown() 操作,从而使访问能够继续向下传递。所以在执行区间查询区间修改的时候,都需要在访问左右节点前进行 pushdown() 操作来更新左右节点的信息。

代码实现 (维护区间和,修改操作是增加 \(d\)

struct Node {
    int l, r;	// 区间 [l, r]
    int sum;	// 区间值
    int add;	// 懒惰标记, 表示要增加的值
} tr[N * 4];

// 更新节点 u 的值
void pushup(int u) {
    tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}

// 更新节点 u 左右节点的值并下放标记
void pushdown(int u) {
    Node &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
    
    if (root.add) {	// 当前节点有标记才执行
        // 更新左右节点的值
        // 因为维护的是区间和, 区间内每个节点加上 root.add, 区间值就相当于加了 区间长度 * root.add
        left.sum += (left.r - left.l + 1) * root.add;
        rignt.sum = (right.r - right.l + 1) * root.add;
        
        // 给左右节点下方标记
        left.add += root.add;
        right.add += root.add;
        
        // 清空当前节点标记
        root.add = 0;
    }
}

// 修改操作, u 是节点编号, [l, r] 是要修改的区间, d 是要增加的值
void modify(int u, int l, int r, int d) {
    if (tr[u].l >= l && tr[u].r <= r) {	// 区间被完全包含, 更新当前节点并打上标记
        tr[u].sum += (tr[u].r - tr[u].l  + 1) * d;	// 更新节点值
        tr[u].add += d;	// 打标记
    } else {
        pushdown(u);	// 先更新左右节点值后再进行访问
        
        int mid = tr[u].l + tr[u].r >> 1;
        
        // 访问左右节点
        if (l <= mid) modify(u << 1, l, r, d);
        if (r >= mid + 1) modiy(u << 1 | 1, l, r, d);
        
        pushup(u);	// 有节点被修改, 需要回溯时进行更新
    }
}

// 查询操作, u 是节点编号, [l, r] 是要查找的区间
void query(int u, int l, int r) {
    if (tr[u].l >= l && tr[u].r <= r)	// 区间被完全包含, 直接返回当前值
        return tr[u].sum;
    
    pushdown(u);	// 先更新左右节点值后再进行访问
    
    int mid = tr[u].l + tr[u].r >> 1;
    int sum = 0;
    
    // 访问左右节点, 求加和值
    if (l <= mid) sum += query(u << 1, l, r);
    if (r >= mid + 1) sum += query(u << 1 | 1, l, r);
    
    return sum;	// 返回加和值
}

例题


P3372 【模板】线段树 1 - 洛谷

题目描述

  1. 1 x y k :将区间 \([x, y]\) 内每个数加上 \(k\)

  2. 2 x y :输出区间 \([x, y]\) 内每个数的和。

代码

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;

const int N = 1e5 + 10;

struct Node {
    int l, r;
    LL sum;
    LL add;
} tr[N * 4];

int a[N];

void pushup(int u) {
    tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}

void pushdown(int u) {
    Node &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];

    if (root.add) {
        left.add += root.add, right.add += root.add;

        left.sum += (LL)(left.r - left.l + 1) * root.add;
        right.sum += (LL)(right.r - right.l + 1) * root.add;

        root.add = 0;
    }
}

void build(int u, int l, int r) {
    if (l == r) tr[u] = {l, r, a[l]};
    else {
        tr[u] = {l, r};
        
        int mid = l + r >> 1;

        build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);

        pushup(u);
    }
}

void modify(int u, int l, int r, int d) {
    if (tr[u].l >= l && tr[u].r <= r) {
        tr[u].sum += (LL)(tr[u].r - tr[u].l + 1) * d;
        tr[u].add += d;
    } else {
        pushdown(u);

        int mid = tr[u].l + tr[u].r >> 1;

        if (l <= mid) modify(u << 1, l, r, d);
        if (r > mid) modify(u << 1 | 1, l, r, d);

        pushup(u);
    }
}

LL query(int u, int l, int r) {
    if (tr[u].l >= l && tr[u].r <= r) 
        return tr[u].sum;

    pushdown(u);

    int mid = tr[u].l + tr[u].r >> 1;
    LL sum = 0;

    if (l <= mid) sum += query(u << 1, l, r);
    if (r > mid) sum += query(u << 1 | 1, l, r);

    return sum;
}

int main() {
    int n, m;
    cin >> n >> m;

    for (int i = 1; i <= n; i++) scanf("%d", &a[i]);

    build(1, 1, n);

    while (m--) {
        int op, x, y, k;
        scanf("%d", &op);

        if (op == 1) {
            scanf("%d%d%d", &x, &y, &k);
            modify(1, x, y, k);
        } else {
            scanf("%d%d", &x, &y);
            printf("%lld\n", query(1, x, y));
        }
    }
    return 0;
}

P3373 【模板】线段树 2 - 洛谷

题目描述

1 x y k:将区间 \([x, y]\) 内每个数乘上 \(k\)

2 x y k :将区间 \([x, y]\) 内每个数加上 \(k\)

3 x y :输出区间 \([x, y]\) 内每个数的和对 \(m\) 取模所得的结果。

代码

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;

const int N = 1e5 + 10;

struct Node {
    int l, r;
    LL sum;
    LL mul;
    LL add;
} tr[N * 4];

int a[N];
int n, q, p;

void pushup(int u) {
    tr[u].sum = (tr[u << 1].sum + tr[u << 1 | 1].sum) % p;
}

void pushdown(int u) {
    Node &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];

    left.sum = (left.sum * root.mul + (left.r - left.l + 1) * root.add) % p;
    right.sum = (right.sum * root.mul + (right.r - right.l + 1) * root.add) % p;

    left.mul = (left.mul * root.mul) % p;
    right.mul = (right.mul * root.mul) % p;

    left.add = (left.add * root.mul + root.add) % p;
    right.add = (right.add * root.mul + root.add) % p;

    root.mul = 1, root.add = 0;
}

void build(int u, int l, int r) {
    tr[u].l = l, tr[u].r = r, tr[u].mul = 1, tr[u].add = 0;
    
    if (l == r) tr[u].sum = a[l];
    else {
        int mid = l + r >> 1;

        build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);

        pushup(u);
    }
}

void modify(int u, int l, int r, int d) {
    if (tr[u].l >= l && tr[u].r <= r) {
        if (d >= 0) {	// 加法操作
            tr[u].sum = (tr[u].sum + (LL)(tr[u].r - tr[u].l + 1) * d) % p;
            tr[u].add = (tr[u].add + d) % p;
        } else {	// 乘法操作
            d = -d;
            tr[u].sum = (tr[u].sum * d) % p;
            tr[u].mul = (tr[u].mul * d) % p;
            tr[u].add = (tr[u].add * d) % p;
        }
    } else {
        pushdown(u);

        int mid = tr[u].l + tr[u].r >> 1;

        if (l <= mid) modify(u << 1, l, r, d);
        if (r > mid) modify(u << 1 | 1, l, r, d);

        pushup(u);
    }
}

LL query(int u, int l, int r) {
    if (tr[u].l >= l && tr[u].r <= r) 
        return tr[u].sum;

    pushdown(u);

    int mid = tr[u].l + tr[u].r >> 1;
    LL sum = 0;

    if (l <= mid) sum = (sum + query(u << 1, l, r)) % p;
    if (r > mid) sum = (sum + query(u << 1 | 1, l, r)) % p;

    return sum;
}

int main() {
    int n, m;
    cin >> n >> m >> p;

    for (int i = 1; i <= n; i++) scanf("%d", &a[i]);

    build(1, 1, n);

    while (m--) {
        int op, x, y, k;
        scanf("%d", &op);

        if (op == 1) {
            scanf("%d%d%d", &x, &y, &k);
            modify(1, x, y, -k);
        } else if (op == 2) {
            scanf("%d%d%d", &x, &y, &k);
            modify(1, x, y, k);
        } else {
            scanf("%d%d", &x, &y);
            printf("%lld\n", query(1, x, y));
        }
    }
    return 0;
}

小结


这里只是讲了线段树的基本用法,例题也都是模板题。其实对于各种线段树的题目来说,最难的就是如何去维护懒惰标记。因为平时遇到的线段树题目都会有着相应的变形,所以都会比较难。而且因为代码量很大,很容易出错,在重新检查、找bug的时候也比较麻烦。

所以还是得多做题,多熟悉,才能更熟练地使用。


参考资料


线段树 - OI Wiki:https://oi-wiki.org/ds/seg/

算法学习笔记(14): 线段树 - 知乎:https://zhuanlan.zhihu.com/p/106118909


热门相关:有个人爱你很久   闺范   闺范   3对1是第一次吗   大妆