import torch
import torch.nn as nn

class ResConvBlock(nn.Module):
    def __init__(self, channels, time_dim, padding_mode='zeros'):
        super().__init__()
        self.first_conv = nn.Conv2d(channels, channels, 3, padding=1, bias=False, padding_mode=padding_mode)
        self.second_conv = nn.Conv2d(channels, channels, 3, padding=1, padding_mode=padding_mode)
        self.gn1 = nn.GroupNorm(1, channels, affine=True)
        self.gn2 = nn.GroupNorm(1, channels, affine=False)
        self.embd_affine = nn.Linear(time_dim, channels * 2)
        self.act = nn.LeakyReLU(inplace=True)

    def forward(self, x, t_emb):
        # Get affine parameters from time embedding
        affine_params = self.embd_affine(t_emb)
        scale, shift = affine_params.chunk(2, dim=1)

        # First convolution path
        h = self.first_conv(self.act(self.gn1(x)))

        # Second convolution path with adaptive normalization
        h = self.gn2(h)
        h = h * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
        h = self.second_conv(self.act(h))

        return x + h


class UpBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim, cat):
        super().__init__()
        self.res = ResConvBlock(in_ch, time_dim)
        self.up = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
        self.cat = cat

    def forward(self, x, t_emb, skip=None):
        x = self.res(x, t_emb)
        x = self.up(x)
        if self.cat:
            x = torch.cat([x, skip], dim=1)
        else:
            x = x + skip
        return x