NeRF Code Review - def raw2outputs

HeyHo·2022년 11월 4일
0

NeRF code Review

목록 보기
7/7
post-thumbnail

run_nerf.py 파일 안에 있다.

Input

raw: [N_rand, N_samples, 3+1], NeRF network로 부터 estimation 된 RGBσ\sigma output
z_vals: [N_rand, N_samples] 코드에서는 Integration time이라고 하는데 이게 뭐지...
rays_d: [N_rand, 3] direction of each ray
white_bkgd: 흰색 배경 flag
pytest -> 이건 뭐?

변수 부가 설명

rgb_map: [N_rand, 3] Estimated RGB color of a ray
disp_map: [N_rand

전체 코드

def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False):
    """Transforms model's predictions to semantically meaningful values.
    Args:
        raw: [num_rays, num_samples along ray, 4]. Prediction from model.
        z_vals: [num_rays, num_samples along ray]. Integration time.
        rays_d: [num_rays, 3]. Direction of each ray.
    Returns:
        rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
        disp_map: [num_rays]. Disparity map. Inverse of depth map.
        acc_map: [num_rays]. Sum of weights along each ray.
        weights: [num_rays, num_samples]. Weights assigned to each sampled color.
        depth_map: [num_rays]. Estimated distance to object.
    """
    raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists)

    dists = z_vals[...,1:] - z_vals[...,:-1]
    dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1)  # [N_rays, N_samples]

    dists = dists * torch.norm(rays_d[...,None,:], dim=-1)

    rgb = torch.sigmoid(raw[...,:3])  # [N_rays, N_samples, 3]
    noise = 0.
    if raw_noise_std > 0.:
        noise = torch.randn(raw[...,3].shape) * raw_noise_std

        # Overwrite randomly sampled data if pytest
        if pytest:
            np.random.seed(0)
            noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std
            noise = torch.Tensor(noise)

    alpha = raw2alpha(raw[...,3] + noise, dists)  # [N_rays, N_samples]
    # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
    weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]
    rgb_map = torch.sum(weights[...,None] * rgb, -2)  # [N_rays, 3]

    depth_map = torch.sum(weights * z_vals, -1)
    disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1))
    acc_map = torch.sum(weights, -1)

    if white_bkgd:
        rgb_map = rgb_map + (1.-acc_map[...,None])

    return rgb_map, disp_map, acc_map, weights, depth_map

1. alpha, dists 구하기

    raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists)

    dists = z_vals[...,1:] - z_vals[...,:-1]
    dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1)  # [N_rays, N_samples]

    dists = dists * torch.norm(rays_d[...,None,:], dim=-1)

    rgb = torch.sigmoid(raw[...,:3])  # [N_rays, N_samples, 3]

  • raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists)

-> 5.2 Hierarchical volume sampling 에서 등장하는 weight term과 alpha term이다.
σi\sigma_i는 Volume density를, δi\delta_i는 하나의 ray에서 sampling point들 사이의 distance(ti+1tit_{i+1} - t_i)를 뜻한다.
raw2alpha는 말 그대로 raw data에서 paper의 alpha로 값을 mapping한다.

raw2alpha:=1exp(ReLU(x)×dists)raw2alpha := 1-exp{(-ReLU(x) \times dists)}
수식코드
αi=1exp(σiδi)\alpha_i = 1 - exp(-\sigma_i\delta_i)raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists)
  • dists = z_vals[...,1:] - z_vals[...,:-1]

출처: https://towardsdatascience.com/its-nerf-from-nothing-build-a-vanilla-nerf-with-pytorch-7846e4c45666

ray에서 startified sampling을 통해 뽑은 point들 사이의 거리.
뭐 대충 이런 느낌인 것 같다...

  • dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1)  # [N_rays, N_samples]
    
     dists = dists * torch.norm(rays_d[...,None,:], dim=-1)

앞에서 구한 stratified sampling point들 사이의 거리인 dists에서 1e10을 concatenate 해주고, ray의 방향 정보인 rays_d의 norm을 구해서 이를 dists에 곱해준다.

  • rgb = torch.sigmoid(raw[...,:3])  # [N_rays, N_samples, 3]
    NeRF model의 RGB + volume density estimation 값을 sigmoid를 통해서 0~1 사이로 mapping 해준다. 왜 굳이 또 mapping 해주는거지?
    • 그래서 NeRF model의 min, max값을 찍어보았다,

    • torch.sigmoid로 mapping 이후 rgb min, max 값

      0.5 근처로 mapping 된 것을 확인할 수 있다.

  • alpha = raw2alpha(raw[...,3] + noise, dists)
    • raw2alpha function을 통해 alpha 계산.

2. TiT_i, RGB, Depth, Disp, acc map 구하기.

2.1 Weight식 깔끔하게 정리하기.

우선, Paper에서 TiT_i는 다음과 같이 정의되어 있다.

이때, alpha는

TiT_iαi\alpha_i 두 수식이 매우 비슷하게 생기지 않았는가? 이 식을 정리하면 다음과 같다.

1αi=exp(σiδi)1-\alpha_i = exp(-\sigma_i\delta_i)

이를 log sum 형태로 나타내면,

Ti=j=1i1(1αj)=exp(j=1i1σjδj)T_i = \prod_{j = 1}^{i-1}( 1-\alpha_j) = exp(-\sum_{j = 1}^{i-1} \sigma_j\delta_j)

다음과 같이 나타낼 수 있다.
코드에서는 이렇게 alpha term을 변형하여, torch.cumprod를 통해서 Pi 연산을 한다. 최종적으로 TiT_i를 계산한다.

    alpha = raw2alpha(raw[...,3] + noise, dists)  # [N_rays, N_samples]
    # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
    weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]
    rgb_map = torch.sum(weights[...,None] * rgb, -2)  # [N_rays, 3]

    depth_map = torch.sum(weights * z_vals, -1)
    disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1))
    acc_map = torch.sum(weights, -1)

    if white_bkgd:
        rgb_map = rgb_map + (1.-acc_map[...,None])

    return rgb_map, disp_map, acc_map, weights, depth_map
  •  weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]
    • 위 코드를 수식으로 나타내면 다음과 같다.
    • wiw_i =Ti(1exp(σiδi)=Tiαi=j=1i1(1αi)αi= T_i(1-exp(-\sigma_i\delta_i) = T_i\alpha_i = \prod_{j = 1}^{i-1}( 1-\alpha_i)\alpha_i
    • torch.cumprod 계산을 통해서 weights를 계산한다.
    • 1e-10 maybe for prevent Nan?

  • rgb_map = torch.sum(weights[...,None] * rgb, -2)  # [N_rays, 3]
    • 위 코드를 수식으로 나타내면 다음과 같다.
    • Cc^(r)=i=1Ncwici\hat{C_c}(r) = \sum_{i = 1}^{N_c} w_ic_i
  • depth_map = torch.sum(weights * z_vals, -1)
    • 그냥 weith에 ray를 startify sampling한 point들에 곱한다.
  • disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1))
    • disparity map을 구한다. inverse depth라고 생각하면 된다.
  • acc_map = torch.sum(weights, -1)
    • 이건 뭐지? acc_map의 역할이 뭔지 좀 더 알아봐야겠다.
profile
Coputer vision, AI

0개의 댓글