树统计 (dfs序+线段树)

    技术2025-06-04  29

    时间限制: 1 Sec  内存限制: 128 MB

    题目描述

    然而,这一切宛如一度揉过的复写纸,无不同原来有着少许然而却是无可挽回的差异。—— 村上春树 关于树的算法有一大堆,样样都是毒瘤。 比如说 2019 CSP-S 的树论题,如果擅长树形数据结构马上想到正解,但是 3edc2wsx1qaz 并不擅长,就只好骗分了。 3edc2wsx1qaz 当时数组开小了,惨遭 RE,3edc2wsx1qaz 一想起这事,不禁夙夜忧叹,辗转反侧。 现在他又遇到一道毒瘤的树上问题了,他下定决心:这次一定要写出正解! 题目是这样的: 有一颗有n个点的树,每条边有一个权值ai。树的根节点为1号节点。定义一对点对(u,v)的距离dist(u,v)为在u到v的简单路径上的所有边的边权的异或。 你需要进行q次操作,操作分为两种: 1.将x点与它父亲所连的边的边权异或w。 2.询问以节点y为根的子树中所有点对的距离之和,答案对998244353取模 也就是说,对于每次 2 操作,设以节点y为根的子树的节点集合为subtrss(y), 你需要求出以下式子的值:  

    输入

    为了方便你获取部分分,我们会告诉你测试点编号。 第一行输入三个正整数n,q,r(2≤n≤10^5,2≤q≤10^5,1≤r≤50),表示树的节点数,操作数,该测试点编号。 接下来n-1行每行三个正整数u,v,w,表示有一条连接u,v,权值为w的边。(1≤u≤n,1≤v≤n,0≤w<2^10) 接下来q行,每行开头输入一个数opt(opt=1 或opt=2 ),表示操作类型。 若opt=1,则再输入两个数x,w(1<x≤n,0≤w<2^10),表示将x号点与它父亲所连的边的边权异或w。 若opt=2 ,则再输入一个数y,表示一次询问,你需要输出以节点y为根的子树中所有点对的距离之和。

    输出

    输出若干行,对于每次 2 操作,输出一个正整数,表示答案。

    样例输入 Copy

    8 8 0 2 1 0 3 1 0 4 3 0 5 2 1 6 5 1 7 5 0 8 1 0 1 4 0 2 7 1 3 0 2 5 1 5 1 1 4 0 1 5 0 2 1

    样例输出 Copy

    0 4 14

    提示

    样例解释: 由于这组数据为样例,所以r=0。 保证测试数据中1≤r≤50

     

    对于一颗子树内的任意两点(x,y)之间距离(如题所述的距离)为dis(1,x)^dis(1,y)

    那么我们知道统计子树按位拆开后对应二进制为1的个数即可。

    对于更新操作,修改边权为w,枚举二进制位,当且仅当w对应的二进制为1时必发生当前点及子树对应位0和1个数的互换,用线段树维护即可。

     

    /**/ #include <cstdio> #include <cstring> #include <cmath> #include <cctype> #include <iostream> #include <algorithm> #include <map> #include <set> #include <vector> #include <string> #include <stack> #include <queue> typedef long long LL; using namespace std; const long long mod = 998244353; const int maxn = 1e5 + 5; int n, q, r, tot, cnt; int head[maxn], dfn[maxn], sz[maxn], son[maxn], top[maxn], f[maxn], id[maxn], w[maxn]; int tr[11][maxn << 2], lzy[11][maxn << 2]; struct node { int v, w, next; }a[maxn << 1]; void dfs(int x, int pre){ f[x] = pre; sz[x] = 1; for (int i = head[x]; i != -1; i = a[i].next){ int v = a[i].v; if(v == pre) continue; w[v] = w[x] ^ a[i].w; dfs(v, x); sz[x] += sz[v]; if(sz[son[x]] < sz[v]) son[x] = v; } } void dfs1(int x, int topf){ top[x] = topf; dfn[x] = ++cnt; id[cnt] = x; if(son[x]) dfs1(son[x], topf); for (int i = head[x]; i != -1; i = a[i].next){ int v = a[i].v; if(v == f[x] || v == son[x]) continue; dfs1(v, v); } } void up(int rt){ for (int i = 0; i < 10; i++){ tr[i][rt] = tr[i][rt << 1] + tr[i][rt << 1 | 1]; } } void up(int rt, int i){ tr[i][rt] = tr[i][rt << 1] + tr[i][rt << 1 | 1]; } void down(int rt, int l, int r){ int mid = (l + r) >> 1; for (int i = 0; i < 10; i++){ if(lzy[i][rt]){ lzy[i][rt << 1] ^= 1; lzy[i][rt << 1 | 1] ^= 1; tr[i][rt << 1] = (mid - l + 1) - tr[i][rt << 1]; tr[i][rt << 1 | 1] = (r - mid) - tr[i][rt << 1 | 1]; lzy[i][rt] = 0; } } } void build(int rt, int l, int r){ if(l == r){ for (int i = 0; i < 10; i++){ if(1 << i & w[id[l]]) tr[i][rt] = 1; else tr[i][rt] = 0; } return ; } int mid = (l + r) >> 1; build(rt << 1, l, mid); build(rt << 1 | 1, mid + 1, r); up(rt); } void update(int rt, int l, int r, int L, int R, int i){ if(L <= l && r <= R){ tr[i][rt] = r - l + 1 - tr[i][rt]; lzy[i][rt] ^= 1; return ; } down(rt, l, r); int mid = (l + r) >> 1; if(mid >= L) update(rt << 1, l, mid, L, R, i); if(mid < R) update(rt << 1 | 1, mid + 1, r, L, R, i); up(rt, i); } int query(int rt, int l, int r, int L, int R, int i){ if(L <= l && r <= R) return tr[i][rt]; down(rt, l, r); int mid = (l + r) >> 1, ans = 0; if(mid >= L) ans += query(rt << 1, l, mid, L, R, i); if(mid < R) ans += query(rt << 1 | 1, mid + 1, r, L, R, i); return ans; } void modify(int x, int W){ for (int i = 0; i < 10; i++){ if(W >> i & 1) update(1, 1, n, dfn[x], dfn[x] + sz[x] - 1, i); } w[x] ^= W; } LL sum(int x){ LL ans = 0; for (int i = 0; i < 10; i++){ int num = query(1, 1, n, dfn[x], dfn[x] + sz[x] - 1, i); ans = (ans + 1LL * num * (sz[x] - num) % mod * (1 << i) % mod); } return ans; } int main() { //freopen("in.txt", "r", stdin); //freopen("out.txt", "w", stdout); memset(head, -1, sizeof(head)); scanf("%d %d %d", &n, &q, &r); for (int i = 1, u, v, w; i < n; i++){ scanf("%d %d %d", &u, &v, &w); a[tot] = node{v, w, head[u]}, head[u] = tot++; a[tot] = node{u, w, head[v]}, head[v] = tot++; } dfs(1, 0); dfs1(1, 1); build(1, 1, n); for (int i = 1, op, x, y, w; i <= q; i++){ scanf("%d", &op); if(op == 1){ scanf("%d %d", &x, &w); modify(x, w); }else{ scanf("%d", &y); printf("%lld\n", (sum(y) << 1) % mod); } } return 0; } /**/

     

    Processed: 0.011, SQL: 9