BOJ_1067_이동 (Java)
[Platinum I] 이동 - 1067
성능 요약
메모리: 28436 KB, 시간: 444 ms
분류
고속 푸리에 변환, 수학
제출 일자
2025년 2월 2일 09:03:38
문제 설명
N개의 수가 있는 X와 Y가 있다. 이때 X나 Y를 순환 이동시킬 수 있다. 순환 이동이란 마지막 원소를 제거하고 그 수를 맨 앞으로 다시 삽입하는 것을 말한다. 예를 들어, {1, 2, 3}을 순환 이동시키면 {3, 1, 2}가 될 것이고, {3, 1, 2}는 {2, 3, 1}이 된다. 순환 이동은 0번 또는 그 이상 할 수 있다. 이 모든 순환 이동을 한 후에 점수를 구하면 된다. 점수 S는 다음과 같이 구한다.
S = X[0]×Y[0] + X[1]×Y[1] + ... + X[N-1]×Y[N-1]
이때 S를 최대로 하면 된다.
입력
첫째 줄에 N이 주어진다. 둘째 줄에는 X에 들어있는 N개의 수가 주어진다. 셋째 줄에는 Y에 있는 수가 모두 주어진다. N은 60,000보다 작거나 같은 자연수이고, X와 Y에 들어있는 모든 수는 100보다 작은 자연수 또는 0이다.
출력
첫째 줄에 S의 최댓값을 출력한다.
문제 풀이
예전에 풀어보고싶었던 주제인 고속 푸리에 변환을 공부해보았다. PS에서 FFT를 convolution을 O(nlogn)에 계산해야 할 상황에 필요한 지식이다. 사실 이 코드는 코테처럼 안보고 푼 문제가 아니라 여러 자료들을 찾아가며 고치고 고쳐 구현한 공부용 코드다. koosaga님, kundol님 등 다양한 블로그들을 보고 공부했으며 그 중 가장 도움 되었던 PPT 자료를 최하단 reference에 첨부하겠다.
나름 이론 공부한 내용도 차후에 다른 게시글에 차근차근 정리하겠다. 이해하면 쉽지만 이를 증명하는 과정을 모두 머리에 넣기가 어려웠다.
일단 문제에 대해 알아보자면, 이 문제는 순환 컨볼루션 개념으로 접근할 수 있다. X를 이동시키면서 S를 구하는 과정은 신호 처리에서의 컨볼루션과 동일한 형태를 가진다. 이를 위해 FFT를 활용한 빠른 곱셈을 사용한다.
핵심 개념
- 순환 이동 (Circular Shift): 배열의 원소를 순환시키는 연산.
-
ex) 배열 {1, 2, 3}을 순환 이동시키면 {3, 1, 2}, 그 다음엔 {2, 3, 1}이 된다. 이러한 순환 이동을 여러 번 한 뒤 각 이동에 대해 점수 S를 계산해야 한다.
-
점수 계산 (S): 점수 S는 두 배열 X와 Y의 대응되는 원소들의 곱의 합이다. 즉, S = X[0] * Y[0] + X[1] * Y[1] + … + X[N-1] * Y[N-1].
- 고속 푸리에 변환 (FFT): FFT는 이 문제를 해결하기 위한 주요 기법.
- X와 Y를 주파수 영역으로 변환하고, 그 후 각각의 원소들을 곱한 뒤, 역 FFT를 수행하여 점수 S를 계산하는 방법이다. 이 방법을 사용하면 순환 이동을 고려한 점수 계산을 O(N log N)의 시간 복잡도로 해결할 수 있다.
풀이 과정
- 배열 크기 확장:
- 주어진 배열 X와 Y를 두 배 크기로 확장한다. X는 자기 자신을 이어붙여서 원형 이동을 구현하고, Y는 역순으로 저장한다. 이건 배열 순환 이동을 자연스럽게 처리하려고 하는 방법이다.
- FFT 변환:
- X와 Y 배열을 각각 FFT로 주파수 영역으로 변환한다. FFT는 주파수 성분을 빠르게 추출하는 방법으로, 이를 이용해 순환 이동된 배열의 점수를 빠르게 계산할 수 있다.
- 원소별 곱셈:
- 변환된 주파수 영역의 X와 Y 배열을 원소별로 곱한다. 이 곱셈은 순환 이동된 배열에 대한 점수 계산과 같은 역할을 한다.
- 역 FFT (IFFT):
- 곱셈을 끝낸 후, 역 FFT를 사용해서 다시 시간 영역으로 변환한다. 이 변환된 값들이 순환 이동된 배열의 점수 S를 나타낸다.
- 최댓값 계산:
- 계산된 점수들 중에서 최댓값을 구한다.
구현 세부 사항
-
배열 확장 및 Y의 역순 저장: X는 두 배 크기로 확장하고 Y는 역순으로 저장한다. 이 방법으로 FFT 결과가 순환 이동을 반영하게 된다.
-
NTT (Number Theoretic Transform): NTT는 FFT와 비슷하지만 모듈러 연산을 사용하는 특징이 있다. 이 문제에서는 큰 수의 곱셈을 빠르게 계산하기 위해 NTT를 활용한다.
-
최댓값 계산: 점수 S는 배열 간의 곱을 더한 결과인데, 음수 값이 나올 수 있어서 모듈러 연산을 이용해 양수로 만든다. 최댓값을 찾을 때 음수가 나오면 MOD를 더해 양수로 변환한다.
코드
BasicFFT 코드
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
/**
* Author: nowalex322, Kim HyeonJae
*/
import java.io.*;
import java.util.*;
public class Main {
static class Complex {
double real, imag;
static final double EPS = 1e-9;
Complex(double real, double imag) {
this.real = real;
this.imag = imag;
}
Complex add(Complex o) {
return new Complex(real + o.real, imag + o.imag);
}
Complex subtract(Complex o) {
return new Complex(real - o.real, imag - o.imag);
}
Complex multiply(Complex o) {
double r = real * o.real - imag * o.imag;
double i = real * o.imag + imag * o.real;
if (Math.abs(r) < EPS) r = 0;
if (Math.abs(i) < EPS) i = 0;
return new Complex(r, i);
}
Complex divide(double d) {
return new Complex(real / d, imag / d);
}
}
static void fft(Complex[] a, boolean invert) {
int n = a.length;
for (int i = 1, j = 0; i < n; i++) {
int bit = n >> 1;
while (j >= bit) {
j -= bit;
bit >>= 1;
}
j += bit;
if (i < j) {
Complex temp = a[i];
a[i] = a[j];
a[j] = temp;
}
}
for (int len = 2; len <= n; len <<= 1) {
double ang = 2 * Math.PI / len * (invert ? -1 : 1);
Complex wlen = new Complex(Math.cos(ang), Math.sin(ang));
for (int i = 0; i < n; i += len) {
Complex w = new Complex(1, 0);
for (int j = 0; j < len/2; j++) {
Complex u = a[i + j];
Complex v = a[i + j + len/2].multiply(w);
a[i + j] = u.add(v);
a[i + j + len/2] = u.subtract(v);
w = w.multiply(wlen);
}
}
}
if (invert) {
for (int i = 0; i < n; i++) {
a[i] = a[i].divide(n);
}
}
}
static BufferedReader br;
static BufferedWriter bw;
static StringTokenizer st;
static final double EPS = 1e-9;
public static void main(String[] args) throws Exception {
new Main().solution();
}
public void solution() throws Exception {
br = new BufferedReader(new InputStreamReader(System.in));
// br = new BufferedReader(new InputStreamReader(new FileInputStream("src/main/java/BOJ_1067_이동/input.txt")));
bw = new BufferedWriter(new OutputStreamWriter(System.out));
int n = Integer.parseInt(br.readLine());
int size = 1;
while (size < 2 * n) size <<= 1;
Complex[] x = new Complex[size];
Complex[] y = new Complex[size];
st = new StringTokenizer(br.readLine());
for (int i = 0; i < n; i++) {
long val = Long.parseLong(st.nextToken());
x[i] = new Complex(val, 0);
x[i + n] = new Complex(val, 0);
}
for (int i = 2*n; i < size; i++) {
x[i] = new Complex(0, 0);
}
st = new StringTokenizer(br.readLine());
long[] temp = new long[n];
for (int i = 0; i < n; i++) {
temp[i] = Long.parseLong(st.nextToken());
}
for (int i = 0; i < n; i++) {
y[i] = new Complex(temp[n-1-i], 0);
}
for (int i = n; i < size; i++) {
y[i] = new Complex(0, 0);
}
fft(x, false);
fft(y, false);
for (int i = 0; i < size; i++) {
x[i] = x[i].multiply(y[i]);
}
fft(x, true);
long max = Long.MIN_VALUE;
for (int i = n-1; i < 2*n-1; i++) {
double val = Math.abs(x[i].real) < EPS ? 0 : x[i].real;
max = Math.max(max, Math.round(val));
}
bw.write(String.valueOf(max));
bw.flush();
bw.close();
br.close();
}
}
기본 FFT 버전
이 버전은 FFT(고속 푸리에 변환) 을 사용해서 문제를 해결한다. FFT는 실수와 복소수를 다룬다는 점에서 NTT와 다르고, 모듈러 연산을 필요로 하지 않는다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
Complex 클래스:
- 복소수를 나타내는 클래스다. 실수(real)와 허수(imag) 값을 가지고 있으며, 기본적인 연산인 더하기, 빼기, 곱셈, 나누기 등이 구현되어 있다.
- multiply() 함수는 복소수 곱셈을, add()와 subtract()는 복소수 덧셈과 뺄셈을 구현한다.
fft() 함수:
- 주어진 복소수 배열에 대해 고속 푸리에 변환을 수행하는 함수다.
- 이진 반사 정렬을 하고, 이후 계단식으로 길이를 확장하면서 FFT를 계산한다.
- 역변환을 위한 invert 처리 부분도 있다.
main() 메소드:
- x와 y 배열을 복소수 배열로 변환한 후 fft()를 적용한다.
- x와 y의 각 값들을 곱한 후 역 푸리에 변환을 다시 한다.
- 변환된 값들 중에서 최댓값을 찾아 출력한다.
핵심 아이디어:
- 배열 X와 Y를 복소수로 변환하고, FFT를 적용한다.
- FFT로 얻은 배열을 원소별로 곱하고, 역 FFT를 적용하여 최댓값을 찾는다.
NTT 활용 코드
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
/**
* Author: nowalex322, Kim HyeonJae
*/
import java.io.*;
import java.util.*;
public class Main {
static class NTT {
static final long MOD = 998244353;
static final long PRIMITIVE_ROOT = 3;
static long pow(long a, long b) {
long res = 1;
while (b > 0) {
if ((b & 1) == 1) {
res = res * a % MOD;
}
a = a * a % MOD;
b >>= 1;
}
return res;
}
static void ntt(long[] a, boolean invert) {
int n = a.length;
for (int i = 1, j = 0; i < n; i++) {
int bit = n >> 1;
while (j >= bit) {
j -= bit;
bit >>= 1;
}
j += bit;
if (i < j) {
long temp = a[i];
a[i] = a[j];
a[j] = temp;
}
}
for (int len = 2; len <= n; len <<= 1) {
long wlen = pow(PRIMITIVE_ROOT, (MOD - 1) / len);
if (invert) {
wlen = pow(wlen, MOD - 2);
}
for (int i = 0; i < n; i += len) {
long w = 1;
for (int j = 0; j < len/2; j++) {
long u = a[i + j];
long v = a[i + j + len/2] * w % MOD;
a[i + j] = (u + v) % MOD;
a[i + j + len/2] = (u - v + MOD) % MOD;
w = w * wlen % MOD;
}
}
}
if (invert) {
long inv_n = pow(n, MOD - 2);
for (int i = 0; i < n; i++) {
a[i] = a[i] * inv_n % MOD;
}
}
}
static long[] multiply(long[] a, long[] b) {
int n = 1;
while (n < a.length + b.length) n <<= 1;
long[] fa = Arrays.copyOf(a, n);
long[] fb = Arrays.copyOf(b, n);
ntt(fa, false);
ntt(fb, false);
for (int i = 0; i < n; i++) {
fa[i] = fa[i] * fb[i] % MOD;
}
ntt(fa, true);
return fa;
}
}
static BufferedReader br;
static BufferedWriter bw;
static StringTokenizer st;
public static void main(String[] args) throws Exception {
new Main().solution();
}
public void solution() throws Exception {
br = new BufferedReader(new InputStreamReader(System.in));
// br = new BufferedReader(new InputStreamReader(new FileInputStream("src/main/java/BOJ_1067_이동/input.txt")));
bw = new BufferedWriter(new OutputStreamWriter(System.out));
int n = Integer.parseInt(br.readLine());
int size = 1;
while (size < 2 * n) size <<= 1;
long[] x = new long[size];
long[] y = new long[size];
st = new StringTokenizer(br.readLine());
for (int i = 0; i < n; i++) {
x[i] = Long.parseLong(st.nextToken());
x[i + n] = x[i];
}
st = new StringTokenizer(br.readLine());
for (int i = 0; i < n; i++) {
long val = Long.parseLong(st.nextToken());
y[(2*n-1-i) % n] = val;
}
NTT.ntt(x, false);
NTT.ntt(y, false);
for (int i = 0; i < size; i++) {
x[i] = x[i] * y[i] % NTT.MOD;
}
NTT.ntt(x, true);
long max = Long.MIN_VALUE;
for (int i = n-1; i < 2*n-1; i++) {
long val = x[i];
if (val < 0) val += NTT.MOD;
max = Math.max(max, val);
}
bw.write(String.valueOf(max));
bw.flush();
bw.close();
br.close();
}
}
NTT 활용 버전 이 코드는 Number Theoretic Transform (NTT) 를 이용해서 문제를 해결한다. NTT는 정수 수학에서 빠른 변환을 제공하는 방법이고, 모듈러 연산을 고려하기 때문에 큰 수의 연산을 할 때 유리하다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
NTT 클래스:
- MOD: 모든 계산에서 모듈러 연산을 할 값이야. 이 문제에서는 998244353이 사용된다.
- PRIMITIVE_ROOT: 원시 근을 나타내는 값인데, 여기서는 3을 사용한다.
- pow(): 큰 수의 거듭제곱을 빠르게 계산하는 함수로, 모듈러 거듭제곱을 구현한다. 시간 복잡도는 O(log b)로 효율적이다.
- ntt(): NTT 변환을 실제로 수행하는 함수다. 여기서 이진 반사 정렬과 고속 변환이 이루어진다.
- multiply(): 두 배열을 곱하는 함수다. 이 함수에서 중요한 점은 NTT를 두 번 사용해서 변환 후 곱셈을 하고 다시 변환하여 원래의 값으로 돌아오게 만든다는 것이다.
main() 메소드:
- x와 y 배열을 순환 이동 가능하게 만들기 위해 x는 두 배 크기로 만들어서 자기 자신을 반복시킨다.
- y는 역순으로 저장해서, 실제로 X와 Y의 순환 이동을 구현하는 효과를 낸다.
- ntt() 함수 두 번 호출하고, 그 결과로 나온 값들 중에서 최댓값을 찾는다.
핵심 아이디어:
- 배열 X와 Y를 순환 이동을 고려해 확장한다.
- NTT로 변환하고 곱셈을 한 후 다시 NTT로 역변환해서 최댓값을 찾는다.