| | import time |
| | import json |
| | import uuid |
| | import torch |
| | from threading import Thread, Event |
| | from fastapi import FastAPI, Request |
| | from fastapi.responses import StreamingResponse |
| | from transformers import ( |
| | AutoModelForCausalLM, |
| | AutoTokenizer, |
| | TextIteratorStreamer, |
| | LogitsProcessor, |
| | LogitsProcessorList, |
| | StoppingCriteria, |
| | StoppingCriteriaList, |
| | ) |
| |
|
| | |
| | |
| | |
| | MODEL_ID = "/workspace/output/glm4_7_30b/hf_temp_07f" |
| | VIEW_NAME = "RWKV-GLM-4.7-Flash-Preview-v0.1" |
| | HOST = "0.0.0.0" |
| | PORT = 8000 |
| |
|
| | |
| | |
| | |
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) |
| | model = AutoModelForCausalLM.from_pretrained( |
| | MODEL_ID, |
| | torch_dtype=torch.bfloat16, |
| | device_map="auto", |
| | trust_remote_code=True, |
| | ) |
| |
|
| | app = FastAPI() |
| |
|
| |
|
| | |
| | |
| | |
| | class PresencePenaltyProcessor(LogitsProcessor): |
| | def __init__(self, penalty): |
| | self.penalty = penalty |
| |
|
| | def __call__(self, input_ids, scores): |
| | for batch_idx in range(input_ids.shape[0]): |
| | unique_tokens = torch.unique(input_ids[batch_idx]) |
| | scores[batch_idx, unique_tokens] -= self.penalty |
| | return scores |
| |
|
| |
|
| | class FrequencyPenaltyProcessor(LogitsProcessor): |
| | def __init__(self, penalty): |
| | self.penalty = penalty |
| |
|
| | def __call__(self, input_ids, scores): |
| | for batch_idx in range(input_ids.shape[0]): |
| | token_counts = torch.bincount( |
| | input_ids[batch_idx], minlength=scores.shape[-1] |
| | ) |
| | scores[batch_idx] -= token_counts * self.penalty |
| | return scores |
| |
|
| |
|
| | |
| | |
| | |
| | class CancelledStoppingCriteria(StoppingCriteria): |
| | """threading.Event がセットされたら生成を打ち切る""" |
| |
|
| | def __init__(self, stop_event: Event): |
| | self.stop_event = stop_event |
| |
|
| | def __call__(self, input_ids, scores, **kwargs): |
| | return self.stop_event.is_set() |
| |
|
| |
|
| | |
| | |
| | |
| | @app.get("/v1/models") |
| | async def list_models(): |
| | return { |
| | "object": "list", |
| | "data": [ |
| | { |
| | "id": VIEW_NAME, |
| | "object": "model", |
| | "created": int(time.time()), |
| | "owned_by": "local", |
| | } |
| | ], |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| | @app.post("/v1/chat/completions") |
| | async def chat_completions(request: Request): |
| | body = await request.json() |
| |
|
| | model_name = body.get("model", MODEL_ID) |
| | messages = body["messages"] |
| | stream = body.get("stream", False) |
| |
|
| | temperature = body.get("temperature", 1.0) |
| | top_p = body.get("top_p", 1.0) |
| | top_k = body.get("top_k", 50) |
| | repetition_penalty = body.get("repetition_penalty", 1.0) |
| | presence_penalty = body.get("presence_penalty", 0.0) |
| | frequency_penalty = body.get("frequency_penalty", 0.0) |
| | max_tokens = body.get("max_tokens", 2048) |
| |
|
| | prompt = tokenizer.apply_chat_template( |
| | messages, tokenize=False, add_generation_prompt=True |
| | ) |
| | inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| |
|
| | processors = LogitsProcessorList() |
| | if presence_penalty > 0: |
| | processors.append(PresencePenaltyProcessor(presence_penalty)) |
| | if frequency_penalty > 0: |
| | processors.append(FrequencyPenaltyProcessor(frequency_penalty)) |
| |
|
| | generate_kwargs = dict( |
| | **inputs, |
| | max_new_tokens=max_tokens, |
| | temperature=temperature, |
| | top_p=top_p, |
| | top_k=top_k, |
| | repetition_penalty=repetition_penalty, |
| | logits_processor=processors, |
| | do_sample=temperature > 0, |
| | ) |
| |
|
| | |
| | if not stream: |
| | outputs = model.generate(**generate_kwargs) |
| | completion_tokens = outputs.shape[1] - inputs["input_ids"].shape[1] |
| | generated_text = tokenizer.decode( |
| | outputs[0][inputs["input_ids"].shape[1] :], skip_special_tokens=False |
| | ) |
| | return { |
| | "id": f"chatcmpl-{uuid.uuid4().hex}", |
| | "object": "chat.completion", |
| | "created": int(time.time()), |
| | "model": model_name, |
| | "choices": [ |
| | { |
| | "index": 0, |
| | "message": {"role": "assistant", "content": generated_text}, |
| | "finish_reason": "stop", |
| | } |
| | ], |
| | "usage": { |
| | "prompt_tokens": inputs["input_ids"].shape[1], |
| | "completion_tokens": completion_tokens, |
| | "total_tokens": inputs["input_ids"].shape[1] + completion_tokens, |
| | }, |
| | } |
| |
|
| | |
| | stop_event = Event() |
| |
|
| | stopping_criteria = StoppingCriteriaList( |
| | [CancelledStoppingCriteria(stop_event)] |
| | ) |
| |
|
| | streamer = TextIteratorStreamer( |
| | tokenizer, skip_prompt=True, skip_special_tokens=False |
| | ) |
| |
|
| | generation_kwargs = dict( |
| | **generate_kwargs, |
| | streamer=streamer, |
| | stopping_criteria=stopping_criteria, |
| | ) |
| |
|
| | thread = Thread(target=model.generate, kwargs=generation_kwargs) |
| | thread.start() |
| |
|
| | async def event_generator(): |
| | completion_id = f"chatcmpl-{uuid.uuid4().hex}" |
| | firsttime = "<think>" |
| | cancelled = False |
| |
|
| | try: |
| | for new_text in streamer: |
| | if await request.is_disconnected(): |
| | stop_event.set() |
| | cancelled = True |
| | break |
| |
|
| | chunk = { |
| | "id": completion_id, |
| | "object": "chat.completion.chunk", |
| | "created": int(time.time()), |
| | "model": model_name, |
| | "choices": [ |
| | { |
| | "index": 0, |
| | "delta": {"content": firsttime + new_text}, |
| | "finish_reason": None, |
| | } |
| | ], |
| | } |
| | firsttime = "" |
| | yield f"data: {json.dumps(chunk)}\n\n" |
| |
|
| | if not cancelled: |
| | yield "data: [DONE]\n\n" |
| |
|
| | except Exception: |
| | stop_event.set() |
| | cancelled = True |
| | finally: |
| | if cancelled: |
| | for _ in streamer: |
| | pass |
| | thread.join(timeout=10) |
| |
|
| | return StreamingResponse( |
| | event_generator(), media_type="text/event-stream" |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| | if __name__ == "__main__": |
| | import uvicorn |
| |
|
| | uvicorn.run( |
| | "test_openai_api:app", |
| | host=HOST, |
| | port=PORT, |
| | reload=False, |
| | ) |