Post

Torch Train with Mulit GPU

Torch Train with Mulit GPU

多GPU训练

1. nn.DataParallel

1
2
3
4
5
6
7
8
9
devices = [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())]
net = nn.DataParallel(net, device_ids=devices)
for epoch in range(num_epochs):
    for X, y in train_iter:
        trainer.zero_grad()
        X, y = X.to(devices[0]), y.to(devices[0])
        l = loss(net(X), y)
        l.backward()
        trainer.step()

2. nn.parallel.DistributedDataParallel

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from torch.nn.parallel import DistributedDataParallel as DDP
def init_distributed(rank, world_size):
    torch.distributed.init_process_group(
        backend='nccl', 
        init_method='env://', # 适合单机多GPU
        world_size=world_size,
        rank=rank
    )
    torch.cuda.set_device(rank)
# usually rank = 0, world_size = torch.cuda.device_count()
init_distributed(rank, world_size)
model = nnModel().cuda(rank)
model = DDP(model, device_ids=[rank])
for epoch in range(num_epochs):
    for data, target in dataloader:
        data, target = data.cuda(rank), target.cuda(rank)
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
This post is licensed under CC BY 4.0 by the author.