[MST] 1251. 하나로 (Java)

안수진·2024년 8월 29일

SWEA

목록 보기
8/17
post-thumbnail

[SWEA] 1251. 하나로

📌 풀이 과정

각 섬 사이의 간선에 대한 정보를 주지 않았기 때문에
모든 섬 간의 거리를 계산해서 정점과 간선 정보를 만들어야 한다.

  1. 모든 섬 간 거리 계산
    각 섬의 x, y 좌표를 이용해 두 섬 사이의 거리를 계산하고, 이를 간선으로 처리한다.
    거리에 환경 부담 세율 E를 곱해 가중치를 계산합니다.

  2. 계산된 간선들을 가중치 순으로 정렬한다. → PriorityQueue 활용
    최소 스패닝 트리(MST)를 구성하여 모든 섬을 최소 비용으로 연결한다.

Collections.sort와 PriorityQueue 정렬 차이

// 모든 간선의 거리를 계산하여 리스트에 추가
for (int i = 0; i < N-1; i++) {
    for (int j = i + 1; j < N; j++) {
          double distance = E * (Math.pow(X[i] - X[j], 2) + Math.pow(Y[i] - Y[j], 2));
          edges.add(new Edge(i, j, distance));
    }
}

            // 간선을 가중치에 따라 정렬
            Collections.sort(edges);

607ms → Collections.sort()
297ms → PriorityQueue

어마무시.. PriorityQueue를 애용해야겠다.


✨ 제출 코드

Kruskal 알고리즘 사용

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.PriorityQueue;
import java.util.StringTokenizer;

public class 하나로_1251 {

    static int N;
    static int[] parent;

    static class Edge implements Comparable<Edge>{
        int start, end;
        double weight;

        public Edge(int start, int end, double weight) {
            this.start = start;
            this.end = end;
            this.weight = weight;
        }

        @Override
        public int compareTo(Edge o) {
            return Double.compare(this.weight, o.weight);
        }

    }

    

    static void make() {
        parent = new int[N];
        for(int i = 0; i < N; i++) {
            parent[i] = i;
        }
    }

    static int findParent(int x) {
        if(parent[x] == x) return x;
        return parent[x] = findParent(parent[x]);
    }

    static boolean union(int a, int b) {
        int rootA = findParent(a);
        int rootB = findParent(b);

        if(rootA == rootB) return false;
        parent[rootB] = rootA;
        return true;
    }

    public static void main(String[] args) throws IOException{
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st;
        int T = Integer.parseInt(br.readLine());

        for(int t = 1; t <= T; t++) {
            N = Integer.parseInt(br.readLine()); // 섬의 개수
            double[] X = new double[N];
            double[] Y = new double[N];
            PriorityQueue<Edge> edges = new PriorityQueue<>();

            st = new StringTokenizer(br.readLine());
            for(int i = 0; i < N; i++) {
                X[i] = Double.parseDouble(st.nextToken()); // 섬의 x좌표
            }
            
            st = new StringTokenizer(br.readLine());
            for(int i = 0; i < N; i++) {
            	Y[i] = Double.parseDouble(st.nextToken()); // 섬의 y좌표
            }

            double E = Double.parseDouble(br.readLine()); // 환경 부담 세율

            // 모든 간선의 거리를 계산하여 리스트에 추가
            for (int i = 0; i < N-1; i++) {
                for (int j = i + 1; j < N; j++) {
                    double distance = E * (Math.pow(X[i] - X[j], 2) + Math.pow(Y[i] - Y[j], 2));
                    edges.add(new Edge(i, j, distance));
                }
            }

            make();
            double cost = 0;
            int cnt = 0;

            while(!edges.isEmpty()) {
                Edge edge = edges.poll();

                if(union(edge.start, edge.end)) {
                    cost += edge.weight;
                    if(++cnt == N - 1) break;
                }
            }

            System.out.println("#" + t + " " + Math.round(cost)); // 정수로 반올림하여 출력
        }

    }

}

Prim 알고리즘 사용

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.LinkedList;
import java.util.PriorityQueue;
import java.util.StringTokenizer;

public class 하나로_1251 {
	
	static class Node implements Comparable<Node>{
		int to;
		long cost;
		
		Node(int to, long cost){
			this.to = to;
			this.cost = cost;
		}
		
		@Override
		public int compareTo(Node o) {
			return Long.compare(this.cost, o.cost);
		}
	}
	
	public static void main(String[] args) throws IOException {
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		StringTokenizer st;
        
		int T = Integer.parseInt(br.readLine());
		for (int tc = 1; tc <= T; tc++) {
			int N = Integer.parseInt(br.readLine());
			long[] X = new long[N];
			long[] Y = new long[N];
			boolean[] visit = new boolean[N];

			st = new StringTokenizer(br.readLine());
			for (int i = 0; i < N; i++) {
				X[i] = Long.parseLong(st.nextToken());
			}
			st = new StringTokenizer(br.readLine());
			for (int i = 0; i < N; i++) {
				Y[i] = Long.parseLong(st.nextToken());
			}
			double E = Double.parseDouble(br.readLine());

			LinkedList<Node>[] list = new LinkedList[N];    // 가능한 모든 간선의 비용을 저장

			for (int i = 0; i < N; i++) {
				list[i] = new LinkedList<>();
				for (int j = 0; j < N; j++) {
					if(i == j)	continue;
					long L = (X[i]-X[j])*(X[i]-X[j]) + (Y[i]-Y[j])*(Y[i]-Y[j]);
					list[i].add(new Node(j, L));
				}
			}
			
			PriorityQueue<Node> pq = new PriorityQueue<>();
			pq.add(new Node(0, 0));
			long ans = 0;
			int cnt = 0;
			
			while(!pq.isEmpty()) {
				Node n = pq.poll();
				
				if(visit[n.to])	continue;
				
				visit[n.to] = true;
				ans += n.cost;
				
				if(++cnt == N)	break;
				
				for (Node node : list[n.to]) {
					if(!visit[node.to])	pq.add(node);
				}
			}
			
			System.out.println("#"+tc+" "+Math.round(ans*E));
		}
	}
}
profile
항상 궁금해하기

0개의 댓글