Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| # generate.py | |
| import torch | |
| from PIL import Image | |
| from transformers import ViTImageProcessor, CLIPProcessor, AutoTokenizer | |
| from vit_captioning.models.encoder import ViTEncoder, CLIPEncoder | |
| from vit_captioning.models.decoder import TransformerDecoder | |
| import argparse | |
| class CaptionGenerator: | |
| def __init__(self, model_type: str, checkpoint_path: str, quantized=False, runAsContainer=False): | |
| print(f"Loading {model_type} | Quantized: {quantized}") | |
| # Setup device | |
| if torch.cuda.is_available(): | |
| self.device = torch.device("cuda") | |
| print("Using NVIDIA CUDA GPU acceleration.") | |
| elif torch.backends.mps.is_available(): | |
| self.device = torch.device("mps") | |
| print("Using Apple MPS GPU acceleration.") | |
| else: | |
| self.device = torch.device("cpu") | |
| print("No GPU found, falling back to CPU.") | |
| # Load tokenizer | |
| #self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
| if (runAsContainer): | |
| self.tokenizer = AutoTokenizer.from_pretrained('/models/bert-tokenizer') | |
| else: | |
| self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
| # Select encoder, processor, output dim | |
| if model_type == "ViTEncoder": | |
| self.encoder = ViTEncoder().to(self.device) | |
| self.encoder_dim = 768 | |
| self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") | |
| elif model_type == "CLIPEncoder": | |
| self.encoder = CLIPEncoder().to(self.device) | |
| self.encoder_dim = 512 | |
| if (runAsContainer): | |
| self.processor = CLIPProcessor.from_pretrained("/models/clip") | |
| else: | |
| self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| else: | |
| raise ValueError("Unknown model type") | |
| if quantized: | |
| print("Applying dynamic quantization to encoder...") | |
| self.encoder = torch.ao.quantization.quantize_dynamic( | |
| self.encoder, | |
| {torch.nn.Linear}, | |
| dtype=torch.qint8 | |
| ) | |
| # Initialize decoder | |
| self.decoder = TransformerDecoder( | |
| vocab_size=30522, | |
| hidden_dim=self.encoder_dim, | |
| encoder_dim=self.encoder_dim | |
| ).to(self.device) | |
| # Load checkpoint | |
| checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
| self.encoder.load_state_dict(checkpoint['encoder_state_dict']) | |
| self.decoder.load_state_dict(checkpoint['decoder_state_dict']) | |
| self.encoder.eval() | |
| self.decoder.eval() | |
| def generate_caption(self, image_path: str) -> dict: | |
| image = Image.open(image_path).convert("RGB") | |
| encoding = self.processor(images=image, return_tensors='pt') | |
| pixel_values = encoding['pixel_values'].to(self.device) | |
| captions = {} | |
| with torch.no_grad(): | |
| encoder_outputs = self.encoder(pixel_values) | |
| # Greedy | |
| caption_ids = self.decoder.generate(encoder_outputs, mode="greedy") | |
| captions['greedy'] = self.tokenizer.decode(caption_ids[0], skip_special_tokens=True) | |
| # Top-k | |
| caption_ids = self.decoder.generate(encoder_outputs, mode="topk", top_k=30) | |
| captions['topk'] = self.tokenizer.decode(caption_ids[0], skip_special_tokens=True) | |
| # Top-p | |
| caption_ids = self.decoder.generate(encoder_outputs, mode="topp", top_p=0.92) | |
| captions['topp'] = self.tokenizer.decode(caption_ids[0], skip_special_tokens=True) | |
| return captions | |
| if __name__ == "__main__": | |
| # CLI usage | |
| parser = argparse.ArgumentParser(description="Generate caption using ViT or CLIP.") | |
| parser.add_argument("--model", type=str, default="ViTEncoder", | |
| choices=["ViTEncoder", "CLIPEncoder"], | |
| help="Choose encoder: ViTEncoder or CLIPEncoder") | |
| parser.add_argument("--checkpoint", type=str, required=True, | |
| help="Path to the .pth checkpoint file") | |
| parser.add_argument("--image", type=str, required=True, | |
| help="Path to input image file") | |
| parser.add_argument( | |
| "--quantized", | |
| action="store_true", | |
| help="Load encoder with dynamic quantization" | |
| ) ### ✅ ADDED | |
| args = parser.parse_args() | |
| generator = CaptionGenerator( | |
| model_type=args.model, | |
| checkpoint_path=args.checkpoint, | |
| runAsContainer=True | |
| ) | |
| captions = generator.generate_caption(args.image) | |
| print(f"Greedy-argmax (deterministic, factual): {captions['greedy']}") | |
| print(f"Top-k (diverse, creative): {captions['topk']}") | |
| print(f"Top-p (diverse, human-like): {captions['topp']}") |