ProArd / proard /utils /layers.py
smi08's picture
Upload folder using huggingface_hub
188f311 verified
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import torch
import torch.nn as nn
from collections import OrderedDict
from proard.utils import get_same_padding, min_divisible_value, SEModule, ShuffleLayer
from proard.utils import MyNetwork, MyModule
from proard.utils import build_activation, make_divisible
__all__ = [
"set_layer_from_config",
"ConvLayer",
"IdentityLayer",
"LinearLayer",
"MultiHeadLinearLayer",
"ZeroLayer",
"MBConvLayer",
"ResidualBlock",
"ResNetBottleneckBlock",
]
def set_layer_from_config(layer_config):
if layer_config is None:
return None
name2layer = {
ConvLayer.__name__: ConvLayer,
IdentityLayer.__name__: IdentityLayer,
LinearLayer.__name__: LinearLayer,
MultiHeadLinearLayer.__name__: MultiHeadLinearLayer,
ZeroLayer.__name__: ZeroLayer,
MBConvLayer.__name__: MBConvLayer,
"MBInvertedConvLayer": MBConvLayer,
##########################################################
ResidualBlock.__name__: ResidualBlock,
ResNetBottleneckBlock.__name__: ResNetBottleneckBlock,
}
layer_name = layer_config.pop("name")
layer = name2layer[layer_name]
return layer.build_from_config(layer_config)
class My2DLayer(MyModule):
def __init__(
self,
in_channels,
out_channels,
use_bn=True,
act_func="relu",
dropout_rate=0,
ops_order="weight_bn_act",
):
super(My2DLayer, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.use_bn = use_bn
self.act_func = act_func
self.dropout_rate = dropout_rate
self.ops_order = ops_order
""" modules """
modules = {}
# batch norm
if self.use_bn:
if self.bn_before_weight:
modules["bn"] = nn.BatchNorm2d(in_channels)
else:
modules["bn"] = nn.BatchNorm2d(out_channels)
else:
modules["bn"] = None
# activation
modules["act"] = build_activation(
self.act_func, self.ops_list[0] != "act" and self.use_bn
)
# dropout
if self.dropout_rate > 0:
modules["dropout"] = nn.Dropout2d(self.dropout_rate, inplace=True)
else:
modules["dropout"] = None
# weight
modules["weight"] = self.weight_op()
# add modules
for op in self.ops_list:
if modules[op] is None:
continue
elif op == "weight":
# dropout before weight operation
if modules["dropout"] is not None:
self.add_module("dropout", modules["dropout"])
for key in modules["weight"]:
self.add_module(key, modules["weight"][key])
else:
self.add_module(op, modules[op])
@property
def ops_list(self):
return self.ops_order.split("_")
@property
def bn_before_weight(self):
for op in self.ops_list:
if op == "bn":
return True
elif op == "weight":
return False
raise ValueError("Invalid ops_order: %s" % self.ops_order)
def weight_op(self):
raise NotImplementedError
""" Methods defined in MyModule """
def forward(self, x):
# similar to nn.Sequential
for module in self._modules.values():
x = module(x)
return x
@property
def module_str(self):
raise NotImplementedError
@property
def config(self):
return {
"in_channels": self.in_channels,
"out_channels": self.out_channels,
"use_bn": self.use_bn,
"act_func": self.act_func,
"dropout_rate": self.dropout_rate,
"ops_order": self.ops_order,
}
@staticmethod
def build_from_config(config):
raise NotImplementedError
class ConvLayer(My2DLayer):
def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
dilation=1,
groups=1,
bias=False,
has_shuffle=False,
use_se=False,
use_bn=True,
act_func="relu",
dropout_rate=0,
ops_order="weight_bn_act",
):
# default normal 3x3_Conv with bn and relu
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.groups = groups
self.bias = bias
self.has_shuffle = has_shuffle
self.use_se = use_se
super(ConvLayer, self).__init__(
in_channels, out_channels, use_bn, act_func, dropout_rate, ops_order
)
if self.use_se:
self.add_module("se", SEModule(self.out_channels))
def weight_op(self):
padding = get_same_padding(self.kernel_size)
if isinstance(padding, int):
padding *= self.dilation
else:
padding[0] *= self.dilation
padding[1] *= self.dilation
weight_dict = OrderedDict(
{
"conv": nn.Conv2d(
self.in_channels,
self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=padding,
dilation=self.dilation,
groups=min_divisible_value(self.in_channels, self.groups),
bias=self.bias,
)
}
)
if self.has_shuffle and self.groups > 1:
weight_dict["shuffle"] = ShuffleLayer(self.groups)
return weight_dict
@property
def module_str(self):
if isinstance(self.kernel_size, int):
kernel_size = (self.kernel_size, self.kernel_size)
else:
kernel_size = self.kernel_size
if self.groups == 1:
if self.dilation > 1:
conv_str = "%dx%d_DilatedConv" % (kernel_size[0], kernel_size[1])
else:
conv_str = "%dx%d_Conv" % (kernel_size[0], kernel_size[1])
else:
if self.dilation > 1:
conv_str = "%dx%d_DilatedGroupConv" % (kernel_size[0], kernel_size[1])
else:
conv_str = "%dx%d_GroupConv" % (kernel_size[0], kernel_size[1])
conv_str += "_O%d" % self.out_channels
if self.use_se:
conv_str = "SE_" + conv_str
conv_str += "_" + self.act_func.upper()
if self.use_bn:
if isinstance(self.bn, nn.GroupNorm):
conv_str += "_GN%d" % self.bn.num_groups
elif isinstance(self.bn, nn.BatchNorm2d):
conv_str += "_BN"
return conv_str
@property
def config(self):
return {
"name": ConvLayer.__name__,
"kernel_size": self.kernel_size,
"stride": self.stride,
"dilation": self.dilation,
"groups": self.groups,
"bias": self.bias,
"has_shuffle": self.has_shuffle,
"use_se": self.use_se,
**super(ConvLayer, self).config,
}
@staticmethod
def build_from_config(config):
return ConvLayer(**config)
class IdentityLayer(My2DLayer):
def __init__(
self,
in_channels,
out_channels,
use_bn=False,
act_func=None,
dropout_rate=0,
ops_order="weight_bn_act",
):
super(IdentityLayer, self).__init__(
in_channels, out_channels, use_bn, act_func, dropout_rate, ops_order
)
def weight_op(self):
return None
@property
def module_str(self):
return "Identity"
@property
def config(self):
return {
"name": IdentityLayer.__name__,
**super(IdentityLayer, self).config,
}
@staticmethod
def build_from_config(config):
return IdentityLayer(**config)
class LinearLayer(MyModule):
def __init__(
self,
in_features,
out_features,
bias=True,
use_bn=False,
act_func=None,
dropout_rate=0,
ops_order="weight_bn_act",
):
super(LinearLayer, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.bias = bias
self.use_bn = use_bn
self.act_func = act_func
self.dropout_rate = dropout_rate
self.ops_order = ops_order
""" modules """
modules = {}
# batch norm
if self.use_bn:
if self.bn_before_weight:
modules["bn"] = nn.BatchNorm1d(in_features)
else:
modules["bn"] = nn.BatchNorm1d(out_features)
else:
modules["bn"] = None
# activation
modules["act"] = build_activation(self.act_func, self.ops_list[0] != "act")
# dropout
if self.dropout_rate > 0:
modules["dropout"] = nn.Dropout(self.dropout_rate, inplace=True)
else:
modules["dropout"] = None
# linear
modules["weight"] = {
"linear": nn.Linear(self.in_features, self.out_features, self.bias)
}
# add modules
for op in self.ops_list:
if modules[op] is None:
continue
elif op == "weight":
if modules["dropout"] is not None:
self.add_module("dropout", modules["dropout"])
for key in modules["weight"]:
self.add_module(key, modules["weight"][key])
else:
self.add_module(op, modules[op])
@property
def ops_list(self):
return self.ops_order.split("_")
@property
def bn_before_weight(self):
for op in self.ops_list:
if op == "bn":
return True
elif op == "weight":
return False
raise ValueError("Invalid ops_order: %s" % self.ops_order)
def forward(self, x):
for module in self._modules.values():
x = module(x)
return x
@property
def module_str(self):
return "%dx%d_Linear" % (self.in_features, self.out_features)
@property
def config(self):
return {
"name": LinearLayer.__name__,
"in_features": self.in_features,
"out_features": self.out_features,
"bias": self.bias,
"use_bn": self.use_bn,
"act_func": self.act_func,
"dropout_rate": self.dropout_rate,
"ops_order": self.ops_order,
}
@staticmethod
def build_from_config(config):
return LinearLayer(**config)
class MultiHeadLinearLayer(MyModule):
def __init__(
self, in_features, out_features, num_heads=1, bias=True, dropout_rate=0
):
super(MultiHeadLinearLayer, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.num_heads = num_heads
self.bias = bias
self.dropout_rate = dropout_rate
if self.dropout_rate > 0:
self.dropout = nn.Dropout(self.dropout_rate, inplace=True)
else:
self.dropout = None
self.layers = nn.ModuleList()
for k in range(num_heads):
layer = nn.Linear(in_features, out_features, self.bias)
self.layers.append(layer)
def forward(self, inputs):
if self.dropout is not None:
inputs = self.dropout(inputs)
outputs = []
for layer in self.layers:
output = layer.forward(inputs)
outputs.append(output)
outputs = torch.stack(outputs, dim=1)
return outputs
@property
def module_str(self):
return self.__repr__()
@property
def config(self):
return {
"name": MultiHeadLinearLayer.__name__,
"in_features": self.in_features,
"out_features": self.out_features,
"num_heads": self.num_heads,
"bias": self.bias,
"dropout_rate": self.dropout_rate,
}
@staticmethod
def build_from_config(config):
return MultiHeadLinearLayer(**config)
def __repr__(self):
return (
"MultiHeadLinear(in_features=%d, out_features=%d, num_heads=%d, bias=%s, dropout_rate=%s)"
% (
self.in_features,
self.out_features,
self.num_heads,
self.bias,
self.dropout_rate,
)
)
class ZeroLayer(MyModule):
def __init__(self):
super(ZeroLayer, self).__init__()
def forward(self, x):
raise ValueError
@property
def module_str(self):
return "Zero"
@property
def config(self):
return {
"name": ZeroLayer.__name__,
}
@staticmethod
def build_from_config(config):
return ZeroLayer()
class MBConvLayer(MyModule):
def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
expand_ratio=6,
mid_channels=None,
act_func="relu6",
use_se=False,
groups=None,
):
super(MBConvLayer, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.expand_ratio = expand_ratio
self.mid_channels = mid_channels
self.act_func = act_func
self.use_se = use_se
self.groups = groups
if self.mid_channels is None:
feature_dim = round(self.in_channels * self.expand_ratio)
else:
feature_dim = self.mid_channels
if self.expand_ratio == 1:
self.inverted_bottleneck = None
else:
self.inverted_bottleneck = nn.Sequential(
OrderedDict(
[
(
"conv",
nn.Conv2d(
self.in_channels, feature_dim, 1, 1, 0, bias=False
),
),
("bn", nn.BatchNorm2d(feature_dim)),
("act", build_activation(self.act_func, inplace=True)),
]
)
)
pad = get_same_padding(self.kernel_size)
groups = (
feature_dim
if self.groups is None
else min_divisible_value(feature_dim, self.groups)
)
depth_conv_modules = [
(
"conv",
nn.Conv2d(
feature_dim,
feature_dim,
kernel_size,
stride,
pad,
groups=groups,
bias=False,
),
),
("bn", nn.BatchNorm2d(feature_dim)),
("act", build_activation(self.act_func, inplace=True)),
]
if self.use_se:
depth_conv_modules.append(("se", SEModule(feature_dim)))
self.depth_conv = nn.Sequential(OrderedDict(depth_conv_modules))
self.point_linear = nn.Sequential(
OrderedDict(
[
("conv", nn.Conv2d(feature_dim, out_channels, 1, 1, 0, bias=False)),
("bn", nn.BatchNorm2d(out_channels)),
]
)
)
def forward(self, x):
if self.inverted_bottleneck:
x = self.inverted_bottleneck(x)
x = self.depth_conv(x)
x = self.point_linear(x)
return x
@property
def module_str(self):
if self.mid_channels is None:
expand_ratio = self.expand_ratio
else:
expand_ratio = self.mid_channels // self.in_channels
layer_str = "%dx%d_MBConv%d_%s" % (
self.kernel_size,
self.kernel_size,
expand_ratio,
self.act_func.upper(),
)
if self.use_se:
layer_str = "SE_" + layer_str
layer_str += "_O%d" % self.out_channels
if self.groups is not None:
layer_str += "_G%d" % self.groups
if isinstance(self.point_linear.bn, nn.GroupNorm):
layer_str += "_GN%d" % self.point_linear.bn.num_groups
elif isinstance(self.point_linear.bn, nn.BatchNorm2d):
layer_str += "_BN"
return layer_str
@property
def config(self):
return {
"name": MBConvLayer.__name__,
"in_channels": self.in_channels,
"out_channels": self.out_channels,
"kernel_size": self.kernel_size,
"stride": self.stride,
"expand_ratio": self.expand_ratio,
"mid_channels": self.mid_channels,
"act_func": self.act_func,
"use_se": self.use_se,
"groups": self.groups,
}
@staticmethod
def build_from_config(config):
return MBConvLayer(**config)
class ResidualBlock(MyModule):
def __init__(self, conv, shortcut):
super(ResidualBlock, self).__init__()
self.conv = conv
self.shortcut = shortcut
def forward(self, x):
if self.conv is None or isinstance(self.conv, ZeroLayer):
res = x
elif self.shortcut is None or isinstance(self.shortcut, ZeroLayer):
res = self.conv(x)
else:
res = self.conv(x) + self.shortcut(x)
return res
@property
def module_str(self):
return "(%s, %s)" % (
self.conv.module_str if self.conv is not None else None,
self.shortcut.module_str if self.shortcut is not None else None,
)
@property
def config(self):
return {
"name": ResidualBlock.__name__,
"conv": self.conv.config if self.conv is not None else None,
"shortcut": self.shortcut.config if self.shortcut is not None else None,
}
@staticmethod
def build_from_config(config):
conv_config = (
config["conv"] if "conv" in config else config["mobile_inverted_conv"]
)
conv = set_layer_from_config(conv_config)
shortcut = set_layer_from_config(config["shortcut"])
return ResidualBlock(conv, shortcut)
@property
def mobile_inverted_conv(self):
return self.conv
class ResNetBottleneckBlock(MyModule):
def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
expand_ratio=0.25,
mid_channels=None,
act_func="relu",
groups=1,
downsample_mode="avgpool_conv",
):
super(ResNetBottleneckBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.expand_ratio = expand_ratio
self.mid_channels = mid_channels
self.act_func = act_func
self.groups = groups
self.downsample_mode = downsample_mode
if self.mid_channels is None:
feature_dim = round(self.out_channels * self.expand_ratio)
else:
feature_dim = self.mid_channels
feature_dim = make_divisible(feature_dim, MyNetwork.CHANNEL_DIVISIBLE)
self.mid_channels = feature_dim
# build modules
self.conv1 = nn.Sequential(
OrderedDict(
[
(
"conv",
nn.Conv2d(self.in_channels, feature_dim, 1, 1, 0, bias=False),
),
("bn", nn.BatchNorm2d(feature_dim)),
("act", build_activation(self.act_func, inplace=True)),
]
)
)
pad = get_same_padding(self.kernel_size)
self.conv2 = nn.Sequential(
OrderedDict(
[
(
"conv",
nn.Conv2d(
feature_dim,
feature_dim,
kernel_size,
stride,
pad,
groups=groups,
bias=False,
),
),
("bn", nn.BatchNorm2d(feature_dim)),
("act", build_activation(self.act_func, inplace=True)),
]
)
)
self.conv3 = nn.Sequential(
OrderedDict(
[
(
"conv",
nn.Conv2d(feature_dim, self.out_channels, 1, 1, 0, bias=False),
),
("bn", nn.BatchNorm2d(self.out_channels)),
]
)
)
if stride == 1 and in_channels == out_channels:
self.downsample = IdentityLayer(in_channels, out_channels)
elif self.downsample_mode == "conv":
self.downsample = nn.Sequential(
OrderedDict(
[
(
"conv",
nn.Conv2d(
in_channels, out_channels, 1, stride, 0, bias=False
),
),
("bn", nn.BatchNorm2d(out_channels)),
]
)
)
elif self.downsample_mode == "avgpool_conv":
self.downsample = nn.Sequential(
OrderedDict(
[
(
"avg_pool",
nn.AvgPool2d(
kernel_size=stride,
stride=stride,
padding=0,
ceil_mode=True,
),
),
(
"conv",
nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False),
),
("bn", nn.BatchNorm2d(out_channels)),
]
)
)
else:
raise NotImplementedError
self.final_act = build_activation(self.act_func, inplace=True)
def forward(self, x):
residual = self.downsample(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = x + residual
x = self.final_act(x)
return x
@property
def module_str(self):
return "(%s, %s)" % (
"%dx%d_BottleneckConv_%d->%d->%d_S%d_G%d"
% (
self.kernel_size,
self.kernel_size,
self.in_channels,
self.mid_channels,
self.out_channels,
self.stride,
self.groups,
),
"Identity"
if isinstance(self.downsample, IdentityLayer)
else self.downsample_mode,
)
@property
def config(self):
return {
"name": ResNetBottleneckBlock.__name__,
"in_channels": self.in_channels,
"out_channels": self.out_channels,
"kernel_size": self.kernel_size,
"stride": self.stride,
"expand_ratio": self.expand_ratio,
"mid_channels": self.mid_channels,
"act_func": self.act_func,
"groups": self.groups,
"downsample_mode": self.downsample_mode,
}
@staticmethod
def build_from_config(config):
return ResNetBottleneckBlock(**config)