Pytorch sigmoid before bceloss Medium

The computation of the bceloss using sigmoid values as inputs can be replaced by a single BCEWithLogitsLoss. By combining these two operations, Pytorch can take advantage of the log-sum-exp trick which offers better numerical stability.

Detector ID
python/pytorch-sigmoid-before-bceloss@v1.0
Category
Common Weakness Enumeration (CWE) external icon
-

Noncompliant example

1def pytorch_sigmoid_before_bceloss_noncompliant():
2    import torch
3    import torch.nn as nn
4    # Noncompliant: `Sigmoid` layer followed by `BCELoss`
5    # is not numerically robust.
6    m = nn.Sigmoid()
7    loss = nn.BCELoss()
8
9    input = torch.randn(3, requires_grad=True)
10    target = torch.empty(3).random_(2)
11
12    output = loss(m(input), target)
13    output.backward()

Compliant example

1def pytorch_sigmoid_before_bceloss_compliant():
2    import torch
3    import torch.nn as nn
4    # Compliant: `BCEWithLogitsLoss` function integrates a `Sigmoid`
5    # layer and the `BCELoss` into one class
6    # and is numerically robust.
7    loss = nn.BCEWithLogitsLoss()
8
9    input = torch.randn(3, requires_grad=True)
10    target = torch.empty(3).random_(2)
11
12    output = loss(input, target)
13    output.backward()