45 lines
1.3 KiB
Python
45 lines
1.3 KiB
Python
|
from abc import ABC
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from torch.utils.data import Dataset
|
||
|
from torch.utils.data import DataLoader, Subset
|
||
|
|
||
|
from PytorchBoot.component import Component
|
||
|
|
||
|
class BaseDataset(ABC, Dataset, Component):
|
||
|
def __init__(self, config):
|
||
|
super(BaseDataset, self).__init__()
|
||
|
self.config = config
|
||
|
|
||
|
@staticmethod
|
||
|
def process_batch(batch, device):
|
||
|
for key in batch.keys():
|
||
|
if isinstance(batch[key], list):
|
||
|
continue
|
||
|
batch[key] = batch[key].to(device)
|
||
|
return batch
|
||
|
|
||
|
def get_collate_fn(self):
|
||
|
return None
|
||
|
|
||
|
def get_loader(self, shuffle=False):
|
||
|
ratio = self.config["ratio"]
|
||
|
if ratio > 1 or ratio <= 0:
|
||
|
raise ValueError(
|
||
|
f"dataset ratio should be between (0,1], found {ratio} in {self.config['name']}"
|
||
|
)
|
||
|
subset_size = max(1,int(len(self) * ratio))
|
||
|
indices = np.random.permutation(len(self))[:subset_size]
|
||
|
subset = Subset(self, indices)
|
||
|
return DataLoader(
|
||
|
subset,
|
||
|
batch_size=self.config["batch_size"],
|
||
|
num_workers=self.config["num_workers"],
|
||
|
shuffle=shuffle,
|
||
|
collate_fn=self.get_collate_fn(),
|
||
|
)
|
||
|
|
||
|
|
||
|
|