2024-10-09 16:13:22 +00:00

128 lines
4.2 KiB
Python
Executable File

import os
import sys
path = os.path.abspath(__file__)
for i in range(4):
path = os.path.dirname(path)
PROJECT_ROOT = path
sys.path.append(PROJECT_ROOT)
import re
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from utils.omni_util import OmniUtil
from runners.preprocessors.rgb_feat.abstract_rgb_feat_preprocessor import RGBFeatPreprocessor
from modules.rgb_encoder.dinov2_encoder import Dinov2Encoder
from PIL import Image
from torch.utils.data import Dataset
class Dinov2InferenceDataset(Dataset):
RGB_TEMPLATE = "rgb_{}.png"
def __init__(
self,
source="nbv1",
data_type="sample",
data_dir="/mnt/h/AI/Datasets",
image_size = 480
):
self.data_dir = data_dir
self.data_path = str(os.path.join(self.data_dir, source, data_type))
self.scene_list = os.listdir(self.data_path)
self.data_list = self.get_datalist()
self.transform = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(int(image_size//14)*14),
transforms.ToTensor(),
transforms.Normalize(mean=0.5, std=0.2)
])
def __len__(self):
return len(self.data_list)
def __getitem__(self, index):
frame_path = self.data_list[index]
frame_data = self.load_frame_data(frame_path=frame_path)
return frame_data
def get_datalist(self):
for scene in self.scene_list:
scene_path = os.path.join(self.data_path, scene)
file_list = os.listdir(scene_path)
scene_frame_list = []
for file in file_list:
if file.startswith("camera_params"):
frame_index = re.findall(r"\d+", file)[0]
frame_path = os.path.join(scene_path, frame_index)
scene_frame_list.append(frame_path)
return scene_frame_list
def load_frame_data(self, frame_path):
rgb = OmniUtil.get_rgb(frame_path)
rgb = Image.fromarray(rgb)
rgb = self.transform(rgb)
ret_dict = {"rgb": rgb, "frame_path": frame_path}
return ret_dict
class Dinov2Preprocessor(RGBFeatPreprocessor):
MODULE_NAME: str = "dinov2"
def __init__(self, config_path):
super().__init__(config_path)
def get_dataloader(self, dataset_config):
dataset = Dinov2InferenceDataset(
source=dataset_config["source"],
data_type=dataset_config["data_type"],
data_dir=dataset_config["data_dir"],
image_size = dataset_config["image_size"]
)
print("Test dataset length: ", len(dataset))
dataloader = DataLoader(
dataset,
batch_size=dataset_config["batch_size"],
shuffle=False,
num_workers=0,
)
print("Test dataloader length: ", len(dataloader))
return dataloader
def get_model(self, model_config=None):
model = Dinov2Encoder(model_config["general"]["model_name"])
model.to("cuda")
return model
def prediction(self, model, dataloader):
results = {}
total = len(dataloader)
for idx, batch_data in enumerate(dataloader):
rgb = batch_data["rgb"].to("cuda")
with torch.no_grad():
rgb_feat = model.encode_rgb(rgb)
frame_paths = batch_data["frame_path"]
for i, frame_path in enumerate(frame_paths):
results[frame_path] = rgb_feat[i].cpu().numpy()
print(f"Processed {idx}/{total} batches")
return results
def visualize_feature(self, rgb_feat, model_name, save_path=None):
model = Dinov2Encoder(model_name)
model.visualize_features(rgb_feat,save_path)
if __name__ == "__main__":
rgb_preproc = Dinov2Preprocessor(config_path="configs/server_rgb_feat_preprocess_config.yaml")
#ßrgb_preproc.run()
rgb_feat = np.load("experiments/rgb_feat_preprocessor_test/data/nbv1/sample/scene_0/rgb_feat_0405.npy")
rgb_preproc.visualize_feature(rgb_feat, "dinov2_vits14", './visualize.png')