ProArd / proard /utils /common_tools.py
smi08's picture
Upload folder using huggingface_hub
188f311 verified
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import numpy as np
import os
import sys
import torch
try:
from urllib import urlretrieve
except ImportError:
from urllib.request import urlretrieve
__all__ = [
"sort_dict",
"get_same_padding",
"get_split_list",
"list_sum",
"list_mean",
"list_join",
"subset_mean",
"sub_filter_start_end",
"min_divisible_value",
"val2list",
"download_url",
"write_log",
"pairwise_accuracy",
"accuracy",
"AverageMeter",
"MultiClassAverageMeter",
"DistributedMetric",
"DistributedTensor",
]
def sort_dict(src_dict, reverse=False, return_dict=True):
output = sorted(src_dict.items(), key=lambda x: x[1], reverse=reverse)
if return_dict:
return dict(output)
else:
return output
def get_same_padding(kernel_size):
if isinstance(kernel_size, tuple):
assert len(kernel_size) == 2, "invalid kernel size: %s" % kernel_size
p1 = get_same_padding(kernel_size[0])
p2 = get_same_padding(kernel_size[1])
return p1, p2
assert isinstance(kernel_size, int), "kernel size should be either `int` or `tuple`"
assert kernel_size % 2 > 0, "kernel size should be odd number"
return kernel_size // 2
def get_split_list(in_dim, child_num, accumulate=False):
in_dim_list = [in_dim // child_num] * child_num
for _i in range(in_dim % child_num):
in_dim_list[_i] += 1
if accumulate:
for i in range(1, child_num):
in_dim_list[i] += in_dim_list[i - 1]
return in_dim_list
def list_sum(x):
return x[0] if len(x) == 1 else x[0] + list_sum(x[1:])
def list_mean(x):
return list_sum(x) / len(x)
def list_join(val_list, sep="\t"):
return sep.join([str(val) for val in val_list])
def subset_mean(val_list, sub_indexes):
sub_indexes = val2list(sub_indexes, 1)
return list_mean([val_list[idx] for idx in sub_indexes])
def sub_filter_start_end(kernel_size, sub_kernel_size):
center = kernel_size // 2
dev = sub_kernel_size // 2
start, end = center - dev, center + dev + 1
assert end - start == sub_kernel_size
return start, end
def min_divisible_value(n1, v1):
"""make sure v1 is divisible by n1, otherwise decrease v1"""
if v1 >= n1:
return n1
while n1 % v1 != 0:
v1 -= 1
return v1
def val2list(val, repeat_time=1):
if isinstance(val, list) or isinstance(val, np.ndarray):
return val
elif isinstance(val, tuple):
return list(val)
else:
return [val for _ in range(repeat_time)]
def download_url(url, model_dir="~/.torch/", overwrite=False):
target_dir = url.split("/")[-1]
model_dir = os.path.expanduser(model_dir)
try:
if not os.path.exists(model_dir):
os.makedirs(model_dir)
model_dir = os.path.join(model_dir, target_dir)
cached_file = model_dir
if not os.path.exists(cached_file) or overwrite:
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
urlretrieve(url, cached_file)
return cached_file
except Exception as e:
# remove lock file so download can be executed next time.
os.remove(os.path.join(model_dir, "download.lock"))
sys.stderr.write("Failed to download from url %s" % url + "\n" + str(e) + "\n")
return None
def write_log(logs_path, log_str, prefix="valid", should_print=True, mode="a"):
if not os.path.exists(logs_path):
os.makedirs(logs_path, exist_ok=True)
""" prefix: valid, train, test """
if prefix in ["valid", "test"]:
with open(os.path.join(logs_path, "valid_console.txt"), mode) as fout:
fout.write(log_str + "\n")
fout.flush()
if prefix in ["valid", "test", "train"]:
with open(os.path.join(logs_path, "train_console.txt"), mode) as fout:
if prefix in ["valid", "test"]:
fout.write("=" * 10)
fout.write(log_str + "\n")
fout.flush()
else:
with open(os.path.join(logs_path, "%s.txt" % prefix), mode) as fout:
fout.write(log_str + "\n")
fout.flush()
if should_print:
print(log_str)
def pairwise_accuracy(la, lb, n_samples=200000):
n = len(la)
assert n == len(lb)
total = 0
count = 0
for _ in range(n_samples):
i = np.random.randint(n)
j = np.random.randint(n)
while i == j:
j = np.random.randint(n)
if la[i] >= la[j] and lb[i] >= lb[j]:
count += 1
if la[i] < la[j] and lb[i] < lb[j]:
count += 1
total += 1
return float(count) / total
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class AverageMeter(object):
"""
Computes and stores the average and current value
Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
"""
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class MultiClassAverageMeter:
"""Multi Binary Classification Tasks"""
def __init__(self, num_classes, balanced=False, **kwargs):
super(MultiClassAverageMeter, self).__init__()
self.num_classes = num_classes
self.balanced = balanced
self.counts = []
for k in range(self.num_classes):
self.counts.append(np.ndarray((2, 2), dtype=np.float32))
self.reset()
def reset(self):
for k in range(self.num_classes):
self.counts[k].fill(0)
def add(self, outputs, targets):
outputs = outputs.data.cpu().numpy()
targets = targets.data.cpu().numpy()
for k in range(self.num_classes):
output = np.argmax(outputs[:, k, :], axis=1)
target = targets[:, k]
x = output + 2 * target
bincount = np.bincount(x.astype(np.int32), minlength=2 ** 2)
self.counts[k] += bincount.reshape((2, 2))
def value(self):
mean = 0
for k in range(self.num_classes):
if self.balanced:
value = np.mean(
(
self.counts[k]
/ np.maximum(np.sum(self.counts[k], axis=1), 1)[:, None]
).diagonal()
)
else:
value = np.sum(self.counts[k].diagonal()) / np.maximum(
np.sum(self.counts[k]), 1
)
mean += value / self.num_classes * 100.0
return mean
class DistributedMetric(object):
"""
Horovod: average metrics from distributed training.
"""
def __init__(self, name):
self.name = name
self.sum = torch.zeros(1)[0]
self.count = torch.zeros(1)[0]
def update(self, val, delta_n=1):
import horovod.torch as hvd
val *= delta_n
self.sum += hvd.allreduce(val.detach().cpu(), name=self.name)
self.count += delta_n
@property
def avg(self):
return self.sum / self.count
class DistributedTensor(object):
def __init__(self, name):
self.name = name
self.sum = None
self.count = torch.zeros(1)[0]
self.synced = False
def update(self, val, delta_n=1):
val *= delta_n
if self.sum is None:
self.sum = val.detach()
else:
self.sum += val.detach()
self.count += delta_n
@property
def avg(self):
import horovod.torch as hvd
if not self.synced:
self.sum = hvd.allreduce(self.sum, name=self.name)
self.synced = True
return self.sum / self.count