4D Gaussian Splatting

김민솔·2024년 12월 12일

Gaussian-Splatting

목록 보기
5/6

Introduction


Gaussian splatting을 dynamic scene에 렌더링한 논문입니다. 가우시안 자체에 deformation 과정을 적용하였습니다. 또한, 4차원을 효율적으로 표현하기 위해 K-planes를 인코딩에 적용하였습니다.

Preliminary

1) Gaussian Splatting

2) Dynamic NeRFs with Deformation Fields

c,σ=M(x+Δx,d,λ)c, \sigma=\mathcal{M}(\mathbf{x}+\Delta \mathbf{x},d,\lambda)

모든 dymanic NeRF 알고리즘은 위의 식으로 정의할 수 있습니다. 8D space (x,d,t,λ)(\mathbf{x},d,t,\lambda)에서 4D space (c,σ)(c, \sigma)를 매핑합니다. 이때 deformation network ϕt:(x,t)Δx\phi_{t}: (\mathbf{x},t) \rightarrow \Delta \mathbf{x}을 사용하여 world-to-canonical mapping을 추정합니다. M\mathcal{M}은 NeRF network를 의미합니다.

Method

1) 4D Gaussian Splatting Framework

view matrix M=[R,T]M=[R,T]와 timestamp tt를 사용하여 3D Gaussians G\mathcal{G}와 Gaussian deformation field network F\mathcal{F}를 정의합니다. 이후 rendered image를 differential splatting으로부터 얻습니다. I^=S(M,G)\hat{I} = \mathcal{S}(M,\mathcal{G}')
Gaussian deformation field network는 다음의 과정으로 구해집니다. spatial-temporal encoder H\mathcal{H}로부터 feature fd=H(G,t)f_{d}=\mathcal{H}(\mathcal{G},t)를 획득합니다. 그리고, multi-head decoder D\mathcal{D}로부터 3D Gaussian의 deformation ΔG=D(f)\Delta \mathcal{G}=\mathcal{D}(f)를 구합니다. 이를 통해 deformed Gaussian G=G+ΔG\mathcal{G}'=\mathcal{G}+\Delta \mathcal{G}를 구합니다.

2) Gaussian Deformation Field Network

Spatial-Temporal Structure Encoder

붙어 있는 3D Gaussian은 비슷한 spatial, temporal 정보를 공유합니다. 이를 효율적으로 처리하기 위해 multi-resolution HexPlane R(i,j)R(i, j)와 tiny MLP ϕd\phi_{d}를 사용하였습니다.
4D K-Planes를 사용하여 4D neural voxel을 6 multi-resolution planes로 분해하였습니다. 모든 3D Gaussian들을 bounding plane voxels에 포함시키고, 근처 temporal voxel에 Gaussian deformation이 인코딩되게 하였습니다.

H(G,t)={Rl(i,j),ϕd(i,j){(x,y),(x,z),(y,z),(x,t),(y,t),(z,t)},l{1,2}}\mathcal{H}(\mathcal{G}, t)=\{ R_{l}(i,j),\phi_{d}|(i,j) \in \{ (x,y), (x,z),(y,z),(x,t),(y,t),(z,t) \}, l \in \{1,2\}\}

encoder H\mathcal{H}는 6 multi-resolution plane modules Rl(i,j)R_{l}(i, j)와 tiny MLP ϕd\phi_{d}로 표현됩니다. 각 voxel module은 R(i,j)Rh×lNi×lNjR(i,j)\in \mathbb{R}^{h\times lN_{i} \times lN_{j}}로 정의됩니다.

  • NN: basic resolution of voxel grid
  • ll: upsampling scale
fh=linterp(Rl(i,j)),(i,j){(x,y),(x,z),(y,z),(x,t),(y,t),(z,t)}f_{h}=\bigcup_{l}\prod\text{interp}(R_{l}(i,j)), \quad (i,j) \in \{ (x,y), (x,z),(y,z),(x,t),(y,t),(z,t) \}
  • fhRhlf_{h}\in \mathbb{R}^{h*l}: neural voxel feature
  • 'interp': bilinear interpolation -> querying the voxel features located at 4 vertices of the grid
    위의 과정으로 얻은 features를 tiny MLP로 모두 묶습니다. fd=ϕd(fh)f_{d}=\phi_{d}(f_{h})

Code

