树统计(虚树)

    技术2023-09-20  79

    时间限制: 1 Sec  内存限制: 128 MB

    题目描述

    骗分过样例,暴力出奇迹。 关于树的算法有一大堆,样样都是毒瘤。 比如说 NOIP2018 提高组的 D2T3,如果会动态 DP 的做法那么就马上想到正解,但是 Tweetuzki 不会动态 DP,就只好骗分了。 可惜树题的码量也是超级大的。听说好多学长都会动态 DP,但是考场上调不出来,只好暴力分收场了。疯狂暗示 Tweetuzki 当时暴力写挂了,有 4 个点写成了死循环……于是分数白白少了 16 分。Tweetuzki 一想起这事,不禁夙夜忧叹,辗转反侧。 现在他又遇到一道毒瘤树上问题了,他下定决心:这次一定要把暴力分写满! 题目是这样的: 有一棵 n 个点的树,边有边权,每个点有颜色 ci。求所有颜色不同的点对的距离之和。由于答案可能很大,你只需要输出其对 998,244,353 取模的结果即可。 形式化地讲,记 u 号点和 v 号点在树上的距离为 dist(u,v),求:

    输入

    输入文件将会遵循以下格式: n type c1 c2 ⋯ cn u1 v1 w1 u2 v2 w2 ⋮ un−1 vn−1 wn−1 第一行两个正整数 n,type(2≤n≤2×105,1≤type≤6),其中 n 表示点数,type为部分分类型,详见数据范围,type=0 表示样例数据。 第二行输入 n 个正整数 ci(1≤ci≤109),表示每个点的颜色。 接下来n−1 行,每行输入三个正整数 ui,vi,wi(1≤ui<vi≤n,1≤wi≤109),描述这棵树。

    输出

    输出一行一个非负整数,表示答案对 998,244,353 取模的结果。

    样例输入 Copy

    4 0 1 2 3 3 1 2 5 2 3 4 3 4 7

    样例输出 Copy

    90

    提示

    满足条件的点对有 (1,2),(1,3),(1,4),(2,1),(2,3),(2,4),(3,1),(3,2),(4,1),(4,2),故答案为 5+9+16+5+4+11+9+4+16+11=90。 Subtask #1:n≤300, type=1。 Subtask #2:n≤2 000, type≤2。 Subtask #3:n≤10 000, type≤3。 Subtask #4:对于第 i (1≤i≤n) 号点,ci=i。type=4。 Subtask #5 :对于第 i(1≤i<n)条边,ui+1=vi。type=5。 Subtask #6:无特殊性质,type≤6。

     

    题目要求不同颜色顶点间的距离和,我们转化为所有顶点间的距离和-相同颜色点间的距离和

    对于所有顶点间的距离和,我们跑一遍图,求出每条边左右的顶点对数即可求出每条边的贡献,最终得到所有边的贡献

    将颜色相同的顶点分别建立一棵虚树,每一颗虚树类似上面跑一遍图即可

    最终答案<<1即可

    /**/ #include <cstdio> #include <cstring> #include <cmath> #include <cctype> #include <iostream> #include <algorithm> #include <map> #include <set> #include <vector> #include <string> #include <stack> #include <queue> typedef long long LL; using namespace std; const long long mod = 998244353; const int maxn = 200005; int n, type, tot, cnt, top, len; int c[maxn], b[maxn]; int head[maxn], sz[maxn], son[maxn], topf[maxn], f[maxn], dep[maxn], dfn[maxn]; LL ans, res, dis[maxn]; int e[maxn], s[maxn], dp[maxn]; bool vis[maxn]; vector<int> v[maxn]; vector<pair<int, LL> > g[maxn]; struct node { int v, w, next; }a[maxn << 1]; bool cmp(int x, int y){ return dfn[x] < dfn[y]; } void dfs(int x, int pre){ sz[x] = 1; dep[x] = dep[pre] + 1; f[x] = pre; for (int i = head[x]; i != -1; i = a[i].next){ int v = a[i].v; if(v == pre) continue; dis[v] = (dis[x] + a[i].w) % mod; dfs(v, x); ans = (ans + 1LL * sz[v] * (n - sz[v]) % mod * a[i].w % mod) % mod; sz[x] += sz[v]; if(sz[son[x]] < sz[v]) son[x] = v; } } void dfs1(int x, int topfa){ topf[x] = topfa; dfn[x] = ++cnt; if(!son[x]) return ; dfs1(son[x], topfa); for (int i = head[x]; i != -1; i = a[i].next){ int v = a[i].v; if(topf[v]) continue; dfs1(v, v); } } int LCA(int x, int y){ while(topf[x] != topf[y]){ if(dep[topf[x]] < dep[topf[y]]) swap(x, y); x = f[topf[x]]; } if(dep[x] > dep[y]) swap(x, y); return x; } void add_edge(int u, int v){ if(u == n + 1) g[u].emplace_back(make_pair(v, 0)); else g[u].emplace_back(make_pair(v, (dis[v] - dis[u] + mod) % mod)); } void insert(int x){ if(top <= 1){ s[++top] = x; return ; } int lca = LCA(s[top], x); if(lca == s[top]){ s[++top] = x; return ; } while(top > 1 && dfn[lca] <= dfn[s[top - 1]]){ add_edge(s[top - 1], s[top]); top--; } if(lca != s[top]) add_edge(lca, s[top]), s[top] = lca; s[++top] = x; } void dfs2(int u){ dp[u] = vis[u]; for (auto x : g[u]){ int v = x.first; LL w = x.second; dfs2(v); dp[u] += dp[v]; res = (res + 1LL * dp[v] * (len - dp[v]) % mod * w % mod) % mod; } g[u].clear(); } int main() { //freopen("in.txt", "r", stdin); //freopen("out.txt", "w", stdout); memset(head, -1, sizeof(head)); scanf("%d %d", &n, &type); for (int i = 1; i <= n; i++) scanf("%d", &c[i]), b[i] = c[i]; sort(b + 1, b + 1 + n); int num = unique(b + 1, b + 1 + n) - b - 1; for (int i = 1; i <= n; i++) c[i] = lower_bound(b + 1, b + 1 + num, c[i]) - b; for (int i = 1; i <= n; i++) v[c[i]].emplace_back(i); for (int i = 1, u, v, w; i < n; i++){ scanf("%d %d %d", &u, &v, &w); a[tot] = node{v, w, head[u]}, head[u] = tot++; a[tot] = node{u, w, head[v]}, head[v] = tot++; } dfs(1, 0); dfs1(1, 1); for (int i = 1; i <= num; i++){ if(v[i].empty()) continue; len = v[i].size(); for (int j = 0; j < len; j++) e[j + 1] = v[i][j], vis[e[j + 1]] = true; sort(e + 1, e + 1 + len, cmp); s[top = 1] = n + 1; for (int j = 1; j <= len; j++) insert(e[j]); while(top > 1) add_edge(s[top - 1], s[top]), top--; res = 0; dfs2(n + 1); for (int j = 1; j <= len; j++) vis[e[j]] = false; ans = (ans - res + mod) % mod; } printf("%lld\n", (ans << 1) % mod); return 0; } /* 8 3 1 2 3 1 3 3 1 2 1 2 1 2 4 2 2 5 2 5 6 3 5 7 3 1 3 4 3 8 4 */

     

    Processed: 0.013, SQL: 9