renate.benchmark.models.resnet module#
- class renate.benchmark.models.resnet.ResNet(block, layers, num_outputs=10, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>, cifar_stem=True, gray_scale=False, prediction_strategy=None, add_icarl_class_means=True)[source]#
Bases:
RenateBenchmarkingModule
ResNet model base class.
TODO: Fix citation Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun: Deep Residual Learning for Image Recognition. CVPR 2016: 770-778
- Parameters:
block¶ (
Type
[Union
[BasicBlock
,Bottleneck
]]) – The type of the block to use as the core building block.layers¶ (
List
[int
]) – The number of blocks in the respective parts of ResNet.num_outputs¶ (
int
) – The number of output units.zero_init_residual¶ (
bool
) – Whether to set the initial weights of the residual blocks to zero through initializing the Batch Normalization parameters at the end of the block to zero.groups¶ (
int
) – The number of groups to be used for the group convolution.width_per_group¶ (
int
) – The width of the group convolution.replace_stride_with_dilation¶ (
Optional
[List
[bool
]]) – Whether to replace the stride with a dilation to save memory.norm_layer¶ (
Type
[Module
]) – What kind of normalization layer to use, following convolutions.cifar_stem¶ (
bool
) – Whether to use a stem for CIFAR-sized images.gray_scale¶ (
bool
) – Whether input images are gray-scale images, i.e. only 1 color channel.prediction_strategy¶ (
Optional
[PredictionStrategy
]) – Continual learning strategies may alter the prediction at train or test time.add_icarl_class_means¶ (
bool
) – IfTrue
, additional parameters used only by theICaRLModelUpdater
are added. Only required when using that updater.
- class renate.benchmark.models.resnet.ResNet18CIFAR(block=<class 'torchvision.models.resnet.BasicBlock'>, layers=[2, 2, 2, 2], cifar_stem=True, **kwargs)[source]#
Bases:
ResNet
- class renate.benchmark.models.resnet.ResNet34CIFAR(block=<class 'torchvision.models.resnet.BasicBlock'>, layers=[3, 4, 6, 3], cifar_stem=True, **kwargs)[source]#
Bases:
ResNet
- class renate.benchmark.models.resnet.ResNet50CIFAR(block=<class 'torchvision.models.resnet.Bottleneck'>, layers=[3, 4, 6, 3], cifar_stem=True, **kwargs)[source]#
Bases:
ResNet
- class renate.benchmark.models.resnet.ResNet18(block=<class 'torchvision.models.resnet.BasicBlock'>, layers=[2, 2, 2, 2], cifar_stem=False, **kwargs)[source]#
Bases:
ResNet