이 macro 탐색방법을 구현하는 건 순전히 이해를 위해서 하는거고 실제로 사용은 안 할것이다. *표시로 된건 스킵해도 좋다.
강화학습을 기반으로 한 nas 알고리즘에선 컨트롤러가 대부분의 역할을 한다. 이번 시간에선 macro contoller를 이용해서 모델의 파라미터를 초기화하는 매서드를 만들어보자.
class MacroController(Controller):
def __init__(self,
search_whole_channels=False,
num_layers=4,
num_branches=6,
out_filters=48,
lstm_size=32,
lstm_num_layers=2,
lstm_keep_prob=1.0,
tanh_constant=None,
temperature=None,
lr_init=1.3,
lr_dec_start=0,
lr_dec_every=100,
lr_dec_rate=0.9,
l2_reg=0,
entropy_weight=None,
grad_bound=None,
use_critic=False,
bl_dec=0.999,
optim_algo='adam',
sync_replicas=False,
num_aggregate=None,
num_replicas=None,
skip_target=0.8,
skip_weight=0.5,
*args,
**kwargs):
self.search_whole_channels = search_whole_channels
self.num_layers = num_layers
self.num_branches = num_branches
self.out_filters = out_filters
self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers
self.lstm_keep_prob = lstm_keep_prob
self.tanh_constant = tanh_constant
self.temperature = temperature
self.lr_init = lr_init
self.lr_dec_start = lr_dec_start
self.lr_dec_every = lr_dec_every
self.lr_dec_rate = lr_dec_rate
self.l2_reg = l2_reg
self.entropy_weight = entropy_weight
self.grad_bound = grad_bound
self.use_critic = use_critic
self.bl_dec = bl_dec
self.use_critic = use_critic
self.optim_algo = optim_algo
self.sync_replicas = sync_replicas
self.num_aggregate = num_aggregate
self.num_replicas = num_replicas
self.skip_target = skip_target
self.skip_weight = skip_weight
self._create_params()
self._build_sample()
우선 컨트롤러에 필요한 걸 위와 같이 초기화 해주자. 많지만 포함되면 좋을걸 일단 위처럼 만들어놓고 나중에 지우던가 하자. 설명은 아래와 같다
search_whole_channels: 전체 채널 단위로 검색할지 여부
num_layers: 아키텍처의 레이어 수
num_branches: 각 레이어에서 사용할 수 있는 브랜치 수
out_filters: 출력 필터 수
lstm_size: LSTM의 크기
lstm_num_layers: LSTM 레이어 수
lstm_keep_prob: LSTM 드롭아웃 확률
tanh_constant: LSTM 출력을 스케일링하는 상수
temperature: 소프트맥스의 온도 파라미터
lr_init: 초기 학습률
lr_dec_start: 학습률 감소 시작 시점
lr_dec_every: 학습률 감소 주기
lr_dec_rate: 학습률 감소 비율
l2_reg: L2 정규화 강도
entropy_weight: 탐색의 엔트로피 보상 가중치
clip_mode: 그라디언트 클리핑 모드
grad_bound: 그라디언트 클리핑 상한
use_critic: 크리틱(critic) 사용 여부
bl_dec: 베이스라인 이동 평균 감쇠율
optim_algo: 최적화 알고리즘 (기본: Adam)
sync_replicas: 동기화된 복제본 사용 여부
num_aggregate: 집계할 샘플 수
num_replicas: 복제본 수
skip_target: 스킵 연결 허용 비율
skip_weight: 스킵 연결에 대한 패널티
학습할 때 필요한 것들을 위와 같이 초기화해준다. 그 다음 파라미터를 초기화해주는 매서드를 하나 만든다.
def _uniform_initializer(self, tensor, minval=-0.1, maxval=0.1):
return init.uniform_(tensor, minval, maxval)
init.uniform은 균등 분포 초기화로 minval부터 maxval까지 일정한 범위내에 랜덤하게 초기화해주는 방법이다. 위 -0.1, 0.1은 논문 저자의 초기화 값을 참고한 값으로 보수적이고 안전한 초기화 방법이다.
그 다음 파라미터를 초기화하는 매서드를 만들자.
def _create_params(self):
with torch.no_grad():
self.w_lstm = []
for layer_id in range(self.lstm_num_layers):
w = torch.empty(2 * self.lstm_size, 4 * self.lstm_size)
self._uniform_initializer(w, minval=-0.1, maxval=0.1)
self.w_lstm.append(w)
self.g_emb = torch.empty(1, self.lstm_size)
self._uniform_initializer(self.g_emb, minval=-0.1, maxval=0.1)
if self.search_whole_channels:
self.w_emb = torch.empty(self.num_branches, self.lstm_size)
self._uniform_initializer(self.w_emb, minval=-0.1, maxval=0.1)
self.w_soft = torch.empty(self.lstm_size, self.num_branches)
self._uniform_initializer(self.w_soft, minval=-0.1, maxval=0.1)
else:
self.w_emb = {"start": [], "count": []}
for branch_id in range(self.num_branches):
w_start = torch.empty(self.out_filters, self.lstm_size)
w_count = torch.empty(self.out_filters - 1, self.lstm_size)
self._uniform_initializer(w_start, minval=-0.1, maxval=0.1)
self._uniform_initializer(w_count, minval=-0.1, maxval=0.1)
self.w_emb["start"].append(w_start)
self.w_emb["count"].append(w_count)
self.w_soft = {"start": [], "count": []}
for branch_id in range(self.num_branches):
w_start = torch.empty(self.lstm_size, self.out_filters)
w_count = torch.empty(self.lstm_size, self.out_filters - 1)
self._uniform_initializer(w_start, minval=-0.1, maxval=0.1)
self._uniform_initializer(w_count, minval=-0.1, maxval=0.1)
self.w_soft["start"].append(w_start)
self.w_soft["count"].append(w_count)
우리의 목표는 극적인 자원 절약에 있다. 이 프로젝트는 크기가 큰 음악 분야 딥러닝의 데이터와 모델을 최대한 압축하고 가성비있게 만들자는 것이고 파라미터 단에서도 메모리 절약을 가져가야 한다.
논문에서 아케텍처의 경우의 수는 10^15로 내 8G짜리 메모리에 다 올리면 당연히 터질거다.
메모리에 대해서 최대한 절약하는 아이디어는 크게 2가지이다.
모델의 경우의 수를 모두 메모리에 올리는 게 아닌 파일로 저장해놨다. 필요할 때 하나씩 꺼내 사용한다 -> 파이썬 제너레이터의 사용
안쓰는 텐서를 최대한 해제하여 메모리 누수를 막는다 -> with문 사용
파이토치에선 텐서의 소유권을 가진 객체가 해제되거나 detach() 함수를 이용하지 않으면 해당 텐서는 메모리에서 해제되지 않는다.
크고 좋은 환경에선 상관없겠지만 8g는 얄짤없이 터진다. 그럴때 사용하면 좋은게 파이썬의 with문이다.
with문은 작은 컨텍스트를 생성하여 스코프 안에 자원을 관리한다.
with문은 크게 __enter()__ , __exit()__
두 가지 매직 매서드로 이루어져 있다.
enter는 리소스를 초기화하고 exit는 리소스를 메모리에서 해제한다
주목할점은 enter에서 항상 획득한 resource descriptor를 반환해야 된다는점이다.
리소스 설명자(Resource Descriptor)는 컴퓨터 시스템에서 사용되는 리소스를 기술하는 데이터 구조로 다음과 같은 정보를 포함한다
리소스의 종류: 파일, 메모리, CPU 시간, 네트워크 연결 등
리소스의 위치: 파일 시스템에서의 경로, 메모리 주소 등
리소스의 크기: 파일의 크기, 메모리의 크기 등
리소스의 사용 가능 여부: 리소스가 사용 중인지, 사용 가능한지 등
즉 우리가 파이토치에서 with을 사용하면 텐서의 정보만 반환해서 얻을 수 있다는 점이다.
코드에서
with torch.no_grad():
을 걸어주면 필요한 파라미터의 정보를 제외한 모든 정보와 텐서들을 exit()에서 자동으로 해제해줘서 메모리 누수를 막아준다.
torch.no_grad()는 그레디언트를 추론하지 않는다는 뜻으로 단순히 파라미터를 초기화 하기때문에 메모리를 절약해준다.
self.search_whole_channels는 전체 search space를 탐색할지 아니면 일부분만 탐색할 지를 결정하는 매개변수고
나머지 코드들은 평이해서 이해할거라고 믿는다