import math import torch from torch.utils.data.distributed import DistributedSampler __all__ = ["MyDistributedSampler", "WeightedDistributedSampler"] class MyDistributedSampler(DistributedSampler): """Allow Subset Sampler in Distributed Training""" def __init__( self, dataset, num_replicas=None, rank=None, shuffle=True, sub_index_list=None ): super(MyDistributedSampler, self).__init__(dataset, num_replicas, rank, shuffle) self.sub_index_list = sub_index_list # numpy self.num_samples = int( math.ceil(len(self.sub_index_list) * 1.0 / self.num_replicas) ) self.total_size = self.num_samples * self.num_replicas print("Use MyDistributedSampler: %d, %d" % (self.num_samples, self.total_size)) def __iter__(self): # deterministically shuffle based on epoch g = torch.Generator() g.manual_seed(self.epoch) indices = torch.randperm(len(self.sub_index_list), generator=g).tolist() # add extra samples to make it evenly divisible indices += indices[: (self.total_size - len(indices))] indices = self.sub_index_list[indices].tolist() assert len(indices) == self.total_size # subsample indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples return iter(indices) class WeightedDistributedSampler(DistributedSampler): """Allow Weighted Random Sampling in Distributed Training""" def __init__( self, dataset, num_replicas=None, rank=None, shuffle=True, weights=None, replacement=True, ): super(WeightedDistributedSampler, self).__init__( dataset, num_replicas, rank, shuffle ) self.weights = ( torch.as_tensor(weights, dtype=torch.double) if weights is not None else None ) self.replacement = replacement print("Use WeightedDistributedSampler") def __iter__(self): if self.weights is None: return super(WeightedDistributedSampler, self).__iter__() else: g = torch.Generator() g.manual_seed(self.epoch) if self.shuffle: # original: indices = torch.randperm(len(self.dataset), generator=g).tolist() indices = torch.multinomial( self.weights, len(self.dataset), self.replacement, generator=g ).tolist() else: indices = list(range(len(self.dataset))) # add extra samples to make it evenly divisible indices += indices[: (self.total_size - len(indices))] assert len(indices) == self.total_size # subsample indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples return iter(indices)