NTT & FFT

From EOJ Wiki
Revision as of 13:08, 26 July 2018 by Ultmaster (talk | contribs) (→‎NTT)
Jump to navigation Jump to search

NTT

typedef long long LL;
const int MAXN = 3e5 + 10;
const int MOD = 998244353;
const int G = 3;

namespace NTT {
    int N, a[MAXN], b[MAXN];

    int pown(LL x, LL n) {
        LL ret = MOD != 1; x %= MOD;
        while (n) {
            if (n & 1) ret = ret * x % MOD;
            x = x * x % MOD;
            n >>= 1;
        }
        return int(ret);
    }

    void clear() {
        memset(a, 0, sizeof a);
        memset(b, 0, sizeof b);
    }

    void ntt(int * a, int N, int f) {
        int i, j = 0, t, k;
        for (i = 1; i < N - 1; i++) {
            for (t = N; j ^= t >>= 1, ~j & t;);
            if (i < j) {
                swap(a[i], a[j]);
            }
        }
        for (i = 1; i < N; i <<= 1) {
            t = i << 1;
            int wn = pown(G, (MOD - 1) / t);
            for (j = 0; j < N; j += t) {
                int w = 1;
                for (k = 0; k < i; k++, w = 1LL * w * wn % MOD) {
                    int x = a[j + k], y = 1LL * w * a[j + k + i] % MOD;
                    a[j + k] = (x + y) % MOD, a[j + k + i] = (x - y + MOD) % MOD;
                }
            }
        }
        if (f == -1) {
            reverse(a + 1, a + N);
            int inv = pown(N, MOD - 2);
            for (i = 0; i < N; i++)
                a[i] = 1LL * a[i] * inv % MOD;
        }
    }

    void conv() {
        ntt(a, N, 1);
        ntt(b, N, 1);
        for (int i = 0; i < N; ++i)
            a[i] = 1LL * a[i] * b[i] % MOD;
        ntt(a, N, -1);
    }
};


namespace NTT {
    const int MAXN = 6E5 + 100;
    int N, a[MAXN], b[MAXN];
    const int G = 3;

    int bin(LL x, LL n) {
        LL ret = MOD != 1;
        for (x %= MOD; n; n >>= 1, x = x * x % MOD)
            if (n & 1) ret = ret * x % MOD;
        return (int) ret;
    }

    void ntt(int * a, int N, int f) {
        int i, j = 0, t, k;
        for (i = 1; i < N - 1; i++) {
            for (t = N; j ^= t >>= 1, ~j & t;);
            if (i < j) {
                swap(a[i], a[j]);
            }
        }
        for (i = 1; i < N; i <<= 1) {
            t = i << 1;
            int wn = bin(G, (MOD - 1) / t);
            for (j = 0; j < N; j += t) {
                int w = 1;
                for (k = 0; k < i; k++, w = 1LL * w * wn % MOD) {
                    int x = a[j + k], y = 1LL * w * a[j + k + i] % MOD;
                    a[j + k] = (x + y) % MOD, a[j + k + i] = (x - y + MOD) % MOD;
                }
            }
        }
        if (f == -1) {
            reverse(a + 1, a + N);
            int inv = bin(N, MOD - 2);
            for (i = 0; i < N; i++)
                a[i] = 1LL * a[i] * inv % MOD;
        }
    }

    void conv(int *s, int *t, int n, int *result) {
        memset(a, 0, sizeof a); memset(b, 0, sizeof b);
        copy(s, s + n, a); copy(t, t + n, b);
        N = 1; while (N < n * 2) N *= 2;
        ntt(a, N, 1);
        ntt(b, N, 1);
        for (int i = 0; i < N; ++i)
            a[i] = 1LL * a[i] * b[i] % MOD;
        ntt(a, N, -1);
        copy(a, a + N, result);
    }
};

FFT

#include<bits/stdc++.h>
using namespace std;
typedef complex<double> E;
const double pi = acos(-1.0);
int n, m;
const int N = 3e5 + 10;
E a[N], b[N];

void FFT(E *x, int n, int type) {
    if (n == 1)return;
    E l[n >> 1], r[n >> 1];
    for (int i = 0; i < n; i += 2)
        l[i >> 1] = x[i], r[i >> 1] = x[i + 1];
    FFT(l, n >> 1, type);
    FFT(r, n >> 1, type);
    E wn(cos(2 * pi / n), sin(type * 2 * pi / n)), w(1, 0);
    for (int i = 0; i < n >> 1; i++, w *= wn)
        x[i] = l[i] + w * r[i], x[i + (n >> 1)] = l[i] - w * r[i];
}

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 0, x; i <= n; i++)
        scanf("%d", &x), a[i] = x;
    for (int i = 0, x; i <= m; i++)
        scanf("%d", &x), b[i] = x;
    m = n + m;
    for (n = 1; n <= m; n <<= 1);
    FFT(a, n, 1);
    FFT(b, n, 1);
    for (int i = 0; i <= n; i++)
        a[i] = a[i] * b[i];
    FFT(a, n, -1);
    for (int i = 0; i <= m; i++)
        printf("%d ", int(round(a[i].real() / n)));
    return 0;
}