import os import sys path = os.path.abspath(__file__) for i in range(4): path = os.path.dirname(path) PROJECT_ROOT = path sys.path.append(PROJECT_ROOT) import re import numpy as np import torch from torch.utils.data import DataLoader from torchvision import transforms from utils.omni_util import OmniUtil from runners.preprocessors.rgb_feat.abstract_rgb_feat_preprocessor import RGBFeatPreprocessor from modules.rgb_encoder.dinov2_encoder import Dinov2Encoder from PIL import Image from torch.utils.data import Dataset class Dinov2InferenceDataset(Dataset): RGB_TEMPLATE = "rgb_{}.png" def __init__( self, source="nbv1", data_type="sample", data_dir="/mnt/h/AI/Datasets", image_size = 480 ): self.data_dir = data_dir self.data_path = str(os.path.join(self.data_dir, source, data_type)) self.scene_list = os.listdir(self.data_path) self.data_list = self.get_datalist() self.transform = transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop(int(image_size//14)*14), transforms.ToTensor(), transforms.Normalize(mean=0.5, std=0.2) ]) def __len__(self): return len(self.data_list) def __getitem__(self, index): frame_path = self.data_list[index] frame_data = self.load_frame_data(frame_path=frame_path) return frame_data def get_datalist(self): for scene in self.scene_list: scene_path = os.path.join(self.data_path, scene) file_list = os.listdir(scene_path) scene_frame_list = [] for file in file_list: if file.startswith("camera_params"): frame_index = re.findall(r"\d+", file)[0] frame_path = os.path.join(scene_path, frame_index) scene_frame_list.append(frame_path) return scene_frame_list def load_frame_data(self, frame_path): rgb = OmniUtil.get_rgb(frame_path) rgb = Image.fromarray(rgb) rgb = self.transform(rgb) ret_dict = {"rgb": rgb, "frame_path": frame_path} return ret_dict class Dinov2Preprocessor(RGBFeatPreprocessor): MODULE_NAME: str = "dinov2" def __init__(self, config_path): super().__init__(config_path) def get_dataloader(self, dataset_config): dataset = Dinov2InferenceDataset( source=dataset_config["source"], data_type=dataset_config["data_type"], data_dir=dataset_config["data_dir"], image_size = dataset_config["image_size"] ) print("Test dataset length: ", len(dataset)) dataloader = DataLoader( dataset, batch_size=dataset_config["batch_size"], shuffle=False, num_workers=0, ) print("Test dataloader length: ", len(dataloader)) return dataloader def get_model(self, model_config=None): model = Dinov2Encoder(model_config["general"]["model_name"]) model.to("cuda") return model def prediction(self, model, dataloader): results = {} total = len(dataloader) for idx, batch_data in enumerate(dataloader): rgb = batch_data["rgb"].to("cuda") with torch.no_grad(): rgb_feat = model.encode_rgb(rgb) frame_paths = batch_data["frame_path"] for i, frame_path in enumerate(frame_paths): results[frame_path] = rgb_feat[i].cpu().numpy() print(f"Processed {idx}/{total} batches") return results def visualize_feature(self, rgb_feat, model_name, save_path=None): model = Dinov2Encoder(model_name) model.visualize_features(rgb_feat,save_path) if __name__ == "__main__": rgb_preproc = Dinov2Preprocessor(config_path="configs/server_rgb_feat_preprocess_config.yaml") #ßrgb_preproc.run() rgb_feat = np.load("experiments/rgb_feat_preprocessor_test/data/nbv1/sample/scene_0/rgb_feat_0405.npy") rgb_preproc.visualize_feature(rgb_feat, "dinov2_vits14", './visualize.png')