
Gaussian splatting을 dynamic scene에 렌더링한 논문입니다. 가우시안 자체에 deformation 과정을 적용하였습니다. 또한, 4차원을 효율적으로 표현하기 위해 K-planes를 인코딩에 적용하였습니다.
모든 dymanic NeRF 알고리즘은 위의 식으로 정의할 수 있습니다. 8D space 에서 4D space 를 매핑합니다. 이때 deformation network 을 사용하여 world-to-canonical mapping을 추정합니다. 은 NeRF network를 의미합니다.

view matrix 와 timestamp 를 사용하여 3D Gaussians 와 Gaussian deformation field network 를 정의합니다. 이후 rendered image를 differential splatting으로부터 얻습니다.
Gaussian deformation field network는 다음의 과정으로 구해집니다. spatial-temporal encoder 로부터 feature 를 획득합니다. 그리고, multi-head decoder 로부터 3D Gaussian의 deformation 를 구합니다. 이를 통해 deformed Gaussian 를 구합니다.
붙어 있는 3D Gaussian은 비슷한 spatial, temporal 정보를 공유합니다. 이를 효율적으로 처리하기 위해 multi-resolution HexPlane 와 tiny MLP 를 사용하였습니다.
4D K-Planes를 사용하여 4D neural voxel을 6 multi-resolution planes로 분해하였습니다. 모든 3D Gaussian들을 bounding plane voxels에 포함시키고, 근처 temporal voxel에 Gaussian deformation이 인코딩되게 하였습니다.
encoder 는 6 multi-resolution plane modules 와 tiny MLP 로 표현됩니다. 각 voxel module은 로 정의됩니다.
class HexPlaneField(nn.Module):
###
"""
init 등의 함수는 생략하였습니다!
"""
###
def get_density(self, pts: torch.Tensor, timestamps: Optional[torch.Tensor] = None):
"""Computes and returns the densities."""
# breakpoint()
pts = normalize_aabb(pts, self.aabb)
pts = torch.cat((pts, timestamps), dim=-1) # [n_rays, n_samples, 4]
pts = pts.reshape(-1, pts.shape[-1])
features = interpolate_ms_features(
pts, ms_grids=self.grids, # noqa
grid_dimensions=self.grid_config[0]["grid_dimensions"],
concat_features=self.concat_features, num_levels=None)
if len(features) < 1:
features = torch.zeros((0, 1)).to(features.device)
return features
def forward(self,
pts: torch.Tensor,
timestamps: Optional[torch.Tensor] = None):
features = self.get_density(pts, timestamps)
return features
aabb로 normalize한 points를 interpolation 과정을 거쳐 feature로 변환합니다. HexPlane, K-plane의 과정과 동일합니다.
모든 3D Gaussian feature들이 인코딩되면, deformation decoder 를 통하여 각각의 deformation을 구하게 됩니다. 따라서 얻게 되는 deformed 3D Gaussian은 다음과 같습니다.
class Deformation(nn.Module):
def __init__(self, D=8, W=256, input_ch=27, input_ch_time=9, grid_pe=0, skips=[], args=None):
super(Deformation, self).__init__()
self.D = D
self.W = W
self.input_ch = input_ch
self.input_ch_time = input_ch_time
self.skips = skips
self.grid_pe = grid_pe
self.no_grid = args.no_grid
self.grid = HexPlaneField(args.bounds, args.kplanes_config, args.multires)
# breakpoint()
self.args = args
# self.args.empty_voxel=True
if self.args.empty_voxel:
self.empty_voxel = DenseGrid(channels=1, world_size=[64,64,64])
if self.args.static_mlp:
self.static_mlp = nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 1))
self.ratio=0
self.create_net()
@property
def get_aabb(self):
return self.grid.get_aabb
def set_aabb(self, xyz_max, xyz_min):
print("Deformation Net Set aabb",xyz_max, xyz_min)
self.grid.set_aabb(xyz_max, xyz_min)
if self.args.empty_voxel:
self.empty_voxel.set_aabb(xyz_max, xyz_min)
def create_net(self):
mlp_out_dim = 0
if self.grid_pe !=0:
grid_out_dim = self.grid.feat_dim+(self.grid.feat_dim)*2
else:
grid_out_dim = self.grid.feat_dim
if self.no_grid:
self.feature_out = [nn.Linear(4,self.W)]
else:
self.feature_out = [nn.Linear(mlp_out_dim + grid_out_dim ,self.W)]
for i in range(self.D-1):
self.feature_out.append(nn.ReLU())
self.feature_out.append(nn.Linear(self.W,self.W))
self.feature_out = nn.Sequential(*self.feature_out)
self.pos_deform = nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 3))
self.scales_deform = nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 3))
self.rotations_deform = nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 4))
self.opacity_deform = nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 1))
self.shs_deform = nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 16*3))
def query_time(self, rays_pts_emb, scales_emb, rotations_emb, time_feature, time_emb):
if self.no_grid:
h = torch.cat([rays_pts_emb[:,:3],time_emb[:,:1]],-1)
else:
grid_feature = self.grid(rays_pts_emb[:,:3], time_emb[:,:1])
# breakpoint()
if self.grid_pe > 1:
grid_feature = poc_fre(grid_feature,self.grid_pe)
hidden = torch.cat([grid_feature],-1)
hidden = self.feature_out(hidden)
return hidden
def forward(self, rays_pts_emb, scales_emb=None, rotations_emb=None, opacity = None,shs_emb=None, time_feature=None, time_emb=None):
return self.forward_dynamic(rays_pts_emb, scales_emb, rotations_emb, opacity, shs_emb, time_feature, time_emb)
def forward_dynamic(self,rays_pts_emb, scales_emb, rotations_emb, opacity_emb, shs_emb, time_feature, time_emb):
hidden = self.query_time(rays_pts_emb, scales_emb, rotations_emb, time_feature, time_emb)
if self.args.static_mlp:
mask = self.static_mlp(hidden)
elif self.args.empty_voxel:
mask = self.empty_voxel(rays_pts_emb[:,:3])
else:
mask = torch.ones_like(opacity_emb[:,0]).unsqueeze(-1)
# breakpoint()
if self.args.no_dx:
pts = rays_pts_emb[:,:3]
else:
dx = self.pos_deform(hidden)
pts = torch.zeros_like(rays_pts_emb[:,:3])
pts = rays_pts_emb[:,:3]*mask + dx
if self.args.no_ds :
scales = scales_emb[:,:3]
else:
ds = self.scales_deform(hidden)
scales = torch.zeros_like(scales_emb[:,:3])
scales = scales_emb[:,:3]*mask + ds
if self.args.no_dr :
rotations = rotations_emb[:,:4]
else:
dr = self.rotations_deform(hidden)
rotations = torch.zeros_like(rotations_emb[:,:4])
if self.args.apply_rotation:
rotations = batch_quaternion_multiply(rotations_emb, dr)
else:
rotations = rotations_emb[:,:4] + dr
if self.args.no_do :
opacity = opacity_emb[:,:1]
else:
do = self.opacity_deform(hidden)
opacity = torch.zeros_like(opacity_emb[:,:1])
opacity = opacity_emb[:,:1]*mask + do
if self.args.no_dshs:
shs = shs_emb
else:
dshs = self.shs_deform(hidden).reshape([shs_emb.shape[0],16,3])
shs = torch.zeros_like(shs_emb)
# breakpoint()
shs = shs_emb*mask.unsqueeze(-1) + dshs
return pts, scales, rotations, opacity, shs
코드 내용은 비교적 간단합니다. pos부터 sh 계수까지 deformation net은 모두 MLP 구조로 구성되어 있습니다. 각 MLP로 뽑은 defomation 값을 원래 변수에 더하여 deformed gaussian을 구합니다. 해당 가우시안을 splatting 최적화하는 것으로 랜더링이 이루어집니다.

초기 3000 iteration에서는 3D Gaussians로 최적화합니다. 3D GS의 sfm으로부터 초기화하는 과정을 이용하기 위함입니다. image rendering을 4D가 아닌, 3D Gaussian으로부터 적용합니다.
L1 color loss와 grid-based total-variational loss를 같이 적용하였습니다.
total-variational loss는 이미지에서 이웃 픽셀 간의 차이를 절대값으로 치환하여 summation하는 loss입니다. deformation 성능 향상을 위해 사용되었습니다.



HexPlane이 성능에 가장 큰 영향을 미친 것으로 파악할 수 있습니다. HexPlane 적용 시, time과 FPS에서의 성능은 감소하지만 이미지 퀄리티가 향상합니다.
[1] 4D Gaussian Splatting for Real-Time Dynamic Scene Rendering, Guanjun Wu, Wei Wei, https://arxiv.org/pdf/2310.08528
[2] 4D Gaussian Splatting Github, https://github.com/hustvl/4DGaussians