Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, Request, Form | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| import os | |
| import torch | |
| # ============================= | |
| # MODEL SETTINGS | |
| # ============================= | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| MODEL_ID = "DSDUDEd/Cass_beta1.0" | |
| # Use writable cache directories | |
| CACHE_DIR = "/tmp/hf_cache" | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR | |
| os.environ["HF_HOME"] = CACHE_DIR | |
| os.environ["HF_DATASETS_CACHE"] = CACHE_DIR | |
| os.environ["HF_METRICS_CACHE"] = CACHE_DIR | |
| os.environ["TORCH_HOME"] = CACHE_DIR | |
| # ============================= | |
| # FASTAPI APP INIT | |
| # ============================= | |
| app = FastAPI() | |
| app.mount("/static", StaticFiles(directory="."), name="static") | |
| templates = Jinja2Templates(directory=".") | |
| # ============================= | |
| # LOAD TOKENIZER + MODEL | |
| # ============================= | |
| print("π Loading Cass tokenizer and model...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR) | |
| # Load base model first, then PEFT adapter | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| cache_dir=CACHE_DIR, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| low_cpu_mem_usage=True | |
| ) | |
| model = PeftModel.from_pretrained(base_model, MODEL_ID, cache_dir=CACHE_DIR) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| print(f"β Model loaded successfully on {device.upper()}") | |
| # ============================= | |
| # CHAT HISTORY | |
| # ============================= | |
| history = [] | |
| # ============================= | |
| # ROUTES | |
| # ============================= | |
| async def get_chat(request: Request): | |
| return templates.TemplateResponse("index.html", {"request": request, "history": history}) | |
| async def post_chat(request: Request, user_input: str = Form(...)): | |
| history.append({"role": "user", "content": user_input}) | |
| inputs = tokenizer(user_input, return_tensors="pt").to(device) | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=250, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| ai_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| history.append({"role": "ai", "content": ai_response}) | |
| return templates.TemplateResponse("index.html", {"request": request, "history": history}) | |
| # ============================= | |
| # RUN LOCAL (optional) | |
| # ============================= | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |