import torch
import torch.nn as nn

class MinibatchStd(nn.Module):
    """Minibatch standard deviation."""

    def __init__(self):
        super().__init__()

    def forward(self, x):
        std = torch.std(x, dim=(0, 1), keepdim=True) + torch.finfo(x.dtype).eps
        std = std.to(x.dtype)
        x = torch.cat([x, std.expand(x.shape[0], 1, *x.shape[2:])], dim=1)
        return x


class CompactGramMatrix(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels
        # Precompute indices for lower triangle (including diagonal)
        self.register_buffer('tril_indices',
                             torch.tril_indices(in_channels, in_channels, offset=0, dtype=torch.int32))

    def forward(self, x):
        """
        Input: (B, C, H, W)
        Output: (B, C*(C+1)//2) compact Gram features
        """
        b, c, h, w = x.size()
        x = x.view(b, c, -1) / ((h * w) ** 0.5)  # Flatten spatial dimensions -> (B, C, H*W), then normalise

        # Compute full Gram matrix (still needed temporarily)
        gram = torch.bmm(x, x.transpose(1, 2))  # (B, C, C)

        # Extract lower triangle including diagonal
        compact_gram = gram[:, self.tril_indices[0], self.tril_indices[1]]  # (B, n_unique)
        return compact_gram