|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import os |
|
|
import sys |
|
|
import torch |
|
|
|
|
|
try: |
|
|
from urllib import urlretrieve |
|
|
except ImportError: |
|
|
from urllib.request import urlretrieve |
|
|
|
|
|
__all__ = [ |
|
|
"sort_dict", |
|
|
"get_same_padding", |
|
|
"get_split_list", |
|
|
"list_sum", |
|
|
"list_mean", |
|
|
"list_join", |
|
|
"subset_mean", |
|
|
"sub_filter_start_end", |
|
|
"min_divisible_value", |
|
|
"val2list", |
|
|
"download_url", |
|
|
"write_log", |
|
|
"pairwise_accuracy", |
|
|
"accuracy", |
|
|
"AverageMeter", |
|
|
"MultiClassAverageMeter", |
|
|
"DistributedMetric", |
|
|
"DistributedTensor", |
|
|
] |
|
|
|
|
|
|
|
|
def sort_dict(src_dict, reverse=False, return_dict=True): |
|
|
output = sorted(src_dict.items(), key=lambda x: x[1], reverse=reverse) |
|
|
if return_dict: |
|
|
return dict(output) |
|
|
else: |
|
|
return output |
|
|
|
|
|
|
|
|
def get_same_padding(kernel_size): |
|
|
if isinstance(kernel_size, tuple): |
|
|
assert len(kernel_size) == 2, "invalid kernel size: %s" % kernel_size |
|
|
p1 = get_same_padding(kernel_size[0]) |
|
|
p2 = get_same_padding(kernel_size[1]) |
|
|
return p1, p2 |
|
|
assert isinstance(kernel_size, int), "kernel size should be either `int` or `tuple`" |
|
|
assert kernel_size % 2 > 0, "kernel size should be odd number" |
|
|
return kernel_size // 2 |
|
|
|
|
|
|
|
|
def get_split_list(in_dim, child_num, accumulate=False): |
|
|
in_dim_list = [in_dim // child_num] * child_num |
|
|
for _i in range(in_dim % child_num): |
|
|
in_dim_list[_i] += 1 |
|
|
if accumulate: |
|
|
for i in range(1, child_num): |
|
|
in_dim_list[i] += in_dim_list[i - 1] |
|
|
return in_dim_list |
|
|
|
|
|
|
|
|
def list_sum(x): |
|
|
return x[0] if len(x) == 1 else x[0] + list_sum(x[1:]) |
|
|
|
|
|
|
|
|
def list_mean(x): |
|
|
return list_sum(x) / len(x) |
|
|
|
|
|
|
|
|
def list_join(val_list, sep="\t"): |
|
|
return sep.join([str(val) for val in val_list]) |
|
|
|
|
|
|
|
|
def subset_mean(val_list, sub_indexes): |
|
|
sub_indexes = val2list(sub_indexes, 1) |
|
|
return list_mean([val_list[idx] for idx in sub_indexes]) |
|
|
|
|
|
|
|
|
def sub_filter_start_end(kernel_size, sub_kernel_size): |
|
|
center = kernel_size // 2 |
|
|
dev = sub_kernel_size // 2 |
|
|
start, end = center - dev, center + dev + 1 |
|
|
assert end - start == sub_kernel_size |
|
|
return start, end |
|
|
|
|
|
|
|
|
def min_divisible_value(n1, v1): |
|
|
"""make sure v1 is divisible by n1, otherwise decrease v1""" |
|
|
if v1 >= n1: |
|
|
return n1 |
|
|
while n1 % v1 != 0: |
|
|
v1 -= 1 |
|
|
return v1 |
|
|
|
|
|
|
|
|
def val2list(val, repeat_time=1): |
|
|
if isinstance(val, list) or isinstance(val, np.ndarray): |
|
|
return val |
|
|
elif isinstance(val, tuple): |
|
|
return list(val) |
|
|
else: |
|
|
return [val for _ in range(repeat_time)] |
|
|
|
|
|
|
|
|
def download_url(url, model_dir="~/.torch/", overwrite=False): |
|
|
target_dir = url.split("/")[-1] |
|
|
model_dir = os.path.expanduser(model_dir) |
|
|
try: |
|
|
if not os.path.exists(model_dir): |
|
|
os.makedirs(model_dir) |
|
|
model_dir = os.path.join(model_dir, target_dir) |
|
|
cached_file = model_dir |
|
|
if not os.path.exists(cached_file) or overwrite: |
|
|
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) |
|
|
urlretrieve(url, cached_file) |
|
|
return cached_file |
|
|
except Exception as e: |
|
|
|
|
|
os.remove(os.path.join(model_dir, "download.lock")) |
|
|
sys.stderr.write("Failed to download from url %s" % url + "\n" + str(e) + "\n") |
|
|
return None |
|
|
|
|
|
|
|
|
def write_log(logs_path, log_str, prefix="valid", should_print=True, mode="a"): |
|
|
if not os.path.exists(logs_path): |
|
|
os.makedirs(logs_path, exist_ok=True) |
|
|
""" prefix: valid, train, test """ |
|
|
if prefix in ["valid", "test"]: |
|
|
with open(os.path.join(logs_path, "valid_console.txt"), mode) as fout: |
|
|
fout.write(log_str + "\n") |
|
|
fout.flush() |
|
|
if prefix in ["valid", "test", "train"]: |
|
|
with open(os.path.join(logs_path, "train_console.txt"), mode) as fout: |
|
|
if prefix in ["valid", "test"]: |
|
|
fout.write("=" * 10) |
|
|
fout.write(log_str + "\n") |
|
|
fout.flush() |
|
|
else: |
|
|
with open(os.path.join(logs_path, "%s.txt" % prefix), mode) as fout: |
|
|
fout.write(log_str + "\n") |
|
|
fout.flush() |
|
|
if should_print: |
|
|
print(log_str) |
|
|
|
|
|
|
|
|
def pairwise_accuracy(la, lb, n_samples=200000): |
|
|
n = len(la) |
|
|
assert n == len(lb) |
|
|
total = 0 |
|
|
count = 0 |
|
|
for _ in range(n_samples): |
|
|
i = np.random.randint(n) |
|
|
j = np.random.randint(n) |
|
|
while i == j: |
|
|
j = np.random.randint(n) |
|
|
if la[i] >= la[j] and lb[i] >= lb[j]: |
|
|
count += 1 |
|
|
if la[i] < la[j] and lb[i] < lb[j]: |
|
|
count += 1 |
|
|
total += 1 |
|
|
return float(count) / total |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def accuracy(output, target, topk=(1,)): |
|
|
"""Computes the precision@k for the specified values of k""" |
|
|
maxk = max(topk) |
|
|
batch_size = target.size(0) |
|
|
|
|
|
_, pred = output.topk(maxk, 1, True, True) |
|
|
pred = pred.t() |
|
|
correct = pred.eq(target.reshape(1, -1).expand_as(pred)) |
|
|
|
|
|
res = [] |
|
|
for k in topk: |
|
|
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) |
|
|
res.append(correct_k.mul_(100.0 / batch_size)) |
|
|
return res |
|
|
|
|
|
|
|
|
class AverageMeter(object): |
|
|
""" |
|
|
Computes and stores the average and current value |
|
|
Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self.val = 0 |
|
|
self.avg = 0 |
|
|
self.sum = 0 |
|
|
self.count = 0 |
|
|
|
|
|
def reset(self): |
|
|
self.val = 0 |
|
|
self.avg = 0 |
|
|
self.sum = 0 |
|
|
self.count = 0 |
|
|
|
|
|
def update(self, val, n=1): |
|
|
self.val = val |
|
|
self.sum += val * n |
|
|
self.count += n |
|
|
self.avg = self.sum / self.count |
|
|
|
|
|
|
|
|
class MultiClassAverageMeter: |
|
|
|
|
|
"""Multi Binary Classification Tasks""" |
|
|
|
|
|
def __init__(self, num_classes, balanced=False, **kwargs): |
|
|
|
|
|
super(MultiClassAverageMeter, self).__init__() |
|
|
self.num_classes = num_classes |
|
|
self.balanced = balanced |
|
|
|
|
|
self.counts = [] |
|
|
for k in range(self.num_classes): |
|
|
self.counts.append(np.ndarray((2, 2), dtype=np.float32)) |
|
|
|
|
|
self.reset() |
|
|
|
|
|
def reset(self): |
|
|
for k in range(self.num_classes): |
|
|
self.counts[k].fill(0) |
|
|
|
|
|
def add(self, outputs, targets): |
|
|
outputs = outputs.data.cpu().numpy() |
|
|
targets = targets.data.cpu().numpy() |
|
|
|
|
|
for k in range(self.num_classes): |
|
|
output = np.argmax(outputs[:, k, :], axis=1) |
|
|
target = targets[:, k] |
|
|
|
|
|
x = output + 2 * target |
|
|
bincount = np.bincount(x.astype(np.int32), minlength=2 ** 2) |
|
|
|
|
|
self.counts[k] += bincount.reshape((2, 2)) |
|
|
|
|
|
def value(self): |
|
|
mean = 0 |
|
|
for k in range(self.num_classes): |
|
|
if self.balanced: |
|
|
value = np.mean( |
|
|
( |
|
|
self.counts[k] |
|
|
/ np.maximum(np.sum(self.counts[k], axis=1), 1)[:, None] |
|
|
).diagonal() |
|
|
) |
|
|
else: |
|
|
value = np.sum(self.counts[k].diagonal()) / np.maximum( |
|
|
np.sum(self.counts[k]), 1 |
|
|
) |
|
|
|
|
|
mean += value / self.num_classes * 100.0 |
|
|
return mean |
|
|
|
|
|
|
|
|
class DistributedMetric(object): |
|
|
""" |
|
|
Horovod: average metrics from distributed training. |
|
|
""" |
|
|
|
|
|
def __init__(self, name): |
|
|
self.name = name |
|
|
self.sum = torch.zeros(1)[0] |
|
|
self.count = torch.zeros(1)[0] |
|
|
|
|
|
def update(self, val, delta_n=1): |
|
|
import horovod.torch as hvd |
|
|
|
|
|
val *= delta_n |
|
|
self.sum += hvd.allreduce(val.detach().cpu(), name=self.name) |
|
|
self.count += delta_n |
|
|
|
|
|
@property |
|
|
def avg(self): |
|
|
return self.sum / self.count |
|
|
|
|
|
|
|
|
class DistributedTensor(object): |
|
|
def __init__(self, name): |
|
|
self.name = name |
|
|
self.sum = None |
|
|
self.count = torch.zeros(1)[0] |
|
|
self.synced = False |
|
|
|
|
|
def update(self, val, delta_n=1): |
|
|
val *= delta_n |
|
|
if self.sum is None: |
|
|
self.sum = val.detach() |
|
|
else: |
|
|
self.sum += val.detach() |
|
|
self.count += delta_n |
|
|
|
|
|
@property |
|
|
def avg(self): |
|
|
import horovod.torch as hvd |
|
|
|
|
|
if not self.synced: |
|
|
self.sum = hvd.allreduce(self.sum, name=self.name) |
|
|
self.synced = True |
|
|
return self.sum / self.count |
|
|
|