Pytorch use nondeterministic algoritm Medium

This code uses APIs with nondeterministic operations by default which could affect reproducibility. Use torch.use_deterministic_algorithms(True) to ensure deterministic algorithms are used.

Detector ID
python/pytorch-use-nondeterministic-algorithm@v1.0
Category
Common Weakness Enumeration (CWE) external icon
-

Noncompliant example

1def pytorch_use_nondeterministic_algorithm_noncompliant():
2    import torch
3    # Noncompliant: `torch.bmm` doesn't use deterministic algorithms
4    # by default.
5    torch.bmm(torch.randn(2, 2, 2).to_sparse().cuda(),
6              torch.randn(2, 2, 2).cuda())

Compliant example

1def pytorch_use_nondeterministic_algorithm_compliant():
2    import torch
3    # Compliant: configure `torch.bmm` to use deterministic algorithms.
4    torch.use_deterministic_algorithms(True)
5    torch.bmm(torch.randn(2, 2, 2).to_sparse().cuda(),
6              torch.randn(2, 2, 2).cuda())