| | import gradio as gr |
| | import json, os, re, traceback, contextlib, math, random |
| | from typing import Any, List, Dict, Optional, Tuple |
| |
|
| | import spaces |
| | import torch |
| | from PIL import Image, ImageDraw |
| | import requests |
| | from transformers import AutoModelForImageTextToText, AutoProcessor |
| | from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize |
| |
|
| | |
| | MODEL_ID = "Hcompany/Holo1-3B" |
| |
|
| | |
| |
|
| | def pick_device() -> str: |
| | """ |
| | On HF Spaces (ZeroGPU), CUDA is only available inside @spaces.GPU calls. |
| | We still honor FORCE_DEVICE for local testing. |
| | """ |
| | forced = os.getenv("FORCE_DEVICE", "").lower().strip() |
| | if forced in {"cpu", "cuda", "mps"}: |
| | return forced |
| | if torch.cuda.is_available(): |
| | return "cuda" |
| | if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): |
| | return "mps" |
| | return "cpu" |
| |
|
| | def pick_dtype(device: str) -> torch.dtype: |
| | if device == "cuda": |
| | major, _ = torch.cuda.get_device_capability() |
| | return torch.bfloat16 if major >= 8 else torch.float16 |
| | if device == "mps": |
| | return torch.float16 |
| | return torch.float32 |
| |
|
| | def move_to_device(batch, device: str): |
| | if isinstance(batch, dict): |
| | return {k: (v.to(device, non_blocking=True) if hasattr(v, "to") else v) for k, v in batch.items()} |
| | if hasattr(batch, "to"): |
| | return batch.to(device, non_blocking=True) |
| | return batch |
| |
|
| | |
| | def apply_chat_template_compat(processor, messages: List[Dict[str, Any]]) -> str: |
| | tok = getattr(processor, "tokenizer", None) |
| | if hasattr(processor, "apply_chat_template"): |
| | return processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| | if tok is not None and hasattr(tok, "apply_chat_template"): |
| | return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| | texts = [] |
| | for m in messages: |
| | for c in m.get("content", []): |
| | if isinstance(c, dict) and c.get("type") == "text": |
| | texts.append(c.get("text", "")) |
| | return "\n".join(texts) |
| |
|
| | def batch_decode_compat(processor, token_id_batches, **kw): |
| | tok = getattr(processor, "tokenizer", None) |
| | if tok is not None and hasattr(tok, "batch_decode"): |
| | return tok.batch_decode(token_id_batches, **kw) |
| | if hasattr(processor, "batch_decode"): |
| | return processor.batch_decode(token_id_batches, **kw) |
| | raise AttributeError("No batch_decode available on processor or tokenizer.") |
| |
|
| | def get_image_proc_params(processor) -> Dict[str, int]: |
| | ip = getattr(processor, "image_processor", None) |
| | return { |
| | "patch_size": getattr(ip, "patch_size", 14), |
| | "merge_size": getattr(ip, "merge_size", 1), |
| | "min_pixels": getattr(ip, "min_pixels", 256 * 256), |
| | "max_pixels": getattr(ip, "max_pixels", 1280 * 1280), |
| | } |
| |
|
| | def trim_generated(generated_ids, inputs): |
| | in_ids = getattr(inputs, "input_ids", None) |
| | if in_ids is None and isinstance(inputs, dict): |
| | in_ids = inputs.get("input_ids", None) |
| | if in_ids is None: |
| | return [out_ids for out_ids in generated_ids] |
| | return [out_ids[len(in_seq):] for in_seq, out_ids in zip(in_ids, generated_ids)] |
| |
|
| | |
| | print(f"Loading model and processor for {MODEL_ID} on CPU startup (ZeroGPU safe)...") |
| | model = None |
| | processor = None |
| | model_loaded = False |
| | load_error_message = "" |
| |
|
| | try: |
| | model = AutoModelForImageTextToText.from_pretrained( |
| | MODEL_ID, |
| | torch_dtype=torch.float32, |
| | trust_remote_code=True, |
| | ) |
| | processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) |
| | model.eval() |
| | model_loaded = True |
| | print("Model and processor loaded on CPU.") |
| | except Exception as e: |
| | load_error_message = ( |
| | f"Error loading model/processor: {e}\n" |
| | "This might be due to network/model ID/library versions.\n" |
| | "Check the full traceback in the logs." |
| | ) |
| | print(load_error_message) |
| | traceback.print_exc() |
| |
|
| | |
| | def get_localization_prompt(pil_image: Image.Image, instruction: str) -> List[dict]: |
| | guidelines: str = ( |
| | "Localize an element on the GUI image according to my instructions and " |
| | "output a click position as Click(x, y) with x num pixels from the left edge " |
| | "and y num pixels from the top edge." |
| | ) |
| | return [ |
| | { |
| | "role": "user", |
| | "content": [ |
| | {"type": "image", "image": pil_image}, |
| | {"type": "text", "text": f"{guidelines}\n{instruction}"} |
| | ], |
| | } |
| | ] |
| |
|
| | |
| | @torch.inference_mode() |
| | def run_inference_localization( |
| | messages_for_template: List[dict[str, Any]], |
| | pil_image_for_processing: Image.Image, |
| | device: str, |
| | dtype: torch.dtype, |
| | do_sample: bool = False, |
| | temperature: float = 0.6, |
| | top_p: float = 0.9, |
| | max_new_tokens: int = 128, |
| | ) -> str: |
| | text_prompt = apply_chat_template_compat(processor, messages_for_template) |
| |
|
| | inputs = processor( |
| | text=[text_prompt], |
| | images=[pil_image_for_processing], |
| | padding=True, |
| | return_tensors="pt", |
| | ) |
| | inputs = move_to_device(inputs, device) |
| |
|
| | |
| | if device == "cuda": |
| | amp_ctx = torch.autocast(device_type="cuda", dtype=dtype) |
| | elif device == "mps": |
| | amp_ctx = torch.autocast(device_type="mps", dtype=torch.float16) |
| | else: |
| | amp_ctx = contextlib.nullcontext() |
| |
|
| | gen_kwargs = dict( |
| | max_new_tokens=max_new_tokens, |
| | do_sample=do_sample, |
| | temperature=temperature, |
| | top_p=top_p, |
| | ) |
| |
|
| | with amp_ctx: |
| | generated_ids = model.generate(**inputs, **gen_kwargs) |
| |
|
| | generated_ids_trimmed = trim_generated(generated_ids, inputs) |
| | decoded_output = batch_decode_compat( |
| | processor, |
| | generated_ids_trimmed, |
| | skip_special_tokens=True, |
| | clean_up_tokenization_spaces=False |
| | ) |
| | return decoded_output[0] if decoded_output else "" |
| |
|
| | |
| | CLICK_RE = re.compile(r"Click\((\d+),\s*(\d+)\)") |
| |
|
| | def parse_click(s: str) -> Optional[Tuple[int, int]]: |
| | m = CLICK_RE.search(s) |
| | if not m: |
| | return None |
| | try: |
| | return int(m.group(1)), int(m.group(2)) |
| | except Exception: |
| | return None |
| |
|
| | @torch.inference_mode() |
| | def sample_clicks( |
| | messages: List[dict], |
| | img: Image.Image, |
| | device: str, |
| | dtype: torch.dtype, |
| | n_samples: int = 7, |
| | temperature: float = 0.6, |
| | top_p: float = 0.9, |
| | seed: Optional[int] = None, |
| | ) -> List[Optional[Tuple[int, int]]]: |
| | """ |
| | Run multiple stochastic decodes to estimate self-consistency. |
| | Returns a list of (x,y) or None (if parsing failed) for each sample. |
| | """ |
| | clicks: List[Optional[Tuple[int, int]]] = [] |
| | |
| | if seed is not None: |
| | torch.manual_seed(seed) |
| | random.seed(seed) |
| | for i in range(n_samples): |
| | |
| | if seed is not None: |
| | torch.manual_seed(seed + i + 1) |
| | random.seed((seed + i + 1) & 0xFFFFFFFF) |
| | out = run_inference_localization( |
| | messages, img, device, dtype, |
| | do_sample=True, temperature=temperature, top_p=top_p |
| | ) |
| | clicks.append(parse_click(out)) |
| | return clicks |
| |
|
| | def cluster_and_confidence( |
| | clicks: List[Optional[Tuple[int,int]]], |
| | img_w: int, |
| | img_h: int, |
| | ) -> Dict[str, Any]: |
| | """ |
| | Simple robust consensus: |
| | - Keep only valid points |
| | - Compute median point (x_med, y_med) |
| | - Compute distances to median |
| | - Inlier threshold = max(8 px, 2% of min(img_w, img_h)) |
| | - Confidence = (#inliers / #total_samples) * clamp(1 - (rms_inlier_dist / thr), 0, 1) |
| | Returns dict with consensus point, confidence, dispersion, and counts. |
| | """ |
| | valid = [xy for xy in clicks if xy is not None] |
| | total = len(clicks) |
| | if total == 0: |
| | return dict(ok=False, reason="no_samples") |
| |
|
| | if not valid: |
| | return dict(ok=False, reason="no_valid_points", total=total) |
| |
|
| | xs = sorted([x for x, _ in valid]) |
| | ys = sorted([y for _, y in valid]) |
| | mid = len(valid) // 2 |
| | if len(valid) % 2 == 1: |
| | x_med = xs[mid] |
| | y_med = ys[mid] |
| | else: |
| | x_med = (xs[mid - 1] + xs[mid]) // 2 |
| | y_med = (ys[mid - 1] + ys[mid]) // 2 |
| |
|
| | thr = max(8.0, 0.02 * min(img_w, img_h)) |
| | dists = [math.hypot(x - x_med, y - y_med) for (x, y) in valid] |
| | inliers = [(xy, d) for xy, d in zip(valid, dists) if d <= thr] |
| | outliers = [(xy, d) for xy, d in zip(valid, dists) if d > thr] |
| | inlier_count = len(inliers) |
| |
|
| | |
| | if inliers: |
| | rms = math.sqrt(sum(d*d for _, d in inliers) / len(inliers)) |
| | else: |
| | rms = float("inf") |
| |
|
| | |
| | if inliers: |
| | sharp = max(0.0, min(1.0, 1.0 - (rms / thr))) |
| | else: |
| | sharp = 0.0 |
| | confidence = (inlier_count / total) * sharp |
| |
|
| | return dict( |
| | ok=True, |
| | x=x_med, y=y_med, |
| | confidence=confidence, |
| | total_samples=total, |
| | valid_samples=len(valid), |
| | inliers=inlier_count, |
| | outliers=len(outliers), |
| | sigma_px=rms if math.isfinite(rms) else None, |
| | inlier_threshold_px=thr, |
| | all_points=valid, |
| | inlier_points=[xy for xy,_ in inliers], |
| | outlier_points=[xy for xy,_ in outliers], |
| | ) |
| |
|
| | def draw_samples( |
| | base_img: Image.Image, |
| | consensus_xy: Optional[Tuple[int,int]], |
| | inliers: List[Tuple[int,int]], |
| | outliers: List[Tuple[int,int]], |
| | ring_color: str = "red", |
| | ) -> Image.Image: |
| | """ |
| | Overlay all sampled points: green=inliers, red=outliers, plus a ring for consensus. |
| | """ |
| | img = base_img.copy().convert("RGB") |
| | draw = ImageDraw.Draw(img) |
| | w, h = img.size |
| | |
| | r = max(3, min(w, h) // 200) |
| |
|
| | |
| | for (x, y) in inliers: |
| | draw.ellipse((x - r, y - r, x + r, y + r), fill="green", outline=None) |
| |
|
| | |
| | for (x, y) in outliers: |
| | draw.ellipse((x - r, y - r, x + r, y + r), fill="red", outline=None) |
| |
|
| | |
| | if consensus_xy is not None: |
| | cx, cy = consensus_xy |
| | ring_r = max(5, min(w, h) // 100, r * 3) |
| | draw.ellipse((cx - ring_r, cy - ring_r, cx + ring_r, cy + ring_r), outline=ring_color, width=max(2, ring_r // 4)) |
| | return img |
| |
|
| | |
| | |
| | @spaces.GPU(duration=120) |
| | def predict_click_location( |
| | input_pil_image: Image.Image, |
| | instruction: str, |
| | estimate_confidence: bool = True, |
| | num_samples: int = 7, |
| | temperature: float = 0.6, |
| | top_p: float = 0.9, |
| | seed: Optional[int] = None, |
| | ): |
| | if not model_loaded or not processor or not model: |
| | return f"Model not loaded. Error: {load_error_message}", None, "device: n/a | dtype: n/a" |
| | if not input_pil_image: |
| | return "No image provided. Please upload an image.", None, "device: n/a | dtype: n/a" |
| | if not instruction or instruction.strip() == "": |
| | return "No instruction provided. Please type an instruction.", input_pil_image.copy().convert("RGB"), "device: n/a | dtype: n/a" |
| |
|
| | |
| | device = pick_device() |
| | dtype = pick_dtype(device) |
| |
|
| | |
| | if device == "cuda": |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| | torch.set_float32_matmul_precision("high") |
| |
|
| | |
| | try: |
| | p = next(model.parameters()) |
| | cur_dev = p.device.type |
| | cur_dtype = p.dtype |
| | except StopIteration: |
| | cur_dev, cur_dtype = "cpu", torch.float32 |
| |
|
| | if cur_dev != device or cur_dtype != dtype: |
| | model.to(device=device, dtype=dtype) |
| | model.eval() |
| |
|
| | |
| | try: |
| | ip = get_image_proc_params(processor) |
| | resized_height, resized_width = smart_resize( |
| | input_pil_image.height, |
| | input_pil_image.width, |
| | factor=ip["patch_size"] * ip["merge_size"], |
| | min_pixels=ip["min_pixels"], |
| | max_pixels=ip["max_pixels"], |
| | ) |
| | resized_image = input_pil_image.resize( |
| | size=(resized_width, resized_height), |
| | resample=Image.Resampling.LANCZOS |
| | ) |
| | except Exception as e: |
| | traceback.print_exc() |
| | return f"Error resizing image: {e}", input_pil_image.copy().convert("RGB"), f"device: {device} | dtype: {dtype}" |
| |
|
| | |
| | messages = get_localization_prompt(resized_image, instruction) |
| |
|
| | |
| | try: |
| | if estimate_confidence and num_samples >= 3: |
| | |
| | clicks = sample_clicks( |
| | messages, resized_image, device, dtype, |
| | n_samples=int(num_samples), |
| | temperature=float(temperature), |
| | top_p=float(top_p), |
| | seed=seed |
| | ) |
| | summary = cluster_and_confidence(clicks, resized_image.width, resized_image.height) |
| |
|
| | if not summary.get("ok", False): |
| | |
| | coord_str = run_inference_localization(messages, resized_image, device, dtype, do_sample=False) |
| | out_img = resized_image.copy().convert("RGB") |
| | match = CLICK_RE.search(coord_str or "") |
| | if match: |
| | x, y = int(match.group(1)), int(match.group(2)) |
| | out_img = draw_samples(out_img, (x, y), [], []) |
| | coords_text = f"{coord_str} | confidence=0.00 (fallback)" |
| | return coords_text, out_img, f"device: {device} | dtype: {str(dtype).replace('torch.', '')}" |
| |
|
| | |
| | x, y = int(summary["x"]), int(summary["y"]) |
| | conf = summary["confidence"] |
| | inliers = summary["inlier_points"] |
| | outliers = summary["outlier_points"] |
| | sigma = summary["sigma_px"] |
| | thr = summary["inlier_threshold_px"] |
| | total = summary["total_samples"] |
| | valid = summary["valid_samples"] |
| |
|
| | |
| | coord_str = f"Click({x}, {y})" |
| | diag = ( |
| | f"confidence={conf:.2f} | samples(valid/total)={valid}/{total} | " |
| | f"inliers={len(inliers)} | σ={sigma:.1f}px | thr={thr:.1f}px | " |
| | f"T={temperature:.2f}, p={top_p:.2f}" |
| | ) |
| |
|
| | |
| | out_img = draw_samples(resized_image, (x, y), inliers, outliers) |
| | return f"{coord_str} | {diag}", out_img, f"device: {device} | dtype: {str(dtype).replace('torch.', '')}" |
| |
|
| | else: |
| | |
| | coord_str = run_inference_localization(messages, resized_image, device, dtype, do_sample=False) |
| | out_img = resized_image.copy().convert("RGB") |
| | match = CLICK_RE.search(coord_str or "") |
| | if match: |
| | x = int(match.group(1)) |
| | y = int(match.group(2)) |
| | |
| | out_img = draw_samples(out_img, (x, y), [], []) |
| | else: |
| | print(f"Could not parse 'Click(x, y)' from model output: {coord_str}") |
| | return coord_str, out_img, f"device: {device} | dtype: {str(dtype).replace('torch.', '')}" |
| |
|
| | except Exception as e: |
| | traceback.print_exc() |
| | return f"Error during model inference: {e}", resized_image.copy().convert("RGB"), f"device: {device} | dtype: {dtype}" |
| |
|
| | |
| | example_image = None |
| | example_instruction = "Enter the server address readyforquantum.com to check its security" |
| | try: |
| | example_image_url = "https://readyforquantum.com/img/screentest.jpg" |
| | example_image = Image.open(requests.get(example_image_url, stream=True).raw) |
| | except Exception as e: |
| | print(f"Could not load example image from URL: {e}") |
| | traceback.print_exc() |
| | try: |
| | example_image = Image.new("RGB", (200, 150), color="lightgray") |
| | draw = ImageDraw.Draw(example_image) |
| | draw.text((10, 10), "Example image\nfailed to load", fill="black") |
| | except Exception: |
| | pass |
| |
|
| | |
| | title = "Holo1-3B: Holo1 Localization Demo (ZeroGPU-ready)" |
| | article = f""" |
| | <p style='text-align: center'> |
| | Model: <a href='https://huggingface.co/{MODEL_ID}' target='_blank'>{MODEL_ID}</a> by HCompany | |
| | Paper: <a href='https://cdn.prod.website-files.com/67e2dbd9acff0c50d4c8a80c/683ec8095b353e8b38317f80_h_tech_report_v1.pdf' target='_blank'>HCompany Tech Report</a> | |
| | Blog: <a href='https://www.hcompany.ai/surfer-h' target='_blank'>Surfer-H Blog Post</a><br/> |
| | <small>GPU (if available) is requested only during inference via @spaces.GPU.</small> |
| | </p> |
| | """ |
| |
|
| | if not model_loaded: |
| | with gr.Blocks() as demo: |
| | gr.Markdown(f"# <center>⚠️ Error: Model Failed to Load ⚠️</center>") |
| | gr.Markdown(f"<center>{load_error_message}</center>") |
| | gr.Markdown("<center>See logs for the full traceback.</center>") |
| | else: |
| | with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| | gr.Markdown(f"<h1 style='text-align: center;'>{title}</h1>") |
| | gr.Markdown(article) |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | input_image_component = gr.Image(type="pil", label="Input UI Image", height=400) |
| | instruction_component = gr.Textbox( |
| | label="Instruction", |
| | placeholder="e.g., Click the 'Login' button", |
| | info="Type the action you want the model to localize on the image." |
| | ) |
| | estimate_conf = gr.Checkbox(value=True, label="Estimate confidence (slower)") |
| | num_samples_slider = gr.Slider(3, 15, value=7, step=1, label="Samples (for confidence)") |
| | temperature_slider = gr.Slider(0.2, 1.2, value=0.6, step=0.05, label="Temperature") |
| | top_p_slider = gr.Slider(0.5, 0.99, value=0.9, step=0.01, label="Top-p") |
| | seed_box = gr.Number(value=None, precision=0, label="Seed (optional, for reproducibility)") |
| | submit_button = gr.Button("Localize Click", variant="primary") |
| |
|
| | with gr.Column(scale=1): |
| | output_coords_component = gr.Textbox( |
| | label="Predicted Coordinates + Confidence", |
| | interactive=False |
| | ) |
| | output_image_component = gr.Image( |
| | type="pil", |
| | label="Image with Samples (green=inliers, red=outliers) and Final Ring", |
| | height=400, |
| | interactive=False |
| | ) |
| | runtime_info = gr.Textbox( |
| | label="Runtime Info", |
| | value="device: n/a | dtype: n/a", |
| | interactive=False |
| | ) |
| |
|
| | if example_image: |
| | gr.Examples( |
| | examples=[[example_image, example_instruction, True, 7, 0.6, 0.9, None]], |
| | inputs=[ |
| | input_image_component, |
| | instruction_component, |
| | estimate_conf, |
| | num_samples_slider, |
| | temperature_slider, |
| | top_p_slider, |
| | seed_box, |
| | ], |
| | outputs=[output_coords_component, output_image_component, runtime_info], |
| | fn=predict_click_location, |
| | cache_examples="lazy", |
| | ) |
| |
|
| | submit_button.click( |
| | fn=predict_click_location, |
| | inputs=[ |
| | input_image_component, |
| | instruction_component, |
| | estimate_conf, |
| | num_samples_slider, |
| | temperature_slider, |
| | top_p_slider, |
| | seed_box, |
| | ], |
| | outputs=[output_coords_component, output_image_component, runtime_info] |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch(debug=True) |
| |
|