|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from timm import create_model
|
|
|
|
|
|
class Identity(nn.Module):
|
|
|
def __init__(self):
|
|
|
super(Identity, self).__init__()
|
|
|
|
|
|
def forward(self, x):
|
|
|
return x
|
|
|
|
|
|
class SwinT(nn.Module):
|
|
|
def __init__(self, model_name='swin_base_patch4_window7_224', global_pool='avg', pretrained=True):
|
|
|
super(SwinT, self).__init__()
|
|
|
self.swin_model = create_model(
|
|
|
model_name, pretrained=pretrained, global_pool=global_pool
|
|
|
)
|
|
|
self.swin_model.head = Identity()
|
|
|
self.global_pool = global_pool
|
|
|
|
|
|
def forward(self, x):
|
|
|
features = self.swin_model(x)
|
|
|
if self.global_pool == 'avg':
|
|
|
features = features.mean(dim=[1, 2])
|
|
|
return features
|
|
|
|
|
|
def extract_features_swint_pool(video, model, device):
|
|
|
swint_feature_list = []
|
|
|
|
|
|
with torch.amp.autocast(device_type='cuda'):
|
|
|
for segment in video:
|
|
|
|
|
|
frames = segment.squeeze(0).to(device)
|
|
|
|
|
|
swint_features = model(frames)
|
|
|
swint_feature_list.append(swint_features)
|
|
|
|
|
|
|
|
|
features = torch.cat(swint_feature_list, dim=0)
|
|
|
return features
|
|
|
|