Basic Framework
This commit is contained in:
35
datasets/dataset.py
Normal file
35
datasets/dataset.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from abc import ABC
|
||||
|
||||
import numpy as np
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
|
||||
class BaseDataset(ABC, Dataset):
|
||||
def __init__(self, config):
|
||||
super(BaseDataset, self).__init__()
|
||||
self.config = config
|
||||
|
||||
@staticmethod
|
||||
def process_batch(batch, device):
|
||||
for key in batch.keys():
|
||||
if isinstance(batch[key], list):
|
||||
continue
|
||||
batch[key] = batch[key].to(device)
|
||||
return batch
|
||||
|
||||
def get_loader(self, shuffle=False):
|
||||
ratio = self.config["ratio"]
|
||||
if ratio > 1 or ratio <= 0:
|
||||
raise ValueError(
|
||||
f"dataset ratio should be between (0,1], found {ratio} in {self.config['name']}"
|
||||
)
|
||||
subset_size = int(len(self) * ratio)
|
||||
indices = np.random.permutation(len(self))[:subset_size]
|
||||
subset = Subset(self, indices)
|
||||
return DataLoader(
|
||||
subset,
|
||||
batch_size=self.config["batch_size"],
|
||||
num_workers=self.config["num_workers"],
|
||||
shuffle=shuffle,
|
||||
)
|
30
datasets/dataset_factory.py
Normal file
30
datasets/dataset_factory.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import sys
|
||||
import os
|
||||
path = os.path.abspath(__file__)
|
||||
for i in range(2):
|
||||
path = os.path.dirname(path)
|
||||
PROJECT_ROOT = path
|
||||
sys.path.append(PROJECT_ROOT)
|
||||
|
||||
from datasets.dataset import BaseDataset
|
||||
|
||||
class DatasetFactory:
|
||||
@staticmethod
|
||||
def create(config) -> BaseDataset:
|
||||
pass
|
||||
|
||||
|
||||
''' ------------ Debug ------------ '''
|
||||
if __name__ == "__main__":
|
||||
|
||||
from configs.config import ConfigManager
|
||||
|
||||
ConfigManager.load_config_with('/home/data/hofee/project/ActivePerception/ActivePerception/configs/server_train_config.yaml')
|
||||
ConfigManager.print_config()
|
||||
dataset = DatasetFactory.create(ConfigManager.get("settings", "test", "dataset_list")[1])
|
||||
print(len(dataset))
|
||||
data_test = dataset.__getitem__(107000)
|
||||
print(data_test['src_path'])
|
||||
import pickle
|
||||
# with open("data_sample_new.pkl", "wb") as f:
|
||||
# pickle.dump(data_test, f)
|
Reference in New Issue
Block a user