논리를 새워서 풀어본 문제 소개.
수직선에 n개의 점이 찍혀있을 때 두 점 사이의 모든 거리의 합을 구하라는 문제로 initial code를 찍어보았을 때는 O(n**2)
만큼의 시간이 걸리게 되는 문제이다.(아래 코드는 initial code)
from sys import stdin
n = int(stdin.readline().strip())
nums = list(map(int,stdin.readline().strip().split()))
result = 0
for i in range(n):
for j in range(i,n):
result += 2*abs(nums[i]-nums[j])
print(result)
어떻게 풀면 좋을까? 고민을 하다가 각각의 수에서 어떤 수가 result에 몇번이 더해지는지, 그리고 몇번이 빠지는지 생각을 해보게 되었다.
문제에서 주어진 예시를 생각해보자.
5
1 5 3 2 4
해당 예제에서 각각의 숫자는 총 10번 계산에 들어가게 된다.
가령 5
를 생각해보면
abs(5-1),abs(5-2),abs(5-3),abs(5-4),abs(5-5)
abs(1-5),abs(2-5),abs(3-5),abs(4-5)
로 앞에 5번, 뒤에 5번이 들어간다.
그리고 여기서 5는 재일 큰 수이기 때문에 같은 수가 들어가는 케이스를 제외한 9
번이 양수로 계산되고, 1
번이 음수로 계산된다.(즉, 총 9 - 1 == 8
번 더해짐)
다음으로 4를 생각해보면 두 번째로 큰 수이기 때문에 9-2 == 7
번이 양수로 계산되고 2+1 == 3
번 음수로 계산된다.(즉, 총 7 - 3 == 4
번 더해짐)
이를 통해 숫자를 정렬했을 때([1,2,3,4,5]
) 계산되는 수의 합을 따져본다면 차례대로 -8,-4,0,4,8
번(음수의 경우 빠지는 케이스)이 된다.
그리고 이를 계산해보면 (-8)*1+(-4)*2+0*3+4*4+8*5 == 40
.
이를 생각하며 짠 코드는
from sys import stdin
n = int(stdin.readline().strip())
nums = sorted(map(int,stdin.readline().strip().split()),reverse=True)
num = 0
for i in range(n):
num += 2*(n-1-2*i)*nums[i]
print(num)
그리고 이 코드의 big O time은 간단히 O(n)
이 된다는 것을 알 수 있다.