nbv_grasping/runners/preprocessors/grasping/abstract_grasping_preprocessor.py

66 lines
2.8 KiB
Python
Raw Permalink Normal View History

2024-10-09 16:13:22 +00:00
import os
import json
import numpy as np
from abc import abstractmethod, ABC
from runners.preprocessor import Preprocessor
from utils.omni_util import OmniUtil
class GraspingPreprocessor(Preprocessor, ABC):
def __init__(self, config_path):
super().__init__(config_path)
self.load_experiment("GSNet")
self.dataset_list_config = self.preprocess_config["dataset_list"]
self.model_config = self.preprocess_config["model"]
def run(self):
"""
- for each dataset
--- get its dataloader
--- for each batch, do prediction
--- preprocess the collected results
--- save processed results
"""
for dataset_config in self.dataset_list_config:
dataloader = self.get_dataloader(dataset_config)
model = self.get_model(self.model_config)
predicted_data = self.prediction(model, dataloader)
processed_data = self.preprocess(predicted_data)
self.save_processed_data(processed_data,dataset_config)
def preprocess(self, predicted_data, require_gripper=False):
for frame_path in predicted_data:
frame_obj_info = predicted_data[frame_path]["predicted_results"]
if require_gripper:
gripper = predicted_data[frame_path]["gripper"]
predicted_data[frame_path]["gripper"] = gripper
predicted_data[frame_path]["sum_score"] = {}
predicted_data[frame_path]["avg_score"] = {}
for obj_name in frame_obj_info:
obj_score_sum = np.sum(frame_obj_info[obj_name]["scores"])
obj_score_avg = np.mean(frame_obj_info[obj_name]["scores"])
predicted_data[frame_path]["sum_score"][obj_name] = obj_score_sum
predicted_data[frame_path]["avg_score"][obj_name] = obj_score_avg
return predicted_data
def save_processed_data(self, processed_data, data_config=None):
data_path = os.path.join(str(self.experiment_path), Preprocessor.DATA, data_config["source"], data_config["data_type"])
for frame_path in processed_data:
data_item = processed_data[frame_path]
scene = os.path.basename(os.path.dirname(frame_path))
idx = os.path.basename(frame_path)
target_scene_path = os.path.join(str(data_path), scene)
if not os.path.exists(target_scene_path):
os.makedirs(target_scene_path)
label_save_path = os.path.join(
target_scene_path,OmniUtil.SCORE_LABEL_TEMPLATE.format(idx)
)
with open(label_save_path, "w+") as f:
json.dump(data_item, f)
print("Processed data saved to: ", data_path)