mgbam commited on
Commit
d93eac7
·
verified ·
1 Parent(s): e0076d2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -0
app.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json, io
2
+ import gradio as gr
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch, torch.nn as nn
6
+ from torchvision import models, transforms, datasets
7
+ from cam_utils import grad_cam
8
+
9
+ # -------- Config --------
10
+ MODEL_NAME = os.environ.get("MODEL_NAME", "efficientnet_b0")
11
+ NUM_CLASSES = int(os.environ.get("NUM_CLASSES", "2"))
12
+ IMAGE_SIZE = int(os.environ.get("IMAGE_SIZE", "224"))
13
+ WEIGHTS_PATH = os.environ.get("WEIGHTS_PATH", "checkpoints/best.pt") # you will upload this file
14
+ CLASS_NAMES = ["Parasitized", "Uninfected"] if NUM_CLASSES == 2 else [str(i) for i in range(NUM_CLASSES)]
15
+
16
+ # -------- Model loading --------
17
+ def build_model(name: str, num_classes: int):
18
+ name = name.lower()
19
+ if name == "efficientnet_b0":
20
+ m = models.efficientnet_b0(weights=None)
21
+ m.classifier[1] = nn.Linear(m.classifier[1].in_features, num_classes)
22
+ return m
23
+ elif name == "resnet50":
24
+ m = models.resnet50(weights=None)
25
+ m.fc = nn.Linear(m.fc.in_features, num_classes)
26
+ return m
27
+ else:
28
+ raise ValueError(f"Unsupported model_name: {name}")
29
+
30
+ _device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ _model = build_model(MODEL_NAME, NUM_CLASSES).to(_device)
32
+
33
+ if not os.path.exists(WEIGHTS_PATH):
34
+ raise FileNotFoundError(
35
+ f"Missing weights at {WEIGHTS_PATH}. "
36
+ "Upload your trained file (e.g., checkpoints/best.pt) to the Space repo."
37
+ )
38
+ _state = torch.load(WEIGHTS_PATH, map_location=_device)
39
+ _model.load_state_dict(_state)
40
+ _model.eval()
41
+
42
+ # -------- Inference helpers --------
43
+ _pre = transforms.Compose([
44
+ transforms.Resize(int(IMAGE_SIZE*1.15)),
45
+ transforms.CenterCrop(IMAGE_SIZE),
46
+ transforms.ToTensor(),
47
+ ])
48
+
49
+ def predict(image: Image.Image, show_cam: bool):
50
+ if image is None:
51
+ return None, None, None
52
+ img = image.convert("RGB")
53
+ x = _pre(img).unsqueeze(0).to(_device)
54
+ with torch.no_grad():
55
+ logits = _model(x).cpu().numpy().squeeze()
56
+ probs = np.exp(logits - logits.max()); probs = probs / probs.sum()
57
+ pred_idx = int(np.argmax(probs))
58
+ label = CLASS_NAMES[pred_idx]
59
+ probs_dict = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
60
+
61
+ overlay = None
62
+ if show_cam:
63
+ cam = grad_cam(_model, img, img_size=IMAGE_SIZE, device=_device)
64
+ overlay = Image.fromarray((cam["overlay"]*255).astype("uint8"))
65
+
66
+ return label, probs_dict, overlay
67
+
68
+ # -------- Optional: quick validation on an uploaded ZIP of a val folder --------
69
+ # Structure expected: root_dir/ClassA/*.png, root_dir/ClassB/*.png, ...
70
+ def validate(zip_or_folder):
71
+ import tempfile, zipfile, shutil
72
+ if zip_or_folder is None:
73
+ return "Upload a .zip of your validation set."
74
+
75
+ tmp = tempfile.mkdtemp()
76
+ path = zip_or_folder.name
77
+ if path.endswith(".zip"):
78
+ with zipfile.ZipFile(zip_or_folder.name, 'r') as zf:
79
+ zf.extractall(tmp)
80
+ root = tmp
81
+ else:
82
+ root = path # folder
83
+
84
+ ds = datasets.ImageFolder(root, transform=_pre)
85
+ dl = torch.utils.data.DataLoader(ds, batch_size=64, shuffle=False, num_workers=2)
86
+ ys, ps = [], []
87
+ with torch.no_grad():
88
+ for xb, yb in dl:
89
+ preds = _model(xb.to(_device)).argmax(1).cpu().numpy()
90
+ ys.extend(yb.numpy()); ps.extend(preds)
91
+ import sklearn.metrics as sk
92
+ rep = sk.classification_report(ys, ps, target_names=ds.classes, output_dict=True)
93
+ cm = sk.confusion_matrix(ys, ps)
94
+
95
+ # render small CM image
96
+ import matplotlib.pyplot as plt
97
+ import seaborn as sns
98
+ fig, ax = plt.subplots(figsize=(4.5,4))
99
+ sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
100
+ xticklabels=ds.classes, yticklabels=ds.classes, ax=ax)
101
+ ax.set_xlabel("Predicted"); ax.set_ylabel("True"); fig.tight_layout()
102
+ buf = io.BytesIO(); fig.savefig(buf, format="png", dpi=160); buf.seek(0)
103
+ cm_img = buf
104
+
105
+ return json.dumps(rep, indent=2), cm_img
106
+
107
+ # -------- Gradio UI --------
108
+ with gr.Blocks(title="Malaria Diagnostic Assistant") as demo:
109
+ gr.Markdown("# 🩸 Malaria Diagnostic Assistant")
110
+ gr.Markdown("Prototype — energy-efficient triage with human-in-the-loop (Adaptive Sparse Training)")
111
+
112
+ with gr.Tab("🔍 Inference"):
113
+ with gr.Row():
114
+ with gr.Column():
115
+ img_in = gr.Image(type="pil", label="Upload blood smear image")
116
+ show_cam = gr.Checkbox(value=True, label="Show Grad-CAM")
117
+ btn = gr.Button("Predict", variant="primary")
118
+ with gr.Column():
119
+ label_out = gr.Label(num_top_classes=2, label="Prediction & Probabilities")
120
+ cam_out = gr.Image(type="pil", label="Grad-CAM overlay")
121
+ btn.click(fn=predict, inputs=[img_in, show_cam], outputs=[label_out, label_out, cam_out])
122
+
123
+ with gr.Tab("✅ Validation (optional)"):
124
+ gr.Markdown("Upload a **.zip** containing a folder with class subfolders (e.g., `Parasitized/`, `Uninfected/`).")
125
+ val_zip = gr.File(label="Validation ZIP", file_types=[".zip"])
126
+ run_eval = gr.Button("Compute report + confusion matrix")
127
+ rep_out = gr.Textbox(label="classification_report (JSON)")
128
+ cm_img = gr.Image(type="filepath", label="Confusion Matrix")
129
+ run_eval.click(fn=validate, inputs=[val_zip], outputs=[rep_out, cm_img])
130
+
131
+ gr.Markdown("---")
132
+ gr.Markdown("Built with EfficientNet-B0 + Adaptive Sparse Training (AST) — not a diagnostic device.")
133
+
134
+ if __name__ == "__main__":
135
+ demo.launch()