File size: 5,091 Bytes
3bcd4a2
 
 
 
808989c
 
 
 
 
 
 
 
 
 
 
92e4e8d
3bcd4a2
92e4e8d
3bcd4a2
808989c
 
 
3bcd4a2
 
 
808989c
3bcd4a2
808989c
3bcd4a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
808989c
 
 
3bcd4a2
 
 
808989c
 
92e4e8d
 
3bcd4a2
 
 
 
 
 
 
 
 
 
808989c
3bcd4a2
 
808989c
 
 
3bcd4a2
 
 
 
808989c
92e4e8d
3bcd4a2
 
92e4e8d
3bcd4a2
 
808989c
 
 
203435a
 
92e4e8d
203435a
808989c
 
 
3bcd4a2
 
 
 
 
 
 
 
808989c
3bcd4a2
 
808989c
 
3bcd4a2
 
 
 
808989c
3bcd4a2
808989c
3bcd4a2
 
 
92e4e8d
 
 
3bcd4a2
 
 
 
808989c
3bcd4a2
 
808989c
3bcd4a2
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import time
import logging
from typing import Optional

# =============================
# Hugging Face cache fix for Spaces
# =============================
os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache/huggingface/transformers"
os.environ["HF_HOME"] = "/tmp/.cache/huggingface"
os.makedirs("/tmp/.cache/huggingface/transformers", exist_ok=True)

# =============================
# Imports
# =============================
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline

# =============================
# Logging
# =============================
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("biogpt_chatbot")

# =============================
# PROMPT TEMPLATES
# =============================
MEDICAL_PROMPTS = {
    "dermatology": """
You are DermX-AI, a specialized medical AI assistant trained in dermatology. 
Your role is to provide clear, evidence-based information about skin conditions, 
diagnostic insights, and treatment options. 
- Use simple but professional language, suitable for both patients and clinicians.  
- When explaining, balance medical accuracy with user-friendly clarity.  
- For any uncertain or critical cases, clearly advise consultation with a dermatologist.  
- Always include safety reminders and disclaimers.  
""",
    "general": """
You are a medical AI assistant designed to provide helpful, evidence-based health information.  
When answering:  
- Ensure accuracy and clarity in medical explanations.  
- Provide actionable lifestyle and preventive care suggestions where applicable.  
- Avoid giving definitive diagnoses or prescriptions—always emphasize professional medical consultation.  
- Be empathetic, supportive, and professional in tone.  
""",
    "disclaimer": """
⚠️ Important: I am an AI medical assistant, not a licensed healthcare professional. 
The information provided is for educational purposes only and should not be 
considered a substitute for professional medical advice, diagnosis, or treatment.  
Please consult a dermatologist or qualified healthcare provider for personalized care.  
""",
}

# =============================
# FastAPI setup
# =============================
class ChatRequest(BaseModel):
    question: str
    context: Optional[str] = None
    mode: Optional[str] = "dermatology"  # "dermatology" | "general"
    max_new_tokens: Optional[int] = 100
    temperature: Optional[float] = 0.7
    top_p: Optional[float] = 0.9

class ChatResponse(BaseModel):
    answer: str
    model: str
    took_seconds: float
    confidence: int
    sources: list

app = FastAPI(title="BioGPT-Large Medical Chatbot")

MODEL_ID = os.environ.get("MODEL_ID", "microsoft/BioGPT-Large")
generator = None

# =============================
# Load model on startup
# =============================
@app.on_event("startup")
def load_model():
    global generator
    try:
        logger.info(f"Loading Hugging Face model via pipeline: {MODEL_ID}")
        generator = pipeline("text-generation", model=MODEL_ID, device=-1)
        logger.info("Model loaded successfully.")
    except Exception as e:
        logger.exception("Failed to load model")
        generator = None

# =============================
# Root endpoint
# =============================
@app.get("/")
def root():
    return {"status": "ok", "model_loaded": generator is not None, "model": MODEL_ID}

# =============================
# Chat endpoint
# =============================
@app.post("/chat", response_model=ChatResponse)
def chat(req: ChatRequest):
    if generator is None:
        raise HTTPException(status_code=500, detail="Model not available.")

    if not req.question.strip():
        raise HTTPException(status_code=400, detail="Question cannot be empty")

    # Select system prompt
    mode = req.mode.lower() if req.mode else "dermatology"
    system_prompt = MEDICAL_PROMPTS.get(mode, MEDICAL_PROMPTS["general"])

    # Build final prompt
    prompt = f"{system_prompt}\n\nUser Question: {req.question.strip()}\n\nAI Answer:"
    if req.context:
        prompt = req.context.strip() + "\n\n" + prompt

    logger.info(f"Generating answer for question: {req.question[:80]}...")
    t0 = time.time()

    try:
        outputs = generator(
            prompt,
            max_new_tokens=req.max_new_tokens,
            temperature=req.temperature,
            top_p=req.top_p,
            do_sample=True,
            return_full_text=False,
            num_return_sequences=1,
        )

        answer = outputs[0]["generated_text"].strip()
        final_answer = f"{answer}\n\n{MEDICAL_PROMPTS['disclaimer']}"

        took = time.time() - t0
        confidence = min(95, 70 + int(len(answer) / 50))

        return ChatResponse(
            answer=final_answer,
            model=MODEL_ID,
            took_seconds=round(took, 2),
            confidence=confidence,
            sources=["HuggingFace", MODEL_ID],
        )
    except Exception as e:
        logger.exception("Generation failed")
        raise HTTPException(status_code=500, detail=str(e))