最终得对抗自己

[BZOJ 3625] 小朋友和二叉树 多项式

真正的神题, 看来生成函数和多项式引入OI已经好久了。
%%% Picks

题目大意

给你长度为N的不重复数列C, 整数m, 求对于每一个正整数s<=m, 点权在数列C内并且点权和为s的二叉树数量。 答案对998244353取模。

解题报告

生成函数与多项式运算的题目。
首先我们设答案数列为[latex]\{F_n\}[/latex]
那么有递推式 $$ F_k = \sum_{i=1}^n\sum_{j=1}^{k-C_i}F_j\times F_{k-C_i-j} $$
其中, [latex]F_0 = 1[/latex]。
为了表示出这个数列的生成函数, 再搞一个生成函数定义如下:
$$ C(x) = \sum_{i=0}^{\infty}(\sum_{j=1}^n [C_j == i])\times x^i $$
很明显这个函数系数是可以处理出来的。
然后数列的生成函数 F(x) 就可以通过递推建立这样的关系:
$$ F(x) = C(x)F(x)^2+1 $$
解一波方程:
$$ F(x) = \frac{1\pm\sqrt{1-4C(x)}}{2C(x)} $$
很明显我们需要把不收敛的根舍掉:
$$ 取x\rightarrow 0, 那么有 \frac{1+\sqrt{1-4C(x)}}{2C(x)}\rightarrow \infty$$
然后只剩下:
$$ F(x) = \frac{1-\sqrt{1-4C(x)}}{2C(x)} $$
由于只需要前m个答案, 我们要求:
$$ \frac{1-\sqrt{1-4C(x)}}{2C(x)} \mod{x^{m+1}} $$
然后发现答案要求取余的素数[latex]998244353=7\times 17\times 2^{23}+1[/latex]
就可以拿来搞NTT进行多项式运算了。

为了计算答案, 我们需要以下技巧:

1.多项式求逆:

我们要求一个[latex]G(x)[/latex]使[latex]G(x)F(x) \equiv 1 \mod{x^n}[/latex]。
n == 1 时, 很明显直接返回F(x)的常数项的逆元就可以了。
n != 1 时, 递归地进行思考:
设[latex]E(x)F(x) \equiv 1 \mod{x^{\lceil \frac{n}{2} \rceil}}[/latex], 即E(x)为F(x)在模[latex]x^{\lceil \frac{n}{2} \rceil}[/latex]意义下的逆元。
由于[latex]G(x)F(x) \equiv 1 \mod{x^n}[/latex], 所以显然有[latex]G(x)F(x) \equiv 1 \mod{x^{\lceil \frac{n}{2} \rceil}}[/latex]。
那么两式相减,又因为F(x)不为0, 有:
$$ G(x)-E(x) \equiv 0 \mod{x^{\lceil \frac{n}{2} \rceil}} $$
平方(很明显模意义下的平方模数也可以为原来模数平方的任意因子)之后乘上F(x):
$$ G(x)-2 \times E(x)+E(x)^2 \times F(x) \equiv 0 \mod{x^n} $$
移项:
$$ G(x) \equiv 2\times E(x)-E(x)^2 \times F(x) \mod{x^n} $$
如果先求解E(x), 那么我们就可以通过这个方法算出G(x)了, 因为乘法原因, 需要注意这里DFT次数界要设置到2N
E(x)可以递归求解, 次数界减半, 终止条件为常数项逆元, 复杂度:
$$ T(n) = T(\frac{n}{2}) + n\log_2(n) = n\log_2(n) $$

2.多项式开平方

我们要求一个[latex]G(x)[/latex]使[latex]G(x)^2 \equiv F(x) \mod{x^n}[/latex]
首先当 n == 1 时, 由于这里需要开根的多项式常数项为1, 返回1即可(否然好像需要神奇算法? 没去了解)。
否然我们假设已经递归求得了[latex] H(x)^2 \equiv F(x) \mod{x^{\lceil \frac{n}{2} \rceil}} [/latex]
移项之后平方:
$$ (H(x)^2-F(x))^2 \equiv 0 \mod{x^n} $$
再移个项:
$$ (H(x)^2+F(x))^2 \equiv 4\times H(x)^2F(x) \mod{x^n} $$
把[latex]4 \times H(x)^2[/latex]除过去:
$$ (\frac{H(x)^2+F(x)}{2\times H(x)})^2 \equiv F(x) \mod{x^n} $$
然后发现[latex]G(x) \equiv \frac{H(x)^2+F(x)}{2\times H(x)} \mod{x^n}[/latex], 对于[latex]H(x)[/latex]递归求解即可。
同样因为乘法原因, 需要注意这里DFT次数界要设置到2N!!!

有了这两种方法, 答案就可以直接计算了。

代码

辣鸡OJ!卡我常数!
还是自己的NTT太弱了…

