우선 정렬하고자 하는 데이터에서 pivot을 설정한다. pivot보다 큰 값은 pivot보다 우측에, 더 작은 값은 pivot보다 좌측에 위치한다. 그 다음, pivot을 기준으로 분할 된 각 영역에서 새로운 pivot을 찾은 후 또 다시 영역을 분할하는 것을 반복한다.
arr = [1,4,2,7,5,8,9,3,6]
pivot = 5
✔ step1
5를 기준으로 [1,4,2,3], [7,8,9,6]
✔ step2
[1,4,2,3]
pivot = 2
2를 기준으로 [1], [4,3]
✔ step3
[7,8,9,6]
pivot = 9
9를 기준으로 [8,7,6],[]
✔ step4
[4,3]
pivot = 3
3을 기준으로 [], [4]
✔ step5
[8,7,6]
pivot = 7
7을 기준으로 [6], [7]
✔ step6
좌우로 분할한 값들을 모두 합쳐주면,
[1,2,3,4,5,6,7,8,9]
def quick_sort(arr):
if len(arr) < 2:
return arr
pivot = arr[len(arr)//2]
smaller_arr, equal_arr, bigger_arr = [], [], []
print("arr ", arr)
print("pivot ", pivot)
for i in arr:
if i<pivot:
smaller_arr.append(i)
elif i>pivot:
bigger_arr.append(i)
else:
equal_arr.append(i)
return quick_sort(smaller_arr) + quick_sort(equal_arr) + quick_sort(bigger_arr)
if __name__ == '__main__' :
arr = list(map(int, input().split()))
print(quick_sort(arr))
pivot값에 따라 시간 복잡도가 달라진다.
가장 이상적인 상황은 pivot값을 기준으로 나눠준 작은 값과 큰 값 영역 데이터의 갯수가 같아서 o(nlogn)의 시간 복잡도가 나오는 것이다.
만일 pivot값을 기준으로 나눠준 값이 고르지 못하고 한쪽으로 치우칠 경우, 성능이 저하되어 최악의경우 o(N^2)의 복잡도를 가지게 된다.😬