NeRF code review - def get_embedder (작성중)

HeyHo·2022년 11월 3일
0

NeRF code Review

목록 보기
4/7
def get_embedder(multires, i=0):
    if i == -1:
        return nn.Identity(), 3
    
    embed_kwargs = {
                'include_input' : True,
                'input_dims' : 3,
                'max_freq_log2' : multires-1,
                'num_freqs' : multires,
                'log_sampling' : True,
                'periodic_fns' : [torch.sin, torch.cos],
    }
    
    embedder_obj = Embedder(**embed_kwargs)
    embed = lambda x, eo=embedder_obj : eo.embed(x)
    return embed, embedder_obj.out_dim
  • multires는 encoding되는 frequency의 max frequency를 의미한다.
  • NeRF paper에서 positon 정보(rays_o)가 encoding 될 때는 multires는 L=10, direction 정보(rays_d)가 encoding 될 때는 multires는 L=4가 된다.
  • positional encoding이 기본적으로 sin, cos로 encoding 되기 때문에 'periodic_fns' : [torch.sin, torch.cos]로 표현되었다.
class Embedder:
    def __init__(self, **kwargs):
        self.kwargs = kwargs
        self.create_embedding_fn()
        
    def create_embedding_fn(self):
        embed_fns = []
        d = self.kwargs['input_dims']
        out_dim = 0
        if self.kwargs['include_input']:
            embed_fns.append(lambda x : x)
            out_dim += d
            
        max_freq = self.kwargs['max_freq_log2']
        N_freqs = self.kwargs['num_freqs']
        
        if self.kwargs['log_sampling']:
            freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
        else:
            freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
            
        for freq in freq_bands:
            for p_fn in self.kwargs['periodic_fns']:    #   torch.sin, torch.cos
                embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))   # sin(2^freq * x), cos(2^freq * x)
                out_dim += d
                    
        self.embed_fns = embed_fns
        self.out_dim = out_dim
        
    def embed(self, inputs):
        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
  • rays_o와 rays_d를 positional encoding 해주는 Embedder class이다.
  • kwargs는 dictionary 형태로 코드 초기에 parser로 argument들이 저장되어있다.
  • rays_o와 rays_d는 channel이 3개이므로, 'input_dims'는 3으로 저장되어 있다.
  • include_input:True -> positional encoding으로 embedding된 function들을 embed_fns에 appnd로 저장할 때, input function을 저장하는 용도로 사용된다.
  • max_freq는 paper에서 encoding 된 function의 마지막 freq에 해당하는 L-1이 된다.
  • N_freq는 paper에서 encoding된 function들의 갯수이다.

- positional Encoding with code

        for freq in freq_bands:
            for p_fn in self.kwargs['periodic_fns']:    #   torch.sin, torch.cos
                embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))   # sin(2^freq * x), cos(2^freq * x)
                out_dim += d
profile
Coputer vision, AI

0개의 댓글