170 lines
5.9 KiB
Python
170 lines
5.9 KiB
Python
import torch
|
|
import numpy as np
|
|
from utils.volume_render_util import VolumeRendererUtil
|
|
import torch.nn.functional as F
|
|
from typing import Tuple, List, Dict, Any, Optional
|
|
|
|
class UncertaintyGuideNeRF:
|
|
"""
|
|
基于NeRF不确定性的主动视图选择策略
|
|
通过计算视图的熵值来引导下一步的最优视图选择
|
|
"""
|
|
|
|
def __init__(self, config: Dict[str, Any]):
|
|
"""
|
|
初始化不确定性引导策略
|
|
|
|
参数:
|
|
config: 配置字典,包含相关参数
|
|
"""
|
|
self.config = config
|
|
self.device = torch.device(config.get("device", "cuda") if torch.cuda.is_available() else "cpu")
|
|
|
|
# 相机参数
|
|
self.width = config.get("width", 800)
|
|
self.height = config.get("height", 800)
|
|
self.focal = config.get("focal", 1000.0)
|
|
|
|
# 采样参数
|
|
self.near = config.get("near", 2.0)
|
|
self.far = config.get("far", 6.0)
|
|
self.coarse_samples = config.get("coarse_samples", 64)
|
|
self.fine_samples = config.get("fine_samples", 128)
|
|
|
|
def generate_rays(self, pose: np.ndarray) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
从相机姿态生成光线
|
|
|
|
参数:
|
|
pose: 相机姿态矩阵 [4, 4]
|
|
|
|
返回:
|
|
rays_o: 光线起点 [H*W, 3]
|
|
rays_d: 光线方向 [H*W, 3]
|
|
"""
|
|
# 创建像素坐标
|
|
i, j = torch.meshgrid(
|
|
torch.linspace(0, self.width - 1, self.width),
|
|
torch.linspace(0, self.height - 1, self.height),
|
|
indexing='ij'
|
|
)
|
|
i = i.t().to(self.device)
|
|
j = j.t().to(self.device)
|
|
|
|
# 转换为相机坐标系中的方向
|
|
dirs = torch.stack([
|
|
(i - self.width * 0.5) / self.focal,
|
|
-(j - self.height * 0.5) / self.focal,
|
|
-torch.ones_like(i)
|
|
], dim=-1)
|
|
|
|
# 转换为世界坐标系
|
|
pose = torch.from_numpy(pose).float().to(self.device)
|
|
rays_d = torch.sum(dirs[..., None, :] * pose[:3, :3], dim=-1)
|
|
rays_o = pose[:3, -1].expand(rays_d.shape)
|
|
|
|
# 展平为批处理格式
|
|
rays_o = rays_o.reshape(-1, 3)
|
|
rays_d = rays_d.reshape(-1, 3)
|
|
|
|
return rays_o, rays_d
|
|
|
|
def evaluate_view_uncertainty(self,
|
|
nerf_model: torch.nn.Module,
|
|
pose: np.ndarray) -> float:
|
|
"""
|
|
评估给定视图的不确定性(熵)
|
|
|
|
参数:
|
|
nerf_model: NeRF模型
|
|
pose: 相机姿态矩阵 [4, 4]
|
|
|
|
返回:
|
|
mean_entropy: 该视图的平均熵值
|
|
"""
|
|
nerf_model.eval()
|
|
with torch.no_grad():
|
|
# 生成光线
|
|
rays_o, rays_d = self.generate_rays(pose)
|
|
|
|
# 对于较大的图像,可能需要分批处理
|
|
batch_size = 4096 # 根据GPU内存调整
|
|
entropy_values = []
|
|
|
|
# 分批处理所有光线
|
|
for i in range(0, rays_o.shape[0], batch_size):
|
|
batch_rays_o = rays_o[i:i+batch_size]
|
|
batch_rays_d = rays_d[i:i+batch_size]
|
|
|
|
# 归一化方向向量
|
|
batch_rays_d = F.normalize(batch_rays_d, dim=-1)
|
|
|
|
# 计算近平面和远平面
|
|
near = torch.ones_like(batch_rays_o[..., 0]) * self.near
|
|
far = torch.ones_like(batch_rays_o[..., 0]) * self.far
|
|
|
|
# 渲染光线并计算熵
|
|
_, weights, _, entropy = VolumeRendererUtil.render_rays(
|
|
nerf_model,
|
|
batch_rays_o,
|
|
batch_rays_d,
|
|
near,
|
|
far,
|
|
self.coarse_samples,
|
|
self.fine_samples
|
|
)
|
|
|
|
entropy_values.append(entropy)
|
|
|
|
# 组合所有批次的熵值
|
|
all_entropy = torch.cat(entropy_values, dim=0)
|
|
|
|
# 重塑为图像格式并计算平均值
|
|
mean_entropy = all_entropy.mean().item()
|
|
|
|
return mean_entropy
|
|
|
|
def evaluate_candidate_views(self,
|
|
nerf_model: torch.nn.Module,
|
|
candidate_poses: np.ndarray) -> np.ndarray:
|
|
"""
|
|
评估候选视图的不确定性(熵)
|
|
|
|
参数:
|
|
nerf_model: NeRF模型
|
|
candidate_poses: 候选相机姿态矩阵列表 [N, 4, 4]
|
|
|
|
返回:
|
|
entropy_values: 各候选视图的熵值 [N]
|
|
"""
|
|
entropy_values = np.zeros(len(candidate_poses))
|
|
|
|
for i, pose in enumerate(candidate_poses):
|
|
entropy_values[i] = self.evaluate_view_uncertainty(nerf_model, pose)
|
|
|
|
return entropy_values
|
|
|
|
def downsample_image(self, rays_o, rays_d, factor=4):
|
|
"""
|
|
降采样光线以加速处理
|
|
|
|
参数:
|
|
rays_o: 光线起点 [H*W, 3]
|
|
rays_d: 光线方向 [H*W, 3]
|
|
factor: 降采样因子
|
|
|
|
返回:
|
|
downsampled_rays_o: 降采样后的光线起点
|
|
downsampled_rays_d: 降采样后的光线方向
|
|
"""
|
|
# 重塑为图像格式
|
|
H = W = int(np.sqrt(rays_o.shape[0]))
|
|
rays_o = rays_o.reshape(H, W, 3)
|
|
rays_d = rays_d.reshape(H, W, 3)
|
|
|
|
# 降采样
|
|
new_H, new_W = H // factor, W // factor
|
|
downsampled_rays_o = rays_o[::factor, ::factor].reshape(-1, 3)
|
|
downsampled_rays_d = rays_d[::factor, ::factor].reshape(-1, 3)
|
|
|
|
return downsampled_rays_o, downsampled_rays_d |