파이토치 멀티 프로세싱

JTDK·2021년 6월 27일
0

딥러닝에서 멀티프로세싱은 연산량이 방대한 backpropagation에 주로 활용된다.

일단 pseudo code로 대충 시나리오를 보자

 model 생성
 model.share_memory()
 프로세스들 생성하고, 각각 다른 환경에서 훈련하되 shared_memory에 있는 model
 의 매개변수는 공유한다.

사실 너무 당연한 소리지만, 이렇게 해야 효율적으로 학습이 가능하다. model의 parameter 들을 공유하지 않으면 프로세스 갯수만큼의 model들이 각자 따로 학습될것이다.

이제 실제로 어떻게 쓰는지 보자


MasterNode = ActorCritic()
MasterNode.share_memory()	#모델은 공유
processes = []
params = {	
    'epochs':1000,
    'n_workers':7,
}
counter = mp.Value('i',0)	#프로세스간 공유하는 'i'nteger 변수 0으로 선언
for i in range(params['n_worker']):	
    p = mp.Process(target=worker, args=(i, MasterNode, counter, params))
    p.start()			#spawn하는 시점
    processes.append(p)
for p in processes:	
    p.join()			#프로세스들이 완전히 끝날때까지 기달
for p in processes:	
    p.terminate()		#프로세스 강제종료

위 코드에서 MasterNode 객체의 값들은 전부 공유되므로, parameter들도 공유된다.
params는 프로세스들간에 공유하지 않는 변수인데 mp.process 에 인자로 넘겨준것을 볼 수 있는데, 이는 MasterNode와 다르게 각각의 프로세스에 값만 복사돼서 들어간다. (코드 맥락상 그럴일은 없겠지만) 만약 어떤 프로세스에서 params 값을 변경해도 다른 프로세스들에는 영향을 끼치지 않는다.

profile
RL, 퀀트 투자 공부 정리

0개의 댓글