renate.data.data_module module#

class renate.data.data_module.RenateDataModule(data_path, src_bucket=None, src_object_name=None, val_size=0.0, seed=0)[source]#

Bases: ABC

Data modules bundle code for data loading and preparation.

A data module implements two methods for data preparation: - prepare_data() downloads the data to the local machine and unpacks it. - setup() creates pytorch dataset objects that return training, test and (possibly) validation data. These two steps are separated to streamline the process when launching multiple training jobs simultaneously, e.g., for hyperparameter optimization. In this case, prepare_data() is only called once per machine.

After these two methods have been called, the data can be accessed using - train_data() - test_data() - val_data(), which return torch datasets (torch.utils.data.Dataset).

Parameters:
  • data_path (Union[Path, str]) – the path to the data to be loaded.

  • src_bucket (Union[Path, str, None]) – the name of the s3 bucket.

  • src_object_name (Union[Path, str, None]) – the folder path in the s3 bucket.

  • val_size (float) – Fraction of the training data to be used for validation.

  • seed (int) – Seed used to fix random number generation.

abstract prepare_data()[source]#

Downloads datasets.

Return type:

None

abstract setup()[source]#

Set up train, test and val datasets.

Return type:

None

train_data()[source]#

Returns training dataset.

Return type:

Dataset

val_data()[source]#

Returns validation dataset.

Return type:

Dataset

test_data()[source]#

Returns test dataset.

Return type:

Dataset

train_collate_fn()[source]#

Returns collate_fn for train DataLoader.

Return type:

Optional[Callable]

val_collate_fn()[source]#

Returns collate_fn for validation DataLoader.

Return type:

Optional[Callable]

test_collate_fn()[source]#

Returns collate_fn for test DataLoader.

Return type:

Optional[Callable]

class renate.data.data_module.CSVDataModule(data_path, train_filename='train.csv', test_filename='test.csv', target_name='y', src_bucket=None, src_object_name=None, val_size=0.0, seed=0)[source]#

Bases: RenateDataModule

A data module loading data from CSV files.

Parameters:
  • data_path (Union[Path, str]) – Path to the folder containing the files.

  • train_filename (Union[Path, str]) – Name of the CSV file containing the training data.

  • test_filename (Union[Path, str]) – Name of the CSV file containing the test data.

  • src_bucket (Union[Path, str, None]) – Name of an s3 bucket. If specified, the folder given by src_object_name will be downloaded from S3 to data_path.

  • src_object_name (Union[Path, str, None]) – Folder path in the s3 bucket.

  • target_name (str) – the header of the column containing the target values.

  • val_size (float) – Fraction of the training data to be used for validation.

  • seed (int) – Seed used to fix random number generation.

prepare_data()[source]#

Downloads data folder from S3 if applicable.

Return type:

None

setup()[source]#

Set up train, test and val datasets.

Return type:

None