Difference between revisions of "NTT & FFT"
Jump to navigation
Jump to search
(Created page with "== NTT == <syntaxhighlight lang='cpp'> typedef long long LL; const int MAXN = 3e5 + 10; const int MOD = 998244353; const int G = 3; namespace NTT { int N, a[MAXN], b[MAX...") |
(→NTT) |
||
(2 intermediate revisions by the same user not shown) | |||
Line 60: | Line 60: | ||
} | } | ||
}; | }; | ||
+ | |||
+ | |||
+ | namespace NTT { | ||
+ | const int MAXN = 6E5 + 100; | ||
+ | int N, a[MAXN], b[MAXN]; | ||
+ | const int G = 3; | ||
+ | |||
+ | int bin(LL x, LL n) { | ||
+ | LL ret = MOD != 1; | ||
+ | for (x %= MOD; n; n >>= 1, x = x * x % MOD) | ||
+ | if (n & 1) ret = ret * x % MOD; | ||
+ | return (int) ret; | ||
+ | } | ||
+ | |||
+ | void ntt(int * a, int N, int f) { | ||
+ | int i, j = 0, t, k; | ||
+ | for (i = 1; i < N - 1; i++) { | ||
+ | for (t = N; j ^= t >>= 1, ~j & t;); | ||
+ | if (i < j) { | ||
+ | swap(a[i], a[j]); | ||
+ | } | ||
+ | } | ||
+ | for (i = 1; i < N; i <<= 1) { | ||
+ | t = i << 1; | ||
+ | int wn = bin(G, (MOD - 1) / t); | ||
+ | for (j = 0; j < N; j += t) { | ||
+ | int w = 1; | ||
+ | for (k = 0; k < i; k++, w = 1LL * w * wn % MOD) { | ||
+ | int x = a[j + k], y = 1LL * w * a[j + k + i] % MOD; | ||
+ | a[j + k] = (x + y) % MOD, a[j + k + i] = (x - y + MOD) % MOD; | ||
+ | } | ||
+ | } | ||
+ | } | ||
+ | if (f == -1) { | ||
+ | reverse(a + 1, a + N); | ||
+ | int inv = bin(N, MOD - 2); | ||
+ | for (i = 0; i < N; i++) | ||
+ | a[i] = 1LL * a[i] * inv % MOD; | ||
+ | } | ||
+ | } | ||
+ | |||
+ | void conv(int *s, int *t, int n, int *result) { | ||
+ | memset(a, 0, sizeof a); memset(b, 0, sizeof b); | ||
+ | copy(s, s + n, a); copy(t, t + n, b); | ||
+ | N = 1; while (N < n * 2) N *= 2; | ||
+ | ntt(a, N, 1); | ||
+ | ntt(b, N, 1); | ||
+ | for (int i = 0; i < N; ++i) | ||
+ | a[i] = 1LL * a[i] * b[i] % MOD; | ||
+ | ntt(a, N, -1); | ||
+ | copy(a, a + N, result); | ||
+ | } | ||
+ | }; | ||
+ | </syntaxhighlight> | ||
+ | |||
+ | 任意模: | ||
+ | |||
+ | <syntaxhighlight lang='cpp'> | ||
+ | namespace NTTPlus { | ||
+ | /* | ||
+ | * 任意模 NTT by kuangbin | ||
+ | * 求A和B的卷积,结果对P取模, 做长度为N1的变换,选取两个质数P1和P2 | ||
+ | * P1-1 和 P2-1 必须是 N1 的倍数 | ||
+ | * E1 和 E2 分别是 P1, P2 的原根 | ||
+ | * F1 = inv(E1, P1), F2 = inv(E2, P2) | ||
+ | * I1 = inv(N1, P1), I2 = inv(N1, P2) | ||
+ | * 然后使用中国剩余定理,保证了结果是小于 MM=P1*P2 的 | ||
+ | * M1 = inv(P2, P1) * P2, M2 = inv(P1, P2) * P1 | ||
+ | */ | ||
+ | |||
+ | const int P = 200003; | ||
+ | const int N1 = 1 << 18; | ||
+ | const int N2 = N1 + 1; | ||
+ | const int P1 = 998244353; //P1 = 2^{23}*7*17 + 1 | ||
+ | const int P2 = 995622913; //P2 = 2^{19}*3*3*211 + 1 | ||
+ | const int E1 = 996173970; | ||
+ | const int E2 = 88560779; | ||
+ | const int F1 = 121392023; //E1*F1 = 1(mod P1) | ||
+ | const int F2 = 840835547; //E2*F2 = 1(mod P2) | ||
+ | const int I1 = 998240545; //I1*N1 = 1(mod P1) | ||
+ | const int I2 = 995619115; //I2*N1 = 1(mod P2) | ||
+ | const LL M1 = 397550359381069386LL; | ||
+ | const LL M2 = 596324591238590904LL; | ||
+ | const LL MM = 993874950619660289LL; //MM = P1*P2 | ||
+ | |||
+ | LL mul(LL u, LL v, LL p) { | ||
+ | return (u * v - LL((long double) u * v / p) * p + p) % p; | ||
+ | } | ||
+ | |||
+ | int trf(int x1, int x2) { | ||
+ | return (mul(M1, x1, MM) + mul(M2, x2, MM)) % MM % P; | ||
+ | } | ||
+ | |||
+ | int A[N2], B[N2], C[N2], A1[N2], B1[N2], C1[N2]; | ||
+ | |||
+ | void fft(int *A, int PM, int PW) { | ||
+ | for (int m = N1, h; h = m / 2, m >= 2; PW = (LL) PW * PW % PM, m = h) | ||
+ | for (int i = 0, w = 1; i < h; i++, w = (LL) w * PW % PM) | ||
+ | for (int j = i; j < N1; j += m) { | ||
+ | int k = j + h, x = (A[j] - A[k] + PM) % PM; | ||
+ | (A[j] += A[k]) %= PM; | ||
+ | A[k] = (LL) w * x % PM; | ||
+ | } | ||
+ | for (int i = 0, j = 1; j < N1 - 1; j++) { | ||
+ | for (int k = N1 / 2; k > (i ^= k); k /= 2); | ||
+ | if (j < i) swap(A[i], A[j]); | ||
+ | } | ||
+ | } | ||
+ | |||
+ | void conv(int *a, int *b, int *res) { | ||
+ | memset(C, 0, sizeof C); | ||
+ | copy(a, a + N1, A1); copy(a, a + N1, A); | ||
+ | copy(b, b + N1, B1); copy(b, b + N1, B); | ||
+ | fft(A1, P1, E1); | ||
+ | fft(B1, P1, E1); | ||
+ | for (int i = 0; i < N1; i++) | ||
+ | C1[i] = (LL) A1[i] * B1[i] % P1; | ||
+ | fft(C1, P1, F1); | ||
+ | for (int i = 0; i < N1; i++) | ||
+ | C1[i] = (LL) C1[i] * I1 % P1; | ||
+ | fft(A, P2, E2); | ||
+ | fft(B, P2, E2); | ||
+ | for (int i = 0; i < N1; i++) | ||
+ | C[i] = (LL) A[i] * B[i] % P2; | ||
+ | fft(C, P2, F2); | ||
+ | for (int i = 0; i < N1; i++) | ||
+ | C[i] = (LL) C[i] * I2 % P2; | ||
+ | for (int i = 0; i < N1; i++) | ||
+ | C[i] = trf(C1[i], C[i]); | ||
+ | copy(C, C + N1, res); | ||
+ | } | ||
+ | }; | ||
+ | </syntaxhighlight> | ||
+ | |||
+ | == FFT == | ||
+ | |||
+ | <syntaxhighlight lang='cpp'> | ||
+ | #include<bits/stdc++.h> | ||
+ | using namespace std; | ||
+ | typedef complex<double> E; | ||
+ | const double pi = acos(-1.0); | ||
+ | int n, m; | ||
+ | const int N = 3e5 + 10; | ||
+ | E a[N], b[N]; | ||
+ | |||
+ | void FFT(E *x, int n, int type) { | ||
+ | if (n == 1)return; | ||
+ | E l[n >> 1], r[n >> 1]; | ||
+ | for (int i = 0; i < n; i += 2) | ||
+ | l[i >> 1] = x[i], r[i >> 1] = x[i + 1]; | ||
+ | FFT(l, n >> 1, type); | ||
+ | FFT(r, n >> 1, type); | ||
+ | E wn(cos(2 * pi / n), sin(type * 2 * pi / n)), w(1, 0); | ||
+ | for (int i = 0; i < n >> 1; i++, w *= wn) | ||
+ | x[i] = l[i] + w * r[i], x[i + (n >> 1)] = l[i] - w * r[i]; | ||
+ | } | ||
+ | |||
+ | int main() { | ||
+ | scanf("%d%d", &n, &m); | ||
+ | for (int i = 0, x; i <= n; i++) | ||
+ | scanf("%d", &x), a[i] = x; | ||
+ | for (int i = 0, x; i <= m; i++) | ||
+ | scanf("%d", &x), b[i] = x; | ||
+ | m = n + m; | ||
+ | for (n = 1; n <= m; n <<= 1); | ||
+ | FFT(a, n, 1); | ||
+ | FFT(b, n, 1); | ||
+ | for (int i = 0; i <= n; i++) | ||
+ | a[i] = a[i] * b[i]; | ||
+ | FFT(a, n, -1); | ||
+ | for (int i = 0; i <= m; i++) | ||
+ | printf("%d ", int(round(a[i].real() / n))); | ||
+ | return 0; | ||
+ | } | ||
</syntaxhighlight> | </syntaxhighlight> |
Latest revision as of 12:42, 27 July 2018
NTT
typedef long long LL;
const int MAXN = 3e5 + 10;
const int MOD = 998244353;
const int G = 3;
namespace NTT {
int N, a[MAXN], b[MAXN];
int pown(LL x, LL n) {
LL ret = MOD != 1; x %= MOD;
while (n) {
if (n & 1) ret = ret * x % MOD;
x = x * x % MOD;
n >>= 1;
}
return int(ret);
}
void clear() {
memset(a, 0, sizeof a);
memset(b, 0, sizeof b);
}
void ntt(int * a, int N, int f) {
int i, j = 0, t, k;
for (i = 1; i < N - 1; i++) {
for (t = N; j ^= t >>= 1, ~j & t;);
if (i < j) {
swap(a[i], a[j]);
}
}
for (i = 1; i < N; i <<= 1) {
t = i << 1;
int wn = pown(G, (MOD - 1) / t);
for (j = 0; j < N; j += t) {
int w = 1;
for (k = 0; k < i; k++, w = 1LL * w * wn % MOD) {
int x = a[j + k], y = 1LL * w * a[j + k + i] % MOD;
a[j + k] = (x + y) % MOD, a[j + k + i] = (x - y + MOD) % MOD;
}
}
}
if (f == -1) {
reverse(a + 1, a + N);
int inv = pown(N, MOD - 2);
for (i = 0; i < N; i++)
a[i] = 1LL * a[i] * inv % MOD;
}
}
void conv() {
ntt(a, N, 1);
ntt(b, N, 1);
for (int i = 0; i < N; ++i)
a[i] = 1LL * a[i] * b[i] % MOD;
ntt(a, N, -1);
}
};
namespace NTT {
const int MAXN = 6E5 + 100;
int N, a[MAXN], b[MAXN];
const int G = 3;
int bin(LL x, LL n) {
LL ret = MOD != 1;
for (x %= MOD; n; n >>= 1, x = x * x % MOD)
if (n & 1) ret = ret * x % MOD;
return (int) ret;
}
void ntt(int * a, int N, int f) {
int i, j = 0, t, k;
for (i = 1; i < N - 1; i++) {
for (t = N; j ^= t >>= 1, ~j & t;);
if (i < j) {
swap(a[i], a[j]);
}
}
for (i = 1; i < N; i <<= 1) {
t = i << 1;
int wn = bin(G, (MOD - 1) / t);
for (j = 0; j < N; j += t) {
int w = 1;
for (k = 0; k < i; k++, w = 1LL * w * wn % MOD) {
int x = a[j + k], y = 1LL * w * a[j + k + i] % MOD;
a[j + k] = (x + y) % MOD, a[j + k + i] = (x - y + MOD) % MOD;
}
}
}
if (f == -1) {
reverse(a + 1, a + N);
int inv = bin(N, MOD - 2);
for (i = 0; i < N; i++)
a[i] = 1LL * a[i] * inv % MOD;
}
}
void conv(int *s, int *t, int n, int *result) {
memset(a, 0, sizeof a); memset(b, 0, sizeof b);
copy(s, s + n, a); copy(t, t + n, b);
N = 1; while (N < n * 2) N *= 2;
ntt(a, N, 1);
ntt(b, N, 1);
for (int i = 0; i < N; ++i)
a[i] = 1LL * a[i] * b[i] % MOD;
ntt(a, N, -1);
copy(a, a + N, result);
}
};
任意模:
namespace NTTPlus {
/*
* 任意模 NTT by kuangbin
* 求A和B的卷积,结果对P取模, 做长度为N1的变换,选取两个质数P1和P2
* P1-1 和 P2-1 必须是 N1 的倍数
* E1 和 E2 分别是 P1, P2 的原根
* F1 = inv(E1, P1), F2 = inv(E2, P2)
* I1 = inv(N1, P1), I2 = inv(N1, P2)
* 然后使用中国剩余定理,保证了结果是小于 MM=P1*P2 的
* M1 = inv(P2, P1) * P2, M2 = inv(P1, P2) * P1
*/
const int P = 200003;
const int N1 = 1 << 18;
const int N2 = N1 + 1;
const int P1 = 998244353; //P1 = 2^{23}*7*17 + 1
const int P2 = 995622913; //P2 = 2^{19}*3*3*211 + 1
const int E1 = 996173970;
const int E2 = 88560779;
const int F1 = 121392023; //E1*F1 = 1(mod P1)
const int F2 = 840835547; //E2*F2 = 1(mod P2)
const int I1 = 998240545; //I1*N1 = 1(mod P1)
const int I2 = 995619115; //I2*N1 = 1(mod P2)
const LL M1 = 397550359381069386LL;
const LL M2 = 596324591238590904LL;
const LL MM = 993874950619660289LL; //MM = P1*P2
LL mul(LL u, LL v, LL p) {
return (u * v - LL((long double) u * v / p) * p + p) % p;
}
int trf(int x1, int x2) {
return (mul(M1, x1, MM) + mul(M2, x2, MM)) % MM % P;
}
int A[N2], B[N2], C[N2], A1[N2], B1[N2], C1[N2];
void fft(int *A, int PM, int PW) {
for (int m = N1, h; h = m / 2, m >= 2; PW = (LL) PW * PW % PM, m = h)
for (int i = 0, w = 1; i < h; i++, w = (LL) w * PW % PM)
for (int j = i; j < N1; j += m) {
int k = j + h, x = (A[j] - A[k] + PM) % PM;
(A[j] += A[k]) %= PM;
A[k] = (LL) w * x % PM;
}
for (int i = 0, j = 1; j < N1 - 1; j++) {
for (int k = N1 / 2; k > (i ^= k); k /= 2);
if (j < i) swap(A[i], A[j]);
}
}
void conv(int *a, int *b, int *res) {
memset(C, 0, sizeof C);
copy(a, a + N1, A1); copy(a, a + N1, A);
copy(b, b + N1, B1); copy(b, b + N1, B);
fft(A1, P1, E1);
fft(B1, P1, E1);
for (int i = 0; i < N1; i++)
C1[i] = (LL) A1[i] * B1[i] % P1;
fft(C1, P1, F1);
for (int i = 0; i < N1; i++)
C1[i] = (LL) C1[i] * I1 % P1;
fft(A, P2, E2);
fft(B, P2, E2);
for (int i = 0; i < N1; i++)
C[i] = (LL) A[i] * B[i] % P2;
fft(C, P2, F2);
for (int i = 0; i < N1; i++)
C[i] = (LL) C[i] * I2 % P2;
for (int i = 0; i < N1; i++)
C[i] = trf(C1[i], C[i]);
copy(C, C + N1, res);
}
};
FFT
#include<bits/stdc++.h>
using namespace std;
typedef complex<double> E;
const double pi = acos(-1.0);
int n, m;
const int N = 3e5 + 10;
E a[N], b[N];
void FFT(E *x, int n, int type) {
if (n == 1)return;
E l[n >> 1], r[n >> 1];
for (int i = 0; i < n; i += 2)
l[i >> 1] = x[i], r[i >> 1] = x[i + 1];
FFT(l, n >> 1, type);
FFT(r, n >> 1, type);
E wn(cos(2 * pi / n), sin(type * 2 * pi / n)), w(1, 0);
for (int i = 0; i < n >> 1; i++, w *= wn)
x[i] = l[i] + w * r[i], x[i + (n >> 1)] = l[i] - w * r[i];
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 0, x; i <= n; i++)
scanf("%d", &x), a[i] = x;
for (int i = 0, x; i <= m; i++)
scanf("%d", &x), b[i] = x;
m = n + m;
for (n = 1; n <= m; n <<= 1);
FFT(a, n, 1);
FFT(b, n, 1);
for (int i = 0; i <= n; i++)
a[i] = a[i] * b[i];
FFT(a, n, -1);
for (int i = 0; i <= m; i++)
printf("%d ", int(round(a[i].real() / n)));
return 0;
}