import sys
input=sys.stdin.readline
def Find(x):
if x!=disjoint[x]:
disjoint[x]=Find(disjoint[x])
return disjoint[x]
def Union(a,b):
a=Find(a)
b=Find(b)
if a>b:
disjoint[a]=b
else:
disjoint[b]=a
def Answer(parent):
tmp=0
for i in range(1,N+1):
if Find(i)==parent:
tmp+=1
return tmp
N,M,K=map(int,input().split())
disjoint=[ i for i in range(N+1) ]
Node1,Node2=0,0
for i in range(M):
u,v=map(int,input().split())
if i!=K-1: # 가중치가 있는 지점은 합치지않는다.
Union(u,v)
else:
Node1,Node2=u,v
if Find(Node1)==Find(Node2): # 예제 2와 같다면
print(0)
else:
total=Answer(disjoint[1])
print(total*(N-total))
📌 어떻게 접근할 것인가?
일단 기본적으로 유니온 파인드를 사용해주었다.
다만 가중치가 1인 지점이 하나가 있는데 이는 을 사용하지 않았다.

왜냐하면 1-2 인 지점과 3-4-5 로 가는 지점 2개에 대해서 두 지점의 길이의 곱이 답이기 때문이다.
만약 예제 2처럼 가중치가 1인 두 노드의 최상위 부모가 같다면 사이클이 생기므로 답은 이다.
그렇지 않으면 유니온 하는 중에 가장 번호가 적은 노드를 우선시 했으므로 첫번째 값인 disjoint[1] 을 기점으로 같은 값의 개수를 구한 후에 나머지 집합의 개수를 곱해주었다.