import time import json import uuid import torch from threading import Thread, Event from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse from transformers import ( AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, LogitsProcessor, LogitsProcessorList, StoppingCriteria, StoppingCriteriaList, ) # ========================================================== # 設定 # ========================================================== MODEL_ID = "/workspace/output/glm4_7_30b/hf_temp_07f" VIEW_NAME = "RWKV-GLM-4.7-Flash-Preview-v0.1" HOST = "0.0.0.0" PORT = 8000 # ========================================================== # モデルロード # ========================================================== tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ) app = FastAPI() # ========================================================== # Logits Processors # ========================================================== class PresencePenaltyProcessor(LogitsProcessor): def __init__(self, penalty): self.penalty = penalty def __call__(self, input_ids, scores): for batch_idx in range(input_ids.shape[0]): unique_tokens = torch.unique(input_ids[batch_idx]) scores[batch_idx, unique_tokens] -= self.penalty return scores class FrequencyPenaltyProcessor(LogitsProcessor): def __init__(self, penalty): self.penalty = penalty def __call__(self, input_ids, scores): for batch_idx in range(input_ids.shape[0]): token_counts = torch.bincount( input_ids[batch_idx], minlength=scores.shape[-1] ) scores[batch_idx] -= token_counts * self.penalty return scores # ========================================================== # Cancellable Stopping Criteria # ========================================================== class CancelledStoppingCriteria(StoppingCriteria): """threading.Event がセットされたら生成を打ち切る""" def __init__(self, stop_event: Event): self.stop_event = stop_event def __call__(self, input_ids, scores, **kwargs): return self.stop_event.is_set() # ========================================================== # Models Endpoint # ========================================================== @app.get("/v1/models") async def list_models(): return { "object": "list", "data": [ { "id": VIEW_NAME, "object": "model", "created": int(time.time()), "owned_by": "local", } ], } # ========================================================== # Chat Completions Endpoint # ========================================================== @app.post("/v1/chat/completions") async def chat_completions(request: Request): body = await request.json() model_name = body.get("model", MODEL_ID) messages = body["messages"] stream = body.get("stream", False) temperature = body.get("temperature", 1.0) top_p = body.get("top_p", 1.0) top_k = body.get("top_k", 50) repetition_penalty = body.get("repetition_penalty", 1.0) presence_penalty = body.get("presence_penalty", 0.0) frequency_penalty = body.get("frequency_penalty", 0.0) max_tokens = body.get("max_tokens", 2048) prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) processors = LogitsProcessorList() if presence_penalty > 0: processors.append(PresencePenaltyProcessor(presence_penalty)) if frequency_penalty > 0: processors.append(FrequencyPenaltyProcessor(frequency_penalty)) generate_kwargs = dict( **inputs, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, logits_processor=processors, do_sample=temperature > 0, ) # ================= Non-stream ================= if not stream: outputs = model.generate(**generate_kwargs) completion_tokens = outputs.shape[1] - inputs["input_ids"].shape[1] generated_text = tokenizer.decode( outputs[0][inputs["input_ids"].shape[1] :], skip_special_tokens=False ) return { "id": f"chatcmpl-{uuid.uuid4().hex}", "object": "chat.completion", "created": int(time.time()), "model": model_name, "choices": [ { "index": 0, "message": {"role": "assistant", "content": generated_text}, "finish_reason": "stop", } ], "usage": { "prompt_tokens": inputs["input_ids"].shape[1], "completion_tokens": completion_tokens, "total_tokens": inputs["input_ids"].shape[1] + completion_tokens, }, } # ================= Streaming ================= stop_event = Event() stopping_criteria = StoppingCriteriaList( [CancelledStoppingCriteria(stop_event)] ) streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=False ) generation_kwargs = dict( **generate_kwargs, streamer=streamer, stopping_criteria=stopping_criteria, ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() async def event_generator(): completion_id = f"chatcmpl-{uuid.uuid4().hex}" firsttime = "" cancelled = False try: for new_text in streamer: if await request.is_disconnected(): stop_event.set() cancelled = True break chunk = { "id": completion_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": model_name, "choices": [ { "index": 0, "delta": {"content": firsttime + new_text}, "finish_reason": None, } ], } firsttime = "" yield f"data: {json.dumps(chunk)}\n\n" if not cancelled: yield "data: [DONE]\n\n" except Exception: stop_event.set() cancelled = True finally: if cancelled: for _ in streamer: pass thread.join(timeout=10) return StreamingResponse( event_generator(), media_type="text/event-stream" ) # ========================================================== # Python実行時に自動起動 # ========================================================== if __name__ == "__main__": import uvicorn uvicorn.run( "test_openai_api:app", host=HOST, port=PORT, reload=False, )