Spaces:
Sleeping
Sleeping
| import torch | |
| import torchvision.transforms as T | |
| from typing import Optional | |
| from src.dataset import generate_image | |
| from src.models import CrossAttentionClassifier, VGGLikeEncode | |
| class CrossAttentionInference: | |
| def __init__( | |
| self, | |
| model_path: str, | |
| shape_params: Optional[dict] = None, | |
| device: torch.device = torch.device("cpu"), | |
| ): | |
| if not shape_params: | |
| self.shape_params = {} | |
| else: | |
| self.shape_params = shape_params | |
| self.device = device | |
| self.encoder = VGGLikeEncode( | |
| in_channels=1, | |
| out_channels=128, | |
| feature_dim=32, | |
| apply_pooling=False | |
| ) | |
| self.model = CrossAttentionClassifier(encoder=self.encoder) | |
| state_dict = torch.load(model_path, map_location=device) | |
| self.model.load_state_dict(state_dict) | |
| self.model.eval() | |
| self.model.to(device) | |
| self.transform = T.Compose([ | |
| T.ToTensor(), | |
| T.Normalize(mean=(0.5,), std=(0.5,)) | |
| ]) | |
| def pil_to_tensor(self, img): | |
| return self.transform(img).unsqueeze(0).to(self.device) | |
| def predict_random_pair(self): | |
| img1, _ = generate_image(**self.shape_params) | |
| img2, _ = generate_image(**self.shape_params) | |
| img1_tensor = self.pil_to_tensor(img1) | |
| img2_tensor = self.pil_to_tensor(img2) | |
| with torch.no_grad(): | |
| logits, _ = self.model(img1_tensor, img2_tensor) | |
| preds = (torch.sigmoid(logits) > 0.5).float() | |
| predicted_label = int(preds.item()) | |
| return predicted_label, (img1, img2) | |