Post

BOJ_14258_XOR 그룹 (Java)

BOJ_14258_XOR 그룹 (Java)

[Platinum V] XOR 그룹 - 14258

문제 링크

성능 요약

메모리: 56844 KB, 시간: 1544 ms

분류

자료 구조, 분리 집합, 오프라인 쿼리

제출 일자

2025년 2월 19일 16:39:34

문제 설명

N*M 격자에 서로 다른 수가 하나씩 들어가 있다. XOR 그룹이라는 것을 정의를 하여 합을 최대로 하려한다. XOR 그룹이란 위, 아래, 오른쪽, 왼쪽으로 인접한 칸에 수가 있다면, 그 칸과 연결되어 그 수를 모두 XOR한 값을 가지는 그룹이 된다. 만약, 중간에 수가 빠져있으면, 연결이 되지 않으므로, 한 격자판에 여러 XOR 그룹이 있을 수 있다.

이제 격자판에서 작은 수부터 제거해 나갈 것이다. 하나를 지울 때 마다 XOR 그룹이 변하는데, XOR그룹의 값의 합의 최대가 될 때, 그 값을 구하여라.

입력

첫째 줄에 n, m이 주어진다.(1 ≤ n, m ≤ 1,000)

다음 n줄에는 격자의 i번째 줄의 수 m개가 주어진다. 수는 1,000,000보다 크지 않은 음이 아닌 정수이다.

출력

XOR 그룹의 값의 합이 최대가 되는 값을 구하여라

문제 풀이

제일 작은수를 제거해나가는 대신 빈 칸에서 제일 큰수를 넣어가며 진행했다. 이때 그룹을 이루는 것끼리 union-find로 진행할 수 있다.

코드

코드 1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
/**
 * Author: nowalex322, Kim HyeonJae
 */

import java.io.*;
import java.util.*;

public class Main {
    class Cell implements Comparable<Cell> {
        int r, c, val;
        public Cell(int r, int c, int val) {
            this.r = r;
            this.c = c;
            this.val = val;
        }

        @Override
        public int compareTo(Cell o) {
            return o.val - this.val;
        }
    }
    static BufferedReader br;
    static BufferedWriter bw;
    static StringTokenizer st;
    static int N, M;
    static long res;
    static int[] dr = {-1, 1, 0, 0}, dc = {0, 0, 1, -1};
    static int[] parent; // parent[i] = i값의 대표
    static long[] XOR; // XOR[px] = 대표px그룹의 XOR값
    static int[][] board;
    static boolean[][] visited;
    static Cell[] cells;
    public static void main(String[] args) throws Exception {
        new Main().solution();
    }

    public void solution() throws Exception {
        br = new BufferedReader(new InputStreamReader(System.in));
//        br = new BufferedReader(new InputStreamReader(new FileInputStream("src/main/java/BOJ_14258_XOR그룹/input.txt")));
        bw = new BufferedWriter(new OutputStreamWriter(System.out));

        st = new StringTokenizer(br.readLine());
        N = Integer.parseInt(st.nextToken());
        M = Integer.parseInt(st.nextToken());
        board = new int[N][M];
        visited = new boolean[N][M];

        parent = new int[1000001];
        XOR = new long[1000001];
        for(int i=0; i< parent.length; i++){
            parent[i] = i;
        }

        cells = new Cell[N*M];
        for(int i=0; i<N; i++) {
            st = new StringTokenizer(br.readLine());
            for(int j=0; j<M; j++) {
                board[i][j] = Integer.parseInt(st.nextToken());
                cells[i*M + j] = new Cell(i, j, board[i][j]);
            }
        }
        Arrays.sort(cells);

        long sum = 0;
        for(Cell c : cells) {
            int currR = c.r, currC = c.c;
            int currVal = c.val;

            visited[currR][currC] = true;
            XOR[currVal] = currVal;
            sum += currVal;

            for(int k=0; k<4; k++) {
                int nextR = currR + dr[k];
                int nextC = currC + dc[k];
                if(nextR >= 0 && nextR < N && nextC >= 0 && nextC < M && visited[nextR][nextC]) {
                    int nextVal = board[nextR][nextC];

                    int px = find(currVal);
                    int py = find(nextVal);

                    // 부모가 다르면 병합해야함
                    if(px != py){
                        sum -= XOR[px];
                        sum -= XOR[py];

                        union(currVal, nextVal);

                        sum += XOR[find(currVal)];
                    }
                }
            }
            res = Math.max(sum, res);
        }
        System.out.println(res);
        bw.flush();
        bw.close();
        br.close();
    }