class HexPlaneField(nn.Module):
	###
	"""
	init 등의 함수는 생략하였습니다!
	"""
	###
    def get_density(self, pts: torch.Tensor, timestamps: Optional[torch.Tensor] = None):
        """Computes and returns the densities."""
        # breakpoint()
        pts = normalize_aabb(pts, self.aabb)
        pts = torch.cat((pts, timestamps), dim=-1)  # [n_rays, n_samples, 4]

        pts = pts.reshape(-1, pts.shape[-1])
        features = interpolate_ms_features(
            pts, ms_grids=self.grids,  # noqa
            grid_dimensions=self.grid_config[0]["grid_dimensions"],
            concat_features=self.concat_features, num_levels=None)
        if len(features) < 1:
            features = torch.zeros((0, 1)).to(features.device)

        return features

    def forward(self,
                pts: torch.Tensor,
                timestamps: Optional[torch.Tensor] = None):

        features = self.get_density(pts, timestamps)

        return features

aabb로 normalize한 points를 interpolation 과정을 거쳐 feature로 변환합니다. HexPlane, K-plane의 과정과 동일합니다.


Multi-head Gaussian Deformation Decoder

모든 3D Gaussian feature들이 인코딩되면, deformation decoder D={ϕx,ϕr,ϕs}\mathcal{D}=\{\phi_{x}, \phi_{r}, \phi_{s}\}를 통하여 각각의 deformation을 구하게 됩니다. 따라서 얻게 되는 deformed 3D Gaussian은 다음과 같습니다.

G={X,s,r,σ,C}\mathcal{G}' = \{\mathcal{X}', s', r', \sigma,\mathcal{C} \}
  • ΔX=ϕx(fd)\Delta\mathcal{X}=\phi_{x}(f_{d}): defomation of position
  • Δr=ϕr(fd)\Delta r = \phi_{r}(f_{d}): defomation of rotation
  • Δs=ϕs(fd)\Delta s=\phi_{s}(f_{d}): defomation of scaling

Code

class Deformation(nn.Module):
    def __init__(self, D=8, W=256, input_ch=27, input_ch_time=9, grid_pe=0, skips=[], args=None):
        super(Deformation, self).__init__()
        self.D = D
        self.W = W
        self.input_ch = input_ch
        self.input_ch_time = input_ch_time
        self.skips = skips
        self.grid_pe = grid_pe
        self.no_grid = args.no_grid
        self.grid = HexPlaneField(args.bounds, args.kplanes_config, args.multires)
        # breakpoint()
        self.args = args
        # self.args.empty_voxel=True
        if self.args.empty_voxel:
            self.empty_voxel = DenseGrid(channels=1, world_size=[64,64,64])
        if self.args.static_mlp:
            self.static_mlp = nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 1))
        
        self.ratio=0
        self.create_net()
    @property
    def get_aabb(self):
        return self.grid.get_aabb
    def set_aabb(self, xyz_max, xyz_min):
        print("Deformation Net Set aabb",xyz_max, xyz_min)
        self.grid.set_aabb(xyz_max, xyz_min)
        if self.args.empty_voxel:
            self.empty_voxel.set_aabb(xyz_max, xyz_min)
    def create_net(self):
        mlp_out_dim = 0
        if self.grid_pe !=0:
            
            grid_out_dim = self.grid.feat_dim+(self.grid.feat_dim)*2 
        else:
            grid_out_dim = self.grid.feat_dim
        if self.no_grid:
            self.feature_out = [nn.Linear(4,self.W)]
        else:
            self.feature_out = [nn.Linear(mlp_out_dim + grid_out_dim ,self.W)]
        
        for i in range(self.D-1):
            self.feature_out.append(nn.ReLU())
            self.feature_out.append(nn.Linear(self.W,self.W))
        self.feature_out = nn.Sequential(*self.feature_out)
        self.pos_deform = nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 3))
        self.scales_deform = nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 3))
        self.rotations_deform = nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 4))
        self.opacity_deform = nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 1))
        self.shs_deform = nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 16*3))

    def query_time(self, rays_pts_emb, scales_emb, rotations_emb, time_feature, time_emb):
        if self.no_grid:
            h = torch.cat([rays_pts_emb[:,:3],time_emb[:,:1]],-1)
        else:
            grid_feature = self.grid(rays_pts_emb[:,:3], time_emb[:,:1])
            # breakpoint()
            if self.grid_pe > 1:
                grid_feature = poc_fre(grid_feature,self.grid_pe)
            hidden = torch.cat([grid_feature],-1) 
        
        hidden = self.feature_out(hidden)   

        return hidden

    def forward(self, rays_pts_emb, scales_emb=None, rotations_emb=None, opacity = None,shs_emb=None, time_feature=None, time_emb=None):
        return self.forward_dynamic(rays_pts_emb, scales_emb, rotations_emb, opacity, shs_emb, time_feature, time_emb)

    def forward_dynamic(self,rays_pts_emb, scales_emb, rotations_emb, opacity_emb, shs_emb, time_feature, time_emb):
        hidden = self.query_time(rays_pts_emb, scales_emb, rotations_emb, time_feature, time_emb)
        if self.args.static_mlp:
            mask = self.static_mlp(hidden)
        elif self.args.empty_voxel:
            mask = self.empty_voxel(rays_pts_emb[:,:3])
        else:
            mask = torch.ones_like(opacity_emb[:,0]).unsqueeze(-1)
        # breakpoint()
        if self.args.no_dx:
            pts = rays_pts_emb[:,:3]
        else:
            dx = self.pos_deform(hidden)
            pts = torch.zeros_like(rays_pts_emb[:,:3])
            pts = rays_pts_emb[:,:3]*mask + dx
        if self.args.no_ds :
            
            scales = scales_emb[:,:3]
        else:
            ds = self.scales_deform(hidden)

            scales = torch.zeros_like(scales_emb[:,:3])
            scales = scales_emb[:,:3]*mask + ds
            
        if self.args.no_dr :
            rotations = rotations_emb[:,:4]
        else:
            dr = self.rotations_deform(hidden)

            rotations = torch.zeros_like(rotations_emb[:,:4])
            if self.args.apply_rotation:
                rotations = batch_quaternion_multiply(rotations_emb, dr)
            else:
                rotations = rotations_emb[:,:4] + dr

        if self.args.no_do :
            opacity = opacity_emb[:,:1] 
        else:
            do = self.opacity_deform(hidden) 
          
            opacity = torch.zeros_like(opacity_emb[:,:1])
            opacity = opacity_emb[:,:1]*mask + do
        if self.args.no_dshs:
            shs = shs_emb
        else:
            dshs = self.shs_deform(hidden).reshape([shs_emb.shape[0],16,3])

            shs = torch.zeros_like(shs_emb)
            # breakpoint()
            shs = shs_emb*mask.unsqueeze(-1) + dshs

        return pts, scales, rotations, opacity, shs

