import os import json import numpy as np from abc import abstractmethod, ABC from runners.preprocessor import Preprocessor from utils.omni_util import OmniUtil class GraspingPreprocessor(Preprocessor, ABC): def __init__(self, config_path): super().__init__(config_path) self.load_experiment("GSNet") self.dataset_list_config = self.preprocess_config["dataset_list"] self.model_config = self.preprocess_config["model"] def run(self): """ - for each dataset --- get its dataloader --- for each batch, do prediction --- preprocess the collected results --- save processed results """ for dataset_config in self.dataset_list_config: dataloader = self.get_dataloader(dataset_config) model = self.get_model(self.model_config) predicted_data = self.prediction(model, dataloader) processed_data = self.preprocess(predicted_data) self.save_processed_data(processed_data,dataset_config) def preprocess(self, predicted_data, require_gripper=False): for frame_path in predicted_data: frame_obj_info = predicted_data[frame_path]["predicted_results"] if require_gripper: gripper = predicted_data[frame_path]["gripper"] predicted_data[frame_path]["gripper"] = gripper predicted_data[frame_path]["sum_score"] = {} predicted_data[frame_path]["avg_score"] = {} for obj_name in frame_obj_info: obj_score_sum = np.sum(frame_obj_info[obj_name]["scores"]) obj_score_avg = np.mean(frame_obj_info[obj_name]["scores"]) predicted_data[frame_path]["sum_score"][obj_name] = obj_score_sum predicted_data[frame_path]["avg_score"][obj_name] = obj_score_avg return predicted_data def save_processed_data(self, processed_data, data_config=None): data_path = os.path.join(str(self.experiment_path), Preprocessor.DATA, data_config["source"], data_config["data_type"]) for frame_path in processed_data: data_item = processed_data[frame_path] scene = os.path.basename(os.path.dirname(frame_path)) idx = os.path.basename(frame_path) target_scene_path = os.path.join(str(data_path), scene) if not os.path.exists(target_scene_path): os.makedirs(target_scene_path) label_save_path = os.path.join( target_scene_path,OmniUtil.SCORE_LABEL_TEMPLATE.format(idx) ) with open(label_save_path, "w+") as f: json.dump(data_item, f) print("Processed data saved to: ", data_path)