renate.data.datasets module#

class renate.data.datasets.ImageDataset(data, labels, transform=None, target_transform=None)[source]#

Bases: Dataset

Dataset class for image datasets where the images are loaded as raw images.

Parameters:
  • data (List[str]) – List of data paths to the images.

  • labels (List[int]) – Labels of images.

  • transform (Optional[Callable]) – Transformation or augmentation to perform on the sample.

  • target_transform (Optional[Callable]) – Transformation or augmentation to perform on the target.

class renate.data.datasets.NestedTensorDataset(nested_tensors)[source]#

Bases: Dataset

A dataset of nested tensors.

Parameters:

nested_tensors (Union[Tensor, Tuple[Union[Tensor, Tuple[NestedTensors], Dict[str, NestedTensors]]], Dict[str, Union[Tensor, Tuple[NestedTensors], Dict[str, NestedTensors]]]]) – A nested tuple/dict structure of tensors. Tensors need to be of equal size along the batch dimension.

class renate.data.datasets.IndexedSubsetDataset(dataset, indexes_to_keep)[source]#

Bases: Dataset

A dataset wrapper to keep specified indexes of a dataset element.

Subset is indexing rows of a (tensor-)dataset, whereas IndexedSubset keeps specified columns. It currently handles Datasets whose elements are tuples.

Parameters:
  • dataset (Dataset) – The dataset to wrap

  • indexes_to_keep (Union[List, Tuple, int]) – An list or tuple of indices that are to be retained.