Spaces:
Running
on
Zero
Running
on
Zero
| import random | |
| import numpy as np | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import spaces | |
| # Available models | |
| MODEL_OPTIONS_TITLE = { | |
| "V6 Model": "alakxender/t5-divehi-title-generation-v6", | |
| "XS Model": "alakxender/t5-dhivehi-title-generation-xs" | |
| } | |
| # Cache for loaded models/tokenizers | |
| MODEL_CACHE = {} | |
| def get_model_and_tokenizer(model_dir): | |
| if model_dir not in MODEL_CACHE: | |
| print(f"Loading model: {model_dir}") | |
| tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Moving model to device: {device}") | |
| model.to(device) | |
| MODEL_CACHE[model_dir] = (tokenizer, model) | |
| return MODEL_CACHE[model_dir] | |
| prefix = "2title: " | |
| max_input_length = 512 | |
| max_target_length = 32 | |
| def generate_title(content, seed, use_sampling, model_choice): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| model_dir = MODEL_OPTIONS_TITLE[model_choice] | |
| tokenizer, model = get_model_and_tokenizer(model_dir) | |
| input_text = prefix + content.strip() | |
| inputs = tokenizer( | |
| input_text, | |
| max_length=max_input_length, | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| gen_kwargs = { | |
| "input_ids": inputs["input_ids"], | |
| "attention_mask": inputs["attention_mask"], | |
| "max_length": max_target_length, | |
| "no_repeat_ngram_size": 2, | |
| } | |
| if use_sampling: | |
| gen_kwargs.update({ | |
| "do_sample": True, | |
| "temperature": 1.0, | |
| "top_p": 0.95, | |
| "num_return_sequences": 1, | |
| }) | |
| else: | |
| gen_kwargs.update({ | |
| "num_beams": 4, | |
| "do_sample": False, | |
| "early_stopping": True, | |
| }) | |
| with torch.no_grad(): | |
| outputs = model.generate(**gen_kwargs) | |
| title = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return title |