182 lines
8.3 KiB
Python
182 lines
8.3 KiB
Python
import os
|
|
import numpy as np
|
|
import requests
|
|
|
|
from PytorchBoot.runners import Runner
|
|
from PytorchBoot.config import ConfigManager
|
|
import PytorchBoot.stereotype as stereotype
|
|
from PytorchBoot.utils.log_util import Log
|
|
|
|
from utils.pose_util import PoseUtil
|
|
from utils.control_util import ControlUtil
|
|
from utils.communicate_util import CommunicateUtil
|
|
from utils.pts_util import PtsUtil
|
|
from utils.view_util import ViewUtil
|
|
from scipy.spatial.transform import Rotation as R
|
|
|
|
@stereotype.runner("inference_runner")
|
|
class InferenceRunner(Runner):
|
|
|
|
def __init__(self, config_path: str):
|
|
super().__init__(config_path)
|
|
self.load_experiment("inference")
|
|
self.inference_config = ConfigManager.get("runner", "inference")
|
|
self.server_url = self.inference_config["server_url"]
|
|
self.max_iter = self.inference_config["max_iter"]
|
|
self.max_fail = self.inference_config["max_fail"]
|
|
self.max_no_new_pts = self.inference_config["max_no_new_pts"]
|
|
self.min_delta_pts_num = self.inference_config["min_delta_pts_num"]
|
|
|
|
def check_stop(self, cnt, fail, no_new_pts):
|
|
if cnt > self.max_iter:
|
|
return True
|
|
if fail > self.max_fail:
|
|
return True
|
|
if no_new_pts > self.max_no_new_pts:
|
|
return True
|
|
return False
|
|
|
|
def split_scan_pts_and_obj_pts(self, world_pts, z_threshold=0.005):
|
|
scan_pts = world_pts[world_pts[:, 2] < z_threshold]
|
|
obj_pts = world_pts[world_pts[:, 2] >= z_threshold]
|
|
return scan_pts, obj_pts
|
|
|
|
def run(self):
|
|
ControlUtil.connect_robot()
|
|
ControlUtil.init()
|
|
scanned_pts_list = []
|
|
scanned_n_to_world_pose = []
|
|
cnt = 0
|
|
fail = 0
|
|
no_new_pts = 0
|
|
|
|
view_data = CommunicateUtil.get_view_data(init=True)
|
|
first_cam_to_real_world = ControlUtil.get_pose()
|
|
if view_data is None:
|
|
Log.error("No view data received")
|
|
fail += 1
|
|
return
|
|
cam_shot_pts = ViewUtil.get_pts(view_data)
|
|
|
|
# ########################################### DEBUG ###########################################
|
|
# sensor_pts = PtsUtil.transform_point_cloud(cam_shot_pts, np.linalg.inv(ControlUtil.CAMERA_TO_LEFT_CAMERA))
|
|
# np.savetxt('/home/yan20/Downloads/left_pts_0.txt', cam_shot_pts)
|
|
# np.savetxt('/home/yan20/Downloads/sensor_pts_0.txt', sensor_pts)
|
|
# #############################################################################################
|
|
|
|
world_shot_pts = PtsUtil.transform_point_cloud(
|
|
cam_shot_pts, first_cam_to_real_world
|
|
)
|
|
#import ipdb; ipdb.set_trace()
|
|
_, world_splitted_shot_pts = self.split_scan_pts_and_obj_pts(
|
|
world_shot_pts
|
|
)
|
|
curr_pts = world_splitted_shot_pts
|
|
curr_pose = first_cam_to_real_world
|
|
curr_pose_6d = PoseUtil.matrix_to_rotation_6d_numpy(curr_pose[:3,:3])
|
|
curr_pose_9d = np.concatenate([curr_pose_6d, curr_pose[:3, 3]], axis=0)
|
|
scanned_pts_list.append(curr_pts.tolist())
|
|
scanned_n_to_world_pose.append(curr_pose_9d.tolist())
|
|
combined_pts = np.concatenate(scanned_pts_list, axis=0)
|
|
downsampled_combined_pts = PtsUtil.voxel_downsample_point_cloud(combined_pts, 0.003)
|
|
last_downsampled_combined_pts_num = downsampled_combined_pts.shape[0]
|
|
Log.info(f"First downsampled combined pts: {last_downsampled_combined_pts_num}")
|
|
|
|
####################################### DEBUG #######################################
|
|
# scan_count = 0
|
|
# save_path = "/home/yan20/Downloads/pts"
|
|
# if not os.path.exists(save_path):
|
|
# os.makedirs(save_path)
|
|
#####################################################################################
|
|
|
|
while not self.check_stop(cnt, fail, no_new_pts):
|
|
|
|
data = {
|
|
"scanned_pts": scanned_pts_list,
|
|
"scanned_n_to_world_pose_9d": scanned_n_to_world_pose
|
|
}
|
|
# pts = np.array(data['scanned_pts'][-1])
|
|
# np.savetxt(f'{save_path}/pts_{scan_count}.txt', pts)
|
|
# scan_count += 1
|
|
response = requests.post(self.server_url, json=data)
|
|
result = response.json()
|
|
pred_pose_9d = np.array(result["pred_pose_9d"])
|
|
pred_rot_6d = pred_pose_9d[0, :6]
|
|
pred_trans = pred_pose_9d[0, 6:]
|
|
pred_rot_mat = PoseUtil.rotation_6d_to_matrix_numpy(pred_rot_6d)
|
|
pred_pose = np.eye(4)
|
|
pred_pose[:3, :3] = pred_rot_mat
|
|
pred_pose[:3, 3] = pred_trans
|
|
target_camera_pose = pred_pose @ ControlUtil.CAMERA_CORRECTION
|
|
|
|
ControlUtil.move_to(target_camera_pose)
|
|
cnt += 1
|
|
|
|
view_data = CommunicateUtil.get_view_data()
|
|
if view_data is None:
|
|
Log.error("No view data received")
|
|
fail += 1
|
|
continue
|
|
cam_shot_pts = ViewUtil.get_pts(view_data)
|
|
left_cam_to_first_left_cam = ViewUtil.get_camera_pose(view_data)
|
|
curr_pose = first_cam_to_real_world @ left_cam_to_first_left_cam @ np.linalg.inv(ControlUtil.CAMERA_CORRECTION)
|
|
# curr_pose = pred_pose
|
|
# curr_pose = first_cam_to_real_world @ ViewUtil.get_camera_pose(view_data)
|
|
print('pred_pose:', pred_pose)
|
|
print('curr_pose:', curr_pose)
|
|
|
|
##################################### DEBUG #####################################
|
|
# print(curr_pose)
|
|
# rot = R.from_matrix(curr_pose[:3, :3])
|
|
# quat_xyzw = rot.as_quat()
|
|
# translation = curr_pose[:3, 3]
|
|
# print(quat_xyzw, translation)
|
|
# # from ipdb import set_trace; set_trace()
|
|
#################################################################################
|
|
|
|
world_shot_pts = PtsUtil.transform_point_cloud(
|
|
cam_shot_pts, first_cam_to_real_world
|
|
)
|
|
_, world_splitted_shot_pts = self.split_scan_pts_and_obj_pts(
|
|
world_shot_pts
|
|
)
|
|
curr_pts = world_splitted_shot_pts
|
|
import ipdb; ipdb.set_trace()
|
|
from utils.vis import visualizeUtil
|
|
visualizeUtil.visualize_pts_and_camera(world_splitted_shot_pts,pred_pose)
|
|
curr_pose_6d = PoseUtil.matrix_to_rotation_6d_numpy(curr_pose[:3,:3])
|
|
curr_pose_9d = np.concatenate([curr_pose_6d, curr_pose[:3, 3]], axis=0)
|
|
scanned_pts_list.append(curr_pts.tolist())
|
|
scanned_n_to_world_pose.append(curr_pose_9d.tolist())
|
|
combined_pts = np.concatenate(scanned_pts_list, axis=0)
|
|
downsampled_combined_pts = PtsUtil.voxel_downsample_point_cloud(combined_pts, 0.003)
|
|
|
|
curr_downsampled_combined_pts_num = downsampled_combined_pts.shape[0]
|
|
Log.info(f"Downsampled combined pts: {curr_downsampled_combined_pts_num}")
|
|
if curr_downsampled_combined_pts_num < last_downsampled_combined_pts_num + self.min_delta_pts_num:
|
|
no_new_pts += 1
|
|
Log.info(f"No new points, cnt: {cnt}, fail: {fail}, no_new_pts: {no_new_pts}")
|
|
continue
|
|
Log.success("Inference finished")
|
|
# self.save_inference_result(scanned_pts_list, downsampled_combined_pts)
|
|
|
|
def create_experiment(self, backup_name=None):
|
|
super().create_experiment(backup_name)
|
|
self.inference_result_dir = os.path.join(self.experiment_path, "inference_result")
|
|
os.makedirs(self.inference_result_dir)
|
|
|
|
def load_experiment(self, backup_name=None):
|
|
super().load_experiment(backup_name)
|
|
self.inference_result_dir = os.path.join(self.experiment_path, "inference_result")
|
|
|
|
def save_inference_result(self, scanned_pts_list, downsampled_combined_pts):
|
|
import time
|
|
dir_name = time.strftime("inference_result_%Y_%m_%d_%Hh%Mm%Ss", time.localtime())
|
|
dir_path = os.path.join(self.inference_result_dir, dir_name)
|
|
for i in range(len(scanned_pts_list)):
|
|
np.savetxt(os.path.join(dir_path, f"{i}.txt"), np.array(scanned_pts_list[i]))
|
|
|
|
np.savetxt(os.path.join(dir_path, "downsampled_combined_pts.txt"), np.array(downsampled_combined_pts))
|
|
|
|
Log.success("Inference result saved")
|
|
|