#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <climits>
#include <algorithm>
#include <complex>
typedef long long LL;
const int PMOD = 998244353;
const int DROOT = 3;
const int MAXL = 1100000;
inline int advPow (int a, int b) {
    int ret = 1;
    while(b) {
        if(b&1) ret = (LL)ret*a%PMOD;
        a = (LL)a*a%PMOD;
        b >>= 1;
    }
    return ret;
}
const int INV2 = advPow(2, PMOD-2);
LL temp[MAXL];
inline int bitReverse (int num, int lvl) {
    int bit = 1, ret = 0;
    while(bit < lvl) {
        ret <<= 1;
        ret += (bool)(bit&num);
        bit <<= 1;
    }
    return ret;
}

void NTT (int * A, int N, int flag) {
    int g = DROOT;
    if(flag < 0) g = advPow(g, PMOD-2);
    for(int i = 0; i<N; i++) temp[bitReverse(i, N)] = A[i];
    for(int lvl = 2; lvl<=N; lvl<<=1) {
        int wn = advPow(g, (PMOD-1)/lvl);
        for(int i = 0; i<N; i+=lvl) {
            int w = 1, t1, t2;
            for(int j = 0; j<(lvl>>1); j++, w=(LL)wn*w%PMOD) {
                t1 = temp[i+j], t2 = (LL)temp[i+(lvl>>1)+j]*w%PMOD;
                temp[i+j] = t1+t2;
                if(temp[i+j] >= PMOD) temp[i+j] -= PMOD;
                temp[i+(lvl>>1)+j] = t1-t2;
                if(temp[i+(lvl>>1)+j] < 0) temp[i+(lvl>>1)+j] += PMOD;
            }
        }
    }
    if(flag < 0) {
        int inv = advPow(N, PMOD-2);
        for(int i = 0; i<N; i++) temp[i] = (LL)temp[i]*inv%PMOD;
    }
    for(int i = 0; i<N; i++) A[i] = temp[i];
}
int invTemp[MAXL], top;
void inverse (const int * A, int * B, int N) {
    if(N == 1) {
        B[0] = advPow(A[0], PMOD-2);
        return ;
    }
    inverse(A, B, N>>1);
    memset(B+N, 0, sizeof(A[0])*N);
    memcpy(invTemp, A, sizeof(A[0])*N);
    memset(invTemp+N, 0, sizeof(A[0])*N);
    NTT(invTemp, N<<1, 1), NTT(B, N<<1, 1);
    for(int i = 0; i<(N<<1); i++)
        B[i] = (LL)B[i]*(((2LL-(LL)invTemp[i]*B[i])%PMOD+PMOD)%PMOD)%PMOD;
    NTT(B, N<<1, -1);
    memset(B+N, 0, sizeof(A[0])*N);
}

int sqrTemp[MAXL];
int ATemp[MAXL];
void sqrt (int * A, int * B, int N) {
    if(N == 1) {
        B[0] = 1;
        return ;
    }
    sqrt(A, B, N>>1);
    memcpy(ATemp, A, sizeof(A[0])*N);
    memset(ATemp+N, 0, sizeof(A[0])*N);
    memset(B+N, 0, sizeof(A[0])*N);
    inverse(B, sqrTemp, N);
    memset(sqrTemp+N, 0, sizeof(A[0])*N);
    NTT(B, N<<1, 1);
    NTT(sqrTemp, N<<1, 1);
    NTT(ATemp, N<<1, 1);
    for(int i = 0; i<(N<<1); i++)
        B[i] = (LL)INV2*(((LL)B[i]+(LL)sqrTemp[i]*ATemp[i]%PMOD)%PMOD)%PMOD;
    NTT(B, N<<1, -1);
    memset(B+N, 0, sizeof(A[0])*N);
}
int tempA[MAXL], tempB[MAXL];
int multiple (int * A, int lenA, int * B, int lenB, int * C) {
    int lvl = 1;
    while(lvl <= lenA+lenB) lvl <<= 1;
    for(int i = 0; i<lenA; i++) tempA[i] = A[i];
    for(int i = lenA; i<lvl; i++) tempA[i] = 0;
    for(int i = 0; i<lenB; i++) tempB[i] = B[i];
    for(int i = lenB; i<lvl; i++) tempB[i] = 0;
    NTT (tempA, lvl, 1), NTT (tempB, lvl, 1);
    for(int i = 0; i<lvl; i++) tempA[i] = (LL)tempA[i]*tempB[i]%PMOD;
    NTT (tempA, lvl, -1);
    for(int i = 0; i<lvl; i++) C[i] = tempA[i];
    return lvl;
}
int polC[MAXL];
int cInv[MAXL], debTemp[MAXL];
int sqrPol[MAXL];
int main () {

    int N, M;
    scanf("%d%d", &N, &M);
    for(int i = 1; i<=N; i++) {
        int a;
        scanf("%d", &a);
        polC[a] = -4;
    }
    polC[0] ++;
    int lvl = 1;
    while(lvl <= M) lvl<<=1;
    sqrt(polC, sqrPol, lvl);
    multiple(sqrPol, lvl, sqrPol, lvl, debTemp);
    sqrPol[0] ++;
    inverse(sqrPol, cInv, lvl);
    for(int i = 1; i<=M; i++)
        printf("%d\n", (int)(2LL*cInv[i]%PMOD));
}

Join the discussion

Your email address will not be published. Required fields are marked *