Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| from dataclasses import dataclass | |
| import torch | |
| import torch.nn as nn | |
| from xformers.components.attention import Attention, AttentionConfig, register_attention | |
| class VisualAttentionConfig(AttentionConfig): | |
| dim_model: int # dimension of the input sequence | |
| class LKA(nn.Module): | |
| def __init__(self, dim: int): | |
| super().__init__() | |
| self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) | |
| self.conv_spatial = nn.Conv2d( | |
| dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3 | |
| ) | |
| self.conv1 = nn.Conv2d(dim, dim, 1) | |
| def forward(self, x: torch.Tensor): | |
| u = x.clone() | |
| attn = self.conv0(x) | |
| attn = self.conv_spatial(attn) | |
| attn = self.conv1(attn) | |
| return u * attn | |
| class Visual(Attention): | |
| def __init__( | |
| self, | |
| dim_model: int, | |
| *_, | |
| **__, | |
| ): | |
| """ | |
| Large kernel attention mechanism, as proposed in `Visual Attention Network`_, Guo et al (2022). | |
| The original notation is tentatively kept as is. See https://github.com/Visual-Attention-Network | |
| for the reference implementation | |
| .. Note: compared to the paper, this block contains the LKA (Large Kernel Attention) | |
| and the prior and posterior transformations (Conv2d and activation) | |
| .. _`Visual Attention Network` : https://arxiv.org/pdf/2202.09741.pdf | |
| """ | |
| super().__init__() | |
| self.block = nn.Sequential( | |
| nn.Conv2d(dim_model, dim_model, 1), | |
| nn.GELU(), | |
| LKA(dim_model), | |
| nn.Conv2d(dim_model, dim_model, 1), | |
| ) | |
| # MHA related flags: | |
| self.requires_same_k_q_dimensions = ( | |
| True # This mechanism only really supports self attention | |
| ) | |
| self.supports_attention_mask = False | |
| self.requires_skip_multi_head = ( | |
| True # This mechanism skips the multihead attention altogether | |
| ) | |
| self.requires_squared_context = ( | |
| True # Recovering the 2D structure from context assumes squared content | |
| ) | |
| self.requires_input_projection = ( | |
| False # This mechanism does not require that the MHA projects inputs | |
| ) | |
| def forward(self, q: torch.Tensor, *_, **__): | |
| # Expose the 2D token structure | |
| B, HW, C = q.shape | |
| H = int(math.sqrt(HW)) | |
| assert H * H == HW | |
| x = q.transpose(-2, -1).reshape(B, C, H, H) | |
| # Large kernel attention | |
| residual = x.clone() | |
| x = self.block(x) | |
| x = x + residual | |
| # Get back to B HW C | |
| return x.flatten(2, 3).transpose(-2, -1) | |