renate.data.data_module module#
- class renate.data.data_module.RenateDataModule(data_path, src_bucket=None, src_object_name=None, val_size=0.0, seed=0)[source]#
Bases:
ABC
Data modules bundle code for data loading and preparation.
A data module implements two methods for data preparation: -
prepare_data()
downloads the data to the local machine and unpacks it. -setup()
creates pytorch dataset objects that return training, test and (possibly) validation data. These two steps are separated to streamline the process when launching multiple training jobs simultaneously, e.g., for hyperparameter optimization. In this case,prepare_data()
is only called once per machine.After these two methods have been called, the data can be accessed using -
train_data()
-test_data()
-val_data()
, which return torch datasets (torch.utils.data.Dataset
).- Parameters:
data_path¶ (
Union
[Path
,str
]) – the path to the data to be loaded.src_bucket¶ (
Union
[Path
,str
,None
]) – the name of the s3 bucket.src_object_name¶ (
Union
[Path
,str
,None
]) – the folder path in the s3 bucket.val_size¶ (
float
) – Fraction of the training data to be used for validation.seed¶ (
int
) – Seed used to fix random number generation.
- train_collate_fn()[source]#
Returns collate_fn for train DataLoader.
- Return type:
Optional
[Callable
]
- class renate.data.data_module.CSVDataModule(data_path, train_filename='train.csv', test_filename='test.csv', target_name='y', src_bucket=None, src_object_name=None, val_size=0.0, seed=0)[source]#
Bases:
RenateDataModule
A data module loading data from CSV files.
- Parameters:
data_path¶ (
Union
[Path
,str
]) – Path to the folder containing the files.train_filename¶ (
Union
[Path
,str
]) – Name of the CSV file containing the training data.test_filename¶ (
Union
[Path
,str
]) – Name of the CSV file containing the test data.src_bucket¶ (
Union
[Path
,str
,None
]) – Name of an s3 bucket. If specified, the folder given bysrc_object_name
will be downloaded from S3 todata_path
.src_object_name¶ (
Union
[Path
,str
,None
]) – Folder path in the s3 bucket.target_name¶ (
str
) – the header of the column containing the target values.val_size¶ (
float
) – Fraction of the training data to be used for validation.seed¶ (
int
) – Seed used to fix random number generation.