这周就做了两题,一题就要调三天。。。。
分治,对 f ( x ) f(x) f(x)在 x l , x l + 1 ⋯ x r x_l,x_{l + 1}\cdots x_r xl,xl+1⋯xr的值,构造 p ( x ) = ∏ i = l m i d ( x − x i ) p(x) = \prod_{i = l}^{mid}(x - x_i) p(x)=∏i=lmid(x−xi),设 f ( x ) f(x) f(x)对 p ( x ) p(x) p(x)取模后为 g ( x ) g(x) g(x),易证对任意 i ∈ [ l , m i d ] i \in [l,mid] i∈[l,mid],均有 f ( x i ) = g ( x i ) f(x_i) = g(x_i) f(xi)=g(xi), i ∈ [ m i d + 1 , r ] i \in [mid + 1,r] i∈[mid+1,r]时同理,递归求解即可
p ( x ) p(x) p(x)可用分治FFT预处理出来
显然 g ( x ) g(x) g(x)次数不超过 f ( x ) f(x) f(x)次数的一半,故每次递归使问题规模减半,由主定理得复杂度为 O ( n log 2 n ) O(n\log^2n) O(nlog2n)
但常数极大…
#include <iostream> #include <cstdio> #include <cstring> #include <vector> using namespace std; const int maxn = 6.5e4,maxm = 1.32e5,mod = 998244353,g = 3; //2 ^ 17 > 1.3e5 !!!!! int n,m,N,a[maxn],rev[maxm],F[maxm],T[maxm],T1[maxm],T2[maxm],Q[maxm],R[maxm]; vector <int> P[4 * maxn]; int read(){ int x = 0; char c = getchar(); while(c < '0' || c > '9') c = getchar(); while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar(); return x; } int qpow(int x,int k){ long long d = 1,t = x; while(k){ if(k & 1) d = d * t % mod; t = t * t % mod,k >>= 1; } return d; } void init(int n){ N = 1; int cnt = 0; while(N <= n) N <<= 1,cnt ++; for(int i = 0; i < N; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (cnt - 1)); } void NTT(int *F,int n,int p){ for(int i = 0; i < n; i ++) if(i < rev[i]) swap(F[i],F[rev[i]]); for(int i = 1; i < n; i <<= 1){ int w1 = qpow(g,(mod - 1) / (i << 1)); for(int j = 0; j < n; j += i << 1){ int w = 1; for(int k = j; k < j + i; k ++){ int t1 = F[k],t2 = 1ll * w * F[k + i] % mod; F[k] = (t1 + t2) % mod,F[k + i] = (t1 - t2 + mod) % mod; w = 1ll * w * w1 % mod; } } } if(p == -1){ int inv = qpow(n,mod - 2); for(int i = 0; i < n; i ++) F[i] = 1ll * F[i] * inv % mod; for(int i = n / 2; i >= 1; i --) swap(F[i],F[n - i]); } } void Mul(int *F,int *G,int n,int p = 0){ init(n << 1); NTT(F,N,1),NTT(G,N,1); for(int i = 0; i < N; i ++) F[i] = 1ll * F[i] * G[i] % mod; NTT(F,N,-1); if(!p) for(int i = n + 1; i < N; i ++) F[i] = 0; } void get_Inv(int n,int *F,int *G){ if(n == 1){ G[0] = qpow(F[0],mod - 2); return; } get_Inv((n + 1) >> 1,F,G); init(n << 1); for(int i = 0; i < n; i ++) T[i] = F[i]; NTT(T,N,1),NTT(G,N,1); for(int i = 0; i < N; i ++) G[i] = 1ll * (2 - 1ll * G[i] * T[i] % mod + mod) % mod * G[i] % mod; NTT(G,N,-1); for(int i = n; i < N; i ++) G[i] = 0; for(int i = 0; i < N; i ++) T[i] = 0; } void Mod(int *F,int *G,int n,int m){ for(int i = 0; i <= n - m; i ++) R[i] = G[m - i]; for(int i = n - m + 1; i <= m; i ++) R[i] = 0; get_Inv(n - m + 1,R,Q); for(int i = 0; i <= n - m; i ++) R[i] = F[n - i]; for(int i = n - m + 1; i <= n; i ++) R[i] = 0; Mul(Q,R,n - m); for(int i = 0; i <= m; i ++) R[i] = G[m - i]; for(int i = m + 1; i < N; i ++) R[i] = 0; Mul(Q,R,max(m,n - m),1); for(int i = 0; i <= n; i ++) Q[i] = (F[n - i] - Q[i] + mod) % mod; for(int i = 0; i < m; i ++) F[i] = Q[n - i]; for(int i = m; i <= n; i ++) F[i] = 0; for(int i = 0; i < N; i ++) R[i] = 0,Q[i] = 0; } void pre(int x,int l,int r){ if(l == r){ P[x].push_back(mod - a[l]),P[x].push_back(1); return; } int mid = (l + r) / 2,ls = x << 1,rs = x << 1 | 1; pre(ls,l,mid),pre(rs,mid + 1,r); for(int i = 0; i <= mid - l + 1; i ++) T1[i] = P[ls][i]; for(int i = 0; i <= r - mid; i ++) T2[i] = P[rs][i]; Mul(T1,T2,mid - l + 1,1); for(int i = 0; i <= r - l + 1; i ++) P[x].push_back(T1[i]); for(int i = 0; i < N; i ++) T1[i] = T2[i] = 0; } void solve(int x,int l,int r,int *F){ if(l == r){ printf("%d\n",F[0]); return; } int mid = (l + r) / 2,T[r - l + 10],ls = x << 1,rs = x << 1 | 1; for(int i = 0; i <= r - l; i ++) T[i] = F[i]; for(int i = 0; i <= mid - l + 1; i ++) T1[i] = P[ls][i]; Mod(T,T1,r - l,mid - l + 1); for(int i = 0; i <= mid - l + 1; i ++) T1[i] = 0; solve(ls,l,mid,T); for(int i = 0; i <= r - l; i ++) T[i] = F[i]; for(int i = 0; i <= r - mid; i ++) T1[i] = P[rs][i]; Mod(T,T1,r - l,r - mid); for(int i = 0; i <= r - mid; i ++) T1[i] = 0; solve(rs,mid + 1,r,T); } int main(){ n = read(),m = read(); if(!m) return 0; for(int i = 0; i <= n; i ++) F[i] = read(); for(int i = 1; i <= m; i ++) a[i] = read(); pre(1,1,m); if(n >= m){ for(int i = 0; i <= m; i ++) T1[i] = P[1][i]; Mod(F,T1,n,m); for(int i = 0; i <= m; i ++) T1[i] = 0; } solve(1,1,m,F); return 0; }由拉格朗日插值公式 拉格朗·日插值 ,有 f ( x ) = ∑ i = 1 n ( ∏ i ≠ j x − x i x i − x j y i ) = ∑ i = 1 n ( y i ∏ i ≠ j ( x i − x j ) ∏ i ≠ j ( x − x j ) ) f(x) = \sum_{i = 1}^n\bigg(\prod_{i \not=j}\frac{x - x_i}{x_i - x_j}y_i\bigg) = \sum_{i = 1}^n\bigg(\frac{y_i}{\prod_{i \not=j}(x_i - x_j)}\prod_{i\not=j}(x - x_j)\bigg) f(x)=i=1∑n(i=j∏xi−xjx−xiyi)=i=1∑n(∏i=j(xi−xj)yii=j∏(x−xj)) 首先考虑式子中的系数 y i ∏ i ≠ j ( x i − x j ) \frac{y_i}{\prod_{i \not=j}(x_i - x_j)} ∏i=j(xi−xj)yi如何快速计算,也就是对于 i ∈ [ 1 , n ] i \in [1,n] i∈[1,n],要求出 ∏ i ≠ j ( x i − x j ) \prod_{i \not=j}(x_i - x_j) ∏i=j(xi−xj)的值: 首先设 g ( x ) = ∏ i = 1 n ( x − x i ) g(x) = \prod_{i = 1}^n(x - x_i) g(x)=i=1∏n(x−xi)
则有 ∏ i ≠ j ( x i − x j ) = g ( x i ) ( x − x i ) \prod_{i \not=j}(x_i - x_j) = \frac{g(x_i)}{(x - x_i)} i=j∏(xi−xj)=(x−xi)g(xi)
根据洛必达法则,易得 g ( x i ) x − x i = g ′ ( x i ) \frac{g(x_i)}{x - x_i} = g'(x_i) x−xig(xi)=g′(xi)
那么问题就变成求 g ′ ( x ) g'(x) g′(x)在 x 1 , x 2 ⋯ , x n x_1,x_2\cdots,x_n x1,x2⋯,xn处的值,直接套上面的多点求值即可 ( g ( x ) g(x) g(x)与多点求值中要预处理的式子是一样的,所以可以先算出来,多点求值时就不用再做一遍)
下面回到正题,考虑如何计算 f ( x ) f(x) f(x)
用分治的思想,设 f l , r ( x ) = ∑ i = l r ( y i ∏ i ≠ j ( x i − x j ) ∏ l ⩽ j ⩽ r , i ≠ j ( x − x j ) ) f_{l,r}(x) = \sum_{i = l}^r\bigg(\frac{y_i}{\prod_{i \not=j}(x_i - x_j)}\prod_{l \leqslant j \leqslant r,i\not=j}(x - x_j)\bigg) fl,r(x)=i=l∑r(∏i=j(xi−xj)yil⩽j⩽r,i=j∏(x−xj))
这里要注意 f l , r ( x ) f_{l,r}(x) fl,r(x)并不代表第 l l l到 r r r个点所确定的 r − l r - l r−l次函数,因为系数的分母中 j j j的取值仍然是 1 1 1到 n n n的,所以这只是为了分治而弄出来的函数,没有实际意义
则有 f l , r ( x ) = ∑ i = l r ( y i g ′ ( x i ) ∏ l ⩽ j ⩽ r , i ≠ j ( x − x j ) ) = ∑ i = l m i d ( y i g ′ ( x i ) ∏ l ⩽ j ⩽ r , i ≠ j ( x − x j ) ) + ∑ i = m i d + 1 r ( y i g ′ ( x i ) ∏ l ⩽ j ⩽ r , i ≠ j ( x − x j ) ) = ( ∏ i = m i d + 1 r ( x − x i ) ) ( ∑ i = l m i d ( y i g ′ ( x i ) ∏ l ⩽ j ⩽ m i d , i ≠ j ( x − x j ) ) ) + ( ∏ i = l m i d ( x − x i ) ) ( ∑ i = m i d + 1 r ( y i g ′ ( x i ) ∏ m i d + 1 ⩽ j ⩽ r , i ≠ j ( x − x j ) ) ) = ( ∏ i = m i d + 1 r ( x − x i ) ) f l , m i d ( x ) + ( ∏ i = l m i d ( x − x i ) ) f m i d + 1 , r ( x ) \begin{aligned} f_{l,r}(x) &= \sum_{i = l}^r\bigg(\frac{y_i}{g'(x_i)}\prod_{l \leqslant j \leqslant r,i\not=j}(x - x_j)\bigg) \\ &= \sum_{i = l}^{mid}\bigg(\frac{y_i}{g'(x_i)}\prod_{l \leqslant j \leqslant r,i\not=j}(x - x_j)\bigg) +\sum_{i = mid + 1}^r\bigg(\frac{y_i}{g'(x_i)}\prod_{l \leqslant j \leqslant r,i\not=j}(x - x_j)\bigg) \\ &= \bigg(\prod_{i = mid +1}^r(x - x_i)\bigg)\bigg(\sum_{i = l}^{mid}\bigg(\frac{y_i}{g'(x_i)}\prod_{l \leqslant j \leqslant mid,i\not=j}(x - x_j)\bigg)\bigg) \end{aligned} \\ \;\;\;+ \bigg(\prod_{i = l}^{mid}(x - x_i)\bigg)\bigg(\sum_{i = mid + 1}^r\bigg(\frac{y_i}{g'(x_i)}\prod_{mid + 1 \leqslant j \leqslant r,i\not=j}(x - x_j)\bigg)\bigg) \\ = \bigg(\prod_{i = mid +1}^r(x - x_i)\bigg)f_{l,mid}(x) + \bigg(\prod_{i = l}^{mid}(x - x_i)\bigg)f_{mid + 1,r}(x)\;\;\;\;\;\;\;\, fl,r(x)=i=l∑r(g′(xi)yil⩽j⩽r,i=j∏(x−xj))=i=l∑mid(g′(xi)yil⩽j⩽r,i=j∏(x−xj))+i=mid+1∑r(g′(xi)yil⩽j⩽r,i=j∏(x−xj))=(i=mid+1∏r(x−xi))(i=l∑mid(g′(xi)yil⩽j⩽mid,i=j∏(x−xj)))+(i=l∏mid(x−xi))(i=mid+1∑r(g′(xi)yimid+1⩽j⩽r,i=j∏(x−xj)))=(i=mid+1∏r(x−xi))fl,mid(x)+(i=l∏mid(x−xi))fmid+1,r(x) 可以直接分治计算,边界 f n , n ( x ) = y n g ′ ( x n ) f_{n,n}(x) = \frac{y_n}{g'(x_n)} fn,n(x)=g′(xn)yn,由主定理得复杂度为 O ( n log 2 n ) O(n\log^2n) O(nlog2n) 常数比上面更大,而且长…
非常巧合的是(也可能是发明人故意构造),分治中的 “系数多项式” ∏ i = l r ( x − x i ) \prod_{i = l}^r(x - x_i) ∏i=lr(x−xi) 也是在多点求值的预处理(也即分治计算 g ( x ) g(x) g(x))时计算过的,不需要另外计算
细节看代码
//为了卡常,有些地方写得有点奇怪,两个最主要的优化用注释标出来了 #include <iostream> #include <cstdio> #include <vector> using namespace std; const int maxn = 1e5 + 50,maxm = 2.63e5,mod = 998244353,g = 3; int n,m,N,x[maxn],y[maxn],F[maxn],G[maxn],H[maxn],rev[maxm],T[maxm],T1[maxm],T2[maxm],Q[maxm],R[maxm],W[20][maxm],inv_W[20][maxm],inv[maxm]; vector <int> P[4 * maxn]; int read(){ int x = 0; char c = getchar(); while(c < '0' || c > '9') c = getchar(); while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar(); return x; } //将加减时的取模运算改为加减模数,非常有用,至少快了800ms/点 inline int add(int x,int y){ if(x + y < mod) return x + y; else return x + y - mod; } inline int dec(int x,int y){ if(x - y >= 0) return x - y; else return x - y + mod; } int qpow(int x,int k){ long long d = 1,t = x; while(k){ if(k & 1) d = d * t % mod; t = t * t % mod,k >>= 1; } return d; } void init(int n){ N = 1; int cnt = 0; while(N <= n) N <<= 1,cnt ++; for(int i = 0; i < N; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (cnt - 1)); } void NTT(int *F,int n,int p){ //预处理单位根和长度逆元,否则复杂度几乎多一个log for(int i = 0; i < n; i ++) if(i < rev[i]) swap(F[i],F[rev[i]]); for(int i = 1,cnt = 0; i < n; i <<= 1){ cnt ++; for(int j = 0; j < n; j += i << 1){ for(int k = j; k < j + i; k ++){ int w = (p == 1 ? W[cnt][k - j] : inv_W[cnt][k - j]); int t1 = F[k],t2 = 1ll * w * F[k + i] % mod; F[k] = add(t1,t2),F[k + i] = dec(t1,t2); } } } if(p == -1) for(int i = 0; i < n; i ++) F[i] = 1ll * F[i] * inv[n] % mod; } void Mul(int *F,int *G,int n,int p = 0){ init(n << 1); NTT(F,N,1),NTT(G,N,1); for(int i = 0; i < N; i ++) F[i] = 1ll * F[i] * G[i] % mod; NTT(F,N,-1); if(!p) for(int i = n + 1; i < N; i ++) F[i] = 0; } void get_Inv(int n,int *F,int *G){ if(n == 1){ G[0] = qpow(F[0],mod - 2); return; } get_Inv((n + 1) >> 1,F,G); init(n << 1); for(int i = 0; i < n; i ++) T[i] = F[i]; NTT(T,N,1),NTT(G,N,1); for(int i = 0; i < N; i ++) G[i] = 1ll * dec(2,1ll * G[i] * T[i] % mod) * G[i] % mod; NTT(G,N,-1); for(int i = n; i < N; i ++) G[i] = 0; for(int i = 0; i < N; i ++) T[i] = 0; } void Mod(int *F,vector <int> G,int n,int m,int *H){ for(int i = 0; i <= n - m; i ++) R[i] = G[m - i]; for(int i = n - m + 1; i <= m; i ++) R[i] = 0; get_Inv(n - m + 1,R,Q); for(int i = 0; i <= n - m; i ++) R[i] = F[n - i]; Mul(Q,R,n - m); for(int i = 0; i <= m; i ++) R[i] = G[m - i]; for(int i = m + 1; i < N; i ++) R[i] = 0; Mul(Q,R,max(m,n - m),1); for(int i = 0; i < m; i ++) H[i] = dec(F[i],Q[n - i]); for(int i = m; i <= n; i ++) H[i] = 0; for(int i = 0; i < N; i ++) R[i] = 0,Q[i] = 0; } void pre(int x,int l,int r,int *a){ if(l == r){ P[x].push_back(mod - a[l]),P[x].push_back(1); return; } int mid = (l + r) / 2,ls = x << 1,rs = x << 1 | 1; pre(ls,l,mid,a),pre(rs,mid + 1,r,a); for(int i = 0; i <= mid - l + 1; i ++) T1[i] = P[ls][i]; for(int i = 0; i <= r - mid; i ++) T2[i] = P[rs][i]; Mul(T1,T2,mid - l + 1,1); for(int i = 0; i <= r - l + 1; i ++) P[x].push_back(T1[i]); for(int i = 0; i < N; i ++) T1[i] = T2[i] = 0; } void calc(int x,int l,int r,int *a,int *F,int *G){ if(r - l <= 256){ for(int i = l; i <= r; i ++){ int s = 0; for(int j = r - l; j >= 0; j --) s = add(1ll * s * a[i] % mod,F[j]); G[i] = s; } return; } int mid = (l + r) / 2,T[r - l + 10],ls = x << 1,rs = x << 1 | 1; Mod(F,P[ls],r - l,mid - l + 1,T); calc(ls,l,mid,a,T,G); Mod(F,P[rs],r - l,r - mid,T); calc(rs,mid + 1,r,a,T,G); } void Eva(int *F,int *a,int n,int m,int *G){ // pre(1,1,m); if(n >= m) Mod(F,P[1],n,m,F); calc(1,1,m,a,F,G); } void solve(int x,int l,int r,int *F){ if(l == r){ F[0] = 1ll * y[l] * qpow(H[l],mod - 2) % mod; return; } int mid = (l + r) / 2,ls = x << 1,rs = x << 1 | 1,Fl[2 * (r - l + 10)],Fr[2 * (r - l + 10)]; for(int i = 0; i <= 2 * (r - l + 2); i ++) Fl[i] = Fr[i] = 0; solve(ls,l,mid,Fl),solve(rs,mid + 1,r,Fr); for(int i = r - mid; i >= 0; i --) T1[i] = P[rs][i]; for(int i = mid - l + 1; i >= 0; i --) T2[i] = P[ls][i]; Mul(T1,Fl,mid - l + 1,1),Mul(T2,Fr,mid - l + 1,1); for(int i = r - l; i >= 0; i --) F[i] = add(T1[i],T2[i]); for(int i = 0; i < N; i ++) T1[i] = T2[i] = 0; } int main(){ n = read(); init(n << 1); for(int i = 0; i <= 19; i ++){ W[i][0] = inv_W[i][0] = 1; W[i][1] = qpow(g,(mod - 1) / (1 << i)),inv_W[i][1] = qpow(W[i][1],mod - 2); for(int j = 2; j < N; j ++) W[i][j] = 1ll * W[i][j - 1] * W[i][1] % mod,inv_W[i][j] = 1ll * inv_W[i][j - 1] * inv_W[i][1] % mod; } for(int i = 1; i <= N; i <<= 1) inv[i] = qpow(i,mod - 2); for(int i = 1; i <= n; i ++) x[i] = read(),y[i] = read(); pre(1,1,n,x); for(int i = 0; i <= n; i ++) G[i] = P[1][i]; for(int i = 0; i < n; i ++) G[i] = 1ll * G[i + 1] * (i + 1) % mod; G[n] = 0; Eva(F,x,n - 1,n,H); solve(1,1,n,F); for(int i = 0; i < n; i ++) printf("%d ",F[i]); printf("\n"); return 0; }