Spaces:
Running
Running
File size: 5,742 Bytes
85e6d30 d7fd659 85e6d30 d7fd659 84bb19d 85e6d30 d7fd659 85e6d30 d7fd659 84bb19d d7fd659 4a9a5d4 d7fd659 84bb19d d7fd659 84bb19d d7fd659 84bb19d d7fd659 4a9a5d4 85e6d30 d7fd659 4a9a5d4 d7fd659 4a9a5d4 d7fd659 85e6d30 d7fd659 cb3dc44 d7fd659 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
# app.py
import os
import json
import time
import logging
from typing import Any
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse
from fastapi.middleware.cors import CORSMiddleware
import httpx
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("hf-proxy")
HF_TOKEN = os.environ.get("HF_TOKEN") # set in Space secrets
REPO_ID = "deepseek-ai/DeepSeek-V3.2-Exp-Base"
INFERENCE_URL = f"https://api-inference.huggingface.co/models/{REPO_ID}"
API_KEY = os.environ.get("API_KEY") # optional simple client key
app = FastAPI(title="HF Proxy + Simple UI")
# Allow CORS (for testing allow all; lock down in prod)
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["*"],
)
# Simple HTML test UI (served at /)
INDEX_HTML = """
<!doctype html>
<html>
<head><meta charset="utf-8"><title>HF Proxy Test UI</title></head>
<body style="font-family:system-ui;padding:18px;">
<h2>HF Proxy Test UI</h2>
<div>
<label>Space base URL (this server):</label>
<input id="base" value="" style="width:100%" />
<p>Leave base blank to use current origin.</p>
</div>
<div>
<textarea id="messages" style="width:100%;height:90px">[{"role":"user","content":"Who are you?"}]</textarea>
</div>
<div>
<input id="max" type="number" value="150" />
<button id="send">Send /predict</button>
</div>
<pre id="out" style="background:#222;color:#eee;padding:12px;border-radius:8px;"></pre>
<script>
document.getElementById("send").onclick = async function(){
let base = document.getElementById("base").value.trim() || window.location.origin;
const url = base + "/predict";
let messages;
try { messages = JSON.parse(document.getElementById("messages").value); } catch(e){ document.getElementById("out").textContent = "Invalid JSON"; return; }
const max = parseInt(document.getElementById("max").value) || 150;
const body = { messages: messages, max_new_tokens: max };
document.getElementById("out").textContent = "Calling " + url + " ...";
try {
const resp = await fetch(url, {
method: "POST",
headers: {"Content-Type":"application/json"},
body: JSON.stringify(body)
});
const txt = await resp.text();
document.getElementById("out").textContent = "HTTP " + resp.status + "\\n\\n" + txt;
} catch (err) {
document.getElementById("out").textContent = "Fetch error: " + err;
}
};
</script>
</body>
</html>
"""
@app.get("/", response_class=HTMLResponse)
async def root():
return HTMLResponse(INDEX_HTML)
@app.get("/health")
async def health():
return {"status": "ok", "time": time.time()}
def render_messages_as_prompt(messages: Any) -> str:
# Convert list of {role,content} to a single string prompt fallback
if isinstance(messages, str):
return messages
if isinstance(messages, dict) and "content" in messages:
return str(messages["content"])
if isinstance(messages, list):
parts = []
for m in messages:
if isinstance(m, dict) and "role" in m and "content" in m:
parts.append(f"<|{m['role']}|>{m['content']}<|end|>")
else:
parts.append(str(m))
return "\n".join(parts)
return str(messages)
@app.post("/predict")
async def predict(request: Request):
# optional API key auth
if API_KEY:
key = request.headers.get("x-api-key")
if key != API_KEY:
raise HTTPException(status_code=401, detail="Invalid API key")
try:
payload = await request.json()
except Exception:
raise HTTPException(status_code=400, detail="Invalid JSON body")
messages_raw = payload.get("messages") or payload.get("inputs") or payload.get("prompt")
if messages_raw is None:
raise HTTPException(status_code=400, detail="Missing 'messages' or 'inputs' or 'prompt' in JSON body")
max_new_tokens = int(payload.get("max_new_tokens", 256))
# Build safe HF request: prefer to send a plain prompt (works for most text-generation models)
# If your model accepts messages array, you can switch to sending messages_raw instead.
prompt = render_messages_as_prompt(messages_raw)
body = {
"inputs": prompt,
"parameters": {"max_new_tokens": max_new_tokens},
"options": {"wait_for_model": True}
}
headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
try:
async with httpx.AsyncClient(timeout=120) as client:
resp = await client.post(INFERENCE_URL, headers=headers, json=body)
if resp.status_code != 200:
# bubble up hf api error
detail = {"hf_status": resp.status_code, "content": resp.text}
raise HTTPException(status_code=502, detail=detail)
data = resp.json()
except HTTPException:
raise
except Exception as e:
logger.exception("Error contacting HF Inference API: %s", e)
raise HTTPException(status_code=502, detail=str(e))
# Normalize response
out_text = ""
if isinstance(data, list) and len(data) > 0:
first = data[0]
if isinstance(first, dict) and "generated_text" in first:
out_text = first["generated_text"]
elif isinstance(first, str):
out_text = first
else:
out_text = json.dumps(data)
elif isinstance(data, dict) and "generated_text" in data:
out_text = data["generated_text"]
elif isinstance(data, str):
out_text = data
else:
out_text = json.dumps(data)
return JSONResponse({"output": out_text, "raw": data})
|