PyTorch Learning Note

Using multiple GPUs

device = torch.device(“cuda:gpu_id1”)

model = nn.DataParallel(model, [gpu_id1, gpu_id2, …])
model.to(device)

input = input.to(device)

Note that gpu_id1 must be the first gpu in the gpu_list in model.DataParallel(arg1, gpu_list)