Pytorch - 3 Distributed Computing

2 minute read

Pytorch gets its own distributed implemetion, either through

  1. MPI backend, point to point, inspriation for torch.distributed
  2. Meta’s own [GLOO](https://github.com/facebookincubator/gloo, which is included in the pre-compiled PyTorch binaries
  3. Nvidia’s NCCL, pronunced as “Nickel”, which is Optimized primitives for inter-GPU communication., which is for CUDA only.
    Check more details at this link

    Distributed Computing

    Main program

    Here we go, bring back the old memories of rank and world_size. To simulate multiple machines, use different process to represent different machines (all running on a single machine).
    ```python import torch.distributed as dist import torch.multiprocessing as mp

    The GLOO backend is used

    def init_process(rank, size, fn, backend=’gloo’): “”” Initialize the distributed environment. “”” os.environ[‘MASTER_ADDR’] = ‘127.0.0.1’ os.environ[‘MASTER_PORT’] = ‘29500’ dist.init_process_group(backend, rank=rank, world_size=size) fn(rank, size)

if name == “main”: size = 2 processes = [] mp.set_start_method(“spawn”) for rank in range(size): # run is the Distributed function,to be implemented later. p = mp.Process(target=init_process, args=(rank, size, run)) p.start() processes.append(p)

for p in processes:
    p.join() ```

Send and Receive

A transfer of data from one process to another is called a point-to-point communication. These are achieved through the send/isend and recv/irecv.

"""Blocking point-to-point communication."""

def run(rank, size):
    tensor = torch.zeros(1)
    if rank == 0:
        tensor += 1
        # Send the tensor to process 1
        dist.send(tensor=tensor, dst=1)
    else:
        # Receive tensor from process 0
        dist.recv(tensor=tensor, src=0)
    print('Rank ', rank, ' has data ', tensor[0])

Non-block send and receive can be synced by wait() call.

"""Non-blocking point-to-point communication."""

def run(rank, size):
    tensor = torch.zeros(1)
    req = None
    if rank == 0:
        tensor += 1
        # Send the tensor to process 1
        req = dist.isend(tensor=tensor, dst=1)
        print('Rank 0 started sending')
    else:
        # Receive tensor from process 0
        req = dist.irecv(tensor=tensor, src=0)
        print('Rank 1 started receiving')
    # we should NOT modify the sent tensor nor access the received tensor before req.wait() has completed. It's all un undefined status till wait().
    req.wait()
    print('Rank ', rank, ' has data ', tensor[0])

Collective Communications

More than broadcast is available here. Alt text Here is an example of all-reduce

""" All-Reduce example."""
def run(rank, size):
    """ Simple collective communication. """
    group = dist.new_group([0, 1])
    tensor = torch.ones(1)
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group)
    print('Rank ', rank, ' has data ', tensor[0])

With data partition (see details in the original post), the training can be done with following code

""" Distributed Synchronous SGD Example """
def run(rank, size):
    torch.manual_seed(1234)
    train_set, bsz = partition_dataset()
    model = Net()
    optimizer = optim.SGD(model.parameters(),
                          lr=0.01, momentum=0.5)

    num_batches = ceil(len(train_set.dataset) / float(bsz))
    for epoch in range(10):
        epoch_loss = 0.0
        for data, target in train_set:
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            epoch_loss += loss.item()
            loss.backward()
            ### need to average the grad.
            average_gradients(model)
            optimizer.step()
        print('Rank ', dist.get_rank(), ', epoch ',
              epoch, ': ', epoch_loss / num_batches)

in which, we will use all_reduce to average the gradients.

""" Gradient averaging. """
def average_gradients(model):
    size = float(dist.get_world_size())
    for param in model.parameters():
        dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
        param.grad.data /= size

Tags:

Categories:

Updated: