import gradio as gr import torch from transformers import Qwen3VLForConditionalGeneration, AutoProcessor from peft import PeftModel from PIL import Image, ImageDraw import re # --- Configuration --- BASE_MODEL = "Qwen/Qwen3-VL-8B-Instruct" LORA_MODEL = "sonali3/qwen3vl-reasoning-cxr" # Detect device if torch.cuda.is_available(): device = "cuda" dtype = torch.float16 else: device = "cpu" dtype = torch.float32 print(f"Using device: {device}") # Load processor processor = AutoProcessor.from_pretrained( BASE_MODEL, trust_remote_code=True ) # Load model (simple loading, no quantization) model = Qwen3VLForConditionalGeneration.from_pretrained( BASE_MODEL, torch_dtype=dtype, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True ) # Attach LoRA weights model = PeftModel.from_pretrained(model, LORA_MODEL) model.eval() def extract_coordinates(text): """Extract all (x=NUMBER,y=NUMBER) coordinates from text.""" pattern = r'\(x=([0-9.]+),y=([0-9.]+)\)' matches = list(re.finditer(pattern, text)) if not matches: return [] results = [] for match in matches: x = float(match.group(1)) y = float(match.group(2)) results.append((x, y)) return results def draw_numbered_dots(image, coordinates): """Draw small blue numbered circles.""" annotated = image.copy() draw = ImageDraw.Draw(annotated) img_w, img_h = image.size BLUE = '#1E88E5' WHITE = '#FFFFFF' if not coordinates: return annotated, 0 # Calculate scale based on max coordinates max_x = max(c[0] for c in coordinates) max_y = max(c[1] for c in coordinates) scale_x = img_w / max(max_x * 1.05, img_w) scale_y = img_h / max(max_y * 1.05, img_h) # Circle size radius = max(10, int(min(img_w, img_h) * 0.015)) for i, (x, y) in enumerate(coordinates): sx = x * scale_x sy = y * scale_y sx = min(max(sx, radius), img_w - radius) sy = min(max(sy, radius), img_h - radius) # Draw blue circle draw.ellipse( [sx - radius, sy - radius, sx + radius, sy + radius], fill=BLUE, outline=WHITE, width=1 ) # Draw number num = str(i + 1) tx = sx - (len(num) * 3) ty = sy - 6 draw.text((tx, ty), num, fill=WHITE) return annotated, len(coordinates) def analyze_xray(image, extra_info): if image is None: return None, "Please upload a chest X-ray image." if not isinstance(image, Image.Image): image = Image.fromarray(image).convert("RGB") else: image = image.convert("RGB") prompt = "Given the chest X-ray and patient metadata, generate the full visually grounded diagnostic chain-of-thought reasoning, linking all findings to the numbered spatial coordinates included in the output." if extra_info and extra_info.strip(): prompt += f"\n\nPatient information: {extra_info}" messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt}, ], } ] inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True ).to(device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=1024, do_sample=False, use_cache=True, ) generated_ids = outputs[0][inputs["input_ids"].shape[1]:] response = processor.decode(generated_ids, skip_special_tokens=True) # Extract coordinates and draw dots coordinates = extract_coordinates(response) annotated_image, count = draw_numbered_dots(image, coordinates) if count > 0: summary = f"\n\n---\n📍 {count} points annotated" else: summary = "\n\n---\n⚠️ No coordinates found" return annotated_image, response + summary demo = gr.Interface( fn=analyze_xray, inputs=[ gr.Image(type="pil", label="Upload Chest X-Ray"), gr.Textbox(label="Patient Info (optional)", placeholder="History, demographics...", lines=2) ], outputs=[ gr.Image(type="pil", label="Annotated X-Ray"), gr.Textbox(label="Analysis", lines=25) ], title="CXR", description="Upload a chest X-ray.", flagging_mode="never" ) if __name__ == "__main__": demo.launch()