Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import torch | |
| from v2.usta_model import UstaModel | |
| from v2.usta_tokenizer import UstaTokenizer | |
| # Load the model and tokenizer | |
| def load_model(custom_model_path=None): | |
| try: | |
| u_tokenizer = UstaTokenizer("v2/tokenizer.json") | |
| print("β Tokenizer loaded successfully! vocab size:", len(u_tokenizer.vocab)) | |
| # Model parameters - adjust these to match your trained model | |
| context_length = 32 | |
| vocab_size = len(u_tokenizer.vocab) | |
| embedding_dim = 12 | |
| num_heads = 4 | |
| num_layers = 8 | |
| device = "cpu" # Use CPU for compatibility | |
| # Load the model | |
| u_model = UstaModel( | |
| vocab_size=vocab_size, | |
| embedding_dim=embedding_dim, | |
| num_heads=num_heads, | |
| context_length=context_length, | |
| num_layers=num_layers, | |
| device=device | |
| ) | |
| # Determine which model file to use | |
| if custom_model_path and os.path.exists(custom_model_path): | |
| model_path = custom_model_path | |
| print(f"π― Using uploaded model: {model_path}") | |
| else: | |
| model_path = "v2/u_model_4000.pth" | |
| if not os.path.exists(model_path): | |
| print("β Model file not found at", model_path) | |
| # Download the model file from GitHub | |
| try: | |
| print("π₯ Downloading model weights from GitHub...") | |
| import requests | |
| url = "https://github.com/malibayram/llm-from-scratch/raw/main/u_model_4000.pth" | |
| headers = { | |
| 'Accept': 'application/octet-stream', | |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' | |
| } | |
| response = requests.get(url, headers=headers) | |
| response.raise_for_status() # Raise an exception for bad status codes | |
| # Check if we got a proper binary file (PyTorch files start with specific bytes) | |
| if response.content[:4] != b'PK\x03\x04' and b'<html' in response.content[:100].lower(): | |
| raise Exception("Downloaded HTML instead of binary file - check URL") | |
| print(f"π¦ Downloaded {len(response.content)} bytes") | |
| # Create v2 directory if it doesn't exist | |
| os.makedirs("v2", exist_ok=True) | |
| # Save the model weights to the local file system | |
| with open(model_path, "wb") as f: | |
| f.write(response.content) | |
| print("β Model weights saved successfully!") | |
| except Exception as e: | |
| print(f"β Failed to download model weights: {e}") | |
| print("Using random initialization.") | |
| if os.path.exists(model_path): | |
| try: | |
| state_dict = torch.load(model_path, map_location="cpu", weights_only=False) | |
| # Handle potential key mapping issues | |
| if "embedding.weight" in state_dict and "embedding.embedding.weight" not in state_dict: | |
| # Map old key names to new key names | |
| new_state_dict = {} | |
| for key, value in state_dict.items(): | |
| if key == "embedding.weight": | |
| new_state_dict["embedding.embedding.weight"] = value | |
| elif key == "pos_embedding.weight": | |
| # Skip positional embedding if not expected | |
| continue | |
| else: | |
| new_state_dict[key] = value | |
| state_dict = new_state_dict | |
| u_model.load_state_dict(state_dict) | |
| u_model.eval() | |
| print("β Model weights loaded successfully!") | |
| return u_model, u_tokenizer, f"β Model loaded from: {model_path}" | |
| except Exception as e: | |
| print(f"β οΈ Warning: Could not load trained weights: {e}") | |
| print("Using random initialization.") | |
| return u_model, u_tokenizer, f"β οΈ Failed to load weights: {e}" | |
| else: | |
| print(f"β οΈ Model file not found at {model_path}. Using random initialization.") | |
| return u_model, u_tokenizer, "β οΈ Using random initialization" | |
| except Exception as e: | |
| print(f"β Error loading model: {e}") | |
| raise e | |
| # Global model variables | |
| model, tokenizer, model_status = None, None, "Not loaded" | |
| # Initialize model and tokenizer globally | |
| try: | |
| model, tokenizer, model_status = load_model() | |
| print("π UstaModel and tokenizer initialized successfully!") | |
| except Exception as e: | |
| print(f"β Failed to initialize model: {e}") | |
| model, tokenizer, model_status = None, None, f"β Error: {e}" | |
| def load_model_from_url(url): | |
| """Load model from a URL""" | |
| global model, tokenizer, model_status | |
| if not url.strip(): | |
| return "β Please provide a URL" | |
| try: | |
| print(f"π₯ Downloading model from URL: {url}") | |
| import requests | |
| headers = { | |
| 'Accept': 'application/octet-stream', | |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' | |
| } | |
| response = requests.get(url, headers=headers) | |
| response.raise_for_status() | |
| # Check if we got a proper binary file | |
| if response.content[:4] != b'PK\x03\x04' and b'<html' in response.content[:100].lower(): | |
| return "β Downloaded HTML instead of binary file - check URL" | |
| # Save temporary file | |
| temp_path = "temp_model.pth" | |
| with open(temp_path, "wb") as f: | |
| f.write(response.content) | |
| # Load the model | |
| new_model, new_tokenizer, status = load_model(temp_path) | |
| # Update global variables | |
| model = new_model | |
| tokenizer = new_tokenizer | |
| model_status = status | |
| # Clean up temp file | |
| if os.path.exists(temp_path): | |
| os.remove(temp_path) | |
| return status | |
| except Exception as e: | |
| error_msg = f"β Failed to load model from URL: {e}" | |
| model_status = error_msg | |
| return error_msg | |
| def load_model_from_file(uploaded_file): | |
| """Load model from uploaded file""" | |
| global model, tokenizer, model_status | |
| if uploaded_file is None: | |
| return "β No file uploaded" | |
| try: | |
| # Check if the file path exists and is valid | |
| file_path = uploaded_file.name if hasattr(uploaded_file, 'name') else str(uploaded_file) | |
| # For HF Spaces compatibility, also try the upload path | |
| if not os.path.exists(file_path) and hasattr(uploaded_file, 'orig_name'): | |
| # Sometimes HF Spaces provides different paths | |
| print(f"Original path not found: {file_path}") | |
| print(f"Trying original name: {uploaded_file.orig_name}") | |
| file_path = uploaded_file.orig_name | |
| print(f"π Attempting to load model from: {file_path}") | |
| # Load the new model | |
| new_model, new_tokenizer, status = load_model(file_path) | |
| # Update global variables | |
| model = new_model | |
| tokenizer = new_tokenizer | |
| model_status = status | |
| return status | |
| except Exception as e: | |
| error_msg = f"β Failed to load uploaded model: {e}" | |
| print(f"Error details: {e}") | |
| model_status = error_msg | |
| return error_msg | |
| def chat_with_usta(message, history, max_tokens=20, temperature=1.0, top_k=64, top_p=1.0): | |
| """Simple chat function""" | |
| if model is None or tokenizer is None: | |
| return history + [["Error", "UstaModel is not available. Please try again later."]] | |
| try: | |
| # Encode the input message | |
| tokens = tokenizer.encode(message) | |
| # Make sure we don't exceed context length | |
| if len(tokens) > 25: # Leave some room for generation | |
| tokens = tokens[-25:] | |
| # Generate response | |
| with torch.no_grad(): | |
| actual_max_tokens = min(max_tokens, 32 - len(tokens)) | |
| generated_tokens = model.generate( | |
| tokens, | |
| max_new_tokens=actual_max_tokens, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p | |
| ) | |
| # Decode the generated tokens | |
| response = tokenizer.decode(generated_tokens) | |
| # Clean up the response (remove the original input) | |
| original_text = tokenizer.decode(tokens.tolist()) | |
| if response.startswith(original_text): | |
| response = response[len(original_text):] | |
| # Clean up any unwanted tokens | |
| response = response.replace("<unk>", "").replace("<pad>", "").strip() | |
| if not response: | |
| response = "I'm not sure how to respond to that with my geographical knowledge." | |
| # Add to history | |
| history.append([message, response]) | |
| return history | |
| except Exception as e: | |
| history.append([message, f"Sorry, I encountered an error: {str(e)}"]) | |
| return history | |
| # Create simple interface | |
| with gr.Blocks(title="π€ Usta Model Chat") as demo: | |
| gr.Markdown("# π€ Usta Model Chat") | |
| gr.Markdown("Chat with a custom transformer language model built from scratch! This model specializes in geographical knowledge.") | |
| # Simple chat interface | |
| chatbot = gr.Chatbot(height=400) | |
| msg = gr.Textbox(label="Your message", placeholder="Ask about countries, capitals, or cities...") | |
| with gr.Row(): | |
| send_btn = gr.Button("Send", variant="primary") | |
| clear_btn = gr.Button("Clear") | |
| # Generation settings | |
| gr.Markdown("## βοΈ Generation Settings") | |
| with gr.Row(): | |
| max_tokens = gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max tokens") | |
| temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature") | |
| with gr.Row(): | |
| top_k = gr.Slider(minimum=1, maximum=64, value=40, step=1, label="Top-k") | |
| top_p = gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.05, label="Top-p (nucleus sampling)") | |
| # Model loading (simplified) | |
| gr.Markdown("## π§ Load Custom Model (Optional)") | |
| with gr.Row(): | |
| model_url = gr.Textbox( | |
| label="Model URL", | |
| placeholder="https://github.com/malibayram/llm-from-scratch/raw/main/u_model_4000.pth", | |
| scale=3 | |
| ) | |
| load_url_btn = gr.Button("Load from URL", scale=1) | |
| with gr.Row(): | |
| model_file = gr.File(label="Upload model file (.pth, .pt, .bin)") | |
| load_file_btn = gr.Button("Load File", scale=1) | |
| status = gr.Textbox(label="Status", value=model_status, interactive=False) | |
| # Event handlers | |
| def send_message(message, history, max_tok, temp, k, p): | |
| if not message.strip(): | |
| return history, "" | |
| return chat_with_usta(message, history, max_tok, temp, k, p), "" | |
| send_btn.click( | |
| send_message, | |
| inputs=[msg, chatbot, max_tokens, temperature, top_k, top_p], | |
| outputs=[chatbot, msg] | |
| ) | |
| msg.submit( | |
| send_message, | |
| inputs=[msg, chatbot, max_tokens, temperature, top_k, top_p], | |
| outputs=[chatbot, msg] | |
| ) | |
| clear_btn.click(lambda: [], outputs=[chatbot]) | |
| load_url_btn.click( | |
| load_model_from_url, | |
| inputs=[model_url], | |
| outputs=[status] | |
| ) | |
| load_file_btn.click( | |
| load_model_from_file, | |
| inputs=[model_file], | |
| outputs=[status] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |