[JAVA] BOJ_10427 빚

이진중·2024년 1월 19일
0

알고리즘

목록 보기
57/76

문제 이해

최종적으로 답은 S(1) ~ S(n) 까지의 합이다.
S(m) 은 민균이가 추가로 갚아야하는 돈이고, 해당 금액을 계산하는 방법은
A 배열에서 m개의 원소를 추출해서 m*(그 원소들중 최댓값) - 원소의 합이다.

즉 최댓값에서 각원소를 뺀 값들을 모두 더한것이다.

Sol 1

가장 쉽게 생각해볼만한건 배열을 임의로 정해서 해당 배열에 대해 S(m)을 구하는 것이다.

n개 배열중 1개로 뽑을 수 있는 경우의 수 nC1 +... nCn 까지 하면 2^n -1 이다.

최대 n은 4000이므로 불가능하다.

Sol 2

S[m] 일때 S[m-1] 을 구하기 위해서는 정렬된 배열에서 가장 작은 원소를 제거하거나, 가장 큰 원소를 제거하거나 이다. (중간에서 제거를 할 경우 가장 작은 원소를 제거한 것 보다 반드시 갚아야할 금액이 커진다)

사실 이 방법으로 문제에서 제시해준 테스트케이스가 모두 통과했어서 뭐가 문젠지 고생했다..

결론적으로 S[m] 을 알고 S[m-1]을 알더라도 S[m-2] 는 해당 방법으로 구할 수 없다.

반례는
#Input
1
5 10 14 5 3 17

#Answer
70
이다.

S[m-1]에 쓰이지 않았던 원소가 S[m-2]에 사용될 수 있기 때문에 계속해서 원소를 삭제해나가는건 틀린답이다.

Sol 3

최댓값이 정해진 배열에서는 가장 차이가 작도록 원소를 정해야한다.
즉, 1 3 4 5 8 에서는

m=3 이고 5가 정해져있을때 1 3 5 가 아니라 5에 가까운 3 4 5로 정해야한다.

해당 포인트를 캐치했다면 누적합을 이용하여 각 S(m)을 구할때 start index를 바꿔가면서 O(n) 시간에 S(m)을 구할 수 있다. 총 O(n^2) 시간이 소요된다.

누적합과 배열의 인덱스

우리는 start index로 부터 end index까지의 합을 구하고 싶다.
여기서 누적합을 어덯게 빼야 해당 값을 구할 수 있을까?

3번부터 5번 index까지의 합을 구하고 싶다면 : 5번까지의 누적합 - 2번까지의 누적합
1번까지 index까지의 합을 구하고 싶다면 : 1번까지의 누적합 - 0번까지의 누적합
0번까지 index까지의 합을 구하고 싶다면 : 0번까지의 누적합 - (-1번까지의 누적합)

이렇게 계산이 된다.

즉, 누적합 배열은 preSum[0] = 0 으로 만들어 둬야 한다.
presum[0] = 0
presum[1] = 0번 index까지의 합
presum[2] = 1번 index까지의 합
...

결국
3번부터 5번 index까지의 합을 구하고 싶다면 : 5번까지의 누적합 - 2번까지의 누적합 : preSum[5+1] - preSum[3]이 된다.

0번부터 1번까지의 합은 : preSum[1+1]-preSum[0]
0번부터 0번까지 합은 : preSum[0+1] - preSum[0]

start ~ end 까지 합 : preSum[end+1] - preSum[start] 로 기억하자.
이렇게 모든 변수를 나타낼 수 있게 된다.

최종 코드

import java.util.*;


public class Main {

    static class Pair {
        Integer from;
        Integer to;

        public Pair(Integer from, Integer to) {
            this.from = from;
            this.to = to;
        }

        public Pair() {
        }
    }

    static class Three {
        Integer from;
        Integer to;
        Integer distance;

        public Three(Integer from, Integer to, Integer distance) {
            this.from = from;
            this.to = to;
            this.distance = distance;
        }

        public Three() {
        }
    }

    public static void main(String[] args) {
        final int MAX = 10001;
        Scanner sc = new Scanner(System.in);

        int t = Integer.parseInt(sc.nextLine());

        for (int i = 0; i < t; i++) {
            Long ret = game(sc);
            System.out.println(ret);
        }
    }

    public static Long game(Scanner sc) {
        String[] s = sc.nextLine().split(" ");
        Integer n = Integer.parseInt(s[0]);

        Long[] AList = new Long[n];
        Long[] preSum = new Long[n+1];

        preSum[0]=0L;

        for (int i = 0; i < n; i++) {
            Long num = Long.parseLong(s[i + 1]);
            AList[i]=num;

        }


        Arrays.sort(AList);

        for (int i = 0; i < n; i++) {
            preSum[i+1] = preSum[i]+AList[i];
        }


        Long ans = 0L;

        for (int m = 2; m <= n; m++) {
            // 2번부터
            // S[m] = ?
            Long minResult = Long.MAX_VALUE;
            for (int start = 0; start + m - 1 < n ; start++) {
                int end = start + m - 1;

                Long result = m* AList[end] - getSum(start,end,preSum); // 0 ~ 1

                minResult = Math.min(result,minResult);
            }
            ans += minResult;
        }
        return ans;
    }

    public static Long getSum(int start, int end , Long[] preSum){
        return preSum[end+1] - preSum[start];
    }

}

0개의 댓글