Spaces:
Runtime error
Runtime error
| import base64 | |
| import io | |
| import json | |
| import os | |
| import time | |
| from typing import Any, Dict, Optional | |
| from PIL import Image | |
| import requests | |
| def _image_to_base64(image: Image.Image) -> str: | |
| buffer = io.BytesIO() | |
| image_format = (image.format or "PNG").upper() | |
| if image_format not in {"PNG", "JPEG", "JPG"}: | |
| image_format = "PNG" | |
| image.save(buffer, format=image_format) | |
| return base64.b64encode(buffer.getvalue()).decode("utf-8") | |
| def _extract_status(payload: Dict[str, Any]) -> Optional[str]: | |
| status_info = payload.get("status") or payload.get("state") | |
| if isinstance(status_info, dict): | |
| state = status_info.get("state") or status_info.get("status") | |
| if isinstance(state, str): | |
| return state.lower() | |
| elif isinstance(status_info, str): | |
| return status_info.lower() | |
| return None | |
| def _poll_bria_status( | |
| status_url: str, | |
| headers: Dict[str, str], | |
| timeout_seconds: int = 120, | |
| poll_interval: float = 1.5, | |
| ) -> Dict[str, Any]: | |
| deadline = time.time() + timeout_seconds | |
| while True: | |
| response = requests.get(status_url, headers=headers, timeout=30) | |
| response.raise_for_status() | |
| payload: Dict[str, Any] = response.json() | |
| state = _extract_status(payload) | |
| if state in {"succeeded", "success", "completed", "done"}: | |
| if isinstance(payload.get("result"), dict): | |
| return payload["result"] | |
| if payload.get("results") is not None: | |
| return payload["results"] | |
| return payload | |
| if state in {"failed", "error", "cancelled", "canceled"}: | |
| raise RuntimeError( | |
| f"Bria VLM API request failed: {json.dumps(payload, indent=2)}" | |
| ) | |
| if time.time() > deadline: | |
| raise TimeoutError( | |
| f"Bria VLM API request timed out while polling {status_url}" | |
| ) | |
| time.sleep(poll_interval) | |
| def _submit_bria_request( | |
| url: str, payload: Dict[str, Any], api_token: str | |
| ) -> Dict[str, Any]: | |
| headers = { | |
| "Content-Type": "application/json", | |
| "api_token": api_token, | |
| } | |
| response = requests.post(url, json=payload, headers=headers, timeout=30) | |
| response.raise_for_status() | |
| initial_payload: Dict[str, Any] = response.json() | |
| status_url = ( | |
| initial_payload.get("status_url") | |
| or initial_payload.get("statusUrl") | |
| or (initial_payload.get("status") or {}).get("status_url") | |
| ) | |
| if status_url: | |
| return _poll_bria_status(status_url, headers) | |
| if isinstance(initial_payload.get("result"), dict): | |
| return initial_payload["result"] | |
| if initial_payload.get("results") is not None: | |
| return initial_payload["results"] | |
| return initial_payload | |
| def _parse_vlm_response(data: Any, prompt_role: str) -> str: | |
| if isinstance(data, dict): | |
| direct_match = data.get(prompt_role) | |
| if isinstance(direct_match, str): | |
| return direct_match | |
| for key in ("prompt", "structured_prompt", "structuredPrompt", "text"): | |
| if key in data: | |
| value = data[key] | |
| if isinstance(value, str): | |
| return value | |
| if isinstance(value, dict): | |
| nested = value.get(prompt_role) | |
| if isinstance(nested, str): | |
| return nested | |
| for key in ("result", "results"): | |
| if key in data: | |
| nested_result = _parse_vlm_response(data[key], prompt_role) | |
| if nested_result: | |
| return nested_result | |
| if isinstance(data, list): | |
| for item in data: | |
| nested_result = _parse_vlm_response(item, prompt_role) | |
| if nested_result: | |
| return nested_result | |
| return json.dumps(data) | |
| def get_prompt_api(image_path: str, prompt_role: str) -> str: | |
| """Send an image to the Bria VLM API and return the extracted prompt text. | |
| The payload keys are aligned with the current public docs but may require | |
| adjustment if your Bria workspace is configured differently. Override the | |
| default endpoint via the ``BRIA_API_VLM_ENDPOINT`` environment variable if | |
| you are using a custom workflow. | |
| """ | |
| api_token = os.environ.get("BRIA_API_KEY") | |
| if not api_token: | |
| raise EnvironmentError( | |
| "BRIA_API_KEY environment variable is required to use the Bria VLM API." | |
| ) | |
| base_url = os.environ.get("BRIA_API_BASE_URL", "https://engine.prod.bria-api.com") | |
| endpoint = os.environ.get("BRIA_API_VLM_ENDPOINT", "/v2/structured_prompt/generate") | |
| url = f"{base_url.rstrip('/')}{endpoint}" | |
| # convert image to base64 | |
| with Image.open(image_path) as image: | |
| image_b64 = _image_to_base64(image) | |
| payload = {"images": [image_b64]} | |
| response = _submit_bria_request(url, payload, api_token) | |
| return response["structured_prompt"] | |
| def get_image_from_url(image_url: str) -> Image.Image: | |
| """Get an image from a URL.""" | |
| response = requests.get(image_url) | |
| return Image.open(io.BytesIO(response.content)) | |
| def generate_image(prompt: str) -> Image.Image: | |
| """Generate an image from a prompt using the Bria VLM API.""" | |
| api_token = os.environ.get("BRIA_API_KEY") | |
| if not api_token: | |
| raise EnvironmentError( | |
| "BRIA_API_KEY environment variable is required to use the Bria VLM API." | |
| ) | |
| base_url = os.environ.get("BRIA_API_BASE_URL", "https://engine.prod.bria-api.com") | |
| endpoint = os.environ.get("BRIA_API_GENERATE_ENDPOINT", "/v2/image/generate") | |
| url = f"{base_url.rstrip('/')}{endpoint}" | |
| payload = {"structured_prompt": prompt} | |
| response = _submit_bria_request(url, payload, api_token) | |
| return get_image_from_url(response["image_url"]) | |