지난 푸리에 변환 글에서 헛소리를 남발했던 부분들을 몇개 발견하였고, 수정도 할 겸 최근에 조금 더 공부한 내용들을 정리해보고자 한다. 이번 글은 수학적인 내용보다는 실제 PS에서 활용되는 부분 위주이다.
우선 푸리에 변환 글 밑에 추가된 몇가지의 자문자답 내용들을 참고하면 좋을 것 같다.
이제 대충 푸리에 변환이 무엇인지 이해했고, 푸리에 급수에서 푸리에 변환, CTFT에서 DTFT, DTFT에서 DFT까지 넘어오는 과정을 어느 정도 맞출 수 있는 정도까지 온 것 같다. 이제 본격적으로 FFT의 비트 재배열부터 Online, Multi-Dimension FFT까지 한 번 살펴보고자 한다.
비트 재배열과 비재귀 구현
우리가 일반적으로 사용하는 FFT 알고리즘은 Cooley-Tukey Algorithm 이다. 흔히 알고있는 분할정복 기법을 활용한 것인데, 다음 수식을 바탕으로 전개된다. 유의해야할 점은 분할 정복이라는 특징 때문에 배열의 크기가 $2^n$ 꼴이어야 한다. 부족한 공간은 $0$으로 채워주면 된다.
\[X_k =E_k+e^{-\frac{2\pi i}{N}k}O_k \\ X_{k+\frac{N}{2}} = E_k - e^{-\frac{2\pi i}{N}k}O_k\]그래서 코드를 잘 살펴보면 짝수 index와 홀수 index를 각각 다른 벡터에 넣어주고, $w^2$를 인자로 넣어 FFT를 한 번 더 돌릴 것이다. $w$가 아니라 $w^2$인 이유는 짝수, 홀수 차수만 가지고 있기 때문이다. (식을 스스로 전개해보면 알 수 있다)
하지만 위 재귀 방식은 메모리도 잡아먹을 뿐만 아니라 생각보다 오래 걸린다. 재귀를 풀어 비재귀 방식으로 접근할 수는 없을까?
이는 재귀 과정에서 각 항들이 어떻게 움직이는지 잘 살펴보면 된다! 항상 새로 들어간 배열에서 인덱스가 짝수이면 앞으로, 홀수면 뒤로 간다. 원래 인덱스 $k=(b_1b_2\cdots b_n)_ 2$를 가지고 있었다면 이제 n번의 이동 과정을 거치게 되는데, i번째 이동 시 $b_{n-i}$에 따라 앞으로 가는지 뒤로 가는지 결정된다. 정성적인 설명을 조금 덧붙이면 앞으로 간다는 것은 $b_i=0$이라는 뜻이고, 인덱스가 작아진다는 느낌이므로 새로운 인덱스의 $n-i$번째 자리가 0이 된다는 뜻이다. 즉, 결론적으로 현재 인덱스를 이진수로 표현하고, 뒤집으면 새로운 인덱스가 된다. 이를 코드로 나타내면 다음과 같다.
1
2
3
4
5
6
for(int i=1,j=0;i<n;i++)
{
int b=(n>>1);
while(!((j^=b)&b)) b>>=1;
if(i<j) swap(a[i],a[j]);
}
우선 $i$와 $j$는 인덱스를 나타낸다. $b$는 $n/2$, 즉 인덱스들을 이진수로 표현했을 때 가장 높은 자리의 비트를 의미한다. while문이 하는 역할은 뒤에서부터 받아올림을 고려하며 한자리씩 올리는 것이다. $b$부터, 즉 최고자리 비트부터 $j$의 비트를 뒤집고, 뒤집은 후의 비트가 $0$이라면 받아올림을 해주어야 하므로 $b/2$ 이후 위 과정을 반복해주는 것이다. $1$이라면 원래 $0$이 었다는 것이므로 뒤집어주는 것으로 마무리 해주면 된다. 말로 풀어서 설명하려니 약간 어려운데, 비트를 뒤집어주는 것이 덧셈을 대신할 수 있다는 사실에 유의하면서 손으로 몇번 해보면 금방 이해할 수 있을 것이다.
이를 적용해준 비재귀 FFT 코드는 다음과 같다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
void FFT(vector<cpx> &f, bool is_rev)
{
int n = (int)f.size();
for(int i=1,j=0;i<n;i++)
{
int b=(n>>1);
while(!((j^=b)&b)) b>>=1;
if(i<j) swap(f[i],f[j]);
}
for(int s=1;s<n;s<<=1)
{
double t = PI/s * (is_rev?1:-1);
cpx w(cos(t),sin(t));
for(int i=0;i<n;i+=(s<<1))
{
cpx wp(1,0);
for(int j=0;j<s;j++)
{
cpx tp = f[s|i|j] * wp;
f[s|i|j]=f[i|j]-tp;
f[i|j]+=tp;
wp *= w;
}
}
}
if(is_rev) for(int i=0;i<n;i++) f[i]/=n;
}
코드를 잘 보면 비트 재배열이 추가되었을 뿐만 아니라, 아래 반복문도 형태가 조금 달라진 것을 확인해볼 수 있다. 우선 $s$는 분할정복 과정에서 계산하고 있는 일종의 단계, 레벨을 의미한다. 앞서 비재귀 FFT를 설명할 때, 분할정복 기법을 활용하여 짝수부분과 홀수부분을 나누어 계산한다고 하였는데, 이러한 성질 때문에 현재 계산하고 있는 범위가 계속 절반씩 줄어든다. 즉, 전체 배열의 크기를 $N=2^n$이라고 하면 우리가 재귀 과정에서 계산하는 범위가 $2^n, 2^{n-1}, \cdots, 4,2,1$이 된다는 것이다. 따라서 위 코드에서는 $2s$가 현재 계산하고 있는 블록의 크기라고 할 수 있다.
이후 $i$ 반복문을 통해 현재 단계의 모든 블록에 대해 반복해주는 것이며 ($i+(s\ll 1)$), $j$ 반복문을 통해 각 블록 내에서 위에서 언급한 FFT 수식을 적용해준다 ($j=0,j<s$) 이때, 나머지 절반은 수식을 통해 채울 수 있으므로 $j<(s\ll 1)$이 아니라 $j<s$인 점을 유의하자. 일종의 최적화로 블록의 크기가 항상 2의 거듭제곱임을 활용하여 덧셈 연산들을 or 연산으로 대체할 수 있다는 부분 또한 알고 가면 좋다. 본인의 취향에 따라서 $s=2$부터 시작하는 등 블록의 개념을 혼용하지 않는 부분 내에서는 적당히 코드를 수정해도 문제 없다.
큰 수 곱셈 3 문제를 비재귀 구현을 통해 해결할 수 있다. FFT 부분만 비재귀로 바꿔주면 되므로 코드는 생략한다.
실수 오차
이제 비재귀 구현까지 알아보았다. 다음으로 알아볼 것은 실수 오차를 줄이기 위한 테크닉들이다. 실수 오차가 발생하는데는 여러 요인이 있겠으나, 우선적으로 실수 자료형인 complex를 사용한다는 점과, 실수와 실수의 곱셈 연산이 굉장히 많이 일어난다는 점, 그리고 실수 자료형이 정확하게 표현하기 어려울 정도로 수가 커진다는 점들이 있겠다. 실수 오차를 어떻게 하면 극복할 수 있을까?
Roots
가장 단순한 방법으로는, 계산 과정에서 가장 많이 쓰이는 복소수 $w^k$를 미리 저장해두고 불러와서 계산하는 방법이 있겠다. 분명히 같은 값이어야 하지만 실수 연산의 한계로 인해 소수 부분의 값이 달라지는 등의 불상사를 예방할 뿐만 아니라 연산량을 줄이는 이점을 가져온다.
1
2
3
4
vector<cpx> roots(n/2);
double ang = 2*PI/n * (inv?-1:1);
for(int i=0; i<n/2; i++) roots[i] = cpx(cos(ang * i), sin(ang * i));
쪼개기
일단 이 훌륭한 블로그를 읽고 오자. 훌륭한 블로그 2도 읽어보면 도움이 된다.
또 하나의 방법은 쪼개기이다. 우선 코드를 먼저 보자. 아마 FFT 관련 문제를 풀고나서 상위권에 위치한 코드들을 몇개 뜯어보면 대부분 아래처럼 conj 연산과 « 연산들로 구성된 multiply 함수를 보았을 것이다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
vector<ll> multiply(vector<ll> &v, vector<ll> &w, ll mod){
int n = 2; while(n < v.size() + w.size()) n <<= 1;
vector<base> v1(n), v2(n), r1(n), r2(n);
for(int i=0; i<v.size(); i++) //비트를 쪼갬(앞에 비트, 뒤에 비트) V1(Z) = V_{F}(Z) + iV_{R}(Z) (V_{F}=V front 즉, 앞 비트 V_{R}=V rear 즉, 뒷 비트)
v1[i] = base(v[i] >> 15, v[i] & 32767); //(a,b)
for(int i=0; i<w.size(); i++)
v2[i] = base(w[i] >> 15, w[i] & 32767); //(c,d) V2(Z) = W_{F}(Z) + iW_{R}(Z)
FFT(v1, 0);
FFT(v2, 0);
for(int i=0; i<n; i++){
int j = (i ? (n - i) : i); //i=0이면 j=0, i>0이면 j= n-i 왜냐면 (a+bi)w^{i}의 conjugate는 (a-bi)w^{-i}=(a-bi)w^{n-i}이기 때문.
base ans1 = (v1[i] + conj(v1[j])) * base(0.5, 0); //V_{F}(Z)
base ans2 = (v1[i] - conj(v1[j])) * base(0, -0.5); //V_{R}(Z)
base ans3 = (v2[i] + conj(v2[j])) * base(0.5, 0); //W_{F}(Z)
base ans4 = (v2[i] - conj(v2[j])) * base(0, -0.5); //W_{R}(Z)
r1[i] = (ans1 * ans3) + (ans1 * ans4) * base(0, 1); //V_{F}(Z)V2(Z)
r2[i] = (ans2 * ans3) + (ans2 * ans4) * base(0, 1); //V_{R}(Z)V2(Z)
}
FFT(r1, 1);
FFT(r2, 1);
vector<ll> ret(n);
for(int i=0; i<n; i++){
ll av = (ll)round(r1[i].real()); //V_{F}W_{F}(Z) (앞비트 2개가 곱해졌으니 2^30곱하기)
ll bv = (ll)round(r1[i].imag()) + (ll)round(r2[i].real()); //V_{F}W_{R}(Z) + W_{F}V_{R}(Z) (앞비트와 뒷비트가 곱해졌으니 2^15 곱하기)
ll cv = (ll)round(r2[i].imag()); //W_{R}(Z)V_{R}(Z) (뒷비트 끼리의 연산)
av %= mod, bv %= mod, cv %= mod;
ret[i] = ((av << 30) + (bv << 15) + cv) % mod;
ret[i] = (ret[i] + mod) % mod;
}
return ret;
}
사실 과거의 내가 그랬듯이, DFT를 겉핥기 식으로만 알고 있다면 위 코드는 이해하기 쉽지 않다. 글에서 언급되지는 않았지만, DFT의 몇가지 성질들을 알고 있어야 하는데, 바로 선형성과 공액 복소수에 대한 성질이다.
선형성이라는 성질은 DFT 과정이 사실 행렬곱과 같은 과정임을 알고 있다면 선형 대수적으로 볼때 자명하게 이해할 수 있다. 아니더라도 단순히 식 전개를 해보면 쉽게 알 수 있다.
\[x[n]=ax_1[n]+bx_2[n] \leftrightarrow X(w)=aX_1(w)+bX_2(w)\]공액 복소수에 대한 성질은 수식으로 표현하면 다음과 같다.
\[x[k] \Leftrightarrow X[k] \\ x^\ast[k] \Leftrightarrow X^\ast[N-k]\]즉, 공액 복소수의 DFT는 기존 DFT의 Time Reversal의 공액 복소수라는 것이다. 이 또한 수식적 전개를 통해 증명할 수 있다.
또 하나의 성질로 Conjugate Symmetry가 있는데, 이는 실수 수열 $x[n]$에 대해 다음이 성립한다.
\[X[n-k]=X*[k]\]실수 수열일 때만 성립한다는 점을 염두에 두길 바란다.
이제 전체적인 흐름부터 알아보자.
다항식을 곱하는데, 그 과정에서 수들이 너무 크다는 것이 문제이다. 그래서 일단 작은 다항식으로 나눠서 계산한 후, 다시 합치는 테크닉을 활용할 것이다. 즉 $A(x)\times B(x)$를 구하는데, $A(x)=A_1(x)+A_2(x)\times C, B(x)=B_1(x)+B_2(x)\times C$로 쪼갠 후, $A(x)B(x) = A_1(x)B_1(x)+C\times(A_1(x)B_2(x)+A_2(x)B_1(x))+C^2\times(A_2(x)B(x))$인 점을 활용하여 각 항들을 계산하고 다시 합칠 것이다.
다시 코드로 돌아와서, 벡터 $V$와 $W$가 입력되었다고 해보자. 이제 우리는 적당한 $C=2^{15}$를 잡고 벡터를 쪼개줄 것이다.
1
2
3
4
for(int i=0; i<v.size(); i++)
v1[i] = base(v[i] >> 15, v[i] & 32767);
for(int i=0; i<w.size(); i++)
v2[i] = base(w[i] >> 15, w[i] & 32767);
즉 $V_1$과 $V_2$에 기존에는 큰 정수였던 아이들을 복소수로 쪼개서 넣어둔 것이다. 이들을 다항식의 계수 벡터로 본다면, $V_1(x)=V_F(x)+iV_R(x), V_2(x)=W_F(x)+iW_R(x)$라고 할 수 있겠다. 이때 아래첨자 $F$와 $R$은 각각 앞 15비트 부분, 뒤 15비트 부분을 의미한다. 이후 $V_1,V_2$를 FFT 시켜준다.
그리고 앞서 설명한 DFT의 성질들을 활용하여 DFT된 $V_F,V_R,W_F,W_R$을 구해준 뒤, 주파수 공간에서의 곱 $V_FW_F,V_FW_R,V_RW_F,V_RW_R$을 계산하고, IDFT하여 복소수 공간에서의 합성곱 $VW$를 구할 것이다.
1
2
3
4
5
6
7
int j = (i ? (n - i) : i);
base ans1 = (v1[i] + conj(v1[j])) * base(0.5, 0);
base ans2 = (v1[i] - conj(v1[j])) * base(0, -0.5);
base ans3 = (v2[i] + conj(v2[j])) * base(0.5, 0);
base ans4 = (v2[i] - conj(v2[j])) * base(0, -0.5);
r1[i] = (ans1 * ans3) + (ans1 * ans4) * base(0, 1);
r2[i] = (ans2 * ans3) + (ans2 * ans4) * base(0, 1);
DFT의 선형성에 의해 $FFT(V_1[i])=FFT(V_F[i])+FFT(iV_R[i])$이 성립한다. 자 그리고 앞서 설명한 공액 복소수의 성질을 적용하면, $V_1[i] \Leftrightarrow FFT(V_1[i])$, $\overline{V_1[i]} \Leftrightarrow \overline{FFT(V_1[n-i])}$ 가 성립한다. 이때, 선형성에 의해 $FFT(\overline{V_1[i]})=FFT(V_F[i])-FFT(iV_R[i])=\overline{FFT(V_1[n-i])}$이다! 즉, ans1에 저장되는 값은 것은 $FFT(V_F[i])$인 것이다! 마찬가지로 ans2=$V_R$, ans3=$W_F$, ans4=$W_R$가 FFT 된 값이 저장된다. 구현상 $i=0$일때 $FFT(V_1[n-0])=FFT(V_1[0])$임에 유의하자.
이제 주파수 공간에서의 곱을 구해주자. 그렇다면 $R_1$에 저장되는 것은 무엇일까? 그렇다, 바로 $FFT(V_FV_2)$이다. 마찬가지로 $R_2=FFT(V_RV_2)$이다. IDFT를 해주면 결과적으로 $R_1,R_2$에 저장되는 것은 $V_FV_2, V_RV_2$이다. 놀라운 점은 $R_1$의 실수부에는 $V_FW_F$가, 허수부에는 $V_FW_R$이 저장되어 있으며, $R_2$의 경우에는 $V_RW_F$와 $V_RW_R$이 저장되어 있다는 점이다! 그리고 이는 쉽게 분리해낼 수 있는 값들이다.
1
2
3
4
5
6
ll av = (ll)round(r1[i].real());
ll bv = (ll)round(r1[i].imag()) + (ll)round(r2[i].real());
ll cv = (ll)round(r2[i].imag());
av %= mod, bv %= mod, cv %= mod;
ret[i] = ((av << 30) + (bv << 15) + cv) % mod;
ret[i] = (ret[i] + mod) % mod;
따라서 앞에서 $A(x)B(x)$의 예시를 들었던 것처럼 실수부 허수부로 분리를 해주고, $C=2^{15}$인 점을 고려하여 $2^{30}$, $2^{15}$를 곱해 원래의 목표, $V(x)W(x)$를 복원해준다.
여기서 생각해보고 넘어가면 좋을 점은 왜 주어진 실수 벡터를 복소수로 쪼개느냐이다. 공액 복소수의 대칭성을 이용하기 위함과, 실수부와 허수부의 값들을 쉽게 분리해낼 수 있다는 점이 답이 될 수 있겠다.
쪼개기 최적화
쪼개기 방법을 잘 이해했는가? 하지만 위 코드를 그대로 제출한다면 시간이 꽤나 많이 소모되는 것을 확인할 수 있을 것이다. 최적화 할 수 있는 방법이 없을까? 잘 생각해보면 ans3과 ans4를 굳이 나눠서 계산할 필요가 없는 것 같다. 결국 우리가 필요한건 분리된 $V_F$와 $V_R$이고, 이들을 $W$와 곱해서 IDFT 해주면 되는 것이기에 조금 더 간단하게 만들 수 있다. 더 나아가 반복문의 범위도 살펴보면 처음 $0$을 제외하고는 $i$와 $n-i$가 매칭되기 때문에 범위를 절반으로 줄일 수 있을 것이다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
vector<ll> multiply(vector<ll> &X, vector<ll> &Y) {
ll i, j, n = 1;
while (n < X.size() + Y.size()) n <<= 1;
vector<cpx> P(n), Q(n);
vector<ll> Z(X.size() + Y.size() - 1);
for (i = 0; i < X.size(); ++i) P[i] = cpx(X[i] >> 15, X[i] & 32767);
for (i = 0; i < Y.size(); ++i) Q[i] = cpx(Y[i] >> 15, Y[i] & 32767);
fft(P, 1); fft(Q, 1);
for (i = 0; i*2 <= n; ++i) {
j = i ? n-i : 0;
cpx v1 = P[i] + conj(P[j]), v2 = conj(P[i]) - P[j];
tie(P[i], Q[i], P[j], Q[j]) = make_tuple(v1 * Q[i], conj(v2) * conj(Q[j]), conj(v1) * Q[j], -v2 * conj(Q[i]));
}
fft(P, -1); fft(Q, -1);
for (i = 0; i < Z.size(); ++i) Z[i] = ((ll)round(P[i].real()) << 29) + ((ll)round(Q[i].real()) >> 1) + ((ll)round(P[i].imag() + Q[i].imag()) << 14);
return Z;
}
코드 내의 변수명들이 조금 달려졌음에 유의하여 잘 살펴보자. 우선 첫번째로 Conjugate가 대칭적임을 이용하여 반복문의 범위를 절반으로 줄였고, FFT 후 Front와 Rear 파트의 곱들을 구하는 과정이 훨씬 단순해졌다.
조금 어렵게 느껴질 수 있으니 직접 손으로 써보면서 확실히 이해하고 넘어가는 것이 중요하다. v1에는 $2\text{FFT}(X_f[i])$가 저장되고, v2에는 $-2\text{FFT}(iX_r[i])$가 저장된다. 공액 복소수에 대한 성질을 적용하여 식을 풀어보면 알 수 있는 부분이다. 그렇다면 v1과 v2의 Conjugate는 무엇을 의미하는 걸까? 이 또한 생각해보면 conj(v1)에는 $2\text{FFT}(X_f[j])$가, conj(v2)에는 $2\text{FFT}(iX_r[i])$가 저장됨을 알 수 있다.
그렇다면 반복문이 끝난 후 P와 Q에는 어떤 값들이 저장되어 있을까?
$v1\ast Q[i]=2\text{FFT}(X_f[i])\ast \text{FFT}(Q[i])$, $conj(v1)\ast Q[j]=2\text{FFT}(X_f[j])\ast \text{FFT}(Q[j])$가 되므로, IFFT 하면 $P = 2(X_fQ_f+iX_fQ_r)$이 될 것이다!
Q 값을 업데이트하는 과정에서 Q에 conj 연산을 취하는 부분이 조금 당황스러울 수 있는데, 이 또한 공액 복소수의 성질을 생각해보면 자연스럽게 넘어갈 수 있다. $Q$의 DFT를 $FFT(Q[k])$라고 한다면, $\overline{Q}$의 DFT는 $\overline{FFT(Q[n-k])}$가 될 것이다.
$conj(v2)\ast conj(Q[j])=2\text{FFT}(iX_r[i])\ast \text{FFT}(\overline{Q[i]})$, $-v2\ast conj(Q[i])=2\text{FFT}(iX_r[j])\ast \text{FFT}(\overline{Q[j]})$이므로, 이 또한 IFFT하면 $Q = 2(X_rQ_r+iX_rQ_f)$가 된다. 위에서 했던것과 실수부 허수부가 뒤바뀌었음에 주의하자.
묶기 in 다항식의 곱셈
사실 위 방식처럼 정확도를 위해 여러번의 DFT를 해주는 경우 시간복잡도 손실이 상당하다. 그럴 수 밖에 없는 것이, 기존에는 2번의 DFT만으로 합성곱을 계산할 수 있었지만, 정확도를 위해 4번의 DFT가 필요해진 상황이기 때문이다. 그렇다면 정확도를 유지하는 선에서 시간복잡도를 더 단축할 수 있는 방법은 없을까?
우리는 정확도를 높이기 위해서 기존 수열의 값들을 반으로 쪼갰었다. 이제는 시간복잡도를 위해 수열의 값들을 묶어볼 것이다. 쉽게 말해서 $[1,2,3,4]$과 $[6,7,8,9]$의 합성곱을 구해야한다고 해보자. 이때, 수열을 2개씩 묶어보면 $[12,34], [67,89]$가 된다. 당연하게도 두 합성곱의 결과는 완전히 다르겠지만, 이들을 다항식의 계수로 보고 간단한 처리만 해준다면 같은 결과를 내놓을 것이다.
쉽게 말해서 $[1,2,3,4]=1x^3+2x^2+3x+4, [5,6,7,8]=5x^3+6x^2+7x+8$로 취급할 때 $x=10$을 대입하는 것이고, $[12,34]=12x+34, [56,78]=56x+78$로 볼 때 $x=100$을 대입하는 것이다. 이때 묶는 단위에 따라서 연산량이 크게 감소하는 이점을 얻을 수 있지만 처리해야할 수의 단위가 커진다는 단점이 따라 오기에 적당한 크기로 묶으면 된다. 구현 또한 간단한 편이니 직접 해보자.
만약 큰 수 곱셈 2 문제를 풀었다면 이 코드를 참고해보는 것도 좋아 보인다.
NTT
FFT의 가장 큰 한계는 바로 실수 연산으로 인한 오차라고 할 수 있겠다. 뿐만 아니라 문제를 풀다보면 특정 수로 나눈 값을 구하라는 등 실수 연산만으로는 해결하기 많이 어려워보이는 작업들이 보이곤 한다. 이러한 점들을 보완해주는 것이 바로 정수만을 사용하는 FFT, Number Theoretic Transform NTT이다.
FFT -> NTT
FFT에서 가장 중요한 점이 무엇일까? 바로 복소수 $w$이다. $w^n=1$을 만족하면서 $w^i\neq w^j$라는 것이다. 이 조건들은 각각 길이가 $n$인 대칭군을 만들 수 있음과 직교성을 기반으로 한 기저의 역할을 수행할 수 있음을 의미한다.
이를 그대로 정수론에 대입시키면, FFT를 할 수 있는 길이는 $2^n$꼴인점을 고려하여 $p=a*2^b+1$꼴의 소수를 뽑고, 소수의 원시근 $w$를 찾으면 된다. 페르마의 소정리에 의해 $w^{p-1}\equiv 1 \mod p$가 성립하므로 결국 $w^\frac{p-1}{n}$로 FFT에서의 복소수를 대신해주면 된다. 원시근과 관련된 내용은 이글을 참고하자. 이렇게 되면 $n\leq 2^b$인 배열(수열)들에 대해 NTT를 해줄 수 있게된다.
대표적으로 쓰이는 쌍으로 $998244353=119 \cdot 2^{23} + 1$가 있겠다. 구현은 다음과 같다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
const ll w = 3;
const ll mod = 998244353;
ll po(ll x, ll y)
{
ll ret = 1;
while(y)
{
if(y&1) ret=ret*x%mod;
y>>=1; x=x*x%mod;
}
return ret;
}
void NTT(poly &f, bool is_rev=0)
{
int n = (int)f.size();
for(int i=1,j=0;i<n;i++)
{
int b=(n>>1);
while(!((j^=b)&b)) b>>=1;
if(i<j) swap(f[i],f[j]);
}
for(int i=1;i<n;i<<=1)
{
ll x = po(w,(mod/i)>>1); if(is_rev) x = po(x,mod-2);
for(int j=0;j<n;j+=(i<<1))
{
ll y=1;
for(int k=0;k<i;k++)
{
ll z = f[i|j|k] * y % mod;
f[i|j|k] = f[j|k] - z;
if(f[i|j|k]<0) f[i|j|k] += mod;
f[j|k] += z;
if(f[j|k]>=mod) f[j|k] -= mod;
y = y * x % mod;
}
}
}
if(is_rev)
{
ll t = po(n,mod-2);
for(int i=0;i<n;i++) f[i]=f[i]*t%mod;
}
}
FFT와 근본적으로 같은 알고리즘이기에 $w$ 부분만 위에서 설명한 내용들로 바꿔주면 된다. 구현하는데 있어 아래 내용들은 한 번 확인하고 넘어가는 것이 좋겠다.
1
2
3
ll x = po(w,(mod/i)>>1); if(is_rev) x = po(x,mod-2);
~
ll t = po(n,mod-2);
앞서 $w^\frac{p-1}{n}$을 $w$ 대신 사용해준다고 했는데, 998244353의 경우 $\lfloor\frac{p-1}{n}\rfloor=\lfloor\frac{p}{n}\rfloor$인 점을 고려하고, 우리가 보고 있는 블록의 크기가 $2i$이므로 $\lfloor\frac{\lfloor\frac{p}{q}\rfloor}{r}\rfloor=\lfloor\frac{p}{qr}\rfloor$임을 이용하면 위와 같이 코드를 작성할 수 있다. IFFT의 경우 $x^{-1}$을 구해주어야 하는데, NTT는 체 $\Z_p$에서 FFT를 해주는 것이므로, 역수 계산은 모두 역원을 이용해주면 된다. 페르마의 소정리에 의해 $x$는 소수 $p$와 서로소이므로 $x^{p-1}=1 \mod p$이고, $x^{-1}\equiv x^{p-2} \mod p$이기에 IFFT를 할 때 $N$으로 나눠주는 것 또한 $N^{p-2}$를 곱해주면 된다.
Arbitary Modular
NTT의 한계라고 한다면 바로 특정 소수 모듈러에 대해서만 적용이 가능하다는 것이다. 하지만 모듈러와 관련된 매우 유명한 정리인 CRT, 중국인의 나머지 정리를 사용한다면 임의의 수 또한 처리할 수 있다. 쉽게 말해 적당히 NTT가 가능한 여러 소수에 대해 결과값들을 구해놓고, 중국인의 나머지 정리를 활용해 결과를 합쳐주면 된다는 것이다.