Source code for renate.benchmark.models.resnet

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Type, Union

import torch.nn as nn
from torchvision.models.resnet import BasicBlock, Bottleneck
from torchvision.models.resnet import ResNet as _ResNet

from renate.benchmark.models.base import RenateBenchmarkingModule
from renate.models.prediction_strategies import PredictionStrategy


[docs] class ResNet(RenateBenchmarkingModule): """ResNet model base class. TODO: Fix citation Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun: Deep Residual Learning for Image Recognition. CVPR 2016: 770-778 Args: block: The type of the block to use as the core building block. layers: The number of blocks in the respective parts of ResNet. num_outputs: The number of output units. zero_init_residual: Whether to set the initial weights of the residual blocks to zero through initializing the Batch Normalization parameters at the end of the block to zero. groups: The number of groups to be used for the group convolution. width_per_group: The width of the group convolution. replace_stride_with_dilation: Whether to replace the stride with a dilation to save memory. norm_layer: What kind of normalization layer to use, following convolutions. cifar_stem: Whether to use a stem for CIFAR-sized images. gray_scale: Whether input images are gray-scale images, i.e. only 1 color channel. prediction_strategy: Continual learning strategies may alter the prediction at train or test time. add_icarl_class_means: If ``True``, additional parameters used only by the ``ICaRLModelUpdater`` are added. Only required when using that updater. """ def __init__( self, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], num_outputs: int = 10, zero_init_residual: bool = False, groups: int = 1, width_per_group: int = 64, replace_stride_with_dilation: Optional[List[bool]] = None, norm_layer: Type[nn.Module] = nn.BatchNorm2d, cifar_stem: bool = True, gray_scale: bool = False, prediction_strategy: Optional[PredictionStrategy] = None, add_icarl_class_means: bool = True, ) -> None: model = _ResNet( block=block, layers=layers, num_classes=num_outputs, zero_init_residual=zero_init_residual, groups=groups, width_per_group=width_per_group, replace_stride_with_dilation=replace_stride_with_dilation, norm_layer=norm_layer, ) super().__init__( embedding_size=model.fc.in_features, num_outputs=num_outputs, constructor_arguments={ "block": block, "layers": layers, "zero_init_residual": zero_init_residual, "groups": groups, "width_per_group": width_per_group, "replace_stride_with_dilation": replace_stride_with_dilation, "norm_layer": norm_layer, "cifar_stem": cifar_stem, "gray_scale": gray_scale, }, prediction_strategy=prediction_strategy, add_icarl_class_means=add_icarl_class_means, ) self._backbone = model if cifar_stem: self._backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self._backbone.maxpool = nn.Identity() if gray_scale: self._backbone.conv1 = nn.Conv2d( 1, self._backbone.conv1.out_channels, kernel_size=self._backbone.conv1.kernel_size, stride=self._backbone.conv1.stride, padding=self._backbone.conv1.padding, bias=self._backbone.conv1.bias is not None, ) self._backbone.fc = nn.Identity() for m in self.modules(): if hasattr(m, "reset_parameters"): m.reset_parameters()
[docs] class ResNet18CIFAR(ResNet): def __init__(self, block=BasicBlock, layers=[2, 2, 2, 2], cifar_stem=True, **kwargs) -> None: super().__init__(block=block, layers=layers, cifar_stem=cifar_stem, **kwargs)
[docs] class ResNet34CIFAR(ResNet): def __init__(self, block=BasicBlock, layers=[3, 4, 6, 3], cifar_stem=True, **kwargs) -> None: super().__init__(block=block, layers=layers, cifar_stem=cifar_stem, **kwargs)
[docs] class ResNet50CIFAR(ResNet): def __init__(self, block=Bottleneck, layers=[3, 4, 6, 3], cifar_stem=True, **kwargs) -> None: super().__init__(block=block, layers=layers, cifar_stem=cifar_stem, **kwargs)
[docs] class ResNet18(ResNet): def __init__(self, block=BasicBlock, layers=[2, 2, 2, 2], cifar_stem=False, **kwargs) -> None: super().__init__(block=block, layers=layers, cifar_stem=cifar_stem, **kwargs)
[docs] class ResNet34(ResNet): def __init__(self, block=BasicBlock, layers=[3, 4, 6, 3], cifar_stem=False, **kwargs) -> None: super().__init__(block=block, layers=layers, cifar_stem=cifar_stem, **kwargs)
[docs] class ResNet50(ResNet): def __init__(self, block=Bottleneck, layers=[3, 4, 6, 3], cifar_stem=False, **kwargs) -> None: super().__init__(block=block, layers=layers, cifar_stem=cifar_stem, **kwargs)