* 음악 분류 딥러닝을 만들자(28) - Enas의 controller와 build_sampler

응큼한포도·2024년 9월 24일
0

build_sampler

이제 컨트롤러를 이용해 파라미터를 초기화하는 매서드를 만들었으니 이번에 레이어들을 샘플링해서 모델을 만드는 build_sampler()를 만들자.

https://github.com/melodyguan/enas/blob/master/src/cifar10/general_controller
을 참고하여 파이토치로 바꾸고 메모리 관리를 신경써서 바꿈

일단 의존성은 다음과 같다

from nas.controller import Controller
import torch.nn.init as init
import torch.nn as nn
import torch
import torch.nn.functional as F
from common_ops import stack_lstm

여기서 common_ops는 lstm을 실행하기 위한 매서드로 다음과 같다.

import torch

def lstm(x, prev_c, prev_h, w):
    ifog = torch.matmul(torch.cat([x, prev_h], dim=1), w)
    i, f, o, g = torch.split(ifog, ifog.size(1) // 4, dim=1)
    i = torch.sigmoid(i)
    f = torch.sigmoid(f)
    o = torch.sigmoid(o)
    g = torch.tanh(g)
    next_c = i * g + f * prev_c
    next_h = o * torch.tanh(next_c)
    return next_c, next_h

def stack_lstm(x, prev_c, prev_h, w):
    next_c, next_h = [], []
    for layer_id, (_c, _h, _w) in enumerate(zip(prev_c, prev_h, w)):
        inputs = x if layer_id == 0 else next_h[-1]
        curr_c, curr_h = lstm(inputs, _c, _h, _w)
        next_c.append(curr_c)
        next_h.append(curr_h)
    return next_c, next_h

단순히 lstm의 레이어를 쌓는 과정을 분리해서 관리했다

    def _create_params(self):
        with torch.no_grad():
            self.w_lstm = []
            for layer_id in range(self.lstm_num_layers):
                w = torch.empty(2 * self.lstm_size, 4 * self.lstm_size)
                self._uniform_initializer(w, minval=-0.1, maxval=0.1)
                self.w_lstm.append(w)

            self.g_emb = torch.empty(1, self.lstm_size)
            self._uniform_initializer(self.g_emb, minval=-0.1, maxval=0.1)

            if self.search_whole_channels:
                self.w_emb = torch.empty(self.num_branches, self.lstm_size)
                self._uniform_initializer(self.w_emb, minval=-0.1, maxval=0.1)
                self.w_soft = torch.empty(self.lstm_size, self.num_branches)
                self._uniform_initializer(self.w_soft, minval=-0.1, maxval=0.1)
            else:
                self.w_emb = {"start": [], "count": []}
                for branch_id in range(self.num_branches):
                    w_start = torch.empty(self.out_filters, self.lstm_size)
                    w_count = torch.empty(self.out_filters - 1, self.lstm_size)
                    self._uniform_initializer(w_start, minval=-0.1, maxval=0.1)
                    self._uniform_initializer(w_count, minval=-0.1, maxval=0.1)
                    self.w_emb["start"].append(w_start)
                    self.w_emb["count"].append(w_count)

                self.w_soft = {"start": [], "count": []}
                for branch_id in range(self.num_branches):
                    w_start = torch.empty(self.lstm_size, self.out_filters)
                    w_count = torch.empty(self.lstm_size, self.out_filters - 1)
                    self._uniform_initializer(w_start, minval=-0.1, maxval=0.1)
                    self._uniform_initializer(w_count, minval=-0.1, maxval=0.1)
                    self.w_soft["start"].append(w_start)
                    self.w_soft["count"].append(w_count)
            self.w_attn_1 = nn.Parameter(torch.Tensor(self.lstm_size, self.lstm_size))
            self.w_attn_2 = nn.Parameter(torch.Tensor(self.lstm_size, self.lstm_size))
            self.v_attn = nn.Parameter(torch.Tensor(self.lstm_size, 1))

어텐션에 관한 코드를 추가하기 위해서 create_params를 위와 같이 바꿔주자.

그 다음 lstm으로 샘플링을 하는 매서드를 다음과 같이 만들어주자

 def _build_sampler(self):
        with torch.no_grad():

            anchors = []
            anchors_w_1 = []

            arc_seq = []
            entropys = []
            log_probs = []
            skip_count = []
            skip_penalties = []

            prev_c = [torch.zeros(1, self.lstm_size) for _ in range(self.lstm_num_layers)]
            prev_h = [torch.zeros(1, self.lstm_size) for _ in range(self.lstm_num_layers)]
            inputs = self.g_emb
            skip_targets = torch.tensor([1.0 - self.skip_target, self.skip_weight], dtype=torch.float32)

            for layer_id in range(self.num_layers):
                if self.search_whole_channels:
                    next_c, next_h = stack_lstm(inputs, prev_c, prev_h, self.w_lstm)
                    prev_c, prev_h = next_c, next_h
                    logit = torch.matmul(next_h[-1], self.w_soft)

                    if self.temperature is not None:
                        logit /= self.temperature
                    if self.tanh_constant is not None:
                        logit *= self.tanh_constant * torch.tanh(logit)

                    if self.search_for == "macro" or self.search_for == "branch":
                        branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1)
                        branch_id = branch_id.view(1)
                    elif self.search_for == "connection":
                        branch_id = torch.tensor([0], dtype=torch.int32)
                    else:
                        raise ValueError("Unknown search type")

                    arc_seq.append(branch_id)
                    log_prob = F.cross_entropy(logit, branch_id)
                    log_probs.append(log_prob)
                    entropy = log_prob.detach() * torch.exp(-log_prob.detach())
                    entropys.append(entropy)
                    inputs = self.w_emb[branch_id]

                else:
                    for branch_id in range(self.num_branches):
                        next_c, next_h = stack_lstm(inputs, prev_c, prev_h, self.w_lstm)
                        prev_c, prev_h = next_c, next_h
                        logit = torch.matmul(next_h[-1], self.w_soft["start"][branch_id])

                        if self.temperature is not None:
                            logit /= self.temperature
                        if self.tanh_constant is not None:
                            logit = self.tanh_constant * torch.tanh(logit)

                        start = torch.multinomial(F.softmax(logit, dim=-1), 1)
                        start = start.view(1)
                        arc_seq.append(start)

                        log_prob = F.cross_entropy(logit, start)
                        log_probs.append(log_prob)
                        entropy = log_prob.detach() * torch.exp(-log_prob.detach())
                        entropys.append(entropy)
                        inputs = self.w_emb["start"][branch_id][start]

                        next_c, next_h = stack_lstm(inputs, prev_c, prev_h, self.w_lstm)
                        prev_c, prev_h = next_c, next_h
                        logit = torch.matmul(next_h[-1], self.w_soft["count"][branch_id])
                        if self.temperature is not None:
                            logit /= self.temperature
                        if self.tanh_constant is not None:
                            logit = self.tanh_constant * torch.tanh(logit)
                        mask = torch.arange(0, self.out_filters-1, dtype=torch.int32).view(1, -1)
                        mask = mask <= (self.out_filters - 1 - start)
                        logit = torch.where(mask, logit, torch.full_like(logit, -float('inf')))
                        count = torch.multinomial(logit, 1)
                        arc_seq.append(count + 1)
                        log_prob = F.cross_entropy(logit, count)
                        log_probs.append(log_prob)
                        entropy = log_prob.detach() * torch.exp(-log_prob.detach())
                        entropys.append(entropy)
                        inputs = self.w_emb["count"][branch_id][count]

                    next_c, next_h = stack_lstm(inputs, prev_c, prev_h, self.w_lstm)
                    prev_c, prev_h = next_c, next_h

                    if layer_id > 0:
                        query = torch.cat(anchors_w_1, dim=0)
                        query = torch.tanh(query + torch.matmul(next_h[-1], self.w_attn_2))
                        query = torch.matmul(query, self.v_attn)
                        logit = torch.cat([-query, query], dim=1)

                        if self.temperature is not None:
                            logit /= self.temperature
                        if self.tanh_constant is not None:
                            logit = self.tanh_constant * torch.tanh(logit)

                        skip = torch.multinomial(F.softmax(logit, dim=-1), 1)
                        skip = skip.view(layer_id)
                        arc_seq.append(skip)

                        skip_prob = torch.sigmoid(logit)
                        kl = skip_prob * torch.log(skip_prob) / skip_targets
                        kl = torch.sum(kl)
                        skip_penalties.append(kl)

                        log_prob = F.cross_entropy(logit, skip)
                        log_probs.append(log_prob.sum())

                        entropy = log_prob.sum() * torch.exp(-log_prob.sum())
                        entropys.append(entropy.sum())

                        skip = skip.float().view(1, layer_id)
                        skip_count.append(skip.sum())
                        inputs = torch.matmul(skip, torch.cat(anchors, dim=0))
                        inputs /= (1.0 + skip.sum())
                    else:
                        inputs = self.g_emb

                    anchors.append(next_h[-1].detach())
                    anchors_w_1.append(torch.matmul(next_h[-1].detach(), self.w_attn_1))

                arc_seq = torch.cat(arc_seq, dim=0)
                self.sample_arc = arc_seq.view(-1)

                entropys = torch.stack(entropys)
                self.sample_entropy = entropys.sum()

                log_probs = torch.stack(log_probs)
                self.sample_log_probs = log_probs.sum()

                skip_count = torch.stack(skip_count)
                self.skip_count = skip_count.sum()

                skip_penalties = torch.stack(skip_penalties)
                self.skip_penalties = skip_penalties.mean()

위와 같이 만들어주자. 코드에 관한 해설은 다음 시간에 하겠다.

profile
미친 취준생

0개의 댓글