rudra0410hf commited on
Commit
808989c
·
verified ·
1 Parent(s): 607b349

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -14
app.py CHANGED
@@ -2,16 +2,30 @@ import os
2
  import time
3
  import logging
4
  from typing import Optional
 
 
 
 
 
 
 
 
 
 
 
5
  from fastapi import FastAPI, HTTPException
6
  from pydantic import BaseModel
7
  from transformers import pipeline
8
 
 
 
 
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger("biogpt_chatbot")
11
 
12
- # =========================
13
  # PROMPT TEMPLATES
14
- # =========================
15
  MEDICAL_PROMPTS = {
16
  "dermatology": """
17
  You are DermX-AI, a specialized medical AI assistant trained in dermatology.
@@ -38,14 +52,14 @@ Please consult a dermatologist or qualified healthcare provider for personalized
38
  """,
39
  }
40
 
41
- # =========================
42
- # REQUEST/RESPONSE MODELS
43
- # =========================
44
  class ChatRequest(BaseModel):
45
  question: str
46
  context: Optional[str] = None
47
- mode: Optional[str] = "dermatology" # dermatology | general
48
- max_new_tokens: Optional[int] = 200
49
  temperature: Optional[float] = 0.7
50
  top_p: Optional[float] = 0.9
51
 
@@ -56,30 +70,35 @@ class ChatResponse(BaseModel):
56
  confidence: int
57
  sources: list
58
 
59
- # =========================
60
- # FASTAPI SETUP
61
- # =========================
62
  app = FastAPI(title="BioGPT-Large Medical Chatbot")
63
 
64
- MODEL_ID = "microsoft/BioGPT-Large"
65
  generator = None
66
 
 
 
 
67
  @app.on_event("startup")
68
  def load_model():
69
  global generator
70
- logger.info(f"Loading Hugging Face model via pipeline: {MODEL_ID}")
71
  try:
72
- # Use HF hosted model (CPU is fine, HF handles backend)
73
  generator = pipeline("text-generation", model=MODEL_ID, device=-1)
74
  logger.info("Model loaded successfully.")
75
  except Exception as e:
76
  logger.exception("Failed to load model")
77
  generator = None
78
 
 
 
 
79
  @app.get("/")
80
  def root():
81
  return {"status": "ok", "model_loaded": generator is not None, "model": MODEL_ID}
82
 
 
 
 
83
  @app.post("/chat", response_model=ChatResponse)
84
  def chat(req: ChatRequest):
85
  if generator is None:
@@ -88,14 +107,18 @@ def chat(req: ChatRequest):
88
  if not req.question.strip():
89
  raise HTTPException(status_code=400, detail="Question cannot be empty")
90
 
91
- # Build prompt
92
  mode = req.mode.lower() if req.mode else "dermatology"
93
  system_prompt = MEDICAL_PROMPTS.get(mode, MEDICAL_PROMPTS["general"])
 
 
94
  prompt = f"{system_prompt}\n\nUser Question: {req.question.strip()}\n\nAI Answer:"
95
  if req.context:
96
  prompt = req.context.strip() + "\n\n" + prompt
97
 
 
98
  t0 = time.time()
 
99
  try:
100
  outputs = generator(
101
  prompt,
@@ -106,8 +129,10 @@ def chat(req: ChatRequest):
106
  return_full_text=False,
107
  num_return_sequences=1,
108
  )
 
109
  answer = outputs[0]["generated_text"].strip()
110
  final_answer = f"{answer}\n\n{MEDICAL_PROMPTS['disclaimer']}"
 
111
  took = time.time() - t0
112
  confidence = min(95, 70 + int(len(answer) / 50))
113
 
 
2
  import time
3
  import logging
4
  from typing import Optional
5
+
6
+ # =============================
7
+ # Hugging Face cache fix for Spaces
8
+ # =============================
9
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache/huggingface/transformers"
10
+ os.environ["HF_HOME"] = "/tmp/.cache/huggingface"
11
+ os.makedirs("/tmp/.cache/huggingface/transformers", exist_ok=True)
12
+
13
+ # =============================
14
+ # Imports
15
+ # =============================
16
  from fastapi import FastAPI, HTTPException
17
  from pydantic import BaseModel
18
  from transformers import pipeline
19
 
20
+ # =============================
21
+ # Logging
22
+ # =============================
23
  logging.basicConfig(level=logging.INFO)
24
  logger = logging.getLogger("biogpt_chatbot")
25
 
26
+ # =============================
27
  # PROMPT TEMPLATES
28
+ # =============================
29
  MEDICAL_PROMPTS = {
30
  "dermatology": """
31
  You are DermX-AI, a specialized medical AI assistant trained in dermatology.
 
52
  """,
53
  }
54
 
55
+ # =============================
56
+ # FastAPI setup
57
+ # =============================
58
  class ChatRequest(BaseModel):
59
  question: str
60
  context: Optional[str] = None
61
+ mode: Optional[str] = "dermatology" # "dermatology" | "general"
62
+ max_new_tokens: Optional[int] = 100
63
  temperature: Optional[float] = 0.7
64
  top_p: Optional[float] = 0.9
65
 
 
70
  confidence: int
71
  sources: list
72
 
 
 
 
73
  app = FastAPI(title="BioGPT-Large Medical Chatbot")
74
 
75
+ MODEL_ID = os.environ.get("MODEL_ID", "microsoft/BioGPT-Large")
76
  generator = None
77
 
78
+ # =============================
79
+ # Load model on startup
80
+ # =============================
81
  @app.on_event("startup")
82
  def load_model():
83
  global generator
 
84
  try:
85
+ logger.info(f"Loading Hugging Face model via pipeline: {MODEL_ID}")
86
  generator = pipeline("text-generation", model=MODEL_ID, device=-1)
87
  logger.info("Model loaded successfully.")
88
  except Exception as e:
89
  logger.exception("Failed to load model")
90
  generator = None
91
 
92
+ # =============================
93
+ # Root endpoint
94
+ # =============================
95
  @app.get("/")
96
  def root():
97
  return {"status": "ok", "model_loaded": generator is not None, "model": MODEL_ID}
98
 
99
+ # =============================
100
+ # Chat endpoint
101
+ # =============================
102
  @app.post("/chat", response_model=ChatResponse)
103
  def chat(req: ChatRequest):
104
  if generator is None:
 
107
  if not req.question.strip():
108
  raise HTTPException(status_code=400, detail="Question cannot be empty")
109
 
110
+ # Select system prompt
111
  mode = req.mode.lower() if req.mode else "dermatology"
112
  system_prompt = MEDICAL_PROMPTS.get(mode, MEDICAL_PROMPTS["general"])
113
+
114
+ # Build final prompt
115
  prompt = f"{system_prompt}\n\nUser Question: {req.question.strip()}\n\nAI Answer:"
116
  if req.context:
117
  prompt = req.context.strip() + "\n\n" + prompt
118
 
119
+ logger.info(f"Generating answer for question: {req.question[:80]}...")
120
  t0 = time.time()
121
+
122
  try:
123
  outputs = generator(
124
  prompt,
 
129
  return_full_text=False,
130
  num_return_sequences=1,
131
  )
132
+
133
  answer = outputs[0]["generated_text"].strip()
134
  final_answer = f"{answer}\n\n{MEDICAL_PROMPTS['disclaimer']}"
135
+
136
  took = time.time() - t0
137
  confidence = min(95, 70 + int(len(answer) / 50))
138