Spaces:
Sleeping
Sleeping
File size: 5,091 Bytes
3bcd4a2 808989c 92e4e8d 3bcd4a2 92e4e8d 3bcd4a2 808989c 3bcd4a2 808989c 3bcd4a2 808989c 3bcd4a2 808989c 3bcd4a2 808989c 92e4e8d 3bcd4a2 808989c 3bcd4a2 808989c 3bcd4a2 808989c 92e4e8d 3bcd4a2 92e4e8d 3bcd4a2 808989c 203435a 92e4e8d 203435a 808989c 3bcd4a2 808989c 3bcd4a2 808989c 3bcd4a2 808989c 3bcd4a2 808989c 3bcd4a2 92e4e8d 3bcd4a2 808989c 3bcd4a2 808989c 3bcd4a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import os
import time
import logging
from typing import Optional
# =============================
# Hugging Face cache fix for Spaces
# =============================
os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache/huggingface/transformers"
os.environ["HF_HOME"] = "/tmp/.cache/huggingface"
os.makedirs("/tmp/.cache/huggingface/transformers", exist_ok=True)
# =============================
# Imports
# =============================
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline
# =============================
# Logging
# =============================
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("biogpt_chatbot")
# =============================
# PROMPT TEMPLATES
# =============================
MEDICAL_PROMPTS = {
"dermatology": """
You are DermX-AI, a specialized medical AI assistant trained in dermatology.
Your role is to provide clear, evidence-based information about skin conditions,
diagnostic insights, and treatment options.
- Use simple but professional language, suitable for both patients and clinicians.
- When explaining, balance medical accuracy with user-friendly clarity.
- For any uncertain or critical cases, clearly advise consultation with a dermatologist.
- Always include safety reminders and disclaimers.
""",
"general": """
You are a medical AI assistant designed to provide helpful, evidence-based health information.
When answering:
- Ensure accuracy and clarity in medical explanations.
- Provide actionable lifestyle and preventive care suggestions where applicable.
- Avoid giving definitive diagnoses or prescriptions—always emphasize professional medical consultation.
- Be empathetic, supportive, and professional in tone.
""",
"disclaimer": """
⚠️ Important: I am an AI medical assistant, not a licensed healthcare professional.
The information provided is for educational purposes only and should not be
considered a substitute for professional medical advice, diagnosis, or treatment.
Please consult a dermatologist or qualified healthcare provider for personalized care.
""",
}
# =============================
# FastAPI setup
# =============================
class ChatRequest(BaseModel):
question: str
context: Optional[str] = None
mode: Optional[str] = "dermatology" # "dermatology" | "general"
max_new_tokens: Optional[int] = 100
temperature: Optional[float] = 0.7
top_p: Optional[float] = 0.9
class ChatResponse(BaseModel):
answer: str
model: str
took_seconds: float
confidence: int
sources: list
app = FastAPI(title="BioGPT-Large Medical Chatbot")
MODEL_ID = os.environ.get("MODEL_ID", "microsoft/BioGPT-Large")
generator = None
# =============================
# Load model on startup
# =============================
@app.on_event("startup")
def load_model():
global generator
try:
logger.info(f"Loading Hugging Face model via pipeline: {MODEL_ID}")
generator = pipeline("text-generation", model=MODEL_ID, device=-1)
logger.info("Model loaded successfully.")
except Exception as e:
logger.exception("Failed to load model")
generator = None
# =============================
# Root endpoint
# =============================
@app.get("/")
def root():
return {"status": "ok", "model_loaded": generator is not None, "model": MODEL_ID}
# =============================
# Chat endpoint
# =============================
@app.post("/chat", response_model=ChatResponse)
def chat(req: ChatRequest):
if generator is None:
raise HTTPException(status_code=500, detail="Model not available.")
if not req.question.strip():
raise HTTPException(status_code=400, detail="Question cannot be empty")
# Select system prompt
mode = req.mode.lower() if req.mode else "dermatology"
system_prompt = MEDICAL_PROMPTS.get(mode, MEDICAL_PROMPTS["general"])
# Build final prompt
prompt = f"{system_prompt}\n\nUser Question: {req.question.strip()}\n\nAI Answer:"
if req.context:
prompt = req.context.strip() + "\n\n" + prompt
logger.info(f"Generating answer for question: {req.question[:80]}...")
t0 = time.time()
try:
outputs = generator(
prompt,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
top_p=req.top_p,
do_sample=True,
return_full_text=False,
num_return_sequences=1,
)
answer = outputs[0]["generated_text"].strip()
final_answer = f"{answer}\n\n{MEDICAL_PROMPTS['disclaimer']}"
took = time.time() - t0
confidence = min(95, 70 + int(len(answer) / 50))
return ChatResponse(
answer=final_answer,
model=MODEL_ID,
took_seconds=round(took, 2),
confidence=confidence,
sources=["HuggingFace", MODEL_ID],
)
except Exception as e:
logger.exception("Generation failed")
raise HTTPException(status_code=500, detail=str(e))
|