Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, json, io
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import torch, torch.nn as nn
|
| 6 |
+
from torchvision import models, transforms, datasets
|
| 7 |
+
from cam_utils import grad_cam
|
| 8 |
+
|
| 9 |
+
# -------- Config --------
|
| 10 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "efficientnet_b0")
|
| 11 |
+
NUM_CLASSES = int(os.environ.get("NUM_CLASSES", "2"))
|
| 12 |
+
IMAGE_SIZE = int(os.environ.get("IMAGE_SIZE", "224"))
|
| 13 |
+
WEIGHTS_PATH = os.environ.get("WEIGHTS_PATH", "checkpoints/best.pt") # you will upload this file
|
| 14 |
+
CLASS_NAMES = ["Parasitized", "Uninfected"] if NUM_CLASSES == 2 else [str(i) for i in range(NUM_CLASSES)]
|
| 15 |
+
|
| 16 |
+
# -------- Model loading --------
|
| 17 |
+
def build_model(name: str, num_classes: int):
|
| 18 |
+
name = name.lower()
|
| 19 |
+
if name == "efficientnet_b0":
|
| 20 |
+
m = models.efficientnet_b0(weights=None)
|
| 21 |
+
m.classifier[1] = nn.Linear(m.classifier[1].in_features, num_classes)
|
| 22 |
+
return m
|
| 23 |
+
elif name == "resnet50":
|
| 24 |
+
m = models.resnet50(weights=None)
|
| 25 |
+
m.fc = nn.Linear(m.fc.in_features, num_classes)
|
| 26 |
+
return m
|
| 27 |
+
else:
|
| 28 |
+
raise ValueError(f"Unsupported model_name: {name}")
|
| 29 |
+
|
| 30 |
+
_device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 31 |
+
_model = build_model(MODEL_NAME, NUM_CLASSES).to(_device)
|
| 32 |
+
|
| 33 |
+
if not os.path.exists(WEIGHTS_PATH):
|
| 34 |
+
raise FileNotFoundError(
|
| 35 |
+
f"Missing weights at {WEIGHTS_PATH}. "
|
| 36 |
+
"Upload your trained file (e.g., checkpoints/best.pt) to the Space repo."
|
| 37 |
+
)
|
| 38 |
+
_state = torch.load(WEIGHTS_PATH, map_location=_device)
|
| 39 |
+
_model.load_state_dict(_state)
|
| 40 |
+
_model.eval()
|
| 41 |
+
|
| 42 |
+
# -------- Inference helpers --------
|
| 43 |
+
_pre = transforms.Compose([
|
| 44 |
+
transforms.Resize(int(IMAGE_SIZE*1.15)),
|
| 45 |
+
transforms.CenterCrop(IMAGE_SIZE),
|
| 46 |
+
transforms.ToTensor(),
|
| 47 |
+
])
|
| 48 |
+
|
| 49 |
+
def predict(image: Image.Image, show_cam: bool):
|
| 50 |
+
if image is None:
|
| 51 |
+
return None, None, None
|
| 52 |
+
img = image.convert("RGB")
|
| 53 |
+
x = _pre(img).unsqueeze(0).to(_device)
|
| 54 |
+
with torch.no_grad():
|
| 55 |
+
logits = _model(x).cpu().numpy().squeeze()
|
| 56 |
+
probs = np.exp(logits - logits.max()); probs = probs / probs.sum()
|
| 57 |
+
pred_idx = int(np.argmax(probs))
|
| 58 |
+
label = CLASS_NAMES[pred_idx]
|
| 59 |
+
probs_dict = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
|
| 60 |
+
|
| 61 |
+
overlay = None
|
| 62 |
+
if show_cam:
|
| 63 |
+
cam = grad_cam(_model, img, img_size=IMAGE_SIZE, device=_device)
|
| 64 |
+
overlay = Image.fromarray((cam["overlay"]*255).astype("uint8"))
|
| 65 |
+
|
| 66 |
+
return label, probs_dict, overlay
|
| 67 |
+
|
| 68 |
+
# -------- Optional: quick validation on an uploaded ZIP of a val folder --------
|
| 69 |
+
# Structure expected: root_dir/ClassA/*.png, root_dir/ClassB/*.png, ...
|
| 70 |
+
def validate(zip_or_folder):
|
| 71 |
+
import tempfile, zipfile, shutil
|
| 72 |
+
if zip_or_folder is None:
|
| 73 |
+
return "Upload a .zip of your validation set."
|
| 74 |
+
|
| 75 |
+
tmp = tempfile.mkdtemp()
|
| 76 |
+
path = zip_or_folder.name
|
| 77 |
+
if path.endswith(".zip"):
|
| 78 |
+
with zipfile.ZipFile(zip_or_folder.name, 'r') as zf:
|
| 79 |
+
zf.extractall(tmp)
|
| 80 |
+
root = tmp
|
| 81 |
+
else:
|
| 82 |
+
root = path # folder
|
| 83 |
+
|
| 84 |
+
ds = datasets.ImageFolder(root, transform=_pre)
|
| 85 |
+
dl = torch.utils.data.DataLoader(ds, batch_size=64, shuffle=False, num_workers=2)
|
| 86 |
+
ys, ps = [], []
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
for xb, yb in dl:
|
| 89 |
+
preds = _model(xb.to(_device)).argmax(1).cpu().numpy()
|
| 90 |
+
ys.extend(yb.numpy()); ps.extend(preds)
|
| 91 |
+
import sklearn.metrics as sk
|
| 92 |
+
rep = sk.classification_report(ys, ps, target_names=ds.classes, output_dict=True)
|
| 93 |
+
cm = sk.confusion_matrix(ys, ps)
|
| 94 |
+
|
| 95 |
+
# render small CM image
|
| 96 |
+
import matplotlib.pyplot as plt
|
| 97 |
+
import seaborn as sns
|
| 98 |
+
fig, ax = plt.subplots(figsize=(4.5,4))
|
| 99 |
+
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
|
| 100 |
+
xticklabels=ds.classes, yticklabels=ds.classes, ax=ax)
|
| 101 |
+
ax.set_xlabel("Predicted"); ax.set_ylabel("True"); fig.tight_layout()
|
| 102 |
+
buf = io.BytesIO(); fig.savefig(buf, format="png", dpi=160); buf.seek(0)
|
| 103 |
+
cm_img = buf
|
| 104 |
+
|
| 105 |
+
return json.dumps(rep, indent=2), cm_img
|
| 106 |
+
|
| 107 |
+
# -------- Gradio UI --------
|
| 108 |
+
with gr.Blocks(title="Malaria Diagnostic Assistant") as demo:
|
| 109 |
+
gr.Markdown("# 🩸 Malaria Diagnostic Assistant")
|
| 110 |
+
gr.Markdown("Prototype — energy-efficient triage with human-in-the-loop (Adaptive Sparse Training)")
|
| 111 |
+
|
| 112 |
+
with gr.Tab("🔍 Inference"):
|
| 113 |
+
with gr.Row():
|
| 114 |
+
with gr.Column():
|
| 115 |
+
img_in = gr.Image(type="pil", label="Upload blood smear image")
|
| 116 |
+
show_cam = gr.Checkbox(value=True, label="Show Grad-CAM")
|
| 117 |
+
btn = gr.Button("Predict", variant="primary")
|
| 118 |
+
with gr.Column():
|
| 119 |
+
label_out = gr.Label(num_top_classes=2, label="Prediction & Probabilities")
|
| 120 |
+
cam_out = gr.Image(type="pil", label="Grad-CAM overlay")
|
| 121 |
+
btn.click(fn=predict, inputs=[img_in, show_cam], outputs=[label_out, label_out, cam_out])
|
| 122 |
+
|
| 123 |
+
with gr.Tab("✅ Validation (optional)"):
|
| 124 |
+
gr.Markdown("Upload a **.zip** containing a folder with class subfolders (e.g., `Parasitized/`, `Uninfected/`).")
|
| 125 |
+
val_zip = gr.File(label="Validation ZIP", file_types=[".zip"])
|
| 126 |
+
run_eval = gr.Button("Compute report + confusion matrix")
|
| 127 |
+
rep_out = gr.Textbox(label="classification_report (JSON)")
|
| 128 |
+
cm_img = gr.Image(type="filepath", label="Confusion Matrix")
|
| 129 |
+
run_eval.click(fn=validate, inputs=[val_zip], outputs=[rep_out, cm_img])
|
| 130 |
+
|
| 131 |
+
gr.Markdown("---")
|
| 132 |
+
gr.Markdown("Built with EfficientNet-B0 + Adaptive Sparse Training (AST) — not a diagnostic device.")
|
| 133 |
+
|
| 134 |
+
if __name__ == "__main__":
|
| 135 |
+
demo.launch()
|