torch_deterministic.collate_rngs

torch_deterministic.collate_rngs(x)[source]

A collate function for PyTorch dataloaders that automatically wraps NumPy pseudorandom number generators (PRNGs) in a BatchGenerator object.

All the data types normally recognized by PyTorch’s default_collate are also recognized by this function, so this function can be passed directly to the data loader as the collate argument.

Example:

from torch.utils.data import DataLoader
from torch_deterministic import collate_rngs
from my_dataset import dataset

DataLoader(dataset, collate_fn=collate_rngs)