torch_deterministic.collate_rngs
- torch_deterministic.collate_rngs(x)[source]
A collate function for PyTorch dataloaders that automatically wraps NumPy pseudorandom number generators (PRNGs) in a
BatchGeneratorobject.All the data types normally recognized by PyTorch’s
default_collateare also recognized by this function, so this function can be passed directly to the data loader as thecollateargument.Example:
from torch.utils.data import DataLoader from torch_deterministic import collate_rngs from my_dataset import dataset DataLoader(dataset, collate_fn=collate_rngs)