239 lines
8.4 KiB
Python
239 lines
8.4 KiB
Python
|
import json
|
||
|
import numpy as np
|
||
|
import requests
|
||
|
import torch
|
||
|
from PIL import Image
|
||
|
|
||
|
from utils.cache_util import LRUCache
|
||
|
|
||
|
|
||
|
class ViewUtil:
|
||
|
view_cache = LRUCache(1024)
|
||
|
def load_camera_pose_from_frame(camera_params_path):
|
||
|
with open(camera_params_path, "r") as f:
|
||
|
camera_params = json.load(f)
|
||
|
|
||
|
view_transform = camera_params["cameraViewTransform"]
|
||
|
view_transform = np.resize(view_transform, (4,4))
|
||
|
view_transform = np.linalg.inv(view_transform).T
|
||
|
offset = np.mat([[1,0,0,0],[0,-1,0,0],[0,0,-1,0],[0,0,0,1]])
|
||
|
view_transform = view_transform.dot(offset)
|
||
|
return view_transform
|
||
|
|
||
|
def save_image(rgb, filename):
|
||
|
if rgb.dtype != np.uint8:
|
||
|
rgb = rgb.astype(np.uint8)
|
||
|
img = Image.fromarray(rgb, 'RGB')
|
||
|
img.save(filename)
|
||
|
|
||
|
def save_depth(depth, filename):
|
||
|
if depth.dtype != np.uint16:
|
||
|
depth = depth.astype(np.uint16)
|
||
|
depth_img = Image.fromarray(depth)
|
||
|
depth_img.save(filename)
|
||
|
|
||
|
def save_segmentation(seg, filename):
|
||
|
if seg.dtype != np.uint8:
|
||
|
seg = seg.astype(np.uint8)
|
||
|
seg_img = Image.fromarray(seg)
|
||
|
seg_img.save(filename)
|
||
|
|
||
|
@staticmethod
|
||
|
def get_view(camera_pose,source, data_type,scene,port):
|
||
|
camera_pose_tuple = tuple(map(tuple, camera_pose.tolist()))
|
||
|
cache_key = (camera_pose_tuple, source, data_type, scene, port)
|
||
|
cached_result = ViewUtil.view_cache.get(cache_key)
|
||
|
if cached_result:
|
||
|
print("Cache hit")
|
||
|
return cached_result
|
||
|
|
||
|
url = f"http://127.0.0.1:{port}/get_images"
|
||
|
headers = {
|
||
|
'Content-Type': 'application/json'
|
||
|
}
|
||
|
data = {
|
||
|
'camera_pose': camera_pose.tolist(),
|
||
|
'data_type': data_type,
|
||
|
'source': source,
|
||
|
'scene': scene
|
||
|
}
|
||
|
response = requests.post(url, headers=headers, data=json.dumps(data))
|
||
|
|
||
|
if response.status_code == 200:
|
||
|
results = response.json()
|
||
|
|
||
|
rgb = np.asarray(results['rgb'],dtype=np.uint8)
|
||
|
depth = np.asarray(results['depth'])/1000
|
||
|
seg = np.asarray(results['segmentation'])
|
||
|
seg_labels = results['segmentation_labels']
|
||
|
camera_params = results['camera_params']
|
||
|
ViewUtil.view_cache.put(cache_key, (rgb, depth, seg, seg_labels, camera_params))
|
||
|
return rgb, depth, seg, seg_labels, camera_params
|
||
|
else:
|
||
|
return None
|
||
|
|
||
|
@staticmethod
|
||
|
def get_object_pose_batch(K, mesh, rgb_batch, depth_batch, mask_batch, gt_pose_batch ,port):
|
||
|
url = f"http://127.0.0.1:{port}/predict_estimation_batch"
|
||
|
headers = {
|
||
|
'Content-Type': 'application/json'
|
||
|
}
|
||
|
mesh_data = {
|
||
|
'vertices': mesh.vertices.tolist(),
|
||
|
'faces': mesh.faces.tolist()
|
||
|
}
|
||
|
data = {
|
||
|
'K': K.tolist(),
|
||
|
'rgb_batch': rgb_batch.tolist(),
|
||
|
'depth_batch': depth_batch.tolist(),
|
||
|
'mask_batch': mask_batch.tolist(),
|
||
|
'mesh': mesh_data,
|
||
|
'gt_pose_batch': gt_pose_batch.tolist()
|
||
|
}
|
||
|
response = requests.post(url, headers=headers, data=json.dumps(data))
|
||
|
|
||
|
if response.status_code == 200:
|
||
|
results = response.json()
|
||
|
pose_batch = np.array(results['pose_batch'])
|
||
|
results_batch = results["eval_result_batch"]
|
||
|
return pose_batch, results_batch
|
||
|
else:
|
||
|
return None
|
||
|
|
||
|
@staticmethod
|
||
|
def get_visualized_result(K, mesh, rgb, pose ,port):
|
||
|
url = f"http://127.0.0.1:{port}/get_visualized_result"
|
||
|
headers = {
|
||
|
'Content-Type': 'application/json'
|
||
|
}
|
||
|
mesh_data = {
|
||
|
'vertices': mesh.vertices.tolist(),
|
||
|
'faces': mesh.faces.tolist()
|
||
|
}
|
||
|
data = {
|
||
|
'K': K.tolist(),
|
||
|
'rgb': rgb.tolist(),
|
||
|
'mesh': mesh_data,
|
||
|
'pose': pose.tolist()
|
||
|
}
|
||
|
response = requests.post(url, headers=headers, data=json.dumps(data))
|
||
|
|
||
|
if response.status_code == 200:
|
||
|
results = response.json()
|
||
|
vis_rgb = np.array(results['vis_rgb'])
|
||
|
return vis_rgb
|
||
|
else:
|
||
|
return None
|
||
|
|
||
|
@staticmethod
|
||
|
def get_object_pose(K, mesh, rgb, depth, mask, gt_pose ,port):
|
||
|
url = f"http://127.0.0.1:{port}/predict_estimation"
|
||
|
headers = {
|
||
|
'Content-Type': 'application/json'
|
||
|
}
|
||
|
mesh_data = {
|
||
|
'vertices': mesh.vertices.tolist(),
|
||
|
'faces': mesh.faces.tolist()
|
||
|
}
|
||
|
data = {
|
||
|
'K': K.tolist(),
|
||
|
'rgb': rgb.tolist(),
|
||
|
'depth': depth.tolist(),
|
||
|
'mask': mask.tolist(),
|
||
|
'mesh': mesh_data,
|
||
|
'gt_pose': gt_pose.tolist()
|
||
|
}
|
||
|
response = requests.post(url, headers=headers, data=json.dumps(data))
|
||
|
|
||
|
if response.status_code == 200:
|
||
|
results = response.json()
|
||
|
pose_batch = np.array(results['pose_batch'])
|
||
|
results_batch = results["eval_result_batch"]
|
||
|
return pose_batch, results_batch
|
||
|
else:
|
||
|
return None
|
||
|
|
||
|
def get_pts_dict(depth, seg, seg_labels, camera_params):
|
||
|
cx = camera_params['cx']
|
||
|
cy = camera_params['cy']
|
||
|
fx = camera_params['fx']
|
||
|
fy = camera_params['fy']
|
||
|
width = camera_params['width']
|
||
|
height = camera_params['height']
|
||
|
pts_dict = {name: [] for name in seg_labels.values()}
|
||
|
u = np.arange(width)
|
||
|
v = np.arange(height)
|
||
|
u, v = np.meshgrid(u, v)
|
||
|
Z = depth
|
||
|
X = (u - cx) * Z / fx
|
||
|
Y = (v - cy) * Z / fy
|
||
|
points = np.stack((X, Y, Z), axis=-1).reshape(-1, 3)
|
||
|
labels = seg.reshape(-1)
|
||
|
for label, name in seg_labels.items():
|
||
|
mask = labels == int(label)
|
||
|
pts_dict[name] = points[mask]
|
||
|
return pts_dict
|
||
|
|
||
|
def get_object_center_from_pts_dict(obj,pts_dict):
|
||
|
if obj is None:
|
||
|
for _, pts in pts_dict.items():
|
||
|
if pts.size != 0:
|
||
|
obj_pts = pts
|
||
|
break
|
||
|
else:
|
||
|
obj_pts = pts_dict[obj]
|
||
|
if obj_pts.size == 0:
|
||
|
for _, pts in pts_dict.items():
|
||
|
if pts.size != 0:
|
||
|
obj_pts = pts
|
||
|
break
|
||
|
obj_center = obj_pts.mean(axis=0)
|
||
|
return obj_center
|
||
|
|
||
|
def get_pts_center(pts):
|
||
|
pts_center = pts.mean(axis=0)
|
||
|
return pts_center
|
||
|
|
||
|
def get_scene_pts(pts_dict):
|
||
|
if any(isinstance(pts, torch.Tensor) for pts in pts_dict.values()):
|
||
|
scene_pts = torch.cat([pts for _, pts in pts_dict.items()], dim=0)
|
||
|
return scene_pts
|
||
|
else:
|
||
|
scene_pts = np.concatenate([pts for _, pts in pts_dict.items()])
|
||
|
return scene_pts
|
||
|
|
||
|
def crop_pts(scene_pts, crop_center, radius=0.2):
|
||
|
if isinstance(scene_pts, torch.Tensor):
|
||
|
crop_mask = torch.norm(scene_pts - crop_center, dim=1) < radius
|
||
|
return scene_pts[crop_mask]
|
||
|
else:
|
||
|
crop_mask = np.linalg.norm(scene_pts - crop_center, axis=1) < radius
|
||
|
return scene_pts[crop_mask]
|
||
|
|
||
|
def crop_pts_dict(pts_dict, crop_center, radius=0.2, min_pts_num = 5000):
|
||
|
crop_dict = {}
|
||
|
max_loop = 100
|
||
|
loop = 0
|
||
|
while(loop<=max_loop):
|
||
|
croped_length = 0
|
||
|
for obj, pts in pts_dict.items():
|
||
|
if isinstance(pts, torch.Tensor):
|
||
|
crop_mask = torch.norm(pts - crop_center, dim=1) < radius
|
||
|
crop_dict[obj] = pts[crop_mask]
|
||
|
else:
|
||
|
crop_mask = np.linalg.norm(pts - crop_center, axis=1) < radius
|
||
|
crop_dict[obj] = pts[crop_mask]
|
||
|
croped_length += crop_dict[obj].shape[0]
|
||
|
if croped_length >= min_pts_num:
|
||
|
break
|
||
|
radius += 0.02
|
||
|
loop += 1
|
||
|
return crop_dict
|
||
|
|
||
|
def get_cam_pose_focused_on_point(point_w, cam_pose_w, old_camera_center_w):
|
||
|
distance = np.linalg.norm(point_w-old_camera_center_w)
|
||
|
z_axis_camera = cam_pose_w[:3, 2].reshape(-1)
|
||
|
new_camera_position_w = point_w - distance * z_axis_camera
|
||
|
new_camera_pose_w = cam_pose_w.copy()
|
||
|
new_camera_pose_w[:3, 3] = new_camera_position_w.reshape((3,1))
|
||
|
return new_camera_pose_w
|