Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from lime.lime_text import LimeTextExplainer | |
| import numpy as np | |
| import os | |
| app = FastAPI(title="MedGuard API") | |
| # --- NUCLEAR CORS FIX --- | |
| # Allow EVERYTHING. This rules out CORS as the problem. | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- CONFIGURATION --- | |
| MODEL_PATH = "./model" | |
| DEVICE = "cpu" | |
| print(f"π Loading Model from {MODEL_PATH}...") | |
| model = None | |
| tokenizer = None | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) | |
| model.to(DEVICE) | |
| model.eval() | |
| print("β Model Loaded Successfully!") | |
| except Exception as e: | |
| print(f"β Error loading local model: {e}") | |
| MODEL_NAME = "csebuetnlp/banglabert" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3) | |
| # --- DATA MODELS --- | |
| class QueryRequest(BaseModel): | |
| text: str | |
| class PredictionResponse(BaseModel): | |
| label: str | |
| confidence: float | |
| probs: dict | |
| explanation: list = None | |
| LABELS = ["Highly Relevant", "Partially Relevant", "Not Relevant"] | |
| def predict_proba_lime(texts): | |
| inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=128).to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| return torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy() | |
| def health_check(): | |
| return {"status": "active", "model": "MedGuard v1.0"} | |
| def predict(request: QueryRequest): | |
| if not model or not tokenizer: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| inputs = tokenizer(request.text, return_tensors="pt", truncation=True, max_length=128).to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = F.softmax(outputs.logits, dim=-1).cpu().numpy()[0] | |
| pred_idx = np.argmax(probs) | |
| # LIME (Reduced to 20 samples for speed testing) | |
| explainer = LimeTextExplainer(class_names=LABELS, split_expression=lambda x: x.split()) | |
| exp = explainer.explain_instance(request.text, predict_proba_lime, num_features=6, num_samples=20, labels=[pred_idx]) | |
| lime_features = exp.as_list(label=pred_idx) | |
| return { | |
| "label": LABELS[pred_idx], | |
| "confidence": round(float(probs[pred_idx]) * 100, 2), | |
| "probs": {l: round(float(p), 4) for l, p in zip(LABELS, probs)}, | |
| "explanation": lime_features | |
| } | |
| except Exception as e: | |
| print(f"Server Error: {e}") # Print error to backend terminal | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # Bind to localhost specifically | |
| uvicorn.run(app, host="127.0.0.1", port=8000) |