    public int find(int x){
        if(parent[x] != x) return parent[x] = find(parent[x]);
        return parent[x];
    }

    public void union(int x, int y){
        int px = find(x);
        int py = find(y);
        if(px == py) return;

        if(px < py){
            parent[py] = px;
            XOR[px] ^= XOR[py];
        }
        else{
            parent[px] = py;
            XOR[py] ^= XOR[px];
        }
    }
}

코드 2 (빠른 입출력)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
/**
 * Author: nowalex322, Kim HyeonJae
 */

import java.io.*;
import java.util.*;

public class Main {
    class Cell implements Comparable<Cell> {
        int r, c, val;
        public Cell(int r, int c, int val) {
            this.r = r;
            this.c = c;
            this.val = val;
        }

        @Override
        public int compareTo(Cell o) {
            return o.val - this.val;
        }
    }
    static BufferedReader br;
    static BufferedWriter bw;
    static StringTokenizer st;
    static int N, M;
    static long res;
    static int[] dr = {-1, 1, 0, 0}, dc = {0, 0, 1, -1};
    static int[] parent; // parent[i] = i값의 대표
    static long[] XOR; // XOR[px] = 대표px그룹의 XOR값
    static int[][] board;
    static boolean[][] visited;
    static Cell[] cells;
    static FastReader fr;
    public class FastReader {
        private final DataInputStream din;
        private final byte[] buffer;
        private int bufferPointer, bytesRead;

        public FastReader() {
            din = new DataInputStream(System.in);
            buffer = new byte[16384];
            bufferPointer = bytesRead = 0;
        }

        private byte read() throws IOException {
            if (bufferPointer == bytesRead)
                fillBuffer();
            return buffer[bufferPointer++];
        }

        private void fillBuffer() throws IOException {
            bytesRead = din.read(buffer, bufferPointer = 0, buffer.length);
            if (bytesRead == -1)
                buffer[0] = -1;
        }

        public int nextInt() throws IOException {
            int ret = 0;
            byte c = read();
            while (c <= ' ')
                c = read();
            boolean neg = (c == '-');
            if (neg)
                c = read();
            do {
                ret = ret * 10 + c - '0';
            } while ((c = read()) >= '0' && c <= '9');
            return neg ? -ret : ret;
        }
    }

    public static void main(String[] args) throws Exception {
        new Main().solution();
    }

    public void solution() throws Exception {
        fr = new FastReader();
        N = fr.nextInt();
        M = fr.nextInt();
        board = new int[N][M];
        visited = new boolean[N][M];

        parent = new int[1000001];
        XOR = new long[1000001];
        for(int i=0; i< parent.length; i++){
            parent[i] = i;
        }

        cells = new Cell[N*M];
        for(int i=0; i<N; i++) {
            for(int j=0; j<M; j++) {
                board[i][j] = fr.nextInt();
                cells[i*M + j] = new Cell(i, j, board[i][j]);
            }
        }
        Arrays.sort(cells);

        long sum = 0;
        for(Cell c : cells) {
            int currR = c.r, currC = c.c;
            int currVal = c.val;

            visited[currR][currC] = true;
            XOR[currVal] = currVal;
            sum += currVal;

            for(int k=0; k<4; k++) {
                int nextR = currR + dr[k];
                int nextC = currC + dc[k];
                if(nextR >= 0 && nextR < N && nextC >= 0 && nextC < M && visited[nextR][nextC]) {
                    int nextVal = board[nextR][nextC];

                    int px = find(currVal);
                    int py = find(nextVal);

                    // 부모가 다르면 병합해야함
                    if(px != py){
                        sum -= XOR[px];
                        sum -= XOR[py];

                        union(currVal, nextVal);

                        sum += XOR[find(currVal)];
                    }
                }
            }
            res = Math.max(sum, res);
        }
        System.out.println(res);
    }

    public int find(int x){
        if(parent[x] != x) return parent[x] = find(parent[x]);
        return parent[x];
    }

    public void union(int x, int y){
        int px = find(x);
        int py = find(y);
        if(px == py) return;

        if(px < py){
            parent[py] = px;
            XOR[px] ^= XOR[py];
        }
        else{
            parent[px] = py;
            XOR[py] ^= XOR[px];
        }
    }
}
This post is licensed under CC BY 4.0 by the author.