sonali3's picture
Update app.py
0859c8a verified
import gradio as gr
import torch
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
from peft import PeftModel
from PIL import Image, ImageDraw
import re
# --- Configuration ---
BASE_MODEL = "Qwen/Qwen3-VL-8B-Instruct"
LORA_MODEL = "sonali3/qwen3vl-reasoning-cxr"
# Detect device
if torch.cuda.is_available():
device = "cuda"
dtype = torch.float16
else:
device = "cpu"
dtype = torch.float32
print(f"Using device: {device}")
# Load processor
processor = AutoProcessor.from_pretrained(
BASE_MODEL,
trust_remote_code=True
)
# Load model (simple loading, no quantization)
model = Qwen3VLForConditionalGeneration.from_pretrained(
BASE_MODEL,
torch_dtype=dtype,
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True
)
# Attach LoRA weights
model = PeftModel.from_pretrained(model, LORA_MODEL)
model.eval()
def extract_coordinates(text):
"""Extract all (x=NUMBER,y=NUMBER) coordinates from text."""
pattern = r'\(x=([0-9.]+),y=([0-9.]+)\)'
matches = list(re.finditer(pattern, text))
if not matches:
return []
results = []
for match in matches:
x = float(match.group(1))
y = float(match.group(2))
results.append((x, y))
return results
def draw_numbered_dots(image, coordinates):
"""Draw small blue numbered circles."""
annotated = image.copy()
draw = ImageDraw.Draw(annotated)
img_w, img_h = image.size
BLUE = '#1E88E5'
WHITE = '#FFFFFF'
if not coordinates:
return annotated, 0
# Calculate scale based on max coordinates
max_x = max(c[0] for c in coordinates)
max_y = max(c[1] for c in coordinates)
scale_x = img_w / max(max_x * 1.05, img_w)
scale_y = img_h / max(max_y * 1.05, img_h)
# Circle size
radius = max(10, int(min(img_w, img_h) * 0.015))
for i, (x, y) in enumerate(coordinates):
sx = x * scale_x
sy = y * scale_y
sx = min(max(sx, radius), img_w - radius)
sy = min(max(sy, radius), img_h - radius)
# Draw blue circle
draw.ellipse(
[sx - radius, sy - radius, sx + radius, sy + radius],
fill=BLUE,
outline=WHITE,
width=1
)
# Draw number
num = str(i + 1)
tx = sx - (len(num) * 3)
ty = sy - 6
draw.text((tx, ty), num, fill=WHITE)
return annotated, len(coordinates)
def analyze_xray(image, extra_info):
if image is None:
return None, "Please upload a chest X-ray image."
if not isinstance(image, Image.Image):
image = Image.fromarray(image).convert("RGB")
else:
image = image.convert("RGB")
prompt = "Given the chest X-ray and patient metadata, generate the full visually grounded diagnostic chain-of-thought reasoning, linking all findings to the numbered spatial coordinates included in the output."
if extra_info and extra_info.strip():
prompt += f"\n\nPatient information: {extra_info}"
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
}
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True
).to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=1024,
do_sample=False,
use_cache=True,
)
generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
response = processor.decode(generated_ids, skip_special_tokens=True)
# Extract coordinates and draw dots
coordinates = extract_coordinates(response)
annotated_image, count = draw_numbered_dots(image, coordinates)
if count > 0:
summary = f"\n\n---\n📍 {count} points annotated"
else:
summary = "\n\n---\n⚠️ No coordinates found"
return annotated_image, response + summary
demo = gr.Interface(
fn=analyze_xray,
inputs=[
gr.Image(type="pil", label="Upload Chest X-Ray"),
gr.Textbox(label="Patient Info (optional)", placeholder="History, demographics...", lines=2)
],
outputs=[
gr.Image(type="pil", label="Annotated X-Ray"),
gr.Textbox(label="Analysis", lines=25)
],
title="CXR",
description="Upload a chest X-ray.",
flagging_mode="never"
)
if __name__ == "__main__":
demo.launch()