NeRF Code Review - def train() 내부 ray sampling (작성중)

HeyHo·2022년 11월 2일
0

NeRF code Review

목록 보기
2/7

전체 코드

            if N_rand is not None:
                rays_o, rays_d = get_rays(H, W, K, torch.Tensor(pose))  # (H, W, 3), (H, W, 3)

                if i < args.precrop_iters:
                    dH = int(H//2 * args.precrop_frac)
                    dW = int(W//2 * args.precrop_frac)
                    coords = torch.stack(
                        torch.meshgrid(
                            torch.linspace(H//2 - dH, H//2 + dH - 1, 2*dH), 
                            torch.linspace(W//2 - dW, W//2 + dW - 1, 2*dW)
                        ), -1)
                    if i == start:
                        print(f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}")                
                else:
                    coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1)  # (H, W, 2)

                coords = torch.reshape(coords, [-1,2])  # (H * W, 2)
                select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False)  # (N_rand,)
                select_coords = coords[select_inds].long()  # (N_rand, 2)
                rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]]  # (N_rand, 3)
                rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]]  # (N_rand, 3)
                batch_rays = torch.stack([rays_o, rays_d], 0)
                target_s = target[select_coords[:, 0], select_coords[:, 1]]  # (N_rand, 3)
                # print(target_s)

1. Cropping 부분.(lego)

if N_rand is not None:
   rays_o, rays_d = get_rays(H, W, K, torch.Tensor(pose))  # (H, W, 3), (H, W, 3)

   if i < args.precrop_iters:
        dH = int(H//2 * args.precrop_frac)
        dW = int(W//2 * args.precrop_frac)
        coords = torch.stack(
                 torch.meshgrid(
                 torch.linspace(H//2 - dH, H//2 + dH - 1, 2*dH), 
                 torch.linspace(W//2 - dW, W//2 + dW - 1, 2*dW)
                 ), -1)
        if i == start:
            print(f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}")                
   else:
       coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1)  # (H, W, 2)

ray_o, rays_d를 get_rays를 통해서 return 받는다.
그 다음, i < args.precrop_iters가 true일 경우
[H,W,3]에 해당하는 이미지에서 정 중앙에 위치하고, 면적이 기존 면적에서 1/4에 해당하는 이미지를 cropping하여 indexing 해준다.

ex) lego.blend의 경우, 학습 초기에 center cropping을 진행한다. 기존 이미지는 [400,400,3]
i < args.precrop_iters 일 경우( precrop_iter 전까지 초기 학습에서는 image에서 center를 중점적으로 학습한다.)
다음과 같이 [400,400,3] 이미지에서 가운데 사각형 영역으로 coords가 indexing 된 것을 확인할 수 있다!

  • 하단의 그림 처럼 cropping된 [200 ×\times 200]영역에서 random으로 N_rand개의 pixel을 선택해서 ray로 만들어 nerf Network에 input으로 넣어준다.

2. Ray random sampling.

coords = torch.reshape(coords, [-1,2])  # (H * W, 2)
select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False)  # (N_rand,)
select_coords = coords[select_inds].long()  # (N_rand, 2)
rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]]  # (N_rand, 3)
rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]]  # (N_rand, 3)
batch_rays = torch.stack([rays_o, rays_d], 0)
target_s = target[select_coords[:, 0], select_coords[:, 1]]  # (N_rand, 3)
                

parser에서 학습 초기에 N_rand 값을 외부에서 사용자 지정 값으로 입력 받았다. N_rand는 random ray의 갯수인데, 이는 위 코드에서 활용된다.
torch.reshape(coords, [-1,2])를 통해서 [H,W,3]이였던 shape을 [H * W, 2] shape으로 변경해주고, np.random.choid를 통해서 (0 ~ H * W) 숫자 중에서 random으로 N_rand 갯수 만큼 숫자를 뽑는다. 해당 숫자들은 ray의 index에 해당되고, 선정된 index로 부터 rays_o, rays_d, target을 지정해준다.

rays_o: ray 시작점 위치
rays_d: ray 방향
target_s: Index에 해당하는 이미지의 pixel RGB value

profile
Coputer vision, AI

0개의 댓글