import sys
class Tree:
parent = [0] * 1010
def distance(tree,x,y):
tmpX = x
tmpY = y
xList=[]
yList=[]
parent=[]
if x == y : return 0
if tree.parent[y]==x or tree.parent[x]==y:
return 1
if x==0:
xList.append(0)
if y==0:
yList.append(0)
while x != 0:
xList.append(tree.parent[x])
x = tree.parent[x]
while y!=0:
yList.append(tree.parent[y])
y = tree.parent[y]
if len(xList)<=len(yList):
for i in range(len(xList)):
if xList[i] in yList:
parent.append(xList[i])
else:
for i in range(len(yList)):
if yList[i] in xList:
parent.append(yList[i])
point = max(parent)
cntX = cntY = 0
while tmpX!=point:
tmpX = tree.parent[tmpX]
cntX +=1
while tmpY!=point:
tmpY = tree.parent[tmpY]
cntY +=1
return (cntX+cntY)
if __name__=='__main__':
n,x,y = map(int, sys.stdin.readline().split())
for i in range(n-1):
a,b = map(int, sys.stdin.readline().split())
tree=Tree()
tree.parent[b]=a
if x<=n-1 and y<=n-1:
print(distance(tree,x,y))
else:
print(0)
Code 2 (수정)
import sys
parent = [-1] * 1010
def findParent(node, result):
if node != -1:
result.append(node)
findParent(parent[node], result)
if __name__=='__main__':
n,x,y = map(int, sys.stdin.readline().split())
for i in range(n-1):
a,b = map(int, sys.stdin.readline().split())
parent[b] = a
resultX=[]
resultY=[]
findParent(x,resultX)
findParent(y, resultY)
commonParent=0
if len(resultY)<=len(resultX):
for i in range(len(resultY)):
if resultY[i] in resultX:
commonParent = resultY[i]
break
else:
for i in range(len(resultX)):
if resultX[i] in resultY:
commonParent = resultX[i]
break
cnt = 0
for i in range(len(resultX)):
if resultX[i] == commonParent:
break
else:
cnt+=1
for i in range(len(resultY)):
if resultY[i]==commonParent:
break
else:
cnt+=1
print(cnt)