renate.benchmark.datasets.vision_datasets module#
- class renate.benchmark.datasets.vision_datasets.TinyImageNetDataModule(data_path, src_bucket=None, src_object_name=None, val_size=0.0, seed=0)[source]#
Bases:
RenateDataModule
Datamodule that process TinyImageNet dataset.
Source: http://cs231n.stanford.edu/
The TinyImageNet dataset is a subset of the ImageNet dataset. It contains 200 classes, each with 500 training images, 50 validation images with labels. There are also 50 unlabeled test images per class, which we are not using here. We use the validation split as the test set.
- Parameters:
data_path¶ (
Union
[Path
,str
]) – Path to the directory where the dataset should be stored.src_bucket¶ (
Optional
[str
]) – Name of the bucket where the dataset is stored.src_object_name¶ (
Optional
[str
]) – Name of the object in the bucket where the dataset is stored.val_size¶ (
float
) – Fraction of the training data to be used for validation.seed¶ (
int
) – Seed to be used for splitting the dataset.
- md5s = {'tiny-imagenet-200.zip': '90528d7ca1a48142e341f4ef8d21d0de'}#
- class renate.benchmark.datasets.vision_datasets.TorchVisionDataModule(data_path, src_bucket=None, src_object_name=None, dataset_name='MNIST', val_size=0.0, seed=0)[source]#
Bases:
RenateDataModule
Data module wrapping torchvision datasets.
- Parameters:
data_path¶ (
Union
[Path
,str
]) – the path to the folder containing the dataset files.src_bucket¶ (
Optional
[str
]) – the name of the s3 bucket. If not provided, downloads the data from original source.src_object_name¶ (
Optional
[str
]) – the folder path in the s3 bucket.dataset_name¶ (
str
) – Name of the torchvision dataset.val_size¶ (
float
) – Fraction of the training data to be used for validation.seed¶ (
int
) – Seed used to fix random number generation.
- dataset_dict = {'CIFAR10': (<class 'torchvision.datasets.cifar.CIFAR10'>, 'cifar-10-batches-py'), 'CIFAR100': (<class 'torchvision.datasets.cifar.CIFAR100'>, 'cifar-100-python'), 'FashionMNIST': (<class 'torchvision.datasets.mnist.FashionMNIST'>, 'FashionMNIST'), 'MNIST': (<class 'torchvision.datasets.mnist.MNIST'>, 'MNIST')}#
- dataset_stats = {'CIFAR10': {'mean': (0.49139967861519607, 0.48215840839460783, 0.44653091444546567), 'std': (0.24703223246174102, 0.24348512800151828, 0.26158784172803257)}, 'CIFAR100': {'mean': (0.5070751592371323, 0.48654887331495095, 0.4409178433670343), 'std': (0.26733428587941854, 0.25643846292120615, 0.2761504713263903)}, 'FashionMNIST': {'mean': 0.2860405969887955, 'std': 0.3530242445149223}, 'MNIST': {'mean': 0.1306604762738429, 'std': 0.30810780385646264}}#
- class renate.benchmark.datasets.vision_datasets.CLEARDataModule(data_path, time_step=0, src_bucket=None, src_object_name=None, dataset_name='CLEAR10', val_size=0.0, seed=0)[source]#
Bases:
DataIncrementalDataModule
Datamodule that process CLEAR datasets: CLEAR10 and CLEAR100.
Source: https://clear-benchmark.github.io/.
- Parameters:
data_path¶ (
Union
[Path
,str
]) – the path to the folder containing the dataset files.time_step¶ (
int
) – Loads CLEAR dataset for this time step. Options: CLEAR10: [0,9], CLEAR100: [0,10]src_bucket¶ (
Optional
[str
]) – the name of the s3 bucket. If not provided, downloads the data from original source.src_object_name¶ (
Optional
[str
]) – the folder path in the s3 bucket.dataset_name¶ (
str
) – CLEAR dataset name, options are clear10 and clear100.val_size¶ (
float
) – Fraction of the training data to be used for validation.seed¶ (
int
) – Seed used to fix random number generation.
- dataset_stats = {'CLEAR10': {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]}, 'CLEAR100': {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]}}#
- md5s = {'clear10-test.zip': 'bf9a85bfb78fe742c7ed32648c9a3275', 'clear10-train-image-only.zip': '5171f720810d60b471c308dee595d430', 'clear100-test.zip': 'e160815fb5fd4bc71dacd339ff41e6a9', 'clear100-train-image-only.zip': 'ea85cdba9efcb3abf77eaab5554052c8'}#
- class renate.benchmark.datasets.vision_datasets.DomainNetDataModule(data_path, src_bucket=None, src_object_name=None, domain='clipart', val_size=0.0, seed=0)[source]#
Bases:
DataIncrementalDataModule
Datamodule that provides access to DomainNet.
- Parameters:
data_path¶ (
Union
[Path
,str
]) – the path to the folder containing the dataset files.src_bucket¶ (
Optional
[str
]) – the name of the s3 bucket. If not provided, downloads the data from original source.src_object_name¶ (
Optional
[str
]) – the folder path in the s3 bucket.domain¶ (
str
) – DomainNet domain name, options are clipart, infograph, painting, quickdraw, real, and sketch.val_size¶ (
float
) – Fraction of the training data to be used for validation.seed¶ (
int
) – Seed used to fix random number generation.
- md5s = {'clipart.zip': 'cd0d8f2d77a4e181449b78ed62bccf1e', 'clipart_test.txt': 'f5ddbcfd657a3acf9d0f7da10db22565', 'clipart_train.txt': 'b4349693a7f9c05c53955725c47ed6cb', 'infograph.zip': '720380b86f9e6ab4805bb38b6bd135f8', 'infograph_test.txt': '779626b50869edffe8ea6941c3755c71', 'infograph_train.txt': '379b50054f4ac2018dca4f89421b92d9', 'painting.zip': '1ae32cdb4f98fe7ab5eb0a351768abfd', 'painting_test.txt': '232b35dc53f26d414686ae579e38d9b5', 'painting_train.txt': '7db0e7ca73ad9982f6e1f7f3ef372c0a', 'quickdraw.zip': 'bdc1b6f09f277da1a263389efe0c7a66', 'quickdraw_test.txt': 'c1a828fdfe216fb109f1c0083a252c6f', 'quickdraw_train.txt': 'b732ced3939ac8efdd8c0a889dca56cc', 'real.zip': 'dcc47055e8935767784b7162e7c7cca6', 'real_test.txt': '6098816791c3ebed543c71ffa11b9054', 'real_train.txt': '8ebf02c2075fadd564705f0dc7cd6291', 'sketch.zip': '658d8009644040ff7ce30bb2e820850f', 'sketch_test.txt': 'd8a222e4672cfd585298aa14d02ea441', 'sketch_train.txt': '1233bd18aa9a8a200bf4cecf1c34ef3e'}#
- domains = ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch']#
- dataset_stats = {'all': {'mean': [0.7491, 0.7391, 0.7179], 'std': [0.3318, 0.3314, 0.3512]}, 'clipart': {'mean': [0.7395, 0.7195, 0.6865], 'std': [0.3621, 0.364, 0.3873]}, 'infograph': {'mean': [0.6882, 0.6962, 0.6644], 'std': [0.3328, 0.3095, 0.3277]}, 'painting': {'mean': [0.5737, 0.5456, 0.5067], 'std': [0.3079, 0.3003, 0.3161]}, 'quickdraw': {'mean': [0.9525, 0.9525, 0.9525], 'std': [0.2127, 0.2127, 0.2127]}, 'real': {'mean': [0.6066, 0.5897, 0.5564], 'std': [0.3335, 0.327, 0.3485]}, 'sketch': {'mean': [0.8325, 0.8269, 0.818], 'std': [0.2723, 0.2747, 0.2801]}}#
- class renate.benchmark.datasets.vision_datasets.CDDBDataModule(data_path, src_bucket=None, src_object_name=None, domain='gaugan', val_size=0.0, seed=0)[source]#
Bases:
DataIncrementalDataModule
- md5s = {'CDDB.tar.zip': '823b6496270ba03019dbd6af60cbcb6b'}#
- domains = ['gaugan', 'biggan', 'wild', 'whichfaceisreal', 'san']#
- dataset_stats = {'CDDB': {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]}}#
- google_drive_id = '1NgB8ytBMFBFwyXJQvdVT_yek1EaaEHrg'#
- class renate.benchmark.datasets.vision_datasets.CORE50DataModule(data_path, src_bucket=None, src_object_name=None, scenario='ni', data_id=0, val_size=0.0, seed=0)[source]#
Bases:
DataIncrementalDataModule
Datamodule that process the CORe50 dataset.
It enables to download all the scenarios and with respect to 0th run (as per S-Prompts), set by
scenario
anddata_id
respectively.Source: https://vlomonaco.github.io/core50/. Adapted from: vlomonaco/core50
- Parameters:
data_path¶ (
Union
[Path
,str
]) – The path to the folder containing the dataset files.src_bucket¶ (
Optional
[str
]) – The name of the s3 bucket. If not provided, downloads the data from original source.src_object_name¶ (
Optional
[str
]) – The folder path in the s3 bucket.scenario¶ (
Literal
['ni'
,'nc'
,'nic'
,'nicv2_79'
,'nicv2_196'
,'nicv2_391'
]) – One ofni
,nc
,nic
,nicv2_79
,nicv2_196
andnicv2_391
. This is different from the usage of scenario elsewhere in Renate.data_id¶ (
int
) – One of the several data batches dependent on scenario
- md5s = {'LUP.pkl': '33afc26faa460aca98739137fdfa606e', 'core50_128x128.zip': '745f3373fed08d69343f1058ee559e13', 'labels.pkl': '281c95774306a2196f4505f22fd60ab1', 'paths.pkl': 'b568f86998849184df3ec3465290f1b0'}#
- dataset_stats = {'Core50': {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]}}#