Spaces:
Running
on
Zero
Running
on
Zero
| import random | |
| import numpy as np | |
| import torch | |
| from transformers import T5ForConditionalGeneration, T5Tokenizer | |
| import spaces | |
| # Available models for content generation | |
| MODEL_OPTIONS_CONTENT = { | |
| "MX02 (mixed)": { | |
| "model_id": "alakxender/flan-t5-corpora-mixed", | |
| "default_prompt": "Tell me about: " | |
| }, | |
| "MX01 (articles)": { | |
| "model_id": "alakxender/flan-t5-news-articles", | |
| "default_prompt": "Create an article about: " | |
| } | |
| } | |
| # Cache for loaded models/tokenizers | |
| MODEL_CACHE = {} | |
| def get_model_and_tokenizer(model_choice): | |
| model_dir = MODEL_OPTIONS_CONTENT[model_choice]["model_id"] | |
| if model_dir not in MODEL_CACHE: | |
| print(f"Loading model: {model_dir}") | |
| tokenizer = T5Tokenizer.from_pretrained(model_dir) | |
| model = T5ForConditionalGeneration.from_pretrained(model_dir) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Moving model to device: {device}") | |
| model.to(device) | |
| MODEL_CACHE[model_dir] = (tokenizer, model) | |
| return MODEL_CACHE[model_dir] | |
| def get_default_prompt(model_choice): | |
| return MODEL_OPTIONS_CONTENT[model_choice]["default_prompt"] | |
| def generate_content(prompt, max_new_tokens, num_beams, repetition_penalty, no_repeat_ngram_size, do_sample, model_choice): | |
| tokenizer, model = get_model_and_tokenizer(model_choice) | |
| prompt = get_default_prompt(model_choice) + prompt | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| num_beams=num_beams, | |
| repetition_penalty=repetition_penalty, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| do_sample=do_sample, | |
| early_stopping=False | |
| ) | |
| output_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Trim to the last period | |
| if '.' in output_text: | |
| last_period = output_text.rfind('.') | |
| output_text = output_text[:last_period+1] | |
| return output_text | |