Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import Qwen3VLForConditionalGeneration, AutoProcessor | |
| from peft import PeftModel | |
| from PIL import Image, ImageDraw | |
| import re | |
| # --- Configuration --- | |
| BASE_MODEL = "Qwen/Qwen3-VL-8B-Instruct" | |
| LORA_MODEL = "sonali3/qwen3vl-reasoning-cxr" | |
| # Detect device | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| dtype = torch.float16 | |
| else: | |
| device = "cpu" | |
| dtype = torch.float32 | |
| print(f"Using device: {device}") | |
| # Load processor | |
| processor = AutoProcessor.from_pretrained( | |
| BASE_MODEL, | |
| trust_remote_code=True | |
| ) | |
| # Load model (simple loading, no quantization) | |
| model = Qwen3VLForConditionalGeneration.from_pretrained( | |
| BASE_MODEL, | |
| torch_dtype=dtype, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True | |
| ) | |
| # Attach LoRA weights | |
| model = PeftModel.from_pretrained(model, LORA_MODEL) | |
| model.eval() | |
| def extract_coordinates(text): | |
| """Extract all (x=NUMBER,y=NUMBER) coordinates from text.""" | |
| pattern = r'\(x=([0-9.]+),y=([0-9.]+)\)' | |
| matches = list(re.finditer(pattern, text)) | |
| if not matches: | |
| return [] | |
| results = [] | |
| for match in matches: | |
| x = float(match.group(1)) | |
| y = float(match.group(2)) | |
| results.append((x, y)) | |
| return results | |
| def draw_numbered_dots(image, coordinates): | |
| """Draw small blue numbered circles.""" | |
| annotated = image.copy() | |
| draw = ImageDraw.Draw(annotated) | |
| img_w, img_h = image.size | |
| BLUE = '#1E88E5' | |
| WHITE = '#FFFFFF' | |
| if not coordinates: | |
| return annotated, 0 | |
| # Calculate scale based on max coordinates | |
| max_x = max(c[0] for c in coordinates) | |
| max_y = max(c[1] for c in coordinates) | |
| scale_x = img_w / max(max_x * 1.05, img_w) | |
| scale_y = img_h / max(max_y * 1.05, img_h) | |
| # Circle size | |
| radius = max(10, int(min(img_w, img_h) * 0.015)) | |
| for i, (x, y) in enumerate(coordinates): | |
| sx = x * scale_x | |
| sy = y * scale_y | |
| sx = min(max(sx, radius), img_w - radius) | |
| sy = min(max(sy, radius), img_h - radius) | |
| # Draw blue circle | |
| draw.ellipse( | |
| [sx - radius, sy - radius, sx + radius, sy + radius], | |
| fill=BLUE, | |
| outline=WHITE, | |
| width=1 | |
| ) | |
| # Draw number | |
| num = str(i + 1) | |
| tx = sx - (len(num) * 3) | |
| ty = sy - 6 | |
| draw.text((tx, ty), num, fill=WHITE) | |
| return annotated, len(coordinates) | |
| def analyze_xray(image, extra_info): | |
| if image is None: | |
| return None, "Please upload a chest X-ray image." | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image).convert("RGB") | |
| else: | |
| image = image.convert("RGB") | |
| prompt = "Given the chest X-ray and patient metadata, generate the full visually grounded diagnostic chain-of-thought reasoning, linking all findings to the numbered spatial coordinates included in the output." | |
| if extra_info and extra_info.strip(): | |
| prompt += f"\n\nPatient information: {extra_info}" | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| } | |
| ] | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_tensors="pt", | |
| return_dict=True | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=1024, | |
| do_sample=False, | |
| use_cache=True, | |
| ) | |
| generated_ids = outputs[0][inputs["input_ids"].shape[1]:] | |
| response = processor.decode(generated_ids, skip_special_tokens=True) | |
| # Extract coordinates and draw dots | |
| coordinates = extract_coordinates(response) | |
| annotated_image, count = draw_numbered_dots(image, coordinates) | |
| if count > 0: | |
| summary = f"\n\n---\n📍 {count} points annotated" | |
| else: | |
| summary = "\n\n---\n⚠️ No coordinates found" | |
| return annotated_image, response + summary | |
| demo = gr.Interface( | |
| fn=analyze_xray, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Chest X-Ray"), | |
| gr.Textbox(label="Patient Info (optional)", placeholder="History, demographics...", lines=2) | |
| ], | |
| outputs=[ | |
| gr.Image(type="pil", label="Annotated X-Ray"), | |
| gr.Textbox(label="Analysis", lines=25) | |
| ], | |
| title="CXR", | |
| description="Upload a chest X-ray.", | |
| flagging_mode="never" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |