Heavy-light Decomposition

From EOJ Wiki
Revision as of 06:57, 10 June 2018 by Zerol (talk | contribs)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)
Jump to navigation Jump to search

Weights are on the vertices, query vertex weight.

int fa[maxn], dep[maxn], idx[maxn], out[maxn], ridx[maxn];
namespace hld {
    int sz[maxn], son[maxn], top[maxn], clk;
    void predfs(int u, int d) {
        dep[u] = d; sz[u] = 1;
        int& maxs = son[u] = -1;
        for (int& v: G[u]) {
            if (v == fa[u]) continue;
            fa[v] = u;
            predfs(v, d + 1);
            sz[u] += sz[v];
            if (maxs == -1 || sz[v] > sz[maxs]) maxs = v;
        }
    }
    void dfs(int u, int tp) {
        top[u] = tp; idx[u] = ++clk; ridx[clk] = u;
        if (son[u] != -1) dfs(son[u], tp);
        for (int& v: G[u])
            if (v != fa[u] && v != son[u]) dfs(v, v);
        out[u] = clk;
    }
    template<typename T>
    int go(int u, int v, T&& f = [](int, int) {}) {
        int uu = top[u], vv = top[v];
        while (uu != vv) {
            if (dep[uu] < dep[vv]) { swap(uu, vv); swap(u, v); }
            f(idx[uu], idx[u]);
            u = fa[uu]; uu = top[u];
        }
        if (dep[u] < dep[v]) swap(u, v);
        // choose one
        // f(idx[v], idx[u]);    
        // if (u != v) f(idx[v] + 1, idx[u]);
        return v;
    }
}
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxn = 5e4 + 100;
vector<int> G[maxn];
// dep: depth of node
// sz: size of subtree
// son: the heavy son
// fa: father
// idx: the location of node in intervalTree
// top: the top vertex of node on a path
int dep[maxn], sz[maxn], son[maxn], fa[maxn], idx[maxn], top[maxn], clk;

struct BIT {
    int c[maxn];
    void init() {
        memset(c, 0, sizeof c);
    }

    void _add(int x, int d) {
        x += 5;
        while (x < maxn) {
            c[x] += d;
            x += x & -x;
        }
    }

    void add(int a, int b, int d) {
        assert (a <= b);
        // x[a...b] += d
        _add(a, d); _add(b + 1, -d);
    }

    int query(int x) {
        x += 5;
        int ret = 0;
        while (x) {
            ret += c[x];
            x -= x & -x;
        }
        return ret;
    }
} bit;

void predfs(int u, int d) {
    dep[u] = d;
    sz[u] = 1;
    int& maxs = son[u] = -1;
    for (int v: G[u])
        if (v != fa[u]) {
            fa[v] = u;
            predfs(v, d + 1);
            sz[u] += sz[v];
            if (maxs == -1 || sz[v] > sz[maxs])
                maxs = v;
        }
}

void dfs(int u, int tp) {
    top[u] = tp;
    idx[u] = ++clk;
    if (son[u] == -1) return;
    dfs(son[u], tp);
    for (int v: G[u])
        if (v != son[u] && v != fa[u])
            dfs(v, v);
}

void update(int u, int v, int add) {
    int uu = top[u], vv = top[v];
    while (uu != vv) {
        if (dep[uu] < dep[vv]) { swap(uu, vv); swap(u, v); }
        bit.add(idx[uu], idx[u], add);
        u = fa[uu];
        uu = top[u];
    }
    if (dep[u] < dep[v]) swap(u, v);
    bit.add(idx[v], idx[u], add);
}

int a[maxn];
int n, m, q;

int main() {
    int u, v, l, r, k, d;
    char s[10];
    while (~scanf("%d%d%d", &n, &m, &q)) {
        clk = 0;
        fa[1] = 0;
        bit.init();
        for (int i = 0; i <= n; ++i)
            G[i].clear();

        for (int i = 1; i <= n; ++i)
            scanf("%d", &a[i]);
        for (int i = 0; i < m; ++i) {
            scanf("%d%d", &u, &v);
            G[u].push_back(v);
            G[v].push_back(u);
        }

        predfs(1, 1);
        dfs(1, 1);

        for (int i = 1; i <= n; ++i)
            bit.add(idx[i], idx[i], a[i]);

        while (q--) {
            scanf("%s", s);
            if (s[0] == 'I') {
                scanf("%d%d%d", &l, &r, &d);
                update(l, r, d);
            } else if (s[0] == 'D') {
                scanf("%d%d%d", &l, &r, &d);
                update(l, r, -d);
            } else {
                scanf("%d", &k);
                printf("%d\n", bit.query(idx[k]));
            }
        }
    }
}

Wrapper version:

struct Tree {
    vector<int> G[maxn];
    int dep[maxn], sz[maxn], son[maxn], fa[maxn], idx[maxn], top[maxn], clk;

    void predfs(int u, int d) {
        dep[u] = d;
        sz[u] = 1;
        int& maxs = son[u] = -1;
        for (int v: G[u])
            if (v != fa[u]) {
                fa[v] = u;
                predfs(v, d + 1);
                sz[u] += sz[v];
                if (maxs == -1 || sz[v] > sz[maxs])
                    maxs = v;
            }
    }

    void dfs(int u, int tp) {
        top[u] = tp;
        idx[u] = ++clk;
        if (son[u] == -1) return;
        dfs(son[u], tp);
        for (int v: G[u])
            if (v != son[u] && v != fa[u])
                dfs(v, v);
    }

    void read_tree(int n) {
        clk = fa[1] = 0;
        for (int i = 0; i <= n; ++i)
            G[i].clear();
        for (int i = 2; i <= n; ++i) {
            int k; scanf("%d", &k);
            G[i].push_back(k);
            G[k].push_back(i);
        }
        predfs(1, 0);
        dfs(1, 1);
    }
};

int n, q;
Tree husband, wife;
map<pair<int, int>, vector<int> > cache;

tuple<int, int, int> get_ans(int x, int y) {
    int ans = 1;
    for (int i = x; i; i = husband.fa[husband.top[i]]) {
        for (int j = y; j; j = wife.fa[wife.top[j]]) {
            int u = husband.top[i], v = wife.top[j];
            if (cache.find({u, v}) != cache.end()) {
                vector<int> &tmp = cache[{u, v}];
                auto it = upper_bound(tmp.begin(), tmp.end(), min(i, j));
                if (it != tmp.begin()) {
                    it--;
                    if (husband.dep[*it] > husband.dep[ans])
                        ans = *it;
                }
            }
        }
    }
    return make_tuple(ans, husband.dep[x] - husband.dep[ans] + 1,
                           wife.dep[y] - wife.dep[ans] + 1);
}

int main() {
    while (~scanf("%d%d", &n, &q)) {
        husband.read_tree(n);
        wife.read_tree(n);
        cache.clear();
        for (int i = 1; i <= n; ++i) {
            pair<int, int> top_pair = make_pair(husband.top[i], wife.top[i]);
            cache[top_pair].push_back(i);
        }
        for (auto it: cache)
            sort(it.second.begin(), it.second.end());
        int K = 0, x, y;
        while (q--) {
            scanf("%d%d", &x, &y);
            x = (x + K) % n + 1;
            y = (y + K) % n + 1;
            tie(K, x, y) = get_ans(x, y);
            printf("%d %d %d\n", K, x, y);
        }
    }
}