# Once for All: Train One Network and Specialize it for Efficient Deployment # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han # International Conference on Learning Representations (ICLR), 2020. import torch import torch.nn as nn import torch.nn.functional as F from collections import OrderedDict from .my_modules import MyNetwork __all__ = [ "make_divisible", "build_activation", "ShuffleLayer", "MyGlobalAvgPool2d", "Hswish", "Hsigmoid", "SEModule", "MultiHeadCrossEntropyLoss", ] def make_divisible(v, divisor, min_val=None): """ This function is taken from the original tf repo. It ensures that all layers have a channel number that is divisible by 8 It can be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py :param v: :param divisor: :param min_val: :return: """ if min_val is None: min_val = divisor new_v = max(min_val, int(v + divisor / 2) // divisor * divisor) # Make sure that round down does not go down by more than 10%. if new_v < 0.9 * v: new_v += divisor return new_v def build_activation(act_func, inplace=True): if act_func == "relu": return nn.ReLU(inplace=inplace) elif act_func == "relu6": return nn.ReLU6(inplace=inplace) elif act_func == "tanh": return nn.Tanh() elif act_func == "sigmoid": return nn.Sigmoid() elif act_func == "h_swish": return Hswish(inplace=inplace) elif act_func == "h_sigmoid": return Hsigmoid(inplace=inplace) elif act_func is None or act_func == "none": return None else: raise ValueError("do not support: %s" % act_func) class ShuffleLayer(nn.Module): def __init__(self, groups): super(ShuffleLayer, self).__init__() self.groups = groups def forward(self, x): batch_size, num_channels, height, width = x.size() channels_per_group = num_channels // self.groups # reshape x = x.view(batch_size, self.groups, channels_per_group, height, width) x = torch.transpose(x, 1, 2).contiguous() # flatten x = x.view(batch_size, -1, height, width) return x def __repr__(self): return "ShuffleLayer(groups=%d)" % self.groups class MyGlobalAvgPool2d(nn.Module): def __init__(self, keep_dim=True): super(MyGlobalAvgPool2d, self).__init__() self.keep_dim = keep_dim def forward(self, x): return x.mean(3, keepdim=self.keep_dim).mean(2, keepdim=self.keep_dim) def __repr__(self): return "MyGlobalAvgPool2d(keep_dim=%s)" % self.keep_dim class Hswish(nn.Module): def __init__(self, inplace=True): super(Hswish, self).__init__() self.inplace = inplace def forward(self, x): return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0 def __repr__(self): return "Hswish()" class Hsigmoid(nn.Module): def __init__(self, inplace=True): super(Hsigmoid, self).__init__() self.inplace = inplace def forward(self, x): return F.relu6(x + 3.0, inplace=self.inplace) / 6.0 def __repr__(self): return "Hsigmoid()" class SEModule(nn.Module): REDUCTION = 4 def __init__(self, channel, reduction=None): super(SEModule, self).__init__() self.channel = channel self.reduction = SEModule.REDUCTION if reduction is None else reduction num_mid = make_divisible( self.channel // self.reduction, divisor=MyNetwork.CHANNEL_DIVISIBLE ) self.fc = nn.Sequential( OrderedDict( [ ("reduce", nn.Conv2d(self.channel, num_mid, 1, 1, 0, bias=True)), ("relu", nn.ReLU(inplace=True)), ("expand", nn.Conv2d(num_mid, self.channel, 1, 1, 0, bias=True)), ("h_sigmoid", Hsigmoid(inplace=True)), ] ) ) def forward(self, x): y = x.mean(3, keepdim=True).mean(2, keepdim=True) y = self.fc(y) return x * y def __repr__(self): return "SE(channel=%d, reduction=%d)" % (self.channel, self.reduction) class MultiHeadCrossEntropyLoss(nn.Module): def forward(self, outputs, targets): assert outputs.dim() == 3, outputs assert targets.dim() == 2, targets assert outputs.size(1) == targets.size(1), (outputs, targets) num_heads = targets.size(1) loss = 0 for k in range(num_heads): loss += F.cross_entropy(outputs[:, k, :], targets[:, k]) / num_heads return loss