|
|
import math |
|
|
import torch |
|
|
from torch.utils.data.distributed import DistributedSampler |
|
|
|
|
|
__all__ = ["MyDistributedSampler", "WeightedDistributedSampler"] |
|
|
|
|
|
|
|
|
class MyDistributedSampler(DistributedSampler): |
|
|
"""Allow Subset Sampler in Distributed Training""" |
|
|
|
|
|
def __init__( |
|
|
self, dataset, num_replicas=None, rank=None, shuffle=True, sub_index_list=None |
|
|
): |
|
|
super(MyDistributedSampler, self).__init__(dataset, num_replicas, rank, shuffle) |
|
|
self.sub_index_list = sub_index_list |
|
|
|
|
|
self.num_samples = int( |
|
|
math.ceil(len(self.sub_index_list) * 1.0 / self.num_replicas) |
|
|
) |
|
|
self.total_size = self.num_samples * self.num_replicas |
|
|
print("Use MyDistributedSampler: %d, %d" % (self.num_samples, self.total_size)) |
|
|
|
|
|
def __iter__(self): |
|
|
|
|
|
g = torch.Generator() |
|
|
g.manual_seed(self.epoch) |
|
|
indices = torch.randperm(len(self.sub_index_list), generator=g).tolist() |
|
|
|
|
|
|
|
|
indices += indices[: (self.total_size - len(indices))] |
|
|
indices = self.sub_index_list[indices].tolist() |
|
|
assert len(indices) == self.total_size |
|
|
|
|
|
|
|
|
indices = indices[self.rank : self.total_size : self.num_replicas] |
|
|
assert len(indices) == self.num_samples |
|
|
|
|
|
return iter(indices) |
|
|
|
|
|
|
|
|
class WeightedDistributedSampler(DistributedSampler): |
|
|
"""Allow Weighted Random Sampling in Distributed Training""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dataset, |
|
|
num_replicas=None, |
|
|
rank=None, |
|
|
shuffle=True, |
|
|
weights=None, |
|
|
replacement=True, |
|
|
): |
|
|
super(WeightedDistributedSampler, self).__init__( |
|
|
dataset, num_replicas, rank, shuffle |
|
|
) |
|
|
|
|
|
self.weights = ( |
|
|
torch.as_tensor(weights, dtype=torch.double) |
|
|
if weights is not None |
|
|
else None |
|
|
) |
|
|
self.replacement = replacement |
|
|
print("Use WeightedDistributedSampler") |
|
|
|
|
|
def __iter__(self): |
|
|
if self.weights is None: |
|
|
return super(WeightedDistributedSampler, self).__iter__() |
|
|
else: |
|
|
g = torch.Generator() |
|
|
g.manual_seed(self.epoch) |
|
|
if self.shuffle: |
|
|
|
|
|
indices = torch.multinomial( |
|
|
self.weights, len(self.dataset), self.replacement, generator=g |
|
|
).tolist() |
|
|
else: |
|
|
indices = list(range(len(self.dataset))) |
|
|
|
|
|
|
|
|
indices += indices[: (self.total_size - len(indices))] |
|
|
assert len(indices) == self.total_size |
|
|
|
|
|
|
|
|
indices = indices[self.rank : self.total_size : self.num_replicas] |
|
|
assert len(indices) == self.num_samples |
|
|
|
|
|
return iter(indices) |
|
|
|