Pytorch data loader with multiple workers Medium

Using DataLoader with num_workers greater than 0 can cause increased memory consumption over time when iterating over native Python objects such as list or dict. Pytorch uses multiprocessing in this scenario placing the data in shared memory. However, reference counting triggers copy-on-writes which over time increases the memory consumption. This behavior resembles a memory-leak. Using pandas, numpy, or pyarrow arrays solves this problem.

Detector ID
python/pytorch-data-loader-with-multiple-workers@v1.0
Category
Common Weakness Enumeration (CWE) external icon
-

Noncompliant example

1def pytorch_data_loader_with_multiple_workers_noncompliant():
2    import torch
3    from torch.utils.data import DataLoader
4    import numpy as np
5    sampler = InfomaxNodeRecNeighborSampler(g, [fanout] * (n_layers),
6                                            device=device, full_neighbor=True)
7    pr_node_ids = list(sampler.hetero_map.keys())
8    pr_val_ind = list(np.random.choice(len(pr_node_ids),
9                                       int(len(pr_node_ids) * val_pct),
10                                       replace=False))
11    pr_train_ind = list(set(list(np.arange(len(pr_node_ids))))
12                        .difference(set(pr_val_ind)))
13
14    # Noncompliant: num_workers value is 8 and native python 'list'
15    # is used here to store the dataset.
16    loader = DataLoader(dataset=pr_train_ind,
17                        batch_size=batch_size,
18                        collate_fn=sampler.sample_blocks,
19                        shuffle=True,
20                        num_workers=8)
21
22    optimizer = torch.optim.Adam(model.parameters(),
23                                 lr=lr,
24                                 weight_decay=l2norm)
25
26    # training loop
27    print("start training...")
28
29    for epoch in range(n_epochs):
30        model.train()

Compliant example

1def pytorch_data_loader_with_multiple_workers_compliant(args):
2    import torch.optim
3    import torchvision.datasets as datasets
4    # Data loading code
5    traindir = os.path.join(args.data, 'train')
6    valdir = os.path.join(args.data, 'val')
7    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
8                                     std=[0.229, 0.224, 0.225])
9
10    train_dataset = datasets.ImageFolder(traindir, imagenet_transforms)
11    train_sampler = torch.utils.data.distributed\
12        .DistributedSampler(train_dataset)
13
14    # Compliant: args.workers value is assigned to num_workers,
15    # but native python 'list/dict' is not used here to store the dataset.
16    train_loader = torch.utils.data.DataLoader(train_dataset,
17                                               batch_size=args.batch_size,
18                                               shuffle=(train_sampler is None),
19                                               num_workers=args.workers,
20                                               pin_memory=True,
21                                               sampler=train_sampler)