Source code for renate.models.layers.shared_linear

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import torch.nn as nn


[docs] class SharedMultipleLinear(nn.ModuleDict): """This implements a linear classification layer for multiple tasks (updates). This linear layer can be shared across all tasks or can have a separate layer per task. This follows the `_task_params` in the `RenateBenchmarkingModule` that is a `nn.ModuleDict` that holds a classifier per task (as in TIL). Args: in_features: size of each input sample out_features: size of each output sample bias: If set to ``False``, the layer will not learn an additive bias. Default: ``True`` share_parameters: Flag whether to share parameters or use individual linears per task. The interface remains identical, and the underlying linear layer is shared (or not). num_updates: Number of updates that have happened/is happening. """ def __init__( self, in_features: int, out_features: int, bias: bool = True, share_parameters: bool = True, num_updates: int = 0, ) -> None: self._share_parameters = share_parameters self.in_features = in_features self.out_features = out_features self.bias = bias super().__init__() for _ in range(num_updates): self.increment_task()
[docs] def increment_task(self) -> None: currlen = len(self) if self._share_parameters: self[f"{currlen}"] = ( self[list(self.keys())[0]] if currlen > 0 else nn.Linear(in_features=self.in_features, out_features=self.out_features) ) else: self[f"{currlen}"] = nn.Linear( in_features=self.in_features, out_features=self.out_features )