真正的神题, 看来生成函数和多项式引入OI已经好久了。
%%% Picks
题目大意
给你长度为N的不重复数列C, 整数m, 求对于每一个正整数s<=m, 点权在数列C内并且点权和为s的二叉树数量。 答案对998244353取模。
解题报告
生成函数与多项式运算的题目。
首先我们设答案数列为[latex]\{F_n\}[/latex]
那么有递推式 $$ F_k = \sum_{i=1}^n\sum_{j=1}^{k-C_i}F_j\times F_{k-C_i-j} $$
其中, [latex]F_0 = 1[/latex]。
为了表示出这个数列的生成函数, 再搞一个生成函数定义如下:
$$ C(x) = \sum_{i=0}^{\infty}(\sum_{j=1}^n [C_j == i])\times x^i $$
很明显这个函数系数是可以处理出来的。
然后数列的生成函数
F(x)
就可以通过递推建立这样的关系:
$$ F(x) = C(x)F(x)^2+1 $$
解一波方程:
$$ F(x) = \frac{1\pm\sqrt{1-4C(x)}}{2C(x)} $$
很明显我们需要把不收敛的根舍掉:
$$ 取x\rightarrow 0, 那么有 \frac{1+\sqrt{1-4C(x)}}{2C(x)}\rightarrow \infty$$
然后只剩下:
$$ F(x) = \frac{1-\sqrt{1-4C(x)}}{2C(x)} $$
由于只需要前m个答案, 我们要求:
$$ \frac{1-\sqrt{1-4C(x)}}{2C(x)} \mod{x^{m+1}} $$
然后发现答案要求取余的素数[latex]998244353=7\times 17\times 2^{23}+1[/latex]
就可以拿来搞NTT进行多项式运算了。
为了计算答案, 我们需要以下技巧:
1.多项式求逆:
我们要求一个[latex]G(x)[/latex]使[latex]G(x)F(x) \equiv 1 \mod{x^n}[/latex]。
当
n == 1
时, 很明显直接返回F(x)的常数项的逆元就可以了。
当
n != 1
时, 递归地进行思考:
设[latex]E(x)F(x) \equiv 1 \mod{x^{\lceil \frac{n}{2} \rceil}}[/latex], 即E(x)为F(x)在模[latex]x^{\lceil \frac{n}{2} \rceil}[/latex]意义下的逆元。
由于[latex]G(x)F(x) \equiv 1 \mod{x^n}[/latex], 所以显然有[latex]G(x)F(x) \equiv 1 \mod{x^{\lceil \frac{n}{2} \rceil}}[/latex]。
那么两式相减,又因为F(x)不为0, 有:
$$ G(x)-E(x) \equiv 0 \mod{x^{\lceil \frac{n}{2} \rceil}} $$
平方(很明显模意义下的平方模数也可以为原来模数平方的任意因子)之后乘上F(x):
$$ G(x)-2 \times E(x)+E(x)^2 \times F(x) \equiv 0 \mod{x^n} $$
移项:
$$ G(x) \equiv 2\times E(x)-E(x)^2 \times F(x) \mod{x^n} $$
如果先求解E(x), 那么我们就可以通过这个方法算出G(x)了,
因为乘法原因, 需要注意这里DFT次数界要设置到2N
。
E(x)可以递归求解, 次数界减半, 终止条件为常数项逆元, 复杂度:
$$ T(n) = T(\frac{n}{2}) + n\log_2(n) = n\log_2(n) $$
2.多项式开平方
我们要求一个[latex]G(x)[/latex]使[latex]G(x)^2 \equiv F(x) \mod{x^n}[/latex]
首先当
n == 1
时, 由于这里需要开根的多项式常数项为1, 返回1即可(否然好像需要神奇算法? 没去了解)。
否然我们假设已经递归求得了[latex] H(x)^2 \equiv F(x) \mod{x^{\lceil \frac{n}{2} \rceil}} [/latex]
移项之后平方:
$$ (H(x)^2-F(x))^2 \equiv 0 \mod{x^n} $$
再移个项:
$$ (H(x)^2+F(x))^2 \equiv 4\times H(x)^2F(x) \mod{x^n} $$
把[latex]4 \times H(x)^2[/latex]除过去:
$$ (\frac{H(x)^2+F(x)}{2\times H(x)})^2 \equiv F(x) \mod{x^n} $$
然后发现[latex]G(x) \equiv \frac{H(x)^2+F(x)}{2\times H(x)} \mod{x^n}[/latex], 对于[latex]H(x)[/latex]递归求解即可。
同样因为乘法原因, 需要注意这里DFT次数界要设置到2N!!!
有了这两种方法, 答案就可以直接计算了。
代码
辣鸡OJ!卡我常数!
还是自己的NTT太弱了…
#include <cstdio> #include <cstring> #include <cstdlib> #include <climits> #include <algorithm> #include <complex> typedef long long LL; const int PMOD = 998244353; const int DROOT = 3; const int MAXL = 1100000; inline int advPow (int a, int b) { int ret = 1; while(b) { if(b&1) ret = (LL)ret*a%PMOD; a = (LL)a*a%PMOD; b >>= 1; } return ret; } const int INV2 = advPow(2, PMOD-2); LL temp[MAXL]; inline int bitReverse (int num, int lvl) { int bit = 1, ret = 0; while(bit < lvl) { ret <<= 1; ret += (bool)(bit&num); bit <<= 1; } return ret; } void NTT (int * A, int N, int flag) { int g = DROOT; if(flag < 0) g = advPow(g, PMOD-2); for(int i = 0; i<N; i++) temp[bitReverse(i, N)] = A[i]; for(int lvl = 2; lvl<=N; lvl<<=1) { int wn = advPow(g, (PMOD-1)/lvl); for(int i = 0; i<N; i+=lvl) { int w = 1, t1, t2; for(int j = 0; j<(lvl>>1); j++, w=(LL)wn*w%PMOD) { t1 = temp[i+j], t2 = (LL)temp[i+(lvl>>1)+j]*w%PMOD; temp[i+j] = t1+t2; if(temp[i+j] >= PMOD) temp[i+j] -= PMOD; temp[i+(lvl>>1)+j] = t1-t2; if(temp[i+(lvl>>1)+j] < 0) temp[i+(lvl>>1)+j] += PMOD; } } } if(flag < 0) { int inv = advPow(N, PMOD-2); for(int i = 0; i<N; i++) temp[i] = (LL)temp[i]*inv%PMOD; } for(int i = 0; i<N; i++) A[i] = temp[i]; } int invTemp[MAXL], top; void inverse (const int * A, int * B, int N) { if(N == 1) { B[0] = advPow(A[0], PMOD-2); return ; } inverse(A, B, N>>1); memset(B+N, 0, sizeof(A[0])*N); memcpy(invTemp, A, sizeof(A[0])*N); memset(invTemp+N, 0, sizeof(A[0])*N); NTT(invTemp, N<<1, 1), NTT(B, N<<1, 1); for(int i = 0; i<(N<<1); i++) B[i] = (LL)B[i]*(((2LL-(LL)invTemp[i]*B[i])%PMOD+PMOD)%PMOD)%PMOD; NTT(B, N<<1, -1); memset(B+N, 0, sizeof(A[0])*N); } int sqrTemp[MAXL]; int ATemp[MAXL]; void sqrt (int * A, int * B, int N) { if(N == 1) { B[0] = 1; return ; } sqrt(A, B, N>>1); memcpy(ATemp, A, sizeof(A[0])*N); memset(ATemp+N, 0, sizeof(A[0])*N); memset(B+N, 0, sizeof(A[0])*N); inverse(B, sqrTemp, N); memset(sqrTemp+N, 0, sizeof(A[0])*N); NTT(B, N<<1, 1); NTT(sqrTemp, N<<1, 1); NTT(ATemp, N<<1, 1); for(int i = 0; i<(N<<1); i++) B[i] = (LL)INV2*(((LL)B[i]+(LL)sqrTemp[i]*ATemp[i]%PMOD)%PMOD)%PMOD; NTT(B, N<<1, -1); memset(B+N, 0, sizeof(A[0])*N); } int tempA[MAXL], tempB[MAXL]; int multiple (int * A, int lenA, int * B, int lenB, int * C) { int lvl = 1; while(lvl <= lenA+lenB) lvl <<= 1; for(int i = 0; i<lenA; i++) tempA[i] = A[i]; for(int i = lenA; i<lvl; i++) tempA[i] = 0; for(int i = 0; i<lenB; i++) tempB[i] = B[i]; for(int i = lenB; i<lvl; i++) tempB[i] = 0; NTT (tempA, lvl, 1), NTT (tempB, lvl, 1); for(int i = 0; i<lvl; i++) tempA[i] = (LL)tempA[i]*tempB[i]%PMOD; NTT (tempA, lvl, -1); for(int i = 0; i<lvl; i++) C[i] = tempA[i]; return lvl; } int polC[MAXL]; int cInv[MAXL], debTemp[MAXL]; int sqrPol[MAXL]; int main () { int N, M; scanf("%d%d", &N, &M); for(int i = 1; i<=N; i++) { int a; scanf("%d", &a); polC[a] = -4; } polC[0] ++; int lvl = 1; while(lvl <= M) lvl<<=1; sqrt(polC, sqrPol, lvl); multiple(sqrPol, lvl, sqrPol, lvl, debTemp); sqrPol[0] ++; inverse(sqrPol, cInv, lvl); for(int i = 1; i<=M; i++) printf("%d\n", (int)(2LL*cInv[i]%PMOD)); }
Join the discussion