import torch
import torch.nn as nn

class ResDown(nn.Module):
    def __init__(self, channels, padding_mode='zeros'):
        super().__init__()
        self.conv = nn.Conv2d(channels, 2*channels, 4, 2, 1, padding_mode=padding_mode)
        self.down = nn.AvgPool2d(2, 2)
        self.linear = nn.Conv2d(channels, 2*channels, 1)

    def forward(self, x):
        h = self.conv(x)
        x = self.linear(self.down(x))
        return x + h

class ResDownStrided(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Conv2d(channels, 2 * channels, 3, 2, 1)
        self.residual = nn.Conv2d(channels, 2 * channels, 1, 2)  # Strided conv

    def forward(self, x):
        return self.conv(x) + self.residual(x)