import sys
input=sys.stdin.readline
def init(node,start,end):
if start==end:
tree[node]=0
else:
init(node*2 , start , (start+end)//2)
init(node*2+1, (start+end)//2+1 , end)
tree[node]=tree[node*2]+tree[node*2+1]
def update(node,start,end,value,index):
if index<start or index>end: #범위 밖이라면
return
if start==end:
tree[node]=value
return
update(node*2 , start , (start+end)//2 ,value , index)
update(node*2+1 , (start+end)//2+1 , end , value , index)
tree[node]=tree[node*2]+tree[node*2+1]
return
def query(node,start,end,left,right):
if left>end or right<start:
return 0
if left<=start and right>=end:
return tree[node]
return query(node*2,start,(start+end)//2 , left,right) +query(node*2+1, (start+end)//2+1 , end , left , right)
N=int(input())
L=list(map(int,input().split()))
tree=[0]*(4*N)
total=0
L2=[]
init(1, 1, N)
for i in range(N):
L2.append( (L[i],i+1) )
L2.sort()
for i in range(N):
Q,index=L2[i]
total+=query(1,1,N,index+1,N)
update(1,1,N, 1 , index)
print(total)
📌 어떻게 접근할 것인가?
세그먼트 트리를 사용하였습니다.
L2 라는 새로운 배열을 만들고 배열의 값과 인덱스 값을 넣었습니다.
이후 배열의 값에 따라 정렬해주면 인덱스 값이 섞이게 됩니다.
query 함수를 호출해서 섞인 인덱스 값을 매개변수로 넣고 그 값에 따른 구간에 정렬되지 않은 원소가 몇개인지 합을 구합니다.
ks1ksi님의 블로그 를 참조하였습니다.

파이썬은 느려서 pypy 로 제출하셔야 합니다. 중간에 불필요한 연산이 있으면 최대한 줄여줘야 합니다.