from fastapi import FastAPI, Request, Form from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates import os import torch # ============================= # MODEL SETTINGS # ============================= from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel MODEL_ID = "DSDUDEd/Cass_beta1.0" # Use writable cache directories CACHE_DIR = "/tmp/hf_cache" os.makedirs(CACHE_DIR, exist_ok=True) os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR os.environ["HF_HOME"] = CACHE_DIR os.environ["HF_DATASETS_CACHE"] = CACHE_DIR os.environ["HF_METRICS_CACHE"] = CACHE_DIR os.environ["TORCH_HOME"] = CACHE_DIR # ============================= # FASTAPI APP INIT # ============================= app = FastAPI() app.mount("/static", StaticFiles(directory="."), name="static") templates = Jinja2Templates(directory=".") # ============================= # LOAD TOKENIZER + MODEL # ============================= print("🚀 Loading Cass tokenizer and model...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR) # Load base model first, then PEFT adapter base_model = AutoModelForCausalLM.from_pretrained( MODEL_ID, cache_dir=CACHE_DIR, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, low_cpu_mem_usage=True ) model = PeftModel.from_pretrained(base_model, MODEL_ID, cache_dir=CACHE_DIR) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) print(f"✅ Model loaded successfully on {device.upper()}") # ============================= # CHAT HISTORY # ============================= history = [] # ============================= # ROUTES # ============================= @app.get("/", response_class=HTMLResponse) async def get_chat(request: Request): return templates.TemplateResponse("index.html", {"request": request, "history": history}) @app.post("/", response_class=HTMLResponse) async def post_chat(request: Request, user_input: str = Form(...)): history.append({"role": "user", "content": user_input}) inputs = tokenizer(user_input, return_tensors="pt").to(device) outputs = model.generate( **inputs, max_length=250, pad_token_id=tokenizer.eos_token_id ) ai_response = tokenizer.decode(outputs[0], skip_special_tokens=True) history.append({"role": "ai", "content": ai_response}) return templates.TemplateResponse("index.html", {"request": request, "history": history}) # ============================= # RUN LOCAL (optional) # ============================= if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)