182 lines
7.0 KiB
Python
182 lines
7.0 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from PytorchBoot.stereotype import stereotype
|
||
|
||
@stereotype.module("nerf")
|
||
class NeRF(nn.Module):
|
||
def __init__(self, config):
|
||
super().__init__()
|
||
self.config = config
|
||
|
||
# 读取位置和方向编码维度
|
||
pos_enc_out = 3 * (2 * config["pos_enc_dim"] + 1)
|
||
dir_enc_out = 3 * (2 * config["dir_enc_dim"] + 1)
|
||
|
||
# 读取网络深度和宽度(可配置)
|
||
netdepth_coarse = config.get("netdepth_coarse", 8)
|
||
netwidth_coarse = config.get("netwidth_coarse", 256)
|
||
netdepth_fine = config.get("netdepth_fine", 8)
|
||
netwidth_fine = config.get("netwidth_fine", 256)
|
||
|
||
# 构建跳跃连接
|
||
skips = config.get("skips", [4])
|
||
|
||
# 是否使用视角方向
|
||
self.use_viewdirs = config.get("use_viewdirs", True)
|
||
|
||
# 构建coarse和fine网络
|
||
if self.use_viewdirs:
|
||
# 位置编码 -> 密度 + 特征
|
||
self.pts_linears_coarse = self._build_pts_mlp(
|
||
input_dim=pos_enc_out,
|
||
width=netwidth_coarse,
|
||
depth=netdepth_coarse,
|
||
skips=skips
|
||
)
|
||
self.alpha_linear_coarse = nn.Linear(netwidth_coarse, 1)
|
||
self.feature_linear_coarse = nn.Linear(netwidth_coarse, netwidth_coarse)
|
||
|
||
# 特征 + 方向编码 -> RGB
|
||
self.views_linears_coarse = nn.ModuleList([
|
||
nn.Linear(netwidth_coarse + dir_enc_out, netwidth_coarse//2)
|
||
])
|
||
self.rgb_linear_coarse = nn.Linear(netwidth_coarse//2, 3)
|
||
|
||
# 对fine网络执行相同的操作
|
||
self.pts_linears_fine = self._build_pts_mlp(
|
||
input_dim=pos_enc_out,
|
||
width=netwidth_fine,
|
||
depth=netdepth_fine,
|
||
skips=skips
|
||
)
|
||
self.alpha_linear_fine = nn.Linear(netwidth_fine, 1)
|
||
self.feature_linear_fine = nn.Linear(netwidth_fine, netwidth_fine)
|
||
|
||
self.views_linears_fine = nn.ModuleList([
|
||
nn.Linear(netwidth_fine + dir_enc_out, netwidth_fine//2)
|
||
])
|
||
self.rgb_linear_fine = nn.Linear(netwidth_fine//2, 3)
|
||
else:
|
||
# 不使用视角方向的简化版本
|
||
self.pts_linears_coarse = self._build_pts_mlp(
|
||
input_dim=pos_enc_out,
|
||
width=netwidth_coarse,
|
||
depth=netdepth_coarse,
|
||
skips=skips
|
||
)
|
||
self.output_linear_coarse = nn.Linear(netwidth_coarse, 4)
|
||
|
||
self.pts_linears_fine = self._build_pts_mlp(
|
||
input_dim=pos_enc_out,
|
||
width=netwidth_fine,
|
||
depth=netdepth_fine,
|
||
skips=skips
|
||
)
|
||
self.output_linear_fine = nn.Linear(netwidth_fine, 4)
|
||
|
||
def _build_pts_mlp(self, input_dim, width, depth, skips):
|
||
"""构建处理位置编码的MLP网络,支持跳跃连接"""
|
||
layers = nn.ModuleList()
|
||
|
||
# 第一层
|
||
layers.append(nn.Linear(input_dim, width))
|
||
|
||
# 中间层
|
||
for i in range(1, depth):
|
||
if i in skips:
|
||
layers.append(nn.Linear(input_dim + width, width))
|
||
else:
|
||
layers.append(nn.Linear(width, width))
|
||
|
||
return layers
|
||
|
||
def positional_encoding(self, x, L):
|
||
"""位置编码函数"""
|
||
encodings = [x]
|
||
for i in range(L):
|
||
encodings.append(torch.sin(2**i * x))
|
||
encodings.append(torch.cos(2**i * x))
|
||
return torch.cat(encodings, dim=-1)
|
||
|
||
def forward_mlp(self, pts_embed, viewdirs_embed, is_coarse=True):
|
||
"""前向传播MLP部分"""
|
||
if is_coarse:
|
||
pts_linears = self.pts_linears_coarse
|
||
alpha_linear = self.alpha_linear_coarse if self.use_viewdirs else None
|
||
feature_linear = self.feature_linear_coarse if self.use_viewdirs else None
|
||
views_linears = self.views_linears_coarse if self.use_viewdirs else None
|
||
rgb_linear = self.rgb_linear_coarse if self.use_viewdirs else None
|
||
output_linear = self.output_linear_coarse if not self.use_viewdirs else None
|
||
else:
|
||
pts_linears = self.pts_linears_fine
|
||
alpha_linear = self.alpha_linear_fine if self.use_viewdirs else None
|
||
feature_linear = self.feature_linear_fine if self.use_viewdirs else None
|
||
views_linears = self.views_linears_fine if self.use_viewdirs else None
|
||
rgb_linear = self.rgb_linear_fine if self.use_viewdirs else None
|
||
output_linear = self.output_linear_fine if not self.use_viewdirs else None
|
||
|
||
# 位置编码处理
|
||
h = pts_embed
|
||
for i, l in enumerate(pts_linears):
|
||
h = pts_linears[i](h)
|
||
h = F.relu(h)
|
||
# 处理跳跃连接
|
||
if i in self.config.get("skips", [4]):
|
||
h = torch.cat([pts_embed, h], -1)
|
||
|
||
if self.use_viewdirs:
|
||
# 分支1:计算sigma
|
||
sigma = alpha_linear(h)
|
||
|
||
# 分支2:计算颜色特征
|
||
feature = feature_linear(h)
|
||
|
||
# 结合方向编码
|
||
h = torch.cat([feature, viewdirs_embed], -1)
|
||
|
||
# 视角相关MLP
|
||
for i, l in enumerate(views_linears):
|
||
h = l(h)
|
||
h = F.relu(h)
|
||
|
||
# 输出RGB
|
||
rgb = rgb_linear(h)
|
||
rgb = torch.sigmoid(rgb) # [0,1]范围
|
||
|
||
outputs = torch.cat([rgb, sigma], -1)
|
||
else:
|
||
# 直接输出RGBA
|
||
outputs = output_linear(h)
|
||
rgb = torch.sigmoid(outputs[..., :3]) # [0,1]范围
|
||
sigma = outputs[..., 3:]
|
||
|
||
return rgb, sigma
|
||
|
||
def forward(self, pos, dir, coarse=True):
|
||
"""
|
||
前向传播
|
||
|
||
参数:
|
||
pos: 3D位置 [batch_size, ..., 3]
|
||
dir: 视角方向 [batch_size, ..., 3]
|
||
coarse: 是否使用coarse网络
|
||
|
||
返回:
|
||
sigma: 体积密度 [batch_size, ..., 1]
|
||
color: RGB颜色 [batch_size, ..., 3]
|
||
"""
|
||
# 位置和方向编码
|
||
pos_enc = self.positional_encoding(pos, self.config["pos_enc_dim"])
|
||
|
||
# 当使用视角方向时才编码方向
|
||
if self.use_viewdirs:
|
||
dir_normalized = F.normalize(dir, dim=-1)
|
||
dir_enc = self.positional_encoding(dir_normalized, self.config["dir_enc_dim"])
|
||
else:
|
||
dir_enc = None
|
||
|
||
# 选择使用coarse还是fine网络
|
||
color, sigma = self.forward_mlp(pos_enc, dir_enc, coarse)
|
||
|
||
return sigma, color |