ProArd / proard /utils /my_modules.py
smi08's picture
Upload folder using huggingface_hub
188f311 verified
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import math
import torch.nn as nn
import torch.nn.functional as F
from .common_tools import min_divisible_value
__all__ = [
"MyModule",
"MyNetwork",
"init_models",
"set_bn_param",
"get_bn_param",
"replace_bn_with_gn",
"MyConv2d",
"replace_conv2d_with_my_conv2d",
]
def set_bn_param(net, momentum, eps, gn_channel_per_group=None, ws_eps=None, **kwargs):
replace_bn_with_gn(net, gn_channel_per_group)
for m in net.modules():
if type(m) in [nn.BatchNorm1d, nn.BatchNorm2d]:
m.momentum = momentum
m.eps = eps
elif isinstance(m, nn.GroupNorm):
m.eps = eps
replace_conv2d_with_my_conv2d(net, ws_eps)
return
def get_bn_param(net):
ws_eps = None
for m in net.modules():
if isinstance(m, MyConv2d):
ws_eps = m.WS_EPS
break
for m in net.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
return {
"momentum": m.momentum,
"eps": m.eps,
"ws_eps": ws_eps,
}
elif isinstance(m, nn.GroupNorm):
return {
"momentum": None,
"eps": m.eps,
"gn_channel_per_group": m.num_channels // m.num_groups,
"ws_eps": ws_eps,
}
return None
def replace_bn_with_gn(model, gn_channel_per_group):
if gn_channel_per_group is None:
return
for m in model.modules():
to_replace_dict = {}
for name, sub_m in m.named_children():
if isinstance(sub_m, nn.BatchNorm2d):
num_groups = sub_m.num_features // min_divisible_value(
sub_m.num_features, gn_channel_per_group
)
gn_m = nn.GroupNorm(
num_groups=num_groups,
num_channels=sub_m.num_features,
eps=sub_m.eps,
affine=True,
)
# load weight
gn_m.weight.data.copy_(sub_m.weight.data)
gn_m.bias.data.copy_(sub_m.bias.data)
# load requires_grad
gn_m.weight.requires_grad = sub_m.weight.requires_grad
gn_m.bias.requires_grad = sub_m.bias.requires_grad
to_replace_dict[name] = gn_m
m._modules.update(to_replace_dict)
def replace_conv2d_with_my_conv2d(net, ws_eps=None):
if ws_eps is None:
return
for m in net.modules():
to_update_dict = {}
for name, sub_module in m.named_children():
if isinstance(sub_module, nn.Conv2d) and not sub_module.bias:
# only replace conv2d layers that are followed by normalization layers (i.e., no bias)
to_update_dict[name] = sub_module
for name, sub_module in to_update_dict.items():
m._modules[name] = MyConv2d(
sub_module.in_channels,
sub_module.out_channels,
sub_module.kernel_size,
sub_module.stride,
sub_module.padding,
sub_module.dilation,
sub_module.groups,
sub_module.bias,
)
# load weight
m._modules[name].load_state_dict(sub_module.state_dict())
# load requires_grad
m._modules[name].weight.requires_grad = sub_module.weight.requires_grad
if sub_module.bias is not None:
m._modules[name].bias.requires_grad = sub_module.bias.requires_grad
# set ws_eps
for m in net.modules():
if isinstance(m, MyConv2d):
m.WS_EPS = ws_eps
def init_models(net, model_init="he_fout"):
"""
Conv2d,
BatchNorm2d, BatchNorm1d, GroupNorm
Linear,
"""
if isinstance(net, list):
for sub_net in net:
init_models(sub_net, model_init)
return
for m in net.modules():
if isinstance(m, nn.Conv2d):
if model_init == "he_fout":
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2.0 / n))
elif model_init == "he_fin":
n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
m.weight.data.normal_(0, math.sqrt(2.0 / n))
else:
raise NotImplementedError
if m.bias is not None:
m.bias.data.zero_()
elif type(m) in [nn.BatchNorm2d, nn.BatchNorm1d, nn.GroupNorm]:
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
stdv = 1.0 / math.sqrt(m.weight.size(1))
m.weight.data.uniform_(-stdv, stdv)
if m.bias is not None:
m.bias.data.zero_()
class MyConv2d(nn.Conv2d):
"""
Conv2d with Weight Standardization
https://github.com/joe-siyuan-qiao/WeightStandardization
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
):
super(MyConv2d, self).__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
)
self.WS_EPS = None
def weight_standardization(self, weight):
if self.WS_EPS is not None:
weight_mean = (
weight.mean(dim=1, keepdim=True)
.mean(dim=2, keepdim=True)
.mean(dim=3, keepdim=True)
)
weight = weight - weight_mean
std = (
weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1)
+ self.WS_EPS
)
weight = weight / std.expand_as(weight)
return weight
def forward(self, x):
if self.WS_EPS is None:
return super(MyConv2d, self).forward(x)
else:
return F.conv2d(
x,
self.weight_standardization(self.weight),
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)
def __repr__(self):
return super(MyConv2d, self).__repr__()[:-1] + ", ws_eps=%s)" % self.WS_EPS
class MyModule(nn.Module):
def forward(self, x):
raise NotImplementedError
@property
def module_str(self):
raise NotImplementedError
@property
def config(self):
raise NotImplementedError
@staticmethod
def build_from_config(config):
raise NotImplementedError
class MyNetwork(MyModule):
CHANNEL_DIVISIBLE = 8
def forward(self, x):
raise NotImplementedError
@property
def module_str(self):
raise NotImplementedError
@property
def config(self):
raise NotImplementedError
@staticmethod
def build_from_config(config):
raise NotImplementedError
def zero_last_gamma(self):
raise NotImplementedError
@property
def grouped_block_index(self):
raise NotImplementedError
""" implemented methods """
def set_bn_param(self, momentum, eps, gn_channel_per_group=None, **kwargs):
set_bn_param(self, momentum, eps, gn_channel_per_group, **kwargs)
def get_bn_param(self):
return get_bn_param(self)
def get_parameters(self, keys=None, mode="include"):
if keys is None:
for name, param in self.named_parameters():
if param.requires_grad:
yield param
elif mode == "include":
for name, param in self.named_parameters():
flag = False
for key in keys:
if key in name:
flag = True
break
if flag and param.requires_grad:
yield param
elif mode == "exclude":
for name, param in self.named_parameters():
flag = True
for key in keys:
if key in name:
flag = False
break
if flag and param.requires_grad:
yield param
else:
raise ValueError("do not support: %s" % mode)
def weight_parameters(self):
return self.get_parameters()