182 lines
7.0 KiB
Python
182 lines
7.0 KiB
Python
![]() |
import torch
|
|||
|
import torch.nn as nn
|
|||
|
import torch.nn.functional as F
|
|||
|
from PytorchBoot.stereotype import stereotype
|
|||
|
|
|||
|
@stereotype.module("nerf")
|
|||
|
class NeRF(nn.Module):
|
|||
|
def __init__(self, config):
|
|||
|
super().__init__()
|
|||
|
self.config = config
|
|||
|
|
|||
|
# 读取位置和方向编码维度
|
|||
|
pos_enc_out = 3 * (2 * config["pos_enc_dim"] + 1)
|
|||
|
dir_enc_out = 3 * (2 * config["dir_enc_dim"] + 1)
|
|||
|
|
|||
|
# 读取网络深度和宽度(可配置)
|
|||
|
netdepth_coarse = config.get("netdepth_coarse", 8)
|
|||
|
netwidth_coarse = config.get("netwidth_coarse", 256)
|
|||
|
netdepth_fine = config.get("netdepth_fine", 8)
|
|||
|
netwidth_fine = config.get("netwidth_fine", 256)
|
|||
|
|
|||
|
# 构建跳跃连接
|
|||
|
skips = config.get("skips", [4])
|
|||
|
|
|||
|
# 是否使用视角方向
|
|||
|
self.use_viewdirs = config.get("use_viewdirs", True)
|
|||
|
|
|||
|
# 构建coarse和fine网络
|
|||
|
if self.use_viewdirs:
|
|||
|
# 位置编码 -> 密度 + 特征
|
|||
|
self.pts_linears_coarse = self._build_pts_mlp(
|
|||
|
input_dim=pos_enc_out,
|
|||
|
width=netwidth_coarse,
|
|||
|
depth=netdepth_coarse,
|
|||
|
skips=skips
|
|||
|
)
|
|||
|
self.alpha_linear_coarse = nn.Linear(netwidth_coarse, 1)
|
|||
|
self.feature_linear_coarse = nn.Linear(netwidth_coarse, netwidth_coarse)
|
|||
|
|
|||
|
# 特征 + 方向编码 -> RGB
|
|||
|
self.views_linears_coarse = nn.ModuleList([
|
|||
|
nn.Linear(netwidth_coarse + dir_enc_out, netwidth_coarse//2)
|
|||
|
])
|
|||
|
self.rgb_linear_coarse = nn.Linear(netwidth_coarse//2, 3)
|
|||
|
|
|||
|
# 对fine网络执行相同的操作
|
|||
|
self.pts_linears_fine = self._build_pts_mlp(
|
|||
|
input_dim=pos_enc_out,
|
|||
|
width=netwidth_fine,
|
|||
|
depth=netdepth_fine,
|
|||
|
skips=skips
|
|||
|
)
|
|||
|
self.alpha_linear_fine = nn.Linear(netwidth_fine, 1)
|
|||
|
self.feature_linear_fine = nn.Linear(netwidth_fine, netwidth_fine)
|
|||
|
|
|||
|
self.views_linears_fine = nn.ModuleList([
|
|||
|
nn.Linear(netwidth_fine + dir_enc_out, netwidth_fine//2)
|
|||
|
])
|
|||
|
self.rgb_linear_fine = nn.Linear(netwidth_fine//2, 3)
|
|||
|
else:
|
|||
|
# 不使用视角方向的简化版本
|
|||
|
self.pts_linears_coarse = self._build_pts_mlp(
|
|||
|
input_dim=pos_enc_out,
|
|||
|
width=netwidth_coarse,
|
|||
|
depth=netdepth_coarse,
|
|||
|
skips=skips
|
|||
|
)
|
|||
|
self.output_linear_coarse = nn.Linear(netwidth_coarse, 4)
|
|||
|
|
|||
|
self.pts_linears_fine = self._build_pts_mlp(
|
|||
|
input_dim=pos_enc_out,
|
|||
|
width=netwidth_fine,
|
|||
|
depth=netdepth_fine,
|
|||
|
skips=skips
|
|||
|
)
|
|||
|
self.output_linear_fine = nn.Linear(netwidth_fine, 4)
|
|||
|
|
|||
|
def _build_pts_mlp(self, input_dim, width, depth, skips):
|
|||
|
"""构建处理位置编码的MLP网络,支持跳跃连接"""
|
|||
|
layers = nn.ModuleList()
|
|||
|
|
|||
|
# 第一层
|
|||
|
layers.append(nn.Linear(input_dim, width))
|
|||
|
|
|||
|
# 中间层
|
|||
|
for i in range(1, depth):
|
|||
|
if i in skips:
|
|||
|
layers.append(nn.Linear(input_dim + width, width))
|
|||
|
else:
|
|||
|
layers.append(nn.Linear(width, width))
|
|||
|
|
|||
|
return layers
|
|||
|
|
|||
|
def positional_encoding(self, x, L):
|
|||
|
"""位置编码函数"""
|
|||
|
encodings = [x]
|
|||
|
for i in range(L):
|
|||
|
encodings.append(torch.sin(2**i * x))
|
|||
|
encodings.append(torch.cos(2**i * x))
|
|||
|
return torch.cat(encodings, dim=-1)
|
|||
|
|
|||
|
def forward_mlp(self, pts_embed, viewdirs_embed, is_coarse=True):
|
|||
|
"""前向传播MLP部分"""
|
|||
|
if is_coarse:
|
|||
|
pts_linears = self.pts_linears_coarse
|
|||
|
alpha_linear = self.alpha_linear_coarse if self.use_viewdirs else None
|
|||
|
feature_linear = self.feature_linear_coarse if self.use_viewdirs else None
|
|||
|
views_linears = self.views_linears_coarse if self.use_viewdirs else None
|
|||
|
rgb_linear = self.rgb_linear_coarse if self.use_viewdirs else None
|
|||
|
output_linear = self.output_linear_coarse if not self.use_viewdirs else None
|
|||
|
else:
|
|||
|
pts_linears = self.pts_linears_fine
|
|||
|
alpha_linear = self.alpha_linear_fine if self.use_viewdirs else None
|
|||
|
feature_linear = self.feature_linear_fine if self.use_viewdirs else None
|
|||
|
views_linears = self.views_linears_fine if self.use_viewdirs else None
|
|||
|
rgb_linear = self.rgb_linear_fine if self.use_viewdirs else None
|
|||
|
output_linear = self.output_linear_fine if not self.use_viewdirs else None
|
|||
|
|
|||
|
# 位置编码处理
|
|||
|
h = pts_embed
|
|||
|
for i, l in enumerate(pts_linears):
|
|||
|
h = pts_linears[i](h)
|
|||
|
h = F.relu(h)
|
|||
|
# 处理跳跃连接
|
|||
|
if i in self.config.get("skips", [4]):
|
|||
|
h = torch.cat([pts_embed, h], -1)
|
|||
|
|
|||
|
if self.use_viewdirs:
|
|||
|
# 分支1:计算sigma
|
|||
|
sigma = alpha_linear(h)
|
|||
|
|
|||
|
# 分支2:计算颜色特征
|
|||
|
feature = feature_linear(h)
|
|||
|
|
|||
|
# 结合方向编码
|
|||
|
h = torch.cat([feature, viewdirs_embed], -1)
|
|||
|
|
|||
|
# 视角相关MLP
|
|||
|
for i, l in enumerate(views_linears):
|
|||
|
h = l(h)
|
|||
|
h = F.relu(h)
|
|||
|
|
|||
|
# 输出RGB
|
|||
|
rgb = rgb_linear(h)
|
|||
|
rgb = torch.sigmoid(rgb) # [0,1]范围
|
|||
|
|
|||
|
outputs = torch.cat([rgb, sigma], -1)
|
|||
|
else:
|
|||
|
# 直接输出RGBA
|
|||
|
outputs = output_linear(h)
|
|||
|
rgb = torch.sigmoid(outputs[..., :3]) # [0,1]范围
|
|||
|
sigma = outputs[..., 3:]
|
|||
|
|
|||
|
return rgb, sigma
|
|||
|
|
|||
|
def forward(self, pos, dir, coarse=True):
|
|||
|
"""
|
|||
|
前向传播
|
|||
|
|
|||
|
参数:
|
|||
|
pos: 3D位置 [batch_size, ..., 3]
|
|||
|
dir: 视角方向 [batch_size, ..., 3]
|
|||
|
coarse: 是否使用coarse网络
|
|||
|
|
|||
|
返回:
|
|||
|
sigma: 体积密度 [batch_size, ..., 1]
|
|||
|
color: RGB颜色 [batch_size, ..., 3]
|
|||
|
"""
|
|||
|
# 位置和方向编码
|
|||
|
pos_enc = self.positional_encoding(pos, self.config["pos_enc_dim"])
|
|||
|
|
|||
|
# 当使用视角方向时才编码方向
|
|||
|
if self.use_viewdirs:
|
|||
|
dir_normalized = F.normalize(dir, dim=-1)
|
|||
|
dir_enc = self.positional_encoding(dir_normalized, self.config["dir_enc_dim"])
|
|||
|
else:
|
|||
|
dir_enc = None
|
|||
|
|
|||
|
# 选择使用coarse还是fine网络
|
|||
|
color, sigma = self.forward_mlp(pos_enc, dir_enc, coarse)
|
|||
|
|
|||
|
return sigma, color
|