Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import functools | |
| import os | |
| from typing import Optional | |
| import torch | |
| from torch.distributed import ProcessGroup | |
| _GROUP: Optional[ProcessGroup] = None | |
| _WORLD_SIZE: Optional[int] = None | |
| _LOCAL_RANK: int = 0 | |
| def initialize( | |
| world_size: int, | |
| local_rank: int, | |
| group: Optional[ProcessGroup] = None, | |
| use_gpu: bool = True, | |
| seed: int = 80486, | |
| ) -> str: | |
| """ | |
| Initialize model parallelism support. | |
| Args: | |
| world_size (int): the number of processes running on | |
| the current node available for model parallelism. | |
| local_rank (int): the present process' rank. | |
| group (torch.distributed.ProcessGroup, optional): the | |
| process group to use for model parallel communications. | |
| use_gpu (bool, optional): whether computations are | |
| happening on a GPU or not (defaults to True). | |
| seed (int, optional): the seed used to seed the prng | |
| on all model parallel processes | |
| Returns | |
| The pytorch device to use in the present process. | |
| Note: | |
| If ``group`` is not specified, the default process group is | |
| used for model parallelism. This means that the present | |
| module may be incompatible with other forms of parallelism | |
| such as data parallelism. | |
| """ | |
| global _GROUP | |
| global _WORLD_SIZE | |
| global _LOCAL_RANK | |
| assert local_rank < world_size | |
| if use_gpu: | |
| device = f"cuda:{local_rank}" | |
| torch.cuda.set_device(local_rank) | |
| else: | |
| device = "cpu" | |
| if group is None: | |
| if "MASTER_ADDR" not in os.environ: | |
| assert world_size == 1 | |
| os.environ["MASTER_ADDR"] = "127.0.0.1" | |
| os.environ["MASTER_PORT"] = "1234" | |
| torch.distributed.init_process_group( | |
| backend="nccl" if use_gpu else "gloo", | |
| init_method="env://", | |
| world_size=world_size, | |
| rank=local_rank, | |
| ) | |
| _GROUP = group | |
| _WORLD_SIZE = world_size | |
| _LOCAL_RANK = local_rank | |
| torch.manual_seed(seed) | |
| return device | |
| def get_world_size() -> int: | |
| if _WORLD_SIZE is None: | |
| raise RuntimeError("model parallelism was not initialized") | |
| return _WORLD_SIZE | |
| def get_rank() -> int: | |
| if _WORLD_SIZE is None: | |
| raise RuntimeError("model parallelism was not initialized") | |
| return _LOCAL_RANK | |
| def all_gather(x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Gather a tensor of shape (n, m) into a tensor of shape (n, mp_size * m). | |
| """ | |
| mp_size = get_world_size() | |
| if mp_size == 1: | |
| return x | |
| gather = [torch.empty_like(x) for _ in range(mp_size)] | |
| torch.distributed.all_gather(gather, x, group=_GROUP) | |
| return torch.cat(gather, dim=-1) | |
| def all_reduce(x: torch.Tensor): | |
| if get_world_size() > 1: | |
| # reduce with a sum | |
| torch.distributed.all_reduce(x, group=_GROUP) | |