Spaces:
Running
Running
| import json | |
| import logging | |
| import requests | |
| import urllib3 | |
| urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| def check_server_health(cloud_gateway_api: str, header: dict) -> bool: | |
| """ | |
| Use the appropriate API endpoint to check the server health. | |
| Args: | |
| cloud_gateway_api: API endpoint to probe. | |
| header: Header for Authorization. | |
| Returns: | |
| True if server is active, false otherwise. | |
| """ | |
| try: | |
| response = requests.get( | |
| cloud_gateway_api + "models/metadata", | |
| headers=header, | |
| verify=False, | |
| ) | |
| response.raise_for_status() | |
| return True | |
| except requests.RequestException as e: | |
| logging.error(f"Failed to check server health: {e}") | |
| return False | |
| def request_generation( | |
| header: dict, | |
| message: str, | |
| system_prompt: str, | |
| cloud_gateway_api: str, | |
| model_name: str, | |
| temperature: float = 0.3, | |
| frequency_penalty: float = 0.0, | |
| presence_penalty: float = 0.0, | |
| ): | |
| """ | |
| Request streaming generation from the cloud gateway API. Uses the simple requests module with stream=True to utilize | |
| token-by-token generation from LLM. | |
| Args: | |
| header: authorization header for the API. | |
| message: prompt from the user. | |
| system_prompt: system prompt to append. | |
| cloud_gateway_api (str): API endpoint to send the request. | |
| temperature: the value used to module the next token probabilities. | |
| top_p: if set to float<1, only the smallest set of most probable tokens with probabilities that add up to top_p | |
| or higher are kept for generation. | |
| repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty. | |
| Returns: | |
| """ | |
| payload = { | |
| "model": model_name, | |
| "messages": [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": message}, | |
| ], | |
| "temperature": temperature, | |
| "frequency_penalty": frequency_penalty, | |
| "presence_penalty": presence_penalty, | |
| "stream": True, # Enable streaming | |
| "serving_runtime": "vllm", | |
| } | |
| try: | |
| response = requests.post( | |
| cloud_gateway_api + "chat/conversation", | |
| headers=header, | |
| json=payload, | |
| verify=False, | |
| ) | |
| response.raise_for_status() | |
| # Append the conversation ID with the key X-Conversation-ID to the header | |
| header["X-Conversation-ID"] = response.json()["conversationId"] | |
| with requests.get( | |
| cloud_gateway_api + f"conversation/stream", | |
| headers=header, | |
| verify=False, | |
| stream=True, | |
| ) as response: | |
| for chunk in response.iter_lines(): | |
| if chunk: | |
| # Convert the chunk from bytes to a string and then parse it as json | |
| chunk_str = chunk.decode("utf-8") | |
| # Remove the `data: ` prefix from the chunk if it exists | |
| for _ in range(2): | |
| if chunk_str.startswith("data: "): | |
| chunk_str = chunk_str[len("data: ") :] | |
| # Skip empty chunks | |
| if chunk_str.strip() == "[DONE]": | |
| break | |
| # Parse the chunk into a JSON object | |
| try: | |
| chunk_json = json.loads(chunk_str) | |
| # Extract the "content" field from the choices | |
| if "choices" in chunk_json and chunk_json["choices"]: | |
| content = chunk_json["choices"][0]["delta"].get( | |
| "content", "" | |
| ) | |
| else: | |
| content = "" | |
| # Print the generated content as it's streamed | |
| if content: | |
| yield content | |
| except json.JSONDecodeError: | |
| # Handle any potential errors in decoding | |
| continue | |
| except requests.RequestException as e: | |
| logging.error(f"Failed to generate response: {e}") | |
| yield "Server not responding. Please try again later." | |