Spaces:
Build error
Build error
| import copy | |
| import torch | |
| import torch.nn as nn | |
| from typing import Optional | |
| def _get_clones(mod, n): | |
| return nn.ModuleList([copy.deepcopy(mod) for _ in range(n)]) | |
| class SPOTERTransformerDecoderLayer(nn.TransformerDecoderLayer): | |
| """ | |
| Edited TransformerDecoderLayer implementation omitting the redundant self-attention operation as opposed to the | |
| standard implementation. | |
| """ | |
| def __init__(self, d_model, nhead, dim_feedforward, dropout, activation): | |
| super(SPOTERTransformerDecoderLayer, self).__init__(d_model, nhead, dim_feedforward, dropout, activation) | |
| del self.self_attn | |
| def forward(self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None, | |
| memory_mask: Optional[torch.Tensor] = None, tgt_key_padding_mask: Optional[torch.Tensor] = None, | |
| memory_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| tgt = tgt + self.dropout1(tgt) | |
| tgt = self.norm1(tgt) | |
| tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, | |
| key_padding_mask=memory_key_padding_mask)[0] | |
| tgt = tgt + self.dropout2(tgt2) | |
| tgt = self.norm2(tgt) | |
| tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) | |
| tgt = tgt + self.dropout3(tgt2) | |
| tgt = self.norm3(tgt) | |
| return tgt | |
| class SPOTER(nn.Module): | |
| """ | |
| Implementation of the SPOTER (Sign POse-based TransformER) architecture for sign language recognition from sequence | |
| of skeletal data. | |
| """ | |
| def __init__(self, num_classes, hidden_dim=55): | |
| super().__init__() | |
| self.row_embed = nn.Parameter(torch.rand(50, hidden_dim)) | |
| self.pos = nn.Parameter(torch.cat([self.row_embed[0].unsqueeze(0).repeat(1, 1, 1)], dim=-1).flatten(0, 1).unsqueeze(0)) | |
| self.class_query = nn.Parameter(torch.rand(1, hidden_dim)) | |
| self.transformer = nn.Transformer(hidden_dim, 9, 6, 6) | |
| self.linear_class = nn.Linear(hidden_dim, num_classes) | |
| # Deactivate the initial attention decoder mechanism | |
| custom_decoder_layer = SPOTERTransformerDecoderLayer(self.transformer.d_model, self.transformer.nhead, 2048, | |
| 0.1, "relu") | |
| self.transformer.decoder.layers = _get_clones(custom_decoder_layer, self.transformer.decoder.num_layers) | |
| def forward(self, inputs): | |
| h = torch.unsqueeze(inputs.flatten(start_dim=1), 1).float() | |
| h = self.transformer(self.pos + h, self.class_query.unsqueeze(0)).transpose(0, 1) | |
| res = self.linear_class(h) | |
| return res | |
| if __name__ == "__main__": | |
| pass | |