import torch from pathlib import Path import gradio as gr import json from huggingface_hub import hf_hub_download # -------------------- DEVICE -------------------- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # -------------------- MODEL CONFIG -------------------- MODEL_NAME = "FlameF0X/i3-80m" LOCAL_SAFETENSORS = Path("model.safetensors") LOCAL_BIN = Path("pytorch_model.bin") VOCAB_JSON = Path("chunk_vocab_combined.json") # -------------------- LOAD VOCAB -------------------- with open(VOCAB_JSON, 'r') as f: vocab_data = json.load(f) VOCAB_SIZE = vocab_data["vocab_size"] # -------------------- IMPORT YOUR MODEL CLASS -------------------- from app_classes import i3Model, ChunkTokenizer tokenizer = ChunkTokenizer() tokenizer.load(VOCAB_JSON) model = i3Model( vocab_size=VOCAB_SIZE, d_model=512, n_heads=16, max_seq_len=256, d_state=32 ).to(DEVICE) # -------------------- LOAD WEIGHTS -------------------- try: if LOCAL_SAFETENSORS.exists(): from safetensors.torch import load_file state_dict = load_file(LOCAL_SAFETENSORS) model.load_state_dict(state_dict) print("✅ Loaded weights from local safetensors") elif LOCAL_BIN.exists(): state_dict = torch.load(LOCAL_BIN, map_location=DEVICE, weights_only=False) model.load_state_dict(state_dict) print("✅ Loaded weights from local .bin") else: print("⚡ Downloading model from HuggingFace...") bin_file = hf_hub_download(repo_id=MODEL_NAME, filename="pytorch_model.bin") state_dict = torch.load(bin_file, map_location=DEVICE, weights_only=False) model.load_state_dict(state_dict) print("✅ Loaded weights from HuggingFace") except Exception as e: raise RuntimeError(f"Failed to load model weights: {e}") model.eval() # -------------------- GENERATION FUNCTION -------------------- def generate_text(prompt, max_tokens=100, temperature=0.8, top_k=40): if not prompt.strip(): return "⚠️ Please enter a prompt to generate text." try: idx = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long).to(DEVICE) out_idx = model.generate(idx, max_new_tokens=max_tokens, temperature=temperature, top_k=top_k) return tokenizer.decode(out_idx[0].cpu()) except Exception as e: return f"❌ Generation error: {str(e)}" # -------------------- GRADIO UI -------------------- custom_css = """ .gradio-container { max-width: 1200px !important; } .main-header { text-align: center; margin-bottom: 2rem; } .param-card { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 1.5rem; border-radius: 12px; margin-bottom: 1rem; } """ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: # Header with gr.Row(): gr.Markdown( """ # 🚀 i3-80M Text Generation ### Powered by Mamba-based Architecture Generate creative text using the i3-80M language model with customizable parameters. """, elem_classes="main-header" ) # Main Generation Area with gr.Row(): with gr.Column(scale=2): prompt_input = gr.Textbox( label="✍️ Enter Your Prompt", placeholder="Once upon a time in a distant galaxy...", lines=4, max_lines=8 ) with gr.Accordion("⚙️ Generation Parameters", open=True): with gr.Row(): max_tokens_input = gr.Slider( 10, 500, value=100, step=10, label="Max Tokens", info="Maximum number of tokens to generate" ) temp_input = gr.Slider( 0.1, 2.0, value=0.8, step=0.05, label="Temperature", info="Higher = more creative, Lower = more focused" ) topk_input = gr.Slider( 1, 100, value=40, step=1, label="Top-k Sampling", info="Number of top tokens to consider" ) with gr.Row(): generate_btn = gr.Button("🎨 Generate Text", variant="primary", size="lg") clear_btn = gr.ClearButton(components=[prompt_input], value="🗑️ Clear", size="lg") with gr.Column(scale=2): output_text = gr.Textbox( label="📝 Generated Output", lines=12, max_lines=20, show_copy_button=True ) # Examples Section with gr.Row(): gr.Examples( examples=[ ["The future of artificial intelligence is", 150, 0.7, 50], ["In a world where technology and nature coexist", 200, 0.9, 40], ["The scientist discovered something remarkable", 120, 0.8, 45], ], inputs=[prompt_input, max_tokens_input, temp_input, topk_input], label="💡 Try These Examples" ) # Developer Panel with gr.Accordion("🔧 Developer Info", open=False): total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) with gr.Row(): with gr.Column(): gr.Markdown(f""" **Model Architecture:** - **Model:** i3-80M - **Device:** {DEVICE} - **Vocab Size:** {VOCAB_SIZE:,} - **Parameters:** {total_params:,} ({total_params/1e6:.2f}M) """) with gr.Column(): gr.Markdown(f""" **Configuration:** - **d_model:** 512 - **n_heads:** 16 - **max_seq_len:** 256 - **d_state:** 32 """) # Footer gr.Markdown( """ ---

Built with ❤️ using Gradio | Model: FlameF0X/i3-80m

""", ) # Connect UI generate_btn.click( generate_text, inputs=[prompt_input, max_tokens_input, temp_input, topk_input], outputs=[output_text] ) # -------------------- RUN -------------------- if __name__ == "__main__": demo.launch(share=False)