파이토치 멀티 프로세싱

JTDK·2021년 6월 27일
0

딥러닝에서 멀티프로세싱은 연산량이 방대한 backpropagation에 주로 활용된다.

일단 pseudo code로 대충 시나리오를 보자

 model 생성
 model.share_memory()
 프로세스들 생성하고, 각각 다른 환경에서 훈련하되 shared_memory에 있는 model
 의 매개변수는 공유한다.

사실 너무 당연한 소리지만, 이렇게 해야 효율적으로 학습이 가능하다. model의 parameter 들을 공유하지 않으면 프로세스 갯수만큼의 model들이 각자 따로 학습될것이다.

이제 실제로 어떻게 쓰는지 보자


MasterNode = ActorCritic()
MasterNode.share_memory()	#모델은 공유
processes = []
params = {	
    'epochs':1000,
    'n_workers':7,
}
counter = mp.Value('i',0)	#프로세스간 공유하는 'i'nteger 변수 0으로 선언
for i in range(params['n_worker']):	
    p = mp.Process(target=worker, args=(i, MasterNode, counter, params))
    p.start()			#spawn하는 시점
    processes.append(p)
for p in processes:	
    p.join()			#프로세스들이 완전히 끝날때까지 기달
for p in processes:	
    p.terminate()		#프로세스 강제종료

위 코드에서 MasterNode 객체의 값들은 전부 공유되므로, parameter들도 공유된다.
params는 프로세스들간에 공유하지 않는 변수인데 mp.process 에 인자로 넘겨준것을 볼 수 있는데, 이는 MasterNode와 다르게 각각의 프로세스에 값만 복사돼서 들어간다. (코드 맥락상 그럴일은 없겠지만) 만약 어떤 프로세스에서 params 값을 변경해도 다른 프로세스들에는 영향을 끼치지 않는다.

profile
RL, 퀀트 투자 공부 정리

0개의 댓글

관련 채용 정보