Spaces:
Runtime error
Runtime error
| import argparse | |
| from flask import Flask, jsonify, request, Response | |
| import urllib.parse | |
| import requests | |
| import time | |
| import json | |
| app = Flask(__name__) | |
| parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.") | |
| parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n') | |
| parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="\\nUSER: ") | |
| parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="\\nASSISTANT: ") | |
| parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: '\\nASSISTANT's RULE: ')", default="\\nASSISTANT's RULE: ") | |
| parser.add_argument("--stop", type=str, help="the end of response in chat completions(default: '</s>')", default="</s>") | |
| parser.add_argument("--llama-api", type=str, help="Set the address of server.cpp in llama.cpp(default: http://127.0.0.1:8080)", default='http://127.0.0.1:8080') | |
| parser.add_argument("--api-key", type=str, help="Set the api key to allow only few user(default: NULL)", default="") | |
| parser.add_argument("--host", type=str, help="Set the ip address to listen.(default: 127.0.0.1)", default='127.0.0.1') | |
| parser.add_argument("--port", type=int, help="Set the port to listen.(default: 8081)", default=8081) | |
| args = parser.parse_args() | |
| def is_present(json, key): | |
| try: | |
| buf = json[key] | |
| except KeyError: | |
| return False | |
| return True | |
| #convert chat to prompt | |
| def convert_chat(messages): | |
| prompt = "" + args.chat_prompt.replace("\\n", "\n") | |
| system_n = args.system_name.replace("\\n", "\n") | |
| user_n = args.user_name.replace("\\n", "\n") | |
| ai_n = args.ai_name.replace("\\n", "\n") | |
| stop = args.stop.replace("\\n", "\n") | |
| for line in messages: | |
| if (line["role"] == "system"): | |
| prompt += f"{system_n}{line['content']}" | |
| if (line["role"] == "user"): | |
| prompt += f"{user_n}{line['content']}" | |
| if (line["role"] == "assistant"): | |
| prompt += f"{ai_n}{line['content']}{stop}" | |
| prompt += ai_n.rstrip() | |
| return prompt | |
| def make_postData(body, chat=False, stream=False): | |
| postData = {} | |
| if (chat): | |
| postData["prompt"] = convert_chat(body["messages"]) | |
| else: | |
| postData["prompt"] = body["prompt"] | |
| if(is_present(body, "temperature")): postData["temperature"] = body["temperature"] | |
| if(is_present(body, "top_k")): postData["top_k"] = body["top_k"] | |
| if(is_present(body, "top_p")): postData["top_p"] = body["top_p"] | |
| if(is_present(body, "max_tokens")): postData["n_predict"] = body["max_tokens"] | |
| if(is_present(body, "presence_penalty")): postData["presence_penalty"] = body["presence_penalty"] | |
| if(is_present(body, "frequency_penalty")): postData["frequency_penalty"] = body["frequency_penalty"] | |
| if(is_present(body, "repeat_penalty")): postData["repeat_penalty"] = body["repeat_penalty"] | |
| if(is_present(body, "mirostat")): postData["mirostat"] = body["mirostat"] | |
| if(is_present(body, "mirostat_tau")): postData["mirostat_tau"] = body["mirostat_tau"] | |
| if(is_present(body, "mirostat_eta")): postData["mirostat_eta"] = body["mirostat_eta"] | |
| if(is_present(body, "seed")): postData["seed"] = body["seed"] | |
| if(is_present(body, "logit_bias")): postData["logit_bias"] = [[int(token), body["logit_bias"][token]] for token in body["logit_bias"].keys()] | |
| if (args.stop != ""): | |
| postData["stop"] = [args.stop] | |
| else: | |
| postData["stop"] = [] | |
| if(is_present(body, "stop")): postData["stop"] += body["stop"] | |
| postData["n_keep"] = -1 | |
| postData["stream"] = stream | |
| return postData | |
| def make_resData(data, chat=False, promptToken=[]): | |
| resData = { | |
| "id": "chatcmpl" if (chat) else "cmpl", | |
| "object": "chat.completion" if (chat) else "text_completion", | |
| "created": int(time.time()), | |
| "truncated": data["truncated"], | |
| "model": "LLaMA_CPP", | |
| "usage": { | |
| "prompt_tokens": data["tokens_evaluated"], | |
| "completion_tokens": data["tokens_predicted"], | |
| "total_tokens": data["tokens_evaluated"] + data["tokens_predicted"] | |
| } | |
| } | |
| if (len(promptToken) != 0): | |
| resData["promptToken"] = promptToken | |
| if (chat): | |
| #only one choice is supported | |
| resData["choices"] = [{ | |
| "index": 0, | |
| "message": { | |
| "role": "assistant", | |
| "content": data["content"], | |
| }, | |
| "finish_reason": "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length" | |
| }] | |
| else: | |
| #only one choice is supported | |
| resData["choices"] = [{ | |
| "text": data["content"], | |
| "index": 0, | |
| "logprobs": None, | |
| "finish_reason": "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length" | |
| }] | |
| return resData | |
| def make_resData_stream(data, chat=False, time_now = 0, start=False): | |
| resData = { | |
| "id": "chatcmpl" if (chat) else "cmpl", | |
| "object": "chat.completion.chunk" if (chat) else "text_completion.chunk", | |
| "created": time_now, | |
| "model": "LLaMA_CPP", | |
| "choices": [ | |
| { | |
| "finish_reason": None, | |
| "index": 0 | |
| } | |
| ] | |
| } | |
| if (chat): | |
| if (start): | |
| resData["choices"][0]["delta"] = { | |
| "role": "assistant" | |
| } | |
| else: | |
| resData["choices"][0]["delta"] = { | |
| "content": data["content"] | |
| } | |
| if (data["stop"]): | |
| resData["choices"][0]["finish_reason"] = "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length" | |
| else: | |
| resData["choices"][0]["text"] = data["content"] | |
| if (data["stop"]): | |
| resData["choices"][0]["finish_reason"] = "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length" | |
| return resData | |
| def chat_completions(): | |
| if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key): | |
| return Response(status=403) | |
| body = request.get_json() | |
| stream = False | |
| tokenize = False | |
| if(is_present(body, "stream")): stream = body["stream"] | |
| if(is_present(body, "tokenize")): tokenize = body["tokenize"] | |
| postData = make_postData(body, chat=True, stream=stream) | |
| promptToken = [] | |
| if (tokenize): | |
| tokenData = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/tokenize"), data=json.dumps({"content": postData["prompt"]})).json() | |
| promptToken = tokenData["tokens"] | |
| if (not stream): | |
| data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData)) | |
| print(data.json()) | |
| resData = make_resData(data.json(), chat=True, promptToken=promptToken) | |
| return jsonify(resData) | |
| else: | |
| def generate(): | |
| data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True) | |
| time_now = int(time.time()) | |
| resData = make_resData_stream({}, chat=True, time_now=time_now, start=True) | |
| yield 'data: {}\n'.format(json.dumps(resData)) | |
| for line in data.iter_lines(): | |
| if line: | |
| decoded_line = line.decode('utf-8') | |
| resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now) | |
| yield 'data: {}\n'.format(json.dumps(resData)) | |
| return Response(generate(), mimetype='text/event-stream') | |
| def completion(): | |
| if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key): | |
| return Response(status=403) | |
| body = request.get_json() | |
| stream = False | |
| tokenize = False | |
| if(is_present(body, "stream")): stream = body["stream"] | |
| if(is_present(body, "tokenize")): tokenize = body["tokenize"] | |
| postData = make_postData(body, chat=False, stream=stream) | |
| promptToken = [] | |
| if (tokenize): | |
| tokenData = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/tokenize"), data=json.dumps({"content": postData["prompt"]})).json() | |
| promptToken = tokenData["tokens"] | |
| if (not stream): | |
| data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData)) | |
| print(data.json()) | |
| resData = make_resData(data.json(), chat=False, promptToken=promptToken) | |
| return jsonify(resData) | |
| else: | |
| def generate(): | |
| data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True) | |
| time_now = int(time.time()) | |
| for line in data.iter_lines(): | |
| if line: | |
| decoded_line = line.decode('utf-8') | |
| resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now) | |
| yield 'data: {}\n'.format(json.dumps(resData)) | |
| return Response(generate(), mimetype='text/event-stream') | |
| if __name__ == '__main__': | |
| app.run(args.host, port=args.port) | |