https://www.acmicpc.net/problem/2526
n,p=map(int,input().split())
arr=[0]*p
arr[0]=n
i=0
j=0
c=0
for i in range(1,p):
arr[i]=(arr[i-1]*n)%p
for j in range(i):
if arr[i]==arr[j]:
c=1
break
if c==1: break
print(i if p==i+1 else i-j)