189 lines
7.5 KiB
Python
189 lines
7.5 KiB
Python
|
import os
|
||
|
import numpy as np
|
||
|
import torch.utils.data as torch_data
|
||
|
import kitti_utils
|
||
|
import cv2
|
||
|
from PIL import Image
|
||
|
|
||
|
|
||
|
USE_INTENSITY = False
|
||
|
|
||
|
|
||
|
class KittiDataset(torch_data.Dataset):
|
||
|
def __init__(self, root_dir, split='train', mode='TRAIN'):
|
||
|
self.split = split
|
||
|
self.mode = mode
|
||
|
self.classes = ['Car']
|
||
|
is_test = self.split == 'test'
|
||
|
self.imageset_dir = os.path.join(root_dir, 'KITTI', 'object', 'testing' if is_test else 'training')
|
||
|
|
||
|
split_dir = os.path.join(root_dir, 'KITTI', 'ImageSets', split + '.txt')
|
||
|
self.image_idx_list = [x.strip() for x in open(split_dir).readlines()]
|
||
|
self.sample_id_list = [int(sample_id) for sample_id in self.image_idx_list]
|
||
|
self.num_sample = self.image_idx_list.__len__()
|
||
|
|
||
|
self.npoints = 16384
|
||
|
|
||
|
self.image_dir = os.path.join(self.imageset_dir, 'image_2')
|
||
|
self.lidar_dir = os.path.join(self.imageset_dir, 'velodyne')
|
||
|
self.calib_dir = os.path.join(self.imageset_dir, 'calib')
|
||
|
self.label_dir = os.path.join(self.imageset_dir, 'label_2')
|
||
|
self.plane_dir = os.path.join(self.imageset_dir, 'planes')
|
||
|
|
||
|
def get_image(self, idx):
|
||
|
img_file = os.path.join(self.image_dir, '%06d.png' % idx)
|
||
|
assert os.path.exists(img_file)
|
||
|
return cv2.imread(img_file) # (H, W, 3) BGR mode
|
||
|
|
||
|
def get_image_shape(self, idx):
|
||
|
img_file = os.path.join(self.image_dir, '%06d.png' % idx)
|
||
|
assert os.path.exists(img_file)
|
||
|
im = Image.open(img_file)
|
||
|
width, height = im.size
|
||
|
return height, width, 3
|
||
|
|
||
|
def get_lidar(self, idx):
|
||
|
lidar_file = os.path.join(self.lidar_dir, '%06d.bin' % idx)
|
||
|
assert os.path.exists(lidar_file)
|
||
|
return np.fromfile(lidar_file, dtype=np.float32).reshape(-1, 4)
|
||
|
|
||
|
def get_calib(self, idx):
|
||
|
calib_file = os.path.join(self.calib_dir, '%06d.txt' % idx)
|
||
|
assert os.path.exists(calib_file)
|
||
|
return kitti_utils.Calibration(calib_file)
|
||
|
|
||
|
def get_label(self, idx):
|
||
|
label_file = os.path.join(self.label_dir, '%06d.txt' % idx)
|
||
|
assert os.path.exists(label_file)
|
||
|
return kitti_utils.get_objects_from_label(label_file)
|
||
|
|
||
|
@staticmethod
|
||
|
def get_valid_flag(pts_rect, pts_img, pts_rect_depth, img_shape):
|
||
|
val_flag_1 = np.logical_and(pts_img[:, 0] >= 0, pts_img[:, 0] < img_shape[1])
|
||
|
val_flag_2 = np.logical_and(pts_img[:, 1] >= 0, pts_img[:, 1] < img_shape[0])
|
||
|
val_flag_merge = np.logical_and(val_flag_1, val_flag_2)
|
||
|
pts_valid_flag = np.logical_and(val_flag_merge, pts_rect_depth >= 0)
|
||
|
return pts_valid_flag
|
||
|
|
||
|
def filtrate_objects(self, obj_list):
|
||
|
type_whitelist = self.classes
|
||
|
if self.mode == 'TRAIN':
|
||
|
type_whitelist = list(self.classes)
|
||
|
if 'Car' in self.classes:
|
||
|
type_whitelist.append('Van')
|
||
|
|
||
|
valid_obj_list = []
|
||
|
for obj in obj_list:
|
||
|
if obj.cls_type not in type_whitelist:
|
||
|
continue
|
||
|
|
||
|
valid_obj_list.append(obj)
|
||
|
return valid_obj_list
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.sample_id_list)
|
||
|
|
||
|
def __getitem__(self, index):
|
||
|
sample_id = int(self.sample_id_list[index])
|
||
|
calib = self.get_calib(sample_id)
|
||
|
img_shape = self.get_image_shape(sample_id)
|
||
|
pts_lidar = self.get_lidar(sample_id)
|
||
|
|
||
|
# get valid point (projected points should be in image)
|
||
|
pts_rect = calib.lidar_to_rect(pts_lidar[:, 0:3])
|
||
|
pts_intensity = pts_lidar[:, 3]
|
||
|
|
||
|
pts_img, pts_rect_depth = calib.rect_to_img(pts_rect)
|
||
|
pts_valid_flag = self.get_valid_flag(pts_rect, pts_img, pts_rect_depth, img_shape)
|
||
|
|
||
|
pts_rect = pts_rect[pts_valid_flag][:, 0:3]
|
||
|
pts_intensity = pts_intensity[pts_valid_flag]
|
||
|
|
||
|
if self.npoints < len(pts_rect):
|
||
|
pts_depth = pts_rect[:, 2]
|
||
|
pts_near_flag = pts_depth < 40.0
|
||
|
far_idxs_choice = np.where(pts_near_flag == 0)[0]
|
||
|
near_idxs = np.where(pts_near_flag == 1)[0]
|
||
|
near_idxs_choice = np.random.choice(near_idxs, self.npoints - len(far_idxs_choice), replace=False)
|
||
|
|
||
|
choice = np.concatenate((near_idxs_choice, far_idxs_choice), axis=0) \
|
||
|
if len(far_idxs_choice) > 0 else near_idxs_choice
|
||
|
np.random.shuffle(choice)
|
||
|
else:
|
||
|
choice = np.arange(0, len(pts_rect), dtype=np.int32)
|
||
|
if self.npoints > len(pts_rect):
|
||
|
extra_choice = np.random.choice(choice, self.npoints - len(pts_rect), replace=False)
|
||
|
choice = np.concatenate((choice, extra_choice), axis=0)
|
||
|
np.random.shuffle(choice)
|
||
|
|
||
|
ret_pts_rect = pts_rect[choice, :]
|
||
|
ret_pts_intensity = pts_intensity[choice] - 0.5 # translate intensity to [-0.5, 0.5]
|
||
|
|
||
|
pts_features = [ret_pts_intensity.reshape(-1, 1)]
|
||
|
ret_pts_features = np.concatenate(pts_features, axis=1) if pts_features.__len__() > 1 else pts_features[0]
|
||
|
|
||
|
sample_info = {'sample_id': sample_id}
|
||
|
|
||
|
if self.mode == 'TEST':
|
||
|
if USE_INTENSITY:
|
||
|
pts_input = np.concatenate((ret_pts_rect, ret_pts_features), axis=1) # (N, C)
|
||
|
else:
|
||
|
pts_input = ret_pts_rect
|
||
|
sample_info['pts_input'] = pts_input
|
||
|
sample_info['pts_rect'] = ret_pts_rect
|
||
|
sample_info['pts_features'] = ret_pts_features
|
||
|
return sample_info
|
||
|
|
||
|
gt_obj_list = self.filtrate_objects(self.get_label(sample_id))
|
||
|
|
||
|
gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
|
||
|
|
||
|
# prepare input
|
||
|
if USE_INTENSITY:
|
||
|
pts_input = np.concatenate((ret_pts_rect, ret_pts_features), axis=1) # (N, C)
|
||
|
else:
|
||
|
pts_input = ret_pts_rect
|
||
|
|
||
|
# generate training labels
|
||
|
cls_labels = self.generate_training_labels(ret_pts_rect, gt_boxes3d)
|
||
|
sample_info['pts_input'] = pts_input
|
||
|
sample_info['pts_rect'] = ret_pts_rect
|
||
|
sample_info['cls_labels'] = cls_labels
|
||
|
return sample_info
|
||
|
|
||
|
@staticmethod
|
||
|
def generate_training_labels(pts_rect, gt_boxes3d):
|
||
|
cls_label = np.zeros((pts_rect.shape[0]), dtype=np.int32)
|
||
|
gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d, rotate=True)
|
||
|
extend_gt_boxes3d = kitti_utils.enlarge_box3d(gt_boxes3d, extra_width=0.2)
|
||
|
extend_gt_corners = kitti_utils.boxes3d_to_corners3d(extend_gt_boxes3d, rotate=True)
|
||
|
for k in range(gt_boxes3d.shape[0]):
|
||
|
box_corners = gt_corners[k]
|
||
|
fg_pt_flag = kitti_utils.in_hull(pts_rect, box_corners)
|
||
|
cls_label[fg_pt_flag] = 1
|
||
|
|
||
|
# enlarge the bbox3d, ignore nearby points
|
||
|
extend_box_corners = extend_gt_corners[k]
|
||
|
fg_enlarge_flag = kitti_utils.in_hull(pts_rect, extend_box_corners)
|
||
|
ignore_flag = np.logical_xor(fg_pt_flag, fg_enlarge_flag)
|
||
|
cls_label[ignore_flag] = -1
|
||
|
|
||
|
return cls_label
|
||
|
|
||
|
def collate_batch(self, batch):
|
||
|
batch_size = batch.__len__()
|
||
|
ans_dict = {}
|
||
|
|
||
|
for key in batch[0].keys():
|
||
|
if isinstance(batch[0][key], np.ndarray):
|
||
|
ans_dict[key] = np.concatenate([batch[k][key][np.newaxis, ...] for k in range(batch_size)], axis=0)
|
||
|
|
||
|
else:
|
||
|
ans_dict[key] = [batch[k][key] for k in range(batch_size)]
|
||
|
if isinstance(batch[0][key], int):
|
||
|
ans_dict[key] = np.array(ans_dict[key], dtype=np.int32)
|
||
|
elif isinstance(batch[0][key], float):
|
||
|
ans_dict[key] = np.array(ans_dict[key], dtype=np.float32)
|
||
|
|
||
|
return ans_dict
|