ProArd / proard /utils /pytorch_modules.py
smi08's picture
Upload folder using huggingface_hub
188f311 verified
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from .my_modules import MyNetwork
__all__ = [
"make_divisible",
"build_activation",
"ShuffleLayer",
"MyGlobalAvgPool2d",
"Hswish",
"Hsigmoid",
"SEModule",
"MultiHeadCrossEntropyLoss",
]
def make_divisible(v, divisor, min_val=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_val:
:return:
"""
if min_val is None:
min_val = divisor
new_v = max(min_val, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
def build_activation(act_func, inplace=True):
if act_func == "relu":
return nn.ReLU(inplace=inplace)
elif act_func == "relu6":
return nn.ReLU6(inplace=inplace)
elif act_func == "tanh":
return nn.Tanh()
elif act_func == "sigmoid":
return nn.Sigmoid()
elif act_func == "h_swish":
return Hswish(inplace=inplace)
elif act_func == "h_sigmoid":
return Hsigmoid(inplace=inplace)
elif act_func is None or act_func == "none":
return None
else:
raise ValueError("do not support: %s" % act_func)
class ShuffleLayer(nn.Module):
def __init__(self, groups):
super(ShuffleLayer, self).__init__()
self.groups = groups
def forward(self, x):
batch_size, num_channels, height, width = x.size()
channels_per_group = num_channels // self.groups
# reshape
x = x.view(batch_size, self.groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batch_size, -1, height, width)
return x
def __repr__(self):
return "ShuffleLayer(groups=%d)" % self.groups
class MyGlobalAvgPool2d(nn.Module):
def __init__(self, keep_dim=True):
super(MyGlobalAvgPool2d, self).__init__()
self.keep_dim = keep_dim
def forward(self, x):
return x.mean(3, keepdim=self.keep_dim).mean(2, keepdim=self.keep_dim)
def __repr__(self):
return "MyGlobalAvgPool2d(keep_dim=%s)" % self.keep_dim
class Hswish(nn.Module):
def __init__(self, inplace=True):
super(Hswish, self).__init__()
self.inplace = inplace
def forward(self, x):
return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0
def __repr__(self):
return "Hswish()"
class Hsigmoid(nn.Module):
def __init__(self, inplace=True):
super(Hsigmoid, self).__init__()
self.inplace = inplace
def forward(self, x):
return F.relu6(x + 3.0, inplace=self.inplace) / 6.0
def __repr__(self):
return "Hsigmoid()"
class SEModule(nn.Module):
REDUCTION = 4
def __init__(self, channel, reduction=None):
super(SEModule, self).__init__()
self.channel = channel
self.reduction = SEModule.REDUCTION if reduction is None else reduction
num_mid = make_divisible(
self.channel // self.reduction, divisor=MyNetwork.CHANNEL_DIVISIBLE
)
self.fc = nn.Sequential(
OrderedDict(
[
("reduce", nn.Conv2d(self.channel, num_mid, 1, 1, 0, bias=True)),
("relu", nn.ReLU(inplace=True)),
("expand", nn.Conv2d(num_mid, self.channel, 1, 1, 0, bias=True)),
("h_sigmoid", Hsigmoid(inplace=True)),
]
)
)
def forward(self, x):
y = x.mean(3, keepdim=True).mean(2, keepdim=True)
y = self.fc(y)
return x * y
def __repr__(self):
return "SE(channel=%d, reduction=%d)" % (self.channel, self.reduction)
class MultiHeadCrossEntropyLoss(nn.Module):
def forward(self, outputs, targets):
assert outputs.dim() == 3, outputs
assert targets.dim() == 2, targets
assert outputs.size(1) == targets.size(1), (outputs, targets)
num_heads = targets.size(1)
loss = 0
for k in range(num_heads):
loss += F.cross_entropy(outputs[:, k, :], targets[:, k]) / num_heads
return loss