renate.benchmark.datasets.base module#

class renate.benchmark.datasets.base.DataIncrementalDataModule(data_path, data_id, src_bucket=None, src_object_name=None, val_size=0.0, seed=0)[source]#

Bases: RenateDataModule, ABC

Base class for all RenateDataModule compatible with DataIncrementalScenario.

Defines the API required by the DataIncrementalScenario. All classes extending this class must load the datasets corresponding to the value in data_id whenever setup() is called.

Parameters:
  • data_path (Union[Path, str]) – the path to the folder containing the dataset files.

  • data_id (Union[int, str]) – Time slice to be loaded.

  • src_bucket (Optional[str]) – the name of the s3 bucket. If not provided, downloads the data from original source.

  • src_object_name (Optional[str]) – the folder path in the s3 bucket.

  • val_size (float) – Fraction of the training data to be used for validation.

  • seed (int) – Seed used to fix random number generation.