Distance in Tree CodeForces - 161D(树形dp,点分治,路径长度为k)

    技术2024-11-17  11

    A tree is a connected graph that doesn’t contain any cycles.

    The distance between two vertices of a tree is the length (in edges) of the shortest path between these vertices.

    You are given a tree with n vertices and a positive number k. Find the number of distinct pairs of the vertices which have a distance of exactly k between them. Note that pairs ( v, u) and ( u, v) are considered to be the same pair.

    Input The first line contains two integers n and k (1 ≤ n ≤ 50000, 1 ≤ k ≤ 500) — the number of vertices and the required distance between the vertices.

    Next n - 1 lines describe the edges as " a i b i" (without the quotes) (1 ≤ a i, b i ≤ n, a i ≠ b i), where a i and b i are the vertices connected by the i-th edge. All given edges are different.

    Output Print a single integer — the number of distinct pairs of the tree’s vertices which have a distance of exactly k between them.

    Please do not use the %lld specifier to read or write 64-bit integers in С++. It is preferred to use the cin, cout streams or the %I64d specifier.

    Examples Input 5 2 1 2 2 3 3 4 2 5 Output 4 Input 5 3 1 2 2 3 3 4 4 5 Output 2 Note In the first sample the pairs of vertexes at distance 2 from each other are (1, 3), (1, 5), (3, 5) and (2, 4).

    题意: 求一棵树上距离为k的路径数

    思路: 定义 d p [ i ] [ j ] dp[i][j] dp[i][j]为以 i i i为起点路径长度为 j j j的数目 然后结果为 d p [ v ] [ j − 1 ] ∗ d p [ u ] [ k − j ] dp[v][j-1]*dp[u][k-j] dp[v][j1]dp[u][kj]

    #include <iostream> #include <cstdio> #include <cstring> #include <algorithm> using namespace std; typedef long long ll; const int maxn = 5e4 + 7; int head[maxn],nex[maxn << 1],to[maxn << 1],tot; int dp[maxn][505]; int n,k; ll ans; void add(int x,int y) { to[++tot] = y; nex[tot] = head[x]; head[x] = tot; } void dfs(int u,int fa) { dp[u][0] = 1; for(int i = head[u];i;i = nex[i]) { int v = to[i]; if(v == fa) continue; dfs(v,u); for(int j = 1;j <= k;j++) { //当前子树距离 ans += 1ll * dp[v][j - 1] * dp[u][k - j]; } for(int j = 1;j <= k;j++) { dp[u][j] += dp[v][j - 1]; } } } int main() { scanf("%d%d",&n,&k); for(int i = 1;i < n;i++) { int x,y;scanf("%d%d",&x,&y); add(x,y);add(y,x); } dfs(1,-1); printf("%lld\n",ans); return 0; }

    补上一个点分治写法。我用的双指针找小于等于k的路径数,减去小于k的路径数,结果就是等于k的路径数了。

    实测还是比dp慢了很多(nlogn嘛)

    #include <cstdio> #include <cstring> #include <algorithm> #include <vector> #include <unordered_map> using namespace std; typedef long long ll; const int maxn = 1e5 + 7; int head[maxn],nex[maxn],to[maxn],val[maxn],tot; int a[maxn],b[maxn],d[maxn],cnt[maxn],siz[maxn],w[maxn],CNT[maxn]; ll ans; int n,k; int num,pos,ans_siz; unordered_map<int,int>vis; void add(int x,int y,int z) { to[++tot] = y; nex[tot] = head[x]; val[tot] = z; head[x] = tot; } void dfs_find(int s,int x) { vis[x] = 1; siz[x] = 1; int max_part = 0; for(int i = head[x];i;i = nex[i]) { int v = to[i]; if(vis[v] || w[v]) continue; dfs_find(s,v); siz[x] += siz[v]; max_part = max(max_part,siz[v]); } max_part = max(max_part,s - siz[x]); if(max_part < ans_siz) { ans_siz = max_part; pos = x; } } void dfs(int x) { vis[x] = 1; for(int i = head[x];i;i = nex[i]) { int v = to[i],z = val[i]; if(vis[v] || w[v]) continue; ++cnt[b[x]]; a[++num] = v; d[v] = d[x] + z; b[v] = b[x]; dfs(v); } } int cmp(int a,int b) { return d[a] < d[b]; } void work(int s,int x) { vis.clear(); ans_siz = s; dfs_find(s,x); vis.clear(); num = 1; a[num] = b[pos] = pos; ++cnt[pos]; w[pos] = 1; for(int i = head[pos];i;i = nex[i]) { int v = to[i],z = val[i]; if(vis[v] || w[v]) continue; ++cnt[v];a[++num] = b[v] = v; d[v] = z; dfs(v); } sort(a + 1,a + 1 + num,cmp); int l = 1,r = num; for(int i = 1;i <= num;i++) { CNT[b[a[i]]] = cnt[b[a[i]]]; } int ans1 = 0; //小于等于k的路径数目 --cnt[b[a[1]]]; while(l < r) { while(d[a[l]] + d[a[r]] > k) { --cnt[b[a[r]]];r--; } ans1 += r - l - cnt[b[a[l]]]; l++; --cnt[b[a[l]]]; } l = 1,r = num; int ans2 = 0;//小于k的路径数目 --CNT[b[a[1]]]; while(l < r) { while(d[a[l]] + d[a[r]] >= k) { --CNT[b[a[r]]];r--; } ans2 += r - l - CNT[b[a[l]]]; l++; --CNT[b[a[l]]]; } ans += ans1 - ans2; for(int i = 1;i <= num;i++) { cnt[b[a[i]]] = d[a[i]] = 0; } int now = pos; for(int i = head[now];i;i = nex[i]) { int v = to[i]; if(w[v]) continue; work(siz[v],v); } } void solve() { memset(head,0,sizeof(head)); memset(w,0,sizeof(w)); tot = 0; for(int i = 1;i < n;i++) { int x,y,z;scanf("%d%d",&x,&y); z = 1; add(x,y,z);add(y,x,z); } ans = 0; work(n,1); printf("%lld\n",ans); } int main() { while(~scanf("%d%d",&n,&k) && n && k) { solve(); } return 0; }
    Processed: 0.037, SQL: 9