RWKV-GLM-4.7-Flash-Preview-v0.1 / test_openai_api.py
OpenMOSE's picture
Upload folder using huggingface_hub
e3bb7ae
import time
import json
import uuid
import torch
from threading import Thread, Event
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
LogitsProcessor,
LogitsProcessorList,
StoppingCriteria,
StoppingCriteriaList,
)
# ==========================================================
# 設定
# ==========================================================
MODEL_ID = "/workspace/output/glm4_7_30b/hf_temp_07f"
VIEW_NAME = "RWKV-GLM-4.7-Flash-Preview-v0.1"
HOST = "0.0.0.0"
PORT = 8000
# ==========================================================
# モデルロード
# ==========================================================
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
app = FastAPI()
# ==========================================================
# Logits Processors
# ==========================================================
class PresencePenaltyProcessor(LogitsProcessor):
def __init__(self, penalty):
self.penalty = penalty
def __call__(self, input_ids, scores):
for batch_idx in range(input_ids.shape[0]):
unique_tokens = torch.unique(input_ids[batch_idx])
scores[batch_idx, unique_tokens] -= self.penalty
return scores
class FrequencyPenaltyProcessor(LogitsProcessor):
def __init__(self, penalty):
self.penalty = penalty
def __call__(self, input_ids, scores):
for batch_idx in range(input_ids.shape[0]):
token_counts = torch.bincount(
input_ids[batch_idx], minlength=scores.shape[-1]
)
scores[batch_idx] -= token_counts * self.penalty
return scores
# ==========================================================
# Cancellable Stopping Criteria
# ==========================================================
class CancelledStoppingCriteria(StoppingCriteria):
"""threading.Event がセットされたら生成を打ち切る"""
def __init__(self, stop_event: Event):
self.stop_event = stop_event
def __call__(self, input_ids, scores, **kwargs):
return self.stop_event.is_set()
# ==========================================================
# Models Endpoint
# ==========================================================
@app.get("/v1/models")
async def list_models():
return {
"object": "list",
"data": [
{
"id": VIEW_NAME,
"object": "model",
"created": int(time.time()),
"owned_by": "local",
}
],
}
# ==========================================================
# Chat Completions Endpoint
# ==========================================================
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
body = await request.json()
model_name = body.get("model", MODEL_ID)
messages = body["messages"]
stream = body.get("stream", False)
temperature = body.get("temperature", 1.0)
top_p = body.get("top_p", 1.0)
top_k = body.get("top_k", 50)
repetition_penalty = body.get("repetition_penalty", 1.0)
presence_penalty = body.get("presence_penalty", 0.0)
frequency_penalty = body.get("frequency_penalty", 0.0)
max_tokens = body.get("max_tokens", 2048)
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
processors = LogitsProcessorList()
if presence_penalty > 0:
processors.append(PresencePenaltyProcessor(presence_penalty))
if frequency_penalty > 0:
processors.append(FrequencyPenaltyProcessor(frequency_penalty))
generate_kwargs = dict(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
logits_processor=processors,
do_sample=temperature > 0,
)
# ================= Non-stream =================
if not stream:
outputs = model.generate(**generate_kwargs)
completion_tokens = outputs.shape[1] - inputs["input_ids"].shape[1]
generated_text = tokenizer.decode(
outputs[0][inputs["input_ids"].shape[1] :], skip_special_tokens=False
)
return {
"id": f"chatcmpl-{uuid.uuid4().hex}",
"object": "chat.completion",
"created": int(time.time()),
"model": model_name,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": generated_text},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": inputs["input_ids"].shape[1],
"completion_tokens": completion_tokens,
"total_tokens": inputs["input_ids"].shape[1] + completion_tokens,
},
}
# ================= Streaming =================
stop_event = Event()
stopping_criteria = StoppingCriteriaList(
[CancelledStoppingCriteria(stop_event)]
)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=False
)
generation_kwargs = dict(
**generate_kwargs,
streamer=streamer,
stopping_criteria=stopping_criteria,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
async def event_generator():
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
firsttime = "<think>"
cancelled = False
try:
for new_text in streamer:
if await request.is_disconnected():
stop_event.set()
cancelled = True
break
chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_name,
"choices": [
{
"index": 0,
"delta": {"content": firsttime + new_text},
"finish_reason": None,
}
],
}
firsttime = ""
yield f"data: {json.dumps(chunk)}\n\n"
if not cancelled:
yield "data: [DONE]\n\n"
except Exception:
stop_event.set()
cancelled = True
finally:
if cancelled:
for _ in streamer:
pass
thread.join(timeout=10)
return StreamingResponse(
event_generator(), media_type="text/event-stream"
)
# ==========================================================
# Python実行時に自動起動
# ==========================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"test_openai_api:app",
host=HOST,
port=PORT,
reload=False,
)