Cass-Beta4.0 / app.py
DSDUDEd's picture
Update app.py
ed589f7 verified
from fastapi import FastAPI, Request, Form
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import os
import torch
# =============================
# MODEL SETTINGS
# =============================
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
MODEL_ID = "DSDUDEd/Cass_beta1.0"
# Use writable cache directories
CACHE_DIR = "/tmp/hf_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
os.environ["HF_HOME"] = CACHE_DIR
os.environ["HF_DATASETS_CACHE"] = CACHE_DIR
os.environ["HF_METRICS_CACHE"] = CACHE_DIR
os.environ["TORCH_HOME"] = CACHE_DIR
# =============================
# FASTAPI APP INIT
# =============================
app = FastAPI()
app.mount("/static", StaticFiles(directory="."), name="static")
templates = Jinja2Templates(directory=".")
# =============================
# LOAD TOKENIZER + MODEL
# =============================
print("πŸš€ Loading Cass tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)
# Load base model first, then PEFT adapter
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
cache_dir=CACHE_DIR,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
low_cpu_mem_usage=True
)
model = PeftModel.from_pretrained(base_model, MODEL_ID, cache_dir=CACHE_DIR)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"βœ… Model loaded successfully on {device.upper()}")
# =============================
# CHAT HISTORY
# =============================
history = []
# =============================
# ROUTES
# =============================
@app.get("/", response_class=HTMLResponse)
async def get_chat(request: Request):
return templates.TemplateResponse("index.html", {"request": request, "history": history})
@app.post("/", response_class=HTMLResponse)
async def post_chat(request: Request, user_input: str = Form(...)):
history.append({"role": "user", "content": user_input})
inputs = tokenizer(user_input, return_tensors="pt").to(device)
outputs = model.generate(
**inputs,
max_length=250,
pad_token_id=tokenizer.eos_token_id
)
ai_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
history.append({"role": "ai", "content": ai_response})
return templates.TemplateResponse("index.html", {"request": request, "history": history})
# =============================
# RUN LOCAL (optional)
# =============================
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)