树链剖分

Author Avatar
Axell 8月 06, 2019
  • 在其它设备中阅读本文章

难得有空,在yrc的怂恿下学了下树剖

概念&重点

重儿子:一个节点的最大儿子(子树中点数最多)
其余为轻儿子
重链:重儿子之间连成的链
同理有轻链 树剖即在重链的帮助下加快计算边上的信息(优化的暴力)

实现细节

TODO

代码

/**
 * 树链剖分
 * luogu P3384
 * 线段树不可以再打错了!
 */
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
#define lc (rt << 1)
#define rc ((rt << 1) | 1)
const int MAXN = 100005;
struct node
{
    int Next, y;
} Pth[MAXN << 1];
int head[MAXN], cnt;
void add(int x, int y)
{
    cnt++;
    Pth[cnt] = {head[x], y};
    head[x] = cnt;
}
ll tr[MAXN << 2], lazy[MAXN << 2], p;
int n, m, r, a[MAXN], b[MAXN], siz[MAXN], son[MAXN], dep[MAXN], fa[MAXN], id[MAXN], top[MAXN];

void build(int rt, int l, int r)
{
    if (l == r)
    {
        tr[rt] = b[l];
        return;
    }
    int mid = (l + r) >> 1;
    build(lc, l, mid);
    build(rc, mid + 1, r);
    tr[rt] = tr[lc] + tr[rc];
}

inline void down(int rt, int l, int r)
{
    if (!lazy[rt])
        return;
    lazy[lc] += lazy[rt];
    lazy[rc] += lazy[rt];
    int mid = (l + r) >> 1;
    tr[lc] = (tr[lc] + lazy[rt] * (mid - l + 1)) % p;
    tr[rc] = (tr[rc] + lazy[rt] * (r - mid)) % p;
    lazy[rt] = 0;
}

inline void up(int rt)
{
    tr[rt] = tr[lc] + tr[rc];
}

void change(int rt, int l, int r, int wl, int wr, int v)
{
    if (l == wl && r == wr)
    {
        tr[rt] = (tr[rt] + (r - l + 1) * v) % p;
        lazy[rt] += v;
        lazy[rt] %= p;
        return;
    }
    down(rt, l, r);
    int mid = (l + r) >> 1;
    if (mid >= wr)
        change(lc, l, mid, wl, wr, v);
    else if (mid < wl)
        change(rc, mid + 1, r, wl, wr, v);
    else
        change(lc, l, mid, wl, mid, v), change(rc, mid + 1, r, mid + 1, wr, v);
    up(rt);
}

ll query(int rt, int l, int r, int wl, int wr)
{
    if (l == wl && r == wr)
    {
        return tr[rt];
    }
    down(rt, l, r);
    int mid = (l + r) >> 1;
    if (mid >= wr)
        return query(lc, l, mid, wl, wr);
    else if (mid < wl)
        return query(rc, mid + 1, r, wl, wr);
    else
        return (query(lc, l, mid, wl, mid) + query(rc, mid + 1, r, mid + 1, wr)) % p;
}

void add1(int x, int y, int z)
{
    while (top[x] != top[y])
    {
        if (dep[top[x]] < dep[top[y]])
            swap(x, y);
        change(1, 1, n, id[top[x]], id[x], z);
        x = fa[top[x]];
    }
    if (dep[x] > dep[y])
        change(1, 1, n, id[y], id[x], z);
    else
        change(1, 1, n, id[x], id[y], z);
}

ll query1(int x, int y)
{
    ll ans = 0;
    while (top[x] != top[y])
    {
        if (dep[top[x]] < dep[top[y]])
            swap(x, y);
        ans = (ans + query(1, 1, n, id[top[x]], id[x])) % p;
        x = fa[top[x]];
    }
    if (dep[x] > dep[y])
        ans = (ans + query(1, 1, n, id[y], id[x])) % p;
    else
        ans = (ans + query(1, 1, n, id[x], id[y])) % p;
    return ans;
}

inline void add2(int x, int z)
{
    change(1, 1, n, id[x], id[x] + siz[x] - 1, z);
}

inline ll query2(int x)
{
    return query(1, 1, n, id[x], id[x] + siz[x] - 1);
}

void dfs1(int x, int prv)
{
    siz[x] = 1;
    for (int i = head[x]; i; i = Pth[i].Next)
    {
        int y = Pth[i].y;
        if (y == prv)
            continue;
        dep[y] = dep[x] + 1;
        fa[y] = x;
        dfs1(y, x);
        if (siz[y] > siz[son[x]] || son[x] == 0)
            son[x] = y;
        siz[x] += siz[y];
    }
}

void dfs2(int x, int tf)
{
    id[x] = ++cnt;
    b[cnt] = a[x];
    top[x] = tf;
    if (son[x])
        dfs2(son[x], tf);
    for (int i = head[x]; i; i = Pth[i].Next)
    {
        int y = Pth[i].y;
        if (y != fa[x] && y != son[x])
        {
            dfs2(y, y);
        }
    }
}

int main()
{
    cin >> n >> m >> r >> p;
    for (int i = 1; i <= n; ++i)
    {
        scanf("%d", &a[i]);
    }
    for (int i = 1; i < n; ++i)
    {
        int x, y;
        scanf("%d%d", &x, &y);
        add(x, y);
        add(y, x);
    }
    cnt = 0;
    dep[r] = 1;
    dfs1(r, 0);
    dfs2(r, r);
    build(1, 1, n);
    for (int i = 1; i <= m; ++i)
    {
        int ty;
        scanf("%d", &ty);
        if (ty == 1)
        {
            int x, y, z;
            scanf("%d%d%d", &x, &y, &z);
            add1(x, y, z);
        }
        if (ty == 2)
        {
            int x, y;
            scanf("%d%d", &x, &y);
            printf("%lld\n", query1(x, y));
        }
        if (ty == 3)
        {
            int x, z;
            scanf("%d%d", &x, &z);
            add2(x, z);
        }
        if (ty == 4)
        {
            int x;
            scanf("%d", &x);
            printf("%lld\n", query2(x));
        }
    }
    return 0;
}

知识共享许可协议
本作品采用知识共享署名-非商业性使用-相同方式共享 3.0 未本地化版本许可协议进行许可。

本文链接:https://hs-blog.axell.top/archives/%E6%A0%91%E9%93%BE%E5%89%96%E5%88%86/