|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from collections import OrderedDict |
|
|
from .my_modules import MyNetwork |
|
|
|
|
|
__all__ = [ |
|
|
"make_divisible", |
|
|
"build_activation", |
|
|
"ShuffleLayer", |
|
|
"MyGlobalAvgPool2d", |
|
|
"Hswish", |
|
|
"Hsigmoid", |
|
|
"SEModule", |
|
|
"MultiHeadCrossEntropyLoss", |
|
|
] |
|
|
|
|
|
|
|
|
def make_divisible(v, divisor, min_val=None): |
|
|
""" |
|
|
This function is taken from the original tf repo. |
|
|
It ensures that all layers have a channel number that is divisible by 8 |
|
|
It can be seen here: |
|
|
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py |
|
|
:param v: |
|
|
:param divisor: |
|
|
:param min_val: |
|
|
:return: |
|
|
""" |
|
|
if min_val is None: |
|
|
min_val = divisor |
|
|
new_v = max(min_val, int(v + divisor / 2) // divisor * divisor) |
|
|
|
|
|
if new_v < 0.9 * v: |
|
|
new_v += divisor |
|
|
return new_v |
|
|
|
|
|
|
|
|
def build_activation(act_func, inplace=True): |
|
|
if act_func == "relu": |
|
|
return nn.ReLU(inplace=inplace) |
|
|
elif act_func == "relu6": |
|
|
return nn.ReLU6(inplace=inplace) |
|
|
elif act_func == "tanh": |
|
|
return nn.Tanh() |
|
|
elif act_func == "sigmoid": |
|
|
return nn.Sigmoid() |
|
|
elif act_func == "h_swish": |
|
|
return Hswish(inplace=inplace) |
|
|
elif act_func == "h_sigmoid": |
|
|
return Hsigmoid(inplace=inplace) |
|
|
elif act_func is None or act_func == "none": |
|
|
return None |
|
|
else: |
|
|
raise ValueError("do not support: %s" % act_func) |
|
|
|
|
|
|
|
|
class ShuffleLayer(nn.Module): |
|
|
def __init__(self, groups): |
|
|
super(ShuffleLayer, self).__init__() |
|
|
self.groups = groups |
|
|
|
|
|
def forward(self, x): |
|
|
batch_size, num_channels, height, width = x.size() |
|
|
channels_per_group = num_channels // self.groups |
|
|
|
|
|
x = x.view(batch_size, self.groups, channels_per_group, height, width) |
|
|
x = torch.transpose(x, 1, 2).contiguous() |
|
|
|
|
|
x = x.view(batch_size, -1, height, width) |
|
|
return x |
|
|
|
|
|
def __repr__(self): |
|
|
return "ShuffleLayer(groups=%d)" % self.groups |
|
|
|
|
|
|
|
|
class MyGlobalAvgPool2d(nn.Module): |
|
|
def __init__(self, keep_dim=True): |
|
|
super(MyGlobalAvgPool2d, self).__init__() |
|
|
self.keep_dim = keep_dim |
|
|
|
|
|
def forward(self, x): |
|
|
return x.mean(3, keepdim=self.keep_dim).mean(2, keepdim=self.keep_dim) |
|
|
|
|
|
def __repr__(self): |
|
|
return "MyGlobalAvgPool2d(keep_dim=%s)" % self.keep_dim |
|
|
|
|
|
|
|
|
class Hswish(nn.Module): |
|
|
def __init__(self, inplace=True): |
|
|
super(Hswish, self).__init__() |
|
|
self.inplace = inplace |
|
|
|
|
|
def forward(self, x): |
|
|
return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0 |
|
|
|
|
|
def __repr__(self): |
|
|
return "Hswish()" |
|
|
|
|
|
|
|
|
class Hsigmoid(nn.Module): |
|
|
def __init__(self, inplace=True): |
|
|
super(Hsigmoid, self).__init__() |
|
|
self.inplace = inplace |
|
|
|
|
|
def forward(self, x): |
|
|
return F.relu6(x + 3.0, inplace=self.inplace) / 6.0 |
|
|
|
|
|
def __repr__(self): |
|
|
return "Hsigmoid()" |
|
|
|
|
|
|
|
|
class SEModule(nn.Module): |
|
|
REDUCTION = 4 |
|
|
|
|
|
def __init__(self, channel, reduction=None): |
|
|
super(SEModule, self).__init__() |
|
|
|
|
|
self.channel = channel |
|
|
self.reduction = SEModule.REDUCTION if reduction is None else reduction |
|
|
|
|
|
num_mid = make_divisible( |
|
|
self.channel // self.reduction, divisor=MyNetwork.CHANNEL_DIVISIBLE |
|
|
) |
|
|
|
|
|
self.fc = nn.Sequential( |
|
|
OrderedDict( |
|
|
[ |
|
|
("reduce", nn.Conv2d(self.channel, num_mid, 1, 1, 0, bias=True)), |
|
|
("relu", nn.ReLU(inplace=True)), |
|
|
("expand", nn.Conv2d(num_mid, self.channel, 1, 1, 0, bias=True)), |
|
|
("h_sigmoid", Hsigmoid(inplace=True)), |
|
|
] |
|
|
) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
y = x.mean(3, keepdim=True).mean(2, keepdim=True) |
|
|
y = self.fc(y) |
|
|
return x * y |
|
|
|
|
|
def __repr__(self): |
|
|
return "SE(channel=%d, reduction=%d)" % (self.channel, self.reduction) |
|
|
|
|
|
|
|
|
class MultiHeadCrossEntropyLoss(nn.Module): |
|
|
def forward(self, outputs, targets): |
|
|
assert outputs.dim() == 3, outputs |
|
|
assert targets.dim() == 2, targets |
|
|
|
|
|
assert outputs.size(1) == targets.size(1), (outputs, targets) |
|
|
num_heads = targets.size(1) |
|
|
|
|
|
loss = 0 |
|
|
for k in range(num_heads): |
|
|
loss += F.cross_entropy(outputs[:, k, :], targets[:, k]) / num_heads |
|
|
return loss |
|
|
|