任意模数下快速数论变换的两种实现

我们实现NTT时,总是在 整数模质数 $p=2^kq+1$ 上进行的。原因很简单:为了在 $\mathbb{Z}/p\mathbb{Z}$ 上使用原根以套用复数域 $\mathbb{C}$ 上“单位根”的概念,其阶必然也应当是 $2^k$ 的倍数才行。这就为我们带来了很多不便。假如现在给定一质数 $p’$ 且不保证有 $p’=2^kq+1$,我们就需要另辟蹊径完成卷积。

实现一:中国剩余定理

中国剩余定理:若数 $n$ 的质因数分解为 $\sum_{i=1}^{k}p_i^{e_i}$,有整数模 $n$ 加法(或者,随便你怎么叫吧)

$$\mathbb{Z}/n\mathbb{Z}\cong \mathbb{Z}/p_1^{e1}\mathbb{Z}\times \mathbb{Z}/p_2^{e2}\mathbb{Z}\times \cdots \times\mathbb{Z}/p_k^{ek}\mathbb{Z}$$

或者通俗地讲,若 $m_1, m_2, \cdots, m_k$ 两两互质,则线性同余方程组
$$\left\{\begin{aligned}
x &\equiv x_1\pmod{m_1}\\
x &\equiv x_2\pmod{m_2}\\
&\quad \vdots\\
x &\equiv x_k\pmod{m_k}\end{aligned}\right.$$
在 $\bmod \prod_{i=1}^{k}m_i$ 意义下有唯一解。

于是,如果在卷积过程中不对 $p’$ 取模,我们会得到上限约为 $p’^2n$ 的系数,其中 $n$ 为多项式的次数。在实际应用中,大概为 $10^{23}$。故而我们择取 $3$ 个容易实现NTT的质数 $p_1,p_2,p_3$(我常用 $998244353,1004535809,985661441$,其三者的最小原根均为 $3$,偶因数均为 $2^{21}$),将原式的系数分别对其取模后在整数模 $p_i$ 乘法群下做卷积,最后将得到三个系数 $x_1,x_2,x_3$,分别对应在模 $p_i$ 意义下的实际系数 $x$。

则根据裴蜀定理中国剩余定理,又由于在整数模质数域上每个非零元素均存在逆元,我们对这些方程组两两合并:
$$\begin{aligned}
x_1+k_1p_1&\equiv x_2+k_2p_2\equiv x&\pmod{p_1p_2}\\
x_1+k_1p_1&\equiv x_2&\pmod{p_2}\\
k_1&\equiv \dfrac{x_2-x_1}{p_1}&\pmod{p_2}\\
x_4&\equiv x_1+k_1p_1&\pmod{p_1p_2}\\
x_4+k_4p_1p_2&\equiv x_3+k_3p_3\equiv x&\pmod{p_1p_2p_3}\\
x_4+k_4p_1p_2&\equiv x_3&\pmod{p_3}\\
k_4&\equiv\dfrac{x_3-x_4}{p_1p_2}&\pmod{p_3}\\
\end{aligned}$$

就可以求出在模 $p$ 意义下的系数了。

洛谷题库 P4345 R78201636 记录详情

  1. 1
  2. 2
  3. 3
  4. 4
  5. 5
  6. 6
  7. 7
  8. 8
  9. 9
  10. 10
  11. 11
  12. 12
  13. 13
  14. 14
  15. 15
  16. 16
  17. 17
  18. 18
  19. 19
  20. 20
  21. 21
  22. 22
  23. 23
  24. 24
  25. 25
  26. 26
  27. 27
  28. 28
  29. 29
  30. 30
  31. 31
  32. 32
  33. 33
  34. 34
  35. 35
  36. 36
  37. 37
  38. 38
  39. 39
  40. 40
  41. 41
  42. 42
  43. 43
  44. 44
  45. 45
  46. 46
  47. 47
  48. 48
  49. 49
  50. 50
  51. 51
  52. 52
  53. 53
  54. 54
  55. 55
  56. 56
  57. 57
  58. 58
  59. 59
  60. 60
  61. 61
  62. 62
  63. 63
  64. 64
  65. 65
  66. 66
  67. 67
  68. 68
  69. 69
  70. 70
  71. 71
  72. 72
  73. 73
  74. 74
  75. 75
  76. 76
  77. 77
  78. 78
  79. 79
  80. 80
  81. 81
  82. 82
  83. 83
  84. 84
  85. 85
  86. 86
  87. 87
  88. 88
  89. 89
  90. 90
  91. 91
#include <bits/stdc++.h> using namespace std; /* 快读已省略。 */ #define inl inline #define reint register int #define newl putchar('\n') typedef long long ll; // typedef unsigned long long ull; // typedef __int128 lll; // typedef long double llf; typedef pair <int, int> pint; #define fst first #define scd second #define all(p) begin (p), end (p) using vint = vector <int>; constexpr int N = 1<<18, INF = 0x3f3f3f3f, gen = 3, P1 = 998244353, P2 = 1004535809, P3 = 985661441, invP1P2 = 669690699, invP1P2P3 = 401569863; int n, m, P; vint f, g, pr1, pr2, pr3; inl ll fpow (ll a, ll b, ll mod) { ll res = 1; a %= mod; for (; b; b >>= 1) { if (b & 1) (res *= a) %= mod; (a *= a) %= mod; } return res; } inl void henkan (vint &f, int l) { static int tr[N], lst = tr[0] = 0; if (lst != l) { lst = l; for (int x = 1; x < 1<<l; ++x) tr[x] = tr[x>>1]>>1|((1<<l-1) * (x & 1)); } for (int x = 1; x < 1<<l; ++x) if (tr[x] < x) swap (f[tr[x]], f[x]); } #define clog2(x) ceil (log2 (x)) #define tomod(x) if (mod < INF) x.resize (mod ,0) inl ll inv (ll x, int mod) { return fpow (x, mod - 2, mod); } template <int P, int gen> inl void NTT (vint &f, int l, bool rev) { f.resize (1<<l, 0); henkan (f, l); for (int len = 2; len <= 1<<l; len <<= 1) { const ll w_n = fpow (gen, (P - 1)/len, P); ll w = 1, g, h; for (int st = 0; st < 1<<l; st += len, w = 1) for (int i = 0; i < len/2; ++i, (w *= w_n) %= P) g = f[i + st], h = f[i + st + len/2] * w % P, f[i + st] = (g + h) % P, f[i + st + len/2] = (g + P - h) % P; } if (!rev) return; const ll p = inv (1<<l, P); for (int x = 0; x < 1<<l; ++x) f[x] = f[x] * p % P; reverse (begin (f) + 1, end (f)); } template <int P, int gen> inl vint mul (vint f, vint g, int mod = INF) { int len = f.size () + g.size () - 1, l = clog2 (len); NTT <P, gen> (f, l, 0), NTT <P, gen> (g, l, 0); for (int x = 0; x < 1<<l; ++x) f[x] = 1ll * f[x] * g[x] % P; NTT <P, gen> (f, l, 1); return f.resize (min (mod, len)), f; } int main () { /* */ read (n, m, P); f.resize (n + 1), g.resize (m + 1); for (int x = 0; x <= n; ++x) read (f[x]); for (int x = 0; x <= m; ++x) read (g[x]); pr1 = mul <P1, gen> (f, g); pr2 = mul <P2, gen> (f, g); pr3 = mul <P3, gen> (f, g); for (int i = 0; i <= n + m; ++i) { ll x4 = (pr1[i] + 1ll * P1 * (ll (pr2[i] - pr1[i] + P2) % P2 * invP1P2 % P2)) % (1LL * P1 * P2), k4 = (pr3[i] - x4 % P3 + P3 + 0ll) % P3 * invP1P2P3 % P3; print ((x4 + k4 * P1 % P * P2 % P) % P), putchar (' '); } return 0; }

实现二:拆系数FFT

如你所见,如果我们无脑不取模直接卷积,造出来的系数在 $10^{23}$ 级别。如果使用FFT,在IDFT的过程中还要乘上 $n$——也就是 $10^{28}$,甚至更大。就算是 long double 也承受不了,况且还要考虑浮点误差。

因此我们将一个多项式拆成两个多项式分别相乘。现有常数 $M$($M$ 常取 $4\times 10^4$ 或者 $\sqrt{p}$),我们对 $f(x), g(x)$ 做卷积,则令
$$f(x)=Mf_0(x)+f_1(x),g(x)=Mg_0(x)+g_1(x)$$这样一来,四个多项式的系数均在 $M$ 以下。应用“三次转两次”优化提到的办法,我们应用两次FFT就可以求出其四者的点值表示。将其一一相乘后求得  $f_0(x)g_0(x), f_0(x)g_1(x)+f_1(x)g_0(x), f_1(x)g_1(x)$ 的点值表示,(在 $\bmod p$ 意义下)依次乘上 $M^2, M, 1$ 的系数后相加就是实际系数。这样我们应用了 $5$ 次FFT,但仍然要使用 long double (别忘了,IDFT完成之前系数乘 $n$——大约为 $10^{19}$ 级别,而 double 的有效数字位(fraction)仅有 $52$ 位),实际运行效率不比方法一更优。

  1. 1
  2. 2
  3. 3
  4. 4
  5. 5
  6. 6
  7. 7
  8. 8
  9. 9
  10. 10
  11. 11
  12. 12
  13. 13
  14. 14
  15. 15
  16. 16
  17. 17
  18. 18
  19. 19
  20. 20
  21. 21
  22. 22
  23. 23
  24. 24
  25. 25
  26. 26
  27. 27
  28. 28
  29. 29
  30. 30
  31. 31
  32. 32
  33. 33
  34. 34
  35. 35
  36. 36
  37. 37
  38. 38
  39. 39
  40. 40
  41. 41
  42. 42
  43. 43
  44. 44
  45. 45
  46. 46
  47. 47
  48. 48
  49. 49
  50. 50
  51. 51
  52. 52
  53. 53
  54. 54
  55. 55
  56. 56
  57. 57
  58. 58
  59. 59
  60. 60
  61. 61
  62. 62
  63. 63
  64. 64
  65. 65
  66. 66
  67. 67
  68. 68
  69. 69
  70. 70
  71. 71
  72. 72
  73. 73
  74. 74
  75. 75
  76. 76
  77. 77
  78. 78
  79. 79
  80. 80
  81. 81
  82. 82
  83. 83
  84. 84
  85. 85
  86. 86
  87. 87
  88. 88
  89. 89
  90. 90
  91. 91
  92. 92
#include <bits/stdc++.h> using namespace std; /* 快读已省略。 */ #define inl inline #define reint register int #define newl putchar('\n') typedef long long ll; // typedef unsigned long long ull; // typedef __int128 lll; typedef long double llf; typedef pair <int, int> pint; #define fst first #define scd second #define all(p) begin (p), end (p) using comp = complex <llf>; using vcomp = vector <comp>; using vint = vector <int>; constexpr int N = 1<<18, INF = 0x3f3f3f3f, M = 3.2e4; int n, m, num, P; vint f, g; inl void henkan (vcomp &f, int l) { static int tr[N], lst = tr[0] = 0; if (lst != l) { lst = l; for (int x = 1; x < 1<<l; ++x) tr[x] = tr[x>>1]>>1|((1<<l-1) * (x & 1)); } for (int x = 1; x < 1<<l; ++x) if (tr[x] < x) swap (f[tr[x]], f[x]); } #define clog2(x) ceil (log2 (x)) #define tomod(p) if (mod < INF) p.resize (mod, 0) inl void FFT (vcomp &f, int l, bool rev) { f.resize (1<<l, 0); henkan (f, l); for (int len = 2; len <= 1<<l; len <<= 1) { const comp w_n = comp (cos (M_PI/len*2.0l), sin (M_PI/len*2.0l)); comp w = 1, g, h; for (int st = 0; st < 1<<l; st += len, w = 1) for (int i = st; i < st + len/2; ++i, w *= w_n) g = f[i], h = f[i + len/2] * w, f[i] = g + h, f[i + len/2] = g - h; } if (!rev) return; for (int x = 0; x < 1<<l; ++x) f[x] /= 1<<l; reverse (begin (f) + 1, end (f)); } inl void pair_DFT (vcomp &f, vcomp &g, int l) { vcomp p (1<<l); comp _q; for (int x = 0; x < 1<<l; ++x) p[x] = f[x] + 1il * g[x]; FFT (p, l, 0); // p(x)=f(x)+i g(x), q(x)=f(x)-i g(x) for (int x = 0; x < 1<<l; ++x) _q = conj (p[x ? (1<<l) - x : 0]), f[x] = (p[x] + _q) / 2.0l, g[x] = (p[x] - _q) / 2il; } inl vint mul (vint f, vint g, int mod = INF) { int len = f.size () + g.size () - 1, l = clog2 (len); vcomp f1 (1<<l), g1 (1<<l), f0 (1<<l), g0 (1<<l), p (1<<l), q (1<<l), t (1<<l); f.resize (1<<l, 0), g.resize (1<<l, 0); for (int x = 0; x < 1<<l; ++x) f1[x] = f[x] % M, f0[x] = f[x] / M, g1[x] = g[x] % M, g0[x] = g[x] / M; pair_DFT (f0, g0, l), pair_DFT (f1, g1, l); for (int x = 0; x < 1<<l; ++x) p[x] = f0[x] * g0[x], q[x] = f1[x] * g0[x] + f0[x] * g1[x], t[x] = f1[x] * g1[x]; FFT (p, l, 1), FFT (q, l, 1), FFT (t, l, 1); #define coef(x) ((ll) round (real (x)) % P) for (int x = 0; x < 1<<l; ++x) f[x] = (1ll * M * M % P * coef (p[x]) % P + 1ll * M * coef (q[x]) % P + coef (t[x])) % P; return f.resize (len), f; } int main () { /* */ read (n, m, P); f.resize (n + 1); g.resize (m + 1); for (int x = 0; x <= n; ++x) read (f[x]); for (int x = 0; x <= m; ++x) read (g[x]); for (const int x : mul (f, g)) print (x), putchar (' '); return 0; }
  • 2022年7月3日
  • 1