코드 내용은 비교적 간단합니다. pos부터 sh 계수까지 deformation net은 모두 MLP 구조로 구성되어 있습니다. 각 MLP로 뽑은 defomation 값을 원래 변수에 더하여 deformed gaussian을 구합니다. 해당 가우시안을 splatting 최적화하는 것으로 랜더링이 이루어집니다.

3) Optimization

3D Gaussian Initialization

초기 3000 iteration에서는 3D Gaussians로 최적화합니다. 3D GS의 sfm으로부터 초기화하는 과정을 이용하기 위함입니다. image rendering을 4D가 아닌, 3D Gaussian으로부터 적용합니다.

Loss function

L=I^I+Ltv\mathcal{L}=|\hat{I}-I|+\mathcal{L}_{tv}

L1 color loss와 grid-based total-variational loss를 같이 적용하였습니다.
total-variational loss는 이미지에서 이웃 픽셀 간의 차이를 절대값으로 치환하여 summation하는 loss입니다. deformation 성능 향상을 위해 사용되었습니다.

Experiments

Qualitative

  • 3D GS와 비교 시, 3D GS는 floater가 매우 많이 발생하는 것을 확인할 수 있습니다.
  • Broom 데이터셋으로 확인했을 때, 4D GS의 성능이 더 잘 드러나는 듯 합니다.

Quantitative

  • 기존 dynamic scene rendering에서 SOTA를 기록하였습니다. (논문 기준)
  • FPS와 저장 공간 측면에서도 이점을 크게 얻습니다.

Ablations

HexPlane이 성능에 가장 큰 영향을 미친 것으로 파악할 수 있습니다. HexPlane 적용 시, time과 FPS에서의 성능은 감소하지만 이미지 퀄리티가 향상합니다.

Reference

[1] 4D Gaussian Splatting for Real-Time Dynamic Scene Rendering, Guanjun Wu, Wei Wei, https://arxiv.org/pdf/2310.08528
[2] 4D Gaussian Splatting Github, https://github.com/hustvl/4DGaussians

profile
Interested in Vision, Generative, Neural Rendering

0개의 댓글