Detects if Softmax is used with CrossEntropyLoss. This is redundant as CrossEntropyLoss implicitly computes Softmax.
1def pytorch_redundant_softmax_noncompliant():
2 import torch
3 from torch import nn
4 from torch.utils.data import DataLoader
5 from torchvision import datasets
6 from torchvision.transforms import ToTensor
7
8 training_data = datasets.FashionMNIST(
9 root="data",
10 train=True,
11 download=True,
12 transform=ToTensor()
13 )
14
15 test_data = datasets.FashionMNIST(
16 root="data",
17 train=False,
18 download=True,
19 transform=ToTensor()
20 )
21
22 train_dataloader = DataLoader(training_data, batch_size=64)
23 test_dataloader = DataLoader(test_data, batch_size=64)
24
25 class NeuralNetwork(nn.Module):
26 def __init__(self):
27 super().__init__()
28 self.flatten = nn.Flatten()
29 self.linear_relu_stack = nn.Sequential(
30 nn.Linear(28 * 28, 512),
31 nn.ReLU(),
32 nn.Linear(512, 512),
33 nn.ReLU(),
34 nn.Linear(512, 10)
35 )
36
37 def forward(self, x):
38 x = self.flatten(x)
39 logits = self.linear_relu_stack(x)
40 # Noncompliant: Softmax used with CrossEntropyLoss.
41 logits = nn.functional.softmax(logits)
42 return logits
43
44 model = NeuralNetwork()
45
46 def train_loop(dataloader, model, loss_fn, optimizer):
47 size = len(dataloader.dataset)
48 for batch, (x, y) in enumerate(dataloader):
49 # Compute prediction and loss
50 pred = model(x)
51 loss = loss_fn(pred, y)
52
53 # Backpropagation
54 optimizer.zero_grad()
55 loss.backward()
56 optimizer.step()
57
58 if batch % 100 == 0:
59 loss, current = loss.item(), (batch + 1) * len(x)
60 print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
61
62 def test_loop(dataloader, model, loss_fn):
63 size = len(dataloader.dataset)
64 num_batches = len(dataloader)
65 test_loss, correct = 0, 0
66
67 with torch.no_grad():
68 for x, y in dataloader:
69 pred = model(x)
70 test_loss += loss_fn(pred, y).item()
71 correct += (pred.argmax(1) == y).type(torch.float).sum().item()
72
73 test_loss /= num_batches
74 correct /= size
75 print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, "
76 f"Avg loss: {test_loss:>8f} \n")
77
78 loss_fn = nn.CrossEntropyLoss()
79 learning_rate = 0.05
80 optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
81
82 epochs = 10
83 for t in range(epochs):
84 print(f"Model - Epoch {t + 1}\n-------------------------------")
85 train_loop(train_dataloader, model, loss_fn, optimizer)
86 test_loop(test_dataloader, model, loss_fn)
87
88 print("Done!")
1def pytorch_redundant_softmax_compliant():
2 import torch
3 import torch.nn as nn
4 from transformers import BertModel, BertForSequenceClassification
5 import default_constants
6
7 class BERT(nn.Module):
8 def __init__(self, tokenizer, bert_variant=default_constants.BERT):
9 super(BERT, self).__init__()
10
11 self.num_labels = default_constants.num_labels
12 self.hidden_dim = default_constants.hidden_dim
13 self.dropout_prob = default_constants.dropout_prob
14
15 self.bert = BertModel.from_pretrained(bert_variant)
16 self.bert.resize_token_embeddings(len(tokenizer))
17
18 self.dropout = nn.Dropout(self.dropout_prob).cuda()
19 self.classifier = nn.Linear(
20 self.hidden_dim, self.num_labels).cuda()
21 torch.nn.init.xavier_uniform(self.classifier.weight)
22
23 def forward(
24 self,
25 text,
26 labels,
27 attention_mask=None,
28 token_type_ids=None
29 ):
30 outputs = self.bert(
31 text,
32 attention_mask=attention_mask,
33 token_type_ids=token_type_ids
34 )
35
36 pooled_output = outputs[1]
37
38 pooled_output = self.dropout(pooled_output)
39 logits = self.classifier(pooled_output)
40
41 # Compliant: Softmax is not used with CrossEntropyLoss.
42 loss_fct = nn.CrossEntropyLoss(weight=torch.Tensor(
43 default_constants.class_weight)).cuda(default_constants.DEVICE)
44 loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
45
46 return (
47 loss,
48 logits,
49 )