Source code for torch_deterministic.batch_rng

import numpy as np
from torch.utils.data import default_collate

[docs]class BatchGenerator: """ Wrap a collection of NumPy pseudorandom number generators (PRNGs) such that samples can easily be drawn from all of them at once. Instantiate this class with a list of NumPy PRNGs. Then, any method invoked on this class will automatically be invoked on all of those PRNGs, and the results will collated into a PyTorch tensor. For example:: >>> from torch_deterministic import BatchGenerator >>> bg = BatchGenerator([ ... np.random.default_rng(0), ... np.random.default_rng(1), ... ]) >>> bg.uniform() tensor([0.6370, 0.5118], dtype=torch.float64) This class is meant to facilitate the idea that all of the randomness in each training step should come a PRNG seeded based on the index of the corresponding training example. This PRNG would be created by the dataset, used to build the training example, then returned in case the training loop itself requires any more randomness. The benefit of this approach is that it's very robust. The randomness does not depend on the number of data loader processes, and every training example can be reproduced without having to replay the whole dataset or constantly log the PRNG state. However, it's worth noting that from the point-of-view of trying to get the best possible distribution of random numbers, this approach is suboptimal. PRNGs are only designed to output high-quality randomness if seeded once. There's no guarantee that two PRNGs with different seeds won't output correlated values. In practice, though, this doesn't seem to be a significant issue. The `collate_rngs()` function can be used to make PyTorch data loaders automatically wrap collections of NumPy PRNGs with this class. """
[docs] def __init__(self, rngs): self._rngs = rngs
[docs] def __repr__(self): return f'<{self.__class__.__name__} n={len(self._rngs)}>'
[docs] def __len__(self): return len(self._rngs)
[docs] def __getattr__(self, name): if name.startswith('_'): raise AttributeError(name) def method_wrapper(*args, **kwargs): return default_collate([ getattr(rng, name)(*args, **kwargs) for rng in self._rngs ]) return method_wrapper
[docs] def pin_memory(self): return self
[docs]def collate_rngs(x): """ A collate function for PyTorch dataloaders that automatically wraps NumPy pseudorandom number generators (PRNGs) in a `BatchGenerator` object. All the data types normally recognized by PyTorch's `default_collate` are also recognized by this function, so this function can be passed directly to the data loader as the `collate` argument. Example:: from torch.utils.data import DataLoader from torch_deterministic import collate_rngs from my_dataset import dataset DataLoader(dataset, collate_fn=collate_rngs) """ from torch.utils.data._utils.collate import collate, default_collate_fn_map def collate_rng_fn(x, *, collate_fn_map=None): return BatchGenerator(x) collate_fn_map = { np.random.Generator: collate_rng_fn, **default_collate_fn_map, } return collate(x, collate_fn_map=collate_fn_map)