agent_building / app.py
jarvis0852's picture
Update app.py
d7fd659 verified
# app.py
import os
import json
import time
import logging
from typing import Any
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse
from fastapi.middleware.cors import CORSMiddleware
import httpx
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("hf-proxy")
HF_TOKEN = os.environ.get("HF_TOKEN") # set in Space secrets
REPO_ID = "deepseek-ai/DeepSeek-V3.2-Exp-Base"
INFERENCE_URL = f"https://api-inference.huggingface.co/models/{REPO_ID}"
API_KEY = os.environ.get("API_KEY") # optional simple client key
app = FastAPI(title="HF Proxy + Simple UI")
# Allow CORS (for testing allow all; lock down in prod)
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["*"],
)
# Simple HTML test UI (served at /)
INDEX_HTML = """
<!doctype html>
<html>
<head><meta charset="utf-8"><title>HF Proxy Test UI</title></head>
<body style="font-family:system-ui;padding:18px;">
<h2>HF Proxy Test UI</h2>
<div>
<label>Space base URL (this server):</label>
<input id="base" value="" style="width:100%" />
<p>Leave base blank to use current origin.</p>
</div>
<div>
<textarea id="messages" style="width:100%;height:90px">[{"role":"user","content":"Who are you?"}]</textarea>
</div>
<div>
<input id="max" type="number" value="150" />
<button id="send">Send /predict</button>
</div>
<pre id="out" style="background:#222;color:#eee;padding:12px;border-radius:8px;"></pre>
<script>
document.getElementById("send").onclick = async function(){
let base = document.getElementById("base").value.trim() || window.location.origin;
const url = base + "/predict";
let messages;
try { messages = JSON.parse(document.getElementById("messages").value); } catch(e){ document.getElementById("out").textContent = "Invalid JSON"; return; }
const max = parseInt(document.getElementById("max").value) || 150;
const body = { messages: messages, max_new_tokens: max };
document.getElementById("out").textContent = "Calling " + url + " ...";
try {
const resp = await fetch(url, {
method: "POST",
headers: {"Content-Type":"application/json"},
body: JSON.stringify(body)
});
const txt = await resp.text();
document.getElementById("out").textContent = "HTTP " + resp.status + "\\n\\n" + txt;
} catch (err) {
document.getElementById("out").textContent = "Fetch error: " + err;
}
};
</script>
</body>
</html>
"""
@app.get("/", response_class=HTMLResponse)
async def root():
return HTMLResponse(INDEX_HTML)
@app.get("/health")
async def health():
return {"status": "ok", "time": time.time()}
def render_messages_as_prompt(messages: Any) -> str:
# Convert list of {role,content} to a single string prompt fallback
if isinstance(messages, str):
return messages
if isinstance(messages, dict) and "content" in messages:
return str(messages["content"])
if isinstance(messages, list):
parts = []
for m in messages:
if isinstance(m, dict) and "role" in m and "content" in m:
parts.append(f"<|{m['role']}|>{m['content']}<|end|>")
else:
parts.append(str(m))
return "\n".join(parts)
return str(messages)
@app.post("/predict")
async def predict(request: Request):
# optional API key auth
if API_KEY:
key = request.headers.get("x-api-key")
if key != API_KEY:
raise HTTPException(status_code=401, detail="Invalid API key")
try:
payload = await request.json()
except Exception:
raise HTTPException(status_code=400, detail="Invalid JSON body")
messages_raw = payload.get("messages") or payload.get("inputs") or payload.get("prompt")
if messages_raw is None:
raise HTTPException(status_code=400, detail="Missing 'messages' or 'inputs' or 'prompt' in JSON body")
max_new_tokens = int(payload.get("max_new_tokens", 256))
# Build safe HF request: prefer to send a plain prompt (works for most text-generation models)
# If your model accepts messages array, you can switch to sending messages_raw instead.
prompt = render_messages_as_prompt(messages_raw)
body = {
"inputs": prompt,
"parameters": {"max_new_tokens": max_new_tokens},
"options": {"wait_for_model": True}
}
headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
try:
async with httpx.AsyncClient(timeout=120) as client:
resp = await client.post(INFERENCE_URL, headers=headers, json=body)
if resp.status_code != 200:
# bubble up hf api error
detail = {"hf_status": resp.status_code, "content": resp.text}
raise HTTPException(status_code=502, detail=detail)
data = resp.json()
except HTTPException:
raise
except Exception as e:
logger.exception("Error contacting HF Inference API: %s", e)
raise HTTPException(status_code=502, detail=str(e))
# Normalize response
out_text = ""
if isinstance(data, list) and len(data) > 0:
first = data[0]
if isinstance(first, dict) and "generated_text" in first:
out_text = first["generated_text"]
elif isinstance(first, str):
out_text = first
else:
out_text = json.dumps(data)
elif isinstance(data, dict) and "generated_text" in data:
out_text = data["generated_text"]
elif isinstance(data, str):
out_text = data
else:
out_text = json.dumps(data)
return JSONResponse({"output": out_text, "raw": data})