# Once for All: Train One Network and Specialize it for Efficient Deployment # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han # International Conference on Learning Representations (ICLR), 2020. import numpy as np import os import sys import torch try: from urllib import urlretrieve except ImportError: from urllib.request import urlretrieve __all__ = [ "sort_dict", "get_same_padding", "get_split_list", "list_sum", "list_mean", "list_join", "subset_mean", "sub_filter_start_end", "min_divisible_value", "val2list", "download_url", "write_log", "pairwise_accuracy", "accuracy", "AverageMeter", "MultiClassAverageMeter", "DistributedMetric", "DistributedTensor", ] def sort_dict(src_dict, reverse=False, return_dict=True): output = sorted(src_dict.items(), key=lambda x: x[1], reverse=reverse) if return_dict: return dict(output) else: return output def get_same_padding(kernel_size): if isinstance(kernel_size, tuple): assert len(kernel_size) == 2, "invalid kernel size: %s" % kernel_size p1 = get_same_padding(kernel_size[0]) p2 = get_same_padding(kernel_size[1]) return p1, p2 assert isinstance(kernel_size, int), "kernel size should be either `int` or `tuple`" assert kernel_size % 2 > 0, "kernel size should be odd number" return kernel_size // 2 def get_split_list(in_dim, child_num, accumulate=False): in_dim_list = [in_dim // child_num] * child_num for _i in range(in_dim % child_num): in_dim_list[_i] += 1 if accumulate: for i in range(1, child_num): in_dim_list[i] += in_dim_list[i - 1] return in_dim_list def list_sum(x): return x[0] if len(x) == 1 else x[0] + list_sum(x[1:]) def list_mean(x): return list_sum(x) / len(x) def list_join(val_list, sep="\t"): return sep.join([str(val) for val in val_list]) def subset_mean(val_list, sub_indexes): sub_indexes = val2list(sub_indexes, 1) return list_mean([val_list[idx] for idx in sub_indexes]) def sub_filter_start_end(kernel_size, sub_kernel_size): center = kernel_size // 2 dev = sub_kernel_size // 2 start, end = center - dev, center + dev + 1 assert end - start == sub_kernel_size return start, end def min_divisible_value(n1, v1): """make sure v1 is divisible by n1, otherwise decrease v1""" if v1 >= n1: return n1 while n1 % v1 != 0: v1 -= 1 return v1 def val2list(val, repeat_time=1): if isinstance(val, list) or isinstance(val, np.ndarray): return val elif isinstance(val, tuple): return list(val) else: return [val for _ in range(repeat_time)] def download_url(url, model_dir="~/.torch/", overwrite=False): target_dir = url.split("/")[-1] model_dir = os.path.expanduser(model_dir) try: if not os.path.exists(model_dir): os.makedirs(model_dir) model_dir = os.path.join(model_dir, target_dir) cached_file = model_dir if not os.path.exists(cached_file) or overwrite: sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) urlretrieve(url, cached_file) return cached_file except Exception as e: # remove lock file so download can be executed next time. os.remove(os.path.join(model_dir, "download.lock")) sys.stderr.write("Failed to download from url %s" % url + "\n" + str(e) + "\n") return None def write_log(logs_path, log_str, prefix="valid", should_print=True, mode="a"): if not os.path.exists(logs_path): os.makedirs(logs_path, exist_ok=True) """ prefix: valid, train, test """ if prefix in ["valid", "test"]: with open(os.path.join(logs_path, "valid_console.txt"), mode) as fout: fout.write(log_str + "\n") fout.flush() if prefix in ["valid", "test", "train"]: with open(os.path.join(logs_path, "train_console.txt"), mode) as fout: if prefix in ["valid", "test"]: fout.write("=" * 10) fout.write(log_str + "\n") fout.flush() else: with open(os.path.join(logs_path, "%s.txt" % prefix), mode) as fout: fout.write(log_str + "\n") fout.flush() if should_print: print(log_str) def pairwise_accuracy(la, lb, n_samples=200000): n = len(la) assert n == len(lb) total = 0 count = 0 for _ in range(n_samples): i = np.random.randint(n) j = np.random.randint(n) while i == j: j = np.random.randint(n) if la[i] >= la[j] and lb[i] >= lb[j]: count += 1 if la[i] < la[j] and lb[i] < lb[j]: count += 1 total += 1 return float(count) / total def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.reshape(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res class AverageMeter(object): """ Computes and stores the average and current value Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py """ def __init__(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count class MultiClassAverageMeter: """Multi Binary Classification Tasks""" def __init__(self, num_classes, balanced=False, **kwargs): super(MultiClassAverageMeter, self).__init__() self.num_classes = num_classes self.balanced = balanced self.counts = [] for k in range(self.num_classes): self.counts.append(np.ndarray((2, 2), dtype=np.float32)) self.reset() def reset(self): for k in range(self.num_classes): self.counts[k].fill(0) def add(self, outputs, targets): outputs = outputs.data.cpu().numpy() targets = targets.data.cpu().numpy() for k in range(self.num_classes): output = np.argmax(outputs[:, k, :], axis=1) target = targets[:, k] x = output + 2 * target bincount = np.bincount(x.astype(np.int32), minlength=2 ** 2) self.counts[k] += bincount.reshape((2, 2)) def value(self): mean = 0 for k in range(self.num_classes): if self.balanced: value = np.mean( ( self.counts[k] / np.maximum(np.sum(self.counts[k], axis=1), 1)[:, None] ).diagonal() ) else: value = np.sum(self.counts[k].diagonal()) / np.maximum( np.sum(self.counts[k]), 1 ) mean += value / self.num_classes * 100.0 return mean class DistributedMetric(object): """ Horovod: average metrics from distributed training. """ def __init__(self, name): self.name = name self.sum = torch.zeros(1)[0] self.count = torch.zeros(1)[0] def update(self, val, delta_n=1): import horovod.torch as hvd val *= delta_n self.sum += hvd.allreduce(val.detach().cpu(), name=self.name) self.count += delta_n @property def avg(self): return self.sum / self.count class DistributedTensor(object): def __init__(self, name): self.name = name self.sum = None self.count = torch.zeros(1)[0] self.synced = False def update(self, val, delta_n=1): val *= delta_n if self.sum is None: self.sum = val.detach() else: self.sum += val.detach() self.count += delta_n @property def avg(self): import horovod.torch as hvd if not self.synced: self.sum = hvd.allreduce(self.sum, name=self.name) self.synced = True return self.sum / self.count