ProArd / proard /utils /pytorch_utils.py
smi08's picture
Upload folder using huggingface_hub
188f311 verified
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import math
import copy
import time
import torch
import torch.nn as nn
__all__ = [
"mix_images",
"mix_labels",
"label_smooth",
"cross_entropy_loss_with_soft_target",
"cross_entropy_with_label_smoothing",
"clean_num_batch_tracked",
"rm_bn_from_net",
"get_net_device",
"count_parameters",
"count_net_flops",
"measure_net_latency",
"get_net_info",
"build_optimizer",
"calc_learning_rate",
]
""" Mixup """
def mix_images(images, lam):
flipped_images = torch.flip(images, dims=[0]) # flip along the batch dimension
return lam * images + (1 - lam) * flipped_images
def mix_labels(target, lam, n_classes, label_smoothing=0.1):
onehot_target = label_smooth(target, n_classes, label_smoothing)
flipped_target = torch.flip(onehot_target, dims=[0])
return lam * onehot_target + (1 - lam) * flipped_target
""" Label smooth """
def label_smooth(target, n_classes: int, label_smoothing=0.1):
# convert to one-hot
batch_size = target.size(0)
target = torch.unsqueeze(target, 1)
soft_target = torch.zeros((batch_size, n_classes), device=target.device)
soft_target.scatter_(1, target, 1)
# label smoothing
soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes
return soft_target
def cross_entropy_loss_with_soft_target(pred, soft_target):
logsoftmax = nn.LogSoftmax()
return torch.mean(torch.sum(-soft_target * logsoftmax(pred), 1))
def cross_entropy_with_label_smoothing(pred, target, label_smoothing=0.1):
soft_target = label_smooth(target, pred.size(1), label_smoothing)
return cross_entropy_loss_with_soft_target(pred, soft_target)
""" BN related """
def clean_num_batch_tracked(net):
for m in net.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
if m.num_batches_tracked is not None:
m.num_batches_tracked.zero_()
def rm_bn_from_net(net):
for m in net.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
m.forward = lambda x: x
""" Network profiling """
def get_net_device(net):
return net.parameters().__next__().device
def count_parameters(net):
total_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
return total_params
def count_net_flops(net, data_shape=(1, 3, 32, 32)):
from .flops_counter import profile
if isinstance(net, nn.DataParallel):
net = net.module
flop, _ = profile(copy.deepcopy(net), data_shape)
return flop
def measure_net_latency(
net, l_type="gpu8", fast=True, input_shape=(3, 32, 32), clean=False
):
if isinstance(net, nn.DataParallel):
net = net.module
# remove bn from graph
rm_bn_from_net(net)
# return `ms`
if "gpu" in l_type:
l_type, batch_size = l_type[:3], int(l_type[3:])
else:
batch_size = 1
data_shape = [batch_size] + list(input_shape)
if l_type == "cpu":
if fast:
n_warmup = 5
n_sample = 10
else:
n_warmup = 50
n_sample = 50
if get_net_device(net) != torch.device("cpu"):
if not clean:
print("move net to cpu for measuring cpu latency")
net = copy.deepcopy(net).cpu()
elif l_type == "gpu":
if fast:
n_warmup = 5
n_sample = 10
else:
n_warmup = 50
n_sample = 50
else:
raise NotImplementedError
images = torch.zeros(data_shape, device=get_net_device(net))
measured_latency = {"warmup": [], "sample": []}
net.eval()
with torch.no_grad():
for i in range(n_warmup):
inner_start_time = time.time()
net(images)
used_time = (time.time() - inner_start_time) * 1e3 # ms
measured_latency["warmup"].append(used_time)
if not clean:
print("Warmup %d: %.3f" % (i, used_time))
outer_start_time = time.time()
for i in range(n_sample):
net(images)
total_time = (time.time() - outer_start_time) * 1e3 # ms
measured_latency["sample"].append((total_time, n_sample))
return total_time / n_sample, measured_latency
def get_net_info(net, input_shape=(3, 32, 32), measure_latency=None, print_info=True):
net_info = {}
if isinstance(net, nn.DataParallel):
net = net.module
# parameters
net_info["params"] = count_parameters(net) / 1e6
# flops
net_info["flops"] = count_net_flops(net, [1] + list(input_shape)) / 1e6
# latencies
latency_types = [] if measure_latency is None else measure_latency.split("#")
for l_type in latency_types:
latency, measured_latency = measure_net_latency(
net, l_type, fast=False, input_shape=input_shape
)
net_info["%s latency" % l_type] = {"val": latency, "hist": measured_latency}
if print_info:
print(net)
print("Total training params: %.2fM" % (net_info["params"]))
print("Total FLOPs: %.2fM" % (net_info["flops"]))
for l_type in latency_types:
print(
"Estimated %s latency: %.3fms"
% (l_type, net_info["%s latency" % l_type]["val"])
)
return net_info
""" optimizer """
def build_optimizer(
net_params, opt_type, opt_param, init_lr, weight_decay, no_decay_keys
):
if no_decay_keys is not None:
assert isinstance(net_params, list) and len(net_params) == 2
net_params = [
{"params": net_params[0], "weight_decay": weight_decay},
{"params": net_params[1], "weight_decay": 0},
]
else:
net_params = [{"params": net_params, "weight_decay": weight_decay}]
if opt_type == "sgd":
opt_param = {} if opt_param is None else opt_param
momentum, nesterov = opt_param.get("momentum", 0.9), opt_param.get(
"nesterov", True
)
optimizer = torch.optim.SGD(
net_params, init_lr, momentum=momentum, nesterov=nesterov
)
elif opt_type == "adam":
optimizer = torch.optim.Adam(net_params, init_lr)
else:
raise NotImplementedError
return optimizer
""" learning rate schedule """
def calc_learning_rate(
epoch, init_lr, n_epochs, batch=0, nBatch=None, lr_schedule_type="cosine"
):
if lr_schedule_type == "cosine":
t_total = n_epochs * nBatch
t_cur = epoch * nBatch + batch
lr = 0.5 * init_lr * (1 + math.cos(math.pi * t_cur / t_total))
elif lr_schedule_type is None:
lr = init_lr
else:
raise ValueError("do not support: %s" % lr_schedule_type)
return lr