66 lines
2.8 KiB
Python
Executable File
66 lines
2.8 KiB
Python
Executable File
import os
|
|
import json
|
|
import numpy as np
|
|
from abc import abstractmethod, ABC
|
|
|
|
from runners.preprocessor import Preprocessor
|
|
from utils.omni_util import OmniUtil
|
|
|
|
class GraspingPreprocessor(Preprocessor, ABC):
|
|
|
|
def __init__(self, config_path):
|
|
super().__init__(config_path)
|
|
self.load_experiment("GSNet")
|
|
self.dataset_list_config = self.preprocess_config["dataset_list"]
|
|
self.model_config = self.preprocess_config["model"]
|
|
|
|
def run(self):
|
|
"""
|
|
- for each dataset
|
|
--- get its dataloader
|
|
--- for each batch, do prediction
|
|
--- preprocess the collected results
|
|
--- save processed results
|
|
"""
|
|
for dataset_config in self.dataset_list_config:
|
|
dataloader = self.get_dataloader(dataset_config)
|
|
model = self.get_model(self.model_config)
|
|
predicted_data = self.prediction(model, dataloader)
|
|
processed_data = self.preprocess(predicted_data)
|
|
self.save_processed_data(processed_data,dataset_config)
|
|
|
|
def preprocess(self, predicted_data, require_gripper=False):
|
|
for frame_path in predicted_data:
|
|
frame_obj_info = predicted_data[frame_path]["predicted_results"]
|
|
if require_gripper:
|
|
gripper = predicted_data[frame_path]["gripper"]
|
|
predicted_data[frame_path]["gripper"] = gripper
|
|
predicted_data[frame_path]["sum_score"] = {}
|
|
predicted_data[frame_path]["avg_score"] = {}
|
|
|
|
for obj_name in frame_obj_info:
|
|
obj_score_sum = np.sum(frame_obj_info[obj_name]["scores"])
|
|
obj_score_avg = np.mean(frame_obj_info[obj_name]["scores"])
|
|
predicted_data[frame_path]["sum_score"][obj_name] = obj_score_sum
|
|
predicted_data[frame_path]["avg_score"][obj_name] = obj_score_avg
|
|
|
|
|
|
return predicted_data
|
|
|
|
def save_processed_data(self, processed_data, data_config=None):
|
|
data_path = os.path.join(str(self.experiment_path), Preprocessor.DATA, data_config["source"], data_config["data_type"])
|
|
for frame_path in processed_data:
|
|
data_item = processed_data[frame_path]
|
|
scene = os.path.basename(os.path.dirname(frame_path))
|
|
idx = os.path.basename(frame_path)
|
|
target_scene_path = os.path.join(str(data_path), scene)
|
|
if not os.path.exists(target_scene_path):
|
|
|
|
os.makedirs(target_scene_path)
|
|
label_save_path = os.path.join(
|
|
target_scene_path,OmniUtil.SCORE_LABEL_TEMPLATE.format(idx)
|
|
)
|
|
with open(label_save_path, "w+") as f:
|
|
json.dump(data_item, f)
|
|
print("Processed data saved to: ", data_path)
|