| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from huggingface_hub import PyTorchModelHubMixin | |
| class IntermediateBlock(nn.Module): | |
| def __init__(self, in_channels, num_conv_layers, conv_params): | |
| super(IntermediateBlock, self).__init__() | |
| self.conv_layers = nn.ModuleList([nn.Conv2d(in_channels, *conv_params) for _ in range(num_conv_layers)]) | |
| self.batch_norms = nn.ModuleList([nn.BatchNorm2d(conv_params[0]) for _ in range(num_conv_layers)]) | |
| out_channels = conv_params[0] | |
| self.fc = nn.Linear(in_channels, out_channels) | |
| def forward(self, x): | |
| batch_size = x.size(0) | |
| channel_means = x.mean(dim=[2, 3]) | |
| a = self.fc(channel_means) | |
| x_out = torch.stack([F.leaky_relu(conv(x)) for conv in self.conv_layers], dim=-1).sum(dim=-1) | |
| x_out = torch.stack([bn(x_out) for bn in self.batch_norms], dim=-1).sum(dim=-1) | |
| return x_out * F.leaky_relu(a.view(batch_size, -1, 1, 1)) | |
| class OutputBlock(nn.Module): | |
| def __init__(self, in_channels, num_classes, hidden_sizes=[]): | |
| super(OutputBlock, self).__init__() | |
| self.fc_layers = nn.ModuleList([nn.Linear(in_channels, hidden_sizes[0])] + [nn.Linear(hidden_sizes[i], hidden_sizes[i+1]) for i in range(len(hidden_sizes)-1)] + [nn.Linear(hidden_sizes[-1], num_classes)]) | |
| self.batch_norms = nn.ModuleList([nn.BatchNorm1d(size) for size in hidden_sizes]) | |
| def forward(self, x): | |
| channel_means = x.mean(dim=[2, 3]) | |
| out = F.leaky_relu(channel_means) | |
| for fc, bn in zip(self.fc_layers, self.batch_norms): | |
| out = F.leaky_relu(bn(fc(out))) | |
| return out | |
| class CustomCIFAR10Net(nn.Module, PyTorchModelHubMixin): | |
| def __init__(self, num_classes=10): | |
| super(CustomCIFAR10Net, self).__init__() | |
| self.intermediate_blocks = nn.ModuleList([ | |
| IntermediateBlock(3, 3, [64, 3, 3, 1, 1]), | |
| IntermediateBlock(64, 3, [128, 3, 3, 1, 1]), | |
| IntermediateBlock(128, 3, [256, 3, 3, 1, 1]), | |
| IntermediateBlock(256, 3, [512, 3, 3, 1, 1]), | |
| IntermediateBlock(512, 3, [1024, 3, 3, 1, 1]) | |
| ]) | |
| self.output_block = OutputBlock(1024, num_classes, [512, 256]) | |
| self.dropout = nn.Dropout(0.5) | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): | |
| nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') | |
| def forward(self, x): | |
| for block in self.intermediate_blocks: | |
| x = block(x) | |
| x = self.dropout(x) | |
| x = self.output_block(x) | |
| return x |