import os import time import logging from typing import Optional # ============================= # Hugging Face cache fix for Spaces # ============================= os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache/huggingface/transformers" os.environ["HF_HOME"] = "/tmp/.cache/huggingface" os.makedirs("/tmp/.cache/huggingface/transformers", exist_ok=True) # ============================= # Imports # ============================= from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import pipeline # ============================= # Logging # ============================= logging.basicConfig(level=logging.INFO) logger = logging.getLogger("biogpt_chatbot") # ============================= # PROMPT TEMPLATES # ============================= MEDICAL_PROMPTS = { "dermatology": """ You are DermX-AI, a specialized medical AI assistant trained in dermatology. Your role is to provide clear, evidence-based information about skin conditions, diagnostic insights, and treatment options. - Use simple but professional language, suitable for both patients and clinicians. - When explaining, balance medical accuracy with user-friendly clarity. - For any uncertain or critical cases, clearly advise consultation with a dermatologist. - Always include safety reminders and disclaimers. """, "general": """ You are a medical AI assistant designed to provide helpful, evidence-based health information. When answering: - Ensure accuracy and clarity in medical explanations. - Provide actionable lifestyle and preventive care suggestions where applicable. - Avoid giving definitive diagnoses or prescriptions—always emphasize professional medical consultation. - Be empathetic, supportive, and professional in tone. """, "disclaimer": """ ⚠️ Important: I am an AI medical assistant, not a licensed healthcare professional. The information provided is for educational purposes only and should not be considered a substitute for professional medical advice, diagnosis, or treatment. Please consult a dermatologist or qualified healthcare provider for personalized care. """, } # ============================= # FastAPI setup # ============================= class ChatRequest(BaseModel): question: str context: Optional[str] = None mode: Optional[str] = "dermatology" # "dermatology" | "general" max_new_tokens: Optional[int] = 100 temperature: Optional[float] = 0.7 top_p: Optional[float] = 0.9 class ChatResponse(BaseModel): answer: str model: str took_seconds: float confidence: int sources: list app = FastAPI(title="BioGPT-Large Medical Chatbot") MODEL_ID = os.environ.get("MODEL_ID", "microsoft/BioGPT-Large") generator = None # ============================= # Load model on startup # ============================= @app.on_event("startup") def load_model(): global generator try: logger.info(f"Loading Hugging Face model via pipeline: {MODEL_ID}") generator = pipeline("text-generation", model=MODEL_ID, device=-1) logger.info("Model loaded successfully.") except Exception as e: logger.exception("Failed to load model") generator = None # ============================= # Root endpoint # ============================= @app.get("/") def root(): return {"status": "ok", "model_loaded": generator is not None, "model": MODEL_ID} # ============================= # Chat endpoint # ============================= @app.post("/chat", response_model=ChatResponse) def chat(req: ChatRequest): if generator is None: raise HTTPException(status_code=500, detail="Model not available.") if not req.question.strip(): raise HTTPException(status_code=400, detail="Question cannot be empty") # Select system prompt mode = req.mode.lower() if req.mode else "dermatology" system_prompt = MEDICAL_PROMPTS.get(mode, MEDICAL_PROMPTS["general"]) # Build final prompt prompt = f"{system_prompt}\n\nUser Question: {req.question.strip()}\n\nAI Answer:" if req.context: prompt = req.context.strip() + "\n\n" + prompt logger.info(f"Generating answer for question: {req.question[:80]}...") t0 = time.time() try: outputs = generator( prompt, max_new_tokens=req.max_new_tokens, temperature=req.temperature, top_p=req.top_p, do_sample=True, return_full_text=False, num_return_sequences=1, ) answer = outputs[0]["generated_text"].strip() final_answer = f"{answer}\n\n{MEDICAL_PROMPTS['disclaimer']}" took = time.time() - t0 confidence = min(95, 70 + int(len(answer) / 50)) return ChatResponse( answer=final_answer, model=MODEL_ID, took_seconds=round(took, 2), confidence=confidence, sources=["HuggingFace", MODEL_ID], ) except Exception as e: logger.exception("Generation failed") raise HTTPException(status_code=500, detail=str(e))