File size: 5,742 Bytes
85e6d30
d7fd659
 
 
 
 
 
 
 
85e6d30
 
d7fd659
 
 
 
84bb19d
85e6d30
d7fd659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85e6d30
d7fd659
 
 
84bb19d
d7fd659
 
 
 
 
 
 
 
 
 
 
4a9a5d4
d7fd659
 
 
84bb19d
d7fd659
 
 
 
 
 
 
 
 
 
 
84bb19d
d7fd659
 
 
 
84bb19d
d7fd659
 
 
 
4a9a5d4
85e6d30
 
 
d7fd659
4a9a5d4
d7fd659
 
 
 
 
 
 
 
 
 
 
 
 
4a9a5d4
d7fd659
 
85e6d30
 
 
d7fd659
 
 
 
 
 
 
 
 
 
 
cb3dc44
d7fd659
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# app.py
import os
import json
import time
import logging
from typing import Any
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse
from fastapi.middleware.cors import CORSMiddleware
import httpx

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("hf-proxy")

HF_TOKEN = os.environ.get("HF_TOKEN")  # set in Space secrets
REPO_ID = "deepseek-ai/DeepSeek-V3.2-Exp-Base"
INFERENCE_URL = f"https://api-inference.huggingface.co/models/{REPO_ID}"
API_KEY = os.environ.get("API_KEY")  # optional simple client key

app = FastAPI(title="HF Proxy + Simple UI")

# Allow CORS (for testing allow all; lock down in prod)
origins = ["*"]
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["GET", "POST", "OPTIONS"],
    allow_headers=["*"],
)

# Simple HTML test UI (served at /)
INDEX_HTML = """
<!doctype html>
<html>
<head><meta charset="utf-8"><title>HF Proxy Test UI</title></head>
<body style="font-family:system-ui;padding:18px;">
<h2>HF Proxy Test UI</h2>
<div>
  <label>Space base URL (this server):</label>
  <input id="base" value="" style="width:100%" />
  <p>Leave base blank to use current origin.</p>
</div>
<div>
  <textarea id="messages" style="width:100%;height:90px">[{"role":"user","content":"Who are you?"}]</textarea>
</div>
<div>
  <input id="max" type="number" value="150" />
  <button id="send">Send /predict</button>
</div>
<pre id="out" style="background:#222;color:#eee;padding:12px;border-radius:8px;"></pre>
<script>
document.getElementById("send").onclick = async function(){
  let base = document.getElementById("base").value.trim() || window.location.origin;
  const url = base + "/predict";
  let messages;
  try { messages = JSON.parse(document.getElementById("messages").value); } catch(e){ document.getElementById("out").textContent = "Invalid JSON"; return; }
  const max = parseInt(document.getElementById("max").value) || 150;
  const body = { messages: messages, max_new_tokens: max };
  document.getElementById("out").textContent = "Calling " + url + " ...";
  try {
    const resp = await fetch(url, {
      method: "POST",
      headers: {"Content-Type":"application/json"},
      body: JSON.stringify(body)
    });
    const txt = await resp.text();
    document.getElementById("out").textContent = "HTTP " + resp.status + "\\n\\n" + txt;
  } catch (err) {
    document.getElementById("out").textContent = "Fetch error: " + err;
  }
};
</script>
</body>
</html>
"""

@app.get("/", response_class=HTMLResponse)
async def root():
    return HTMLResponse(INDEX_HTML)

@app.get("/health")
async def health():
    return {"status": "ok", "time": time.time()}

def render_messages_as_prompt(messages: Any) -> str:
    # Convert list of {role,content} to a single string prompt fallback
    if isinstance(messages, str):
        return messages
    if isinstance(messages, dict) and "content" in messages:
        return str(messages["content"])
    if isinstance(messages, list):
        parts = []
        for m in messages:
            if isinstance(m, dict) and "role" in m and "content" in m:
                parts.append(f"<|{m['role']}|>{m['content']}<|end|>")
            else:
                parts.append(str(m))
        return "\n".join(parts)
    return str(messages)

@app.post("/predict")
async def predict(request: Request):
    # optional API key auth
    if API_KEY:
        key = request.headers.get("x-api-key")
        if key != API_KEY:
            raise HTTPException(status_code=401, detail="Invalid API key")
    try:
        payload = await request.json()
    except Exception:
        raise HTTPException(status_code=400, detail="Invalid JSON body")

    messages_raw = payload.get("messages") or payload.get("inputs") or payload.get("prompt")
    if messages_raw is None:
        raise HTTPException(status_code=400, detail="Missing 'messages' or 'inputs' or 'prompt' in JSON body")
    max_new_tokens = int(payload.get("max_new_tokens", 256))

    # Build safe HF request: prefer to send a plain prompt (works for most text-generation models)
    # If your model accepts messages array, you can switch to sending messages_raw instead.
    prompt = render_messages_as_prompt(messages_raw)
    body = {
        "inputs": prompt,
        "parameters": {"max_new_tokens": max_new_tokens},
        "options": {"wait_for_model": True}
    }
    headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}

    try:
        async with httpx.AsyncClient(timeout=120) as client:
            resp = await client.post(INFERENCE_URL, headers=headers, json=body)
            if resp.status_code != 200:
                # bubble up hf api error
                detail = {"hf_status": resp.status_code, "content": resp.text}
                raise HTTPException(status_code=502, detail=detail)
            data = resp.json()
    except HTTPException:
        raise
    except Exception as e:
        logger.exception("Error contacting HF Inference API: %s", e)
        raise HTTPException(status_code=502, detail=str(e))

    # Normalize response
    out_text = ""
    if isinstance(data, list) and len(data) > 0:
        first = data[0]
        if isinstance(first, dict) and "generated_text" in first:
            out_text = first["generated_text"]
        elif isinstance(first, str):
            out_text = first
        else:
            out_text = json.dumps(data)
    elif isinstance(data, dict) and "generated_text" in data:
        out_text = data["generated_text"]
    elif isinstance(data, str):
        out_text = data
    else:
        out_text = json.dumps(data)

    return JSONResponse({"output": out_text, "raw": data})