Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| class ConvBlock(nn.Module): | |
| def __init__(self, in_channels: int, out_channels: int, downsample: bool = True, use_act: bool = True, | |
| use_dropout: bool = False, **kwargs): | |
| super(ConvBlock, self).__init__() | |
| self.conv_block = nn.Sequential( | |
| nn.Conv2d(in_channels=in_channels, out_channels=out_channels, padding_mode="reflect", **kwargs) | |
| if downsample | |
| else nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, **kwargs), | |
| nn.InstanceNorm2d(num_features=out_channels), | |
| nn.ReLU(inplace=True) if use_act else nn.Identity() | |
| ) | |
| if use_dropout: | |
| self.conv_block = nn.Sequential(self.conv_block, nn.Dropout(p=0.5)) | |
| def forward(self, x): | |
| return self.conv_block(x) | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, features: int): | |
| super(ResidualBlock, self).__init__() | |
| self.residual_block = nn.Sequential( | |
| ConvBlock(in_channels=features, out_channels=features, kernel_size=3, padding=1), | |
| ConvBlock(in_channels=features, out_channels=features, kernel_size=3, padding=1, use_act=False), | |
| ) | |
| def forward(self, x): | |
| return x + self.residual_block(x) | |
| class CycleGenerator(nn.Module): | |
| def __init__(self, img_channels: int = 3, latent_dim: int = 64, num_residuals: int = 9): | |
| super(CycleGenerator, self).__init__() | |
| self.base = nn.Sequential( | |
| nn.Conv2d(in_channels=img_channels, out_channels=latent_dim, kernel_size=7, stride=1, padding=3, | |
| padding_mode="reflect"), | |
| nn.ReLU(inplace=True) | |
| ) | |
| self.down_blocks = nn.ModuleList( | |
| [ | |
| ConvBlock(in_channels=latent_dim, out_channels=latent_dim * 2, kernel_size=3, stride=2, padding=1), | |
| ConvBlock(in_channels=latent_dim * 2, out_channels=latent_dim * 4, kernel_size=3, stride=2, padding=1), | |
| ] | |
| ) | |
| self.residual_blocks = nn.Sequential( | |
| *[ResidualBlock(features=latent_dim * 4) for _ in range(num_residuals)] | |
| ) | |
| self.up_blocks = nn.ModuleList( | |
| [ | |
| ConvBlock(in_channels=latent_dim * 4, out_channels=latent_dim * 2, kernel_size=3, stride=2, padding=1, | |
| output_padding=1, | |
| downsample=False), | |
| ConvBlock(in_channels=latent_dim * 2, out_channels=latent_dim, kernel_size=3, stride=2, padding=1, | |
| output_padding=1, | |
| downsample=False), | |
| ] | |
| ) | |
| self.head = nn.Conv2d(in_channels=latent_dim, out_channels=img_channels, kernel_size=7, stride=1, padding=3, | |
| padding_mode="reflect") | |
| def forward(self, x): | |
| x = self.base(x) | |
| for layer in self.down_blocks: | |
| x = layer(x) | |
| x = self.residual_blocks(x) | |
| for layer in self.up_blocks: | |
| x = layer(x) | |
| x = self.head(x) | |
| return torch.tanh(x) | |