Init ptb
This commit is contained in:
1
PytorchBoot/dataset/__init__.py
Normal file
1
PytorchBoot/dataset/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from PytorchBoot.dataset.base_dataset import BaseDataset
|
BIN
PytorchBoot/dataset/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
PytorchBoot/dataset/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
PytorchBoot/dataset/__pycache__/base_dataset.cpython-39.pyc
Normal file
BIN
PytorchBoot/dataset/__pycache__/base_dataset.cpython-39.pyc
Normal file
Binary file not shown.
44
PytorchBoot/dataset/base_dataset.py
Normal file
44
PytorchBoot/dataset/base_dataset.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from abc import ABC
|
||||
|
||||
import numpy as np
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
|
||||
from PytorchBoot.component import Component
|
||||
|
||||
class BaseDataset(ABC, Dataset, Component):
|
||||
def __init__(self, config):
|
||||
super(BaseDataset, self).__init__()
|
||||
self.config = config
|
||||
|
||||
@staticmethod
|
||||
def process_batch(batch, device):
|
||||
for key in batch.keys():
|
||||
if isinstance(batch[key], list):
|
||||
continue
|
||||
batch[key] = batch[key].to(device)
|
||||
return batch
|
||||
|
||||
def get_collate_fn(self):
|
||||
return None
|
||||
|
||||
def get_loader(self, shuffle=False):
|
||||
ratio = self.config["ratio"]
|
||||
if ratio > 1 or ratio <= 0:
|
||||
raise ValueError(
|
||||
f"dataset ratio should be between (0,1], found {ratio} in {self.config['name']}"
|
||||
)
|
||||
subset_size = max(1,int(len(self) * ratio))
|
||||
indices = np.random.permutation(len(self))[:subset_size]
|
||||
subset = Subset(self, indices)
|
||||
return DataLoader(
|
||||
subset,
|
||||
batch_size=self.config["batch_size"],
|
||||
num_workers=self.config["num_workers"],
|
||||
shuffle=shuffle,
|
||||
collate_fn=self.get_collate_fn(),
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user