题目描述 一个数列{an},满足an+2=x⋅an+1+y⋅an,求 答案对998244353取模。
输入 第一行一个整数T,即数据组数。 下面T行,每行5个整数,即n,a1,a2,x,y,含义如上。
输出 共T行,每行1个整数,即每组数据的答案。
样例输入 3 3 0 2 0 0 3 1 2 0 1 3 1 0 0 1
样例输出 8 10 2
提示 对于100%的数据,1≤T≤104,1≤n≤1018,0≤a1,a2,x,y≤998244352。
思路 这题要求的对象是递推式的三次方和,因此可以通过矩阵快速幂的方法加速求递推 设 ∑ n i = 1 a i 3 在 第 n 项 为 s n \sum_{n}^{i=1} a_{i}^{3}在第n项为s_{n} ∑ni=1ai3在第n项为sn, 那么可以得到 s n = ( x ∗ a n − 1 + y ∗ a n − 2 ) 3 + s n − 1 s_{n}=(x*a_{n-1}+y*a_{n-2})^{3}+s_{n-1} sn=(x∗an−1+y∗an−2)3+sn−1, 通过an+2=x⋅an+1+y⋅an, 将 a n 分 解 为 a n − 1 a n − 2 的 形 式 a_{n}分解为a_{n-1} a_{n-2}的形式 an分解为an−1an−2的形式, 得到 s n = x 3 a n − 1 3 + 3 x 2 y a n − 1 2 a n − 2 + 3 x y 2 a n − 1 a n − 2 2 + y 3 a n − 2 3 s_{n}=x^{3}a_{n-1}^{3}+3x^{2}ya_{n-1}^{2}a_{n-2}+3xy^{2}a_{n-1}a_{n-2}^{2}+y^{3}a_{n-2}^{3} sn=x3an−13+3x2yan−12an−2+3xy2an−1an−22+y3an−23, 就可以得到递推矩阵 x n − 1 = [ a n − 1 3 a n − 1 2 a n − 2 a n − 1 a n − 2 2 a n − 2 3 s n ] x_{n-1}= \begin{bmatrix} a_{n-1}^{3} \\ a_{n-1}^{2}a_{n-2} \\ a_{n-1}a_{n-2}^{2} \\ a_{n-2}^{3} \\ s_{n} \end{bmatrix} xn−1=⎣⎢⎢⎢⎢⎡an−13an−12an−2an−1an−22an−23sn⎦⎥⎥⎥⎥⎤ 推出 x n = [ a n 3 a n 2 a n − 1 a n a n − 1 2 a n − 1 3 s n ] x_{n}= \begin{bmatrix} a_{n}^{3} \\ a_{n}^{2}a_{n-1} \\ a_{n}a_{n-1}^{2} \\ a_{n-1}^{3} \\ s_{n} \end{bmatrix} xn=⎣⎢⎢⎢⎢⎡an3an2an−1anan−12an−13sn⎦⎥⎥⎥⎥⎤ 设 x n = A x n − 1 x_{n}=Ax_{n-1} xn=Axn−1 通过分解得到系数矩阵 A = [ x 3 3 x 2 y 3 x y 2 y 3 0 x 2 2 x y y 2 0 0 x y 0 0 0 1 0 0 0 0 x 3 3 x 2 y 3 x y 2 y 3 1 ] A= \begin{bmatrix} x^{3} & 3x^{2}y & 3xy^{2} & y^{3} & 0 \\ x^{2} & 2xy & y^{2} & 0 & 0 \\ x & y & 0 & 0 & 0 \\ 1 & 0 & 0 & 0 & 0 \\ x^{3} & 3x^{2}y & 3xy^{2} & y^{3} & 1 \end{bmatrix} A=⎣⎢⎢⎢⎢⎡x3x2x1x33x2y2xyy03x2y3xy2y2003xy2y3000y300001⎦⎥⎥⎥⎥⎤ 因为 A n − 2 x 1 = x n A^{n-2}x_{1}=x_{n} An−2x1=xn,所以通过矩阵快速幂即可得解
代码实现
#include<bits/stdc++.h> using namespace std; typedef long long ll; typedef unsigned long long ull; const int N=1005; const int M=2e5+5; const int INF=0x3f3f3f3f; const ll LINF=1e18; const ull sed=31; const ll mod= 998244353; const double eps=1e-6; const double PI=acos(-1.0); typedef pair<int,int>P; typedef pair<double,double>Pd; typedef pair<ll,int> plt; typedef pair<ll,ll>pll; template<class T>void read(T &x) { x=0;int f=0;char ch=getchar(); while(ch<'0'||ch>'9') {f|=(ch=='-');ch=getchar();} while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();} x=f?-x:x; return; } struct mat { ll m[5][5]; mat() { memset(m,0,sizeof(m)); } }; mat Mul(mat a,mat b) { mat t; for(int i=0;i<5;i++) for(int j=0;j<5;j++) for(int k=0;k<5;k++) t.m[i][j]=(t.m[i][j]+a.m[i][k]*b.m[k][j])%mod; return t; } mat qpow(mat a,ll b) { mat ret; for(int i=0;i<5;i++) ret.m[i][i]=1; while (b) { if(b&1) ret=Mul(ret,a); a=Mul(a,a); b>>=1; } return ret; } int T; ll n,a1,a2,x,y; int main() { // freopen("a.txt","r",stdin); read(T); while (T--) { read(n);read(a1);read(a2);read(x);read(y); if(n==1) { printf("%lld\n",a1*a1%mod*a1%mod); continue; } else if(n==2) { ll ans=(a1*a1%mod*a1%mod+a2*a2%mod*a2%mod)%mod; printf("%lld\n",ans); continue; } mat t; t.m[0][0]=a2*a2%mod*a2%mod; t.m[1][0]=a2*a2%mod*a1%mod; t.m[2][0]=a2*a1%mod*a1%mod; t.m[3][0]=a1*a1%mod*a1%mod; t.m[4][0]=(a1*a1%mod*a1%mod+a2*a2%mod*a2%mod)%mod; t.m[1][1]=1; t.m[2][2]=1; t.m[3][3]=1; t.m[4][4]=1; mat a; a.m[4][0]=a.m[0][0]=x*x%mod*x%mod; a.m[1][0]=x*x%mod; a.m[2][0]=x%mod; a.m[3][0]=1; a.m[4][1]=a.m[0][1]=3*x*x%mod*y%mod; a.m[1][1]=2*x*y%mod; a.m[2][1]=y%mod; a.m[0][2]=3*x*y%mod*y%mod; a.m[1][2]=y*y%mod; a.m[4][2]=3*x*y%mod*y%mod; a.m[0][3]=a.m[4][3]=y*y%mod*y%mod; a.m[4][4]=1; t=Mul(qpow(a,n-2),t); printf("%lld\n",t.m[4][0]); } return 0; }