renate.utils.pytorch module#

renate.utils.pytorch.reinitialize_model_parameters(model)[source]#

Reinitializes the parameters of a model.

This relies on all submodules of model implementing a method reset_parameters(). This is implemented for core torch.nn layers, but this may not be the case for custom implementations of exotic layers. A warning is logged for modules that do not implement reset_parameters().

The actual logic of reinitializing parameters depends on the type of layer. It may affect the module’s buffers (non-trainable parameters, e.g., batch norm stats) as well.

Parameters:

model (Module) – The model to be reinitialized.

Return type:

None

renate.utils.pytorch.get_generator(seed=None)[source]#

Provides a torch.Generator for the given seed.

torch.default_generator is returned if seed is None.

Return type:

Generator

renate.utils.pytorch.randomly_split_data(dataset, proportions, seed=0)[source]#

Randomly splits a dataset into chunks.

Return type:

List[Dataset]

renate.utils.pytorch.move_tensors_to_device(tensors, device)[source]#

Moves a collection of tensors to device.

The collection tensors can be a nested structure of tensors, tuples, lists, and dicts.

Return type:

Union[Tensor, Tuple[Union[Tensor, Tuple[NestedTensors], Dict[str, NestedTensors]]], Dict[str, Union[Tensor, Tuple[NestedTensors], Dict[str, NestedTensors]]]]

renate.utils.pytorch.get_length_nested_tensors(batch)[source]#

Given a NestedTensor, return its length.

Assumes that the first axis in each element is the same.

Return type:

Size

renate.utils.pytorch.cat_nested_tensors(nested_tensors, axis=0)[source]#

Concatenates the two NestedTensors.

Equivalent of PyTorch’s cat function for NestedTensors.

Parameters:
  • nested_tensors (Union[Tuple[Union[Tensor, Tuple[Union[Tensor, Tuple[NestedTensors], Dict[str, NestedTensors]]], Dict[str, Union[Tensor, Tuple[NestedTensors], Dict[str, NestedTensors]]]]], List[Union[Tensor, Tuple[Union[Tensor, Tuple[NestedTensors], Dict[str, NestedTensors]]], Dict[str, Union[Tensor, Tuple[NestedTensors], Dict[str, NestedTensors]]]]]]) – Tensors to be concatenated.

  • axis (int) – Concatenation axis.

Return type:

Union[Tensor, Tuple[Union[Tensor, Tuple[NestedTensors], Dict[str, NestedTensors]]], Dict[str, Union[Tensor, Tuple[NestedTensors], Dict[str, NestedTensors]]]]

renate.utils.pytorch.unique_classes(dataset)[source]#

Compute the unique class ids in a dataset.

Parameters:

dataset (Dataset) – Instance of Torch dataset.

Return type:

Set[int]

renate.utils.pytorch.complementary_indices(num_outputs, valid_classes)[source]#

Compute the asymmetric difference between the two arguments

Parameters:
  • num_outputs (int) – An integer of total number of classes the model can output.

  • valid_classes (Set[int]) – A set of integers of valid classes.

Return type:

List[int]

class renate.utils.pytorch.ConcatRandomSampler(dataset_lengths, batch_sizes, complete_dataset_iteration=None, generator=None, sampler=None)[source]#

Bases: Sampler[List[int]]

Sampler for sampling batches from ConcatDatasets.

Each sampled batch is composed of batches of different BatchSamplers with the specified batch sizes and ranges.

To clarify the behavior, we provide a little example. dataset_lengths = [5, 2] batch_sizes = [3, 1]

With this setting, we have a set of indices A={0..4} and B={5,6} for the two datasets. The total batch size will be exactly 4. The first three elements are in that batch are elements of A, the last an element of B. An example batch could be [3, 1, 0, 6].

Since we always provide a batch size of exactly ` sum(batch_sizes)``, we drop the last batch.

Parameters:
  • dataset_lengths (List[int]) – The length for the different datasets.

  • batch_sizes (List[int]) – Batch sizes used for specific datasets.

  • complete_dataset_iteration (Optional[int]) – Provide an index to indicate over which dataset to fully iterate. By default, stops whenever iteration is complete for any dataset.

  • generator (Optional[Any]) – Generator used in sampling.

  • sampler (Optional[Sampler]) – Lightning automatically passes a DistributedSamplerWrapper. Only used as an indicator that we are in the distributed case.