import sys
input = sys.stdin.readline
cnt = int(input())
lst = list(map(int, input().strip().split()))
delete_node = int(input())
dict = {}
for i in range(cnt):
dict[i] = []
for i in range(cnt):
if lst[i] == -1:
continue
else:
if i != delete_node:
dict[lst[i]].append(i)
def loop(delete_node):
if len(dict[delete_node]) != 0:
while len(dict[delete_node]) != 0:
for _ in range(len(dict[delete_node])):
loop(dict[delete_node][0])
dict[delete_node].pop(0)
dict.pop(delete_node)
loop(delete_node)
cnt = 0
key_lst = dict.keys()
for i in key_lst:
if len(dict[i]) == 0:
cnt += 1
print(cnt)