import sys
input=sys.stdin.readline
sys.setrecursionlimit(10**9)
def DFS(Node,level):
global total
visit[Node]=True
if level<=K and Apple[Node]==1:
total+=1
for i in Tree[Node]:
if not visit[i]:
DFS(i,level+1)
N,K=map(int,input().split())
Tree=[ [] for _ in range(N+1) ]
for i in range(N-1):
p,c=map(int,input().split())
Tree[p].append(c)
Tree[c].append(p)
Apple=list(map(int,input().split()))
visit=[False]*(N+1)
total=0
DFS(0,0)
print(total)
📌 어떻게 접근할 것인가?
트리의 를 통해 풀었습니다. 의 매개변수에 level 을 추가함으로써 깊이를 측정해줍니다. 따라서 깊이가 보다 작으면서 사과가 있는 노드라면 total 변수에 을 추가해줍니다.
그리고 visit 배열을 통해 중복방문을 처리해줍니다.
문제에서 루트 노드는 이라고 선언했기 때문에 시작점은 으로 잡아줍니다.
또한 트리는 양방향이기 때문에 지점을 서로 두번 에 넣어줍니다.