ruslanmv commited on
Commit
a6720b5
·
1 Parent(s): a91be8c

First commit

Browse files
Files changed (4) hide show
  1. .gitignore +2 -1
  2. app/services/chat_service.py +96 -38
  3. pyproject.toml +1 -1
  4. requirements.txt +1 -1
.gitignore CHANGED
@@ -33,4 +33,5 @@ Thumbs.db
33
  .faiss/
34
  /backup
35
  copy *.*
36
- * copy.*
 
 
33
  .faiss/
34
  /backup
35
  copy *.*
36
+ * copy.*
37
+ * copy *.py
app/services/chat_service.py CHANGED
@@ -6,10 +6,10 @@ import os
6
  import re
7
  import threading
8
  from pathlib import Path
9
- from typing import List, Tuple, Dict, Optional
10
 
11
  from ..core.config import Settings
12
- from ..core.inference.client import RouterRequestsClient
13
  from ..core.rag.retriever import Retriever
14
 
15
  logger = logging.getLogger(__name__)
@@ -33,11 +33,18 @@ STOP_SEQS: List[str] = [
33
  "\nUser:", "User:", "\nAssistant:", "Assistant:"
34
  ]
35
 
 
36
  # Thread-safe singleton retriever
 
37
  _retriever_instance: Optional[Retriever] = None
38
  _retriever_lock = threading.Lock()
39
 
 
40
  def get_retriever(settings: Settings) -> Optional[Retriever]:
 
 
 
 
41
  global _retriever_instance
42
  if _retriever_instance is not None:
43
  return _retriever_instance
@@ -58,16 +65,19 @@ def get_retriever(settings: Settings) -> Optional[Retriever]:
58
  _retriever_instance = None
59
  return _retriever_instance
60
 
 
61
  # ---------- anti-repetition / anti-label helpers ----------
62
  _SENT_SPLIT = re.compile(r'(?<=[\.\!\?])\s+')
63
  _NORM = re.compile(r'[^a-z0-9\s]+')
64
 
 
65
  def _norm_sentence(s: str) -> str:
66
  s = s.lower().strip()
67
  s = _NORM.sub(' ', s)
68
  s = re.sub(r'\s+', ' ', s)
69
  return s
70
 
 
71
  def _jaccard(a: str, b: str) -> float:
72
  ta = set(a.split())
73
  tb = set(b.split())
@@ -75,6 +85,7 @@ def _jaccard(a: str, b: str) -> float:
75
  return 0.0
76
  return len(ta & tb) / max(1, len(ta | tb))
77
 
 
78
  def _squash_repetition(text: str, max_sentences: int = 4, sim_threshold: float = 0.88) -> str:
79
  t = re.sub(r'\s+', ' ', text).strip()
80
  if not t:
@@ -94,10 +105,12 @@ def _squash_repetition(text: str, max_sentences: int = 4, sim_threshold: float =
94
  break
95
  return ' '.join(out).strip()
96
 
 
97
  # Strip common label patterns
98
  _LABEL_PREFIX = re.compile(r'^\s*(?:Answer:|A:)\s*', re.IGNORECASE)
99
  _LABEL_INLINE_Q = re.compile(r'\s*(?:Question:|Q:)\s*$', re.IGNORECASE)
100
 
 
101
  def _strip_labels(text: str) -> str:
102
  s = _LABEL_PREFIX.sub('', text)
103
  # If the model tries to end with "Question:" remove that tail prompt
@@ -106,6 +119,7 @@ def _strip_labels(text: str) -> str:
106
  s = re.sub(r'\b(?:Answer:|A:)\s*', '', s, flags=re.IGNORECASE)
107
  return s.strip()
108
 
 
109
  # ---------- RAG utilities (ranking & snippets) ----------
110
  _ALIAS_TABLE: Dict[str, List[str]] = {
111
  "matrixhub": ["matrix hub", "hub api", "catalog", "registry", "cas"],
@@ -114,9 +128,11 @@ _ALIAS_TABLE: Dict[str, List[str]] = {
114
  }
115
  _WORD_RE = re.compile(r"[A-Za-z0-9_]+")
116
 
 
117
  def _normalize(text: str) -> List[str]:
118
  return [t.lower() for t in _WORD_RE.findall(text)]
119
 
 
120
  def _expand_query(q: str) -> str:
121
  ql = q.lower()
122
  extras: List[str] = []
@@ -127,6 +143,7 @@ def _expand_query(q: str) -> str:
127
  return q + " | " + " ".join(sorted(set(extras)))
128
  return q
129
 
 
130
  def _keyword_overlap_score(query: str, text: str) -> float:
131
  q_tokens = set(_normalize(query))
132
  d_tokens = set(_normalize(text))
@@ -136,6 +153,7 @@ def _keyword_overlap_score(query: str, text: str) -> float:
136
  union = len(q_tokens | d_tokens)
137
  return inter / max(1, union)
138
 
 
139
  def _domain_boost(text: str) -> float:
140
  t = text.lower()
141
  boost = 0.0
@@ -144,6 +162,7 @@ def _domain_boost(text: str) -> float:
144
  boost += 0.05
145
  return min(boost, 0.25)
146
 
 
147
  def _best_paragraphs(text: str, query: str, max_chars: int = 700) -> str:
148
  paras = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()]
149
  if not paras:
@@ -161,6 +180,7 @@ def _best_paragraphs(text: str, query: str, max_chars: int = 700) -> str:
161
  break
162
  return "\n".join(picked)
163
 
 
164
  def _cross_encoder_scores(model: Optional["CrossEncoder"], query: str, docs: List[Dict], max_pairs: int = 50) -> Optional[List[float]]:
165
  if not model:
166
  return None
@@ -171,6 +191,7 @@ def _cross_encoder_scores(model: Optional["CrossEncoder"], query: str, docs: Lis
171
  logger.warning("Cross-encoder scoring failed; continuing without it (%s)", e)
172
  return None
173
 
 
174
  def _rerank_docs(docs: List[Dict], query: str, k_final: int, reranker: Optional["CrossEncoder"] = None) -> List[Dict]:
175
  if not docs:
176
  return []
@@ -203,6 +224,7 @@ def _rerank_docs(docs: List[Dict], query: str, k_final: int, reranker: Optional[
203
  merged.sort(key=lambda x: x[0], reverse=True)
204
  return [d for _s, d in merged[:k_final]]
205
 
 
206
  def _build_context_from_docs(docs: List[Dict], query: str, max_blocks: int = 4) -> Tuple[str, List[str]]:
207
  blocks: List[str] = []
208
  sources: List[str] = []
@@ -216,22 +238,33 @@ def _build_context_from_docs(docs: List[Dict], query: str, max_blocks: int = 4)
216
  prelude = "CONTEXT (use only these facts; if missing, say you don't know):"
217
  return prelude + "\n\n" + "\n\n".join(blocks), sources
218
 
 
219
  # ----------------------------
220
  # Service
221
  # ----------------------------
222
  class ChatService:
 
 
 
 
 
223
  def __init__(self, settings: Settings):
224
  self.settings = settings
225
- self.client = RouterRequestsClient(
226
- model=settings.model.name,
227
- fallback=settings.model.fallback,
228
- provider=getattr(settings.model, "provider", None),
229
- max_retries=2,
230
- connect_timeout=10.0,
231
- read_timeout=60.0,
232
- )
 
 
 
 
233
  self.retriever = get_retriever(settings)
234
 
 
235
  self.reranker = None
236
  use_rerank = os.getenv("RAG_RERANK", "true").lower() in ("1", "true", "yes")
237
  if use_rerank and CrossEncoder is not None:
@@ -272,44 +305,69 @@ class ChatService:
272
 
273
  # ---------- Non-stream ----------
274
  def answer_with_sources(self, query: str) -> Tuple[str, List[str]]:
 
 
 
 
275
  user_msg, sources = self._augment(query)
276
- text = self.client.chat_nonstream(
277
- SYSTEM_PROMPT,
278
- user_msg,
279
- max_tokens=self.settings.model.max_new_tokens,
 
 
280
  temperature=self.settings.model.temperature,
281
- stop=STOP_SEQS,
282
- frequency_penalty=0.2, # mild anti-repeat
283
- presence_penalty=0.0,
284
  )
 
285
  text = _strip_labels(_squash_repetition(text, max_sentences=4, sim_threshold=0.88))
286
  return text, sources
287
 
288
  # ---------- Stream ----------
289
- def stream_answer(self, query: str):
 
 
 
 
290
  user_msg, _ = self._augment(query)
291
- raw = self.client.chat_stream(
292
- SYSTEM_PROMPT,
293
- user_msg,
294
- max_tokens=self.settings.model.max_new_tokens,
 
 
295
  temperature=self.settings.model.temperature,
296
- stop=STOP_SEQS,
297
- frequency_penalty=0.2,
298
- presence_penalty=0.0,
299
  )
300
 
 
 
 
 
 
 
 
 
 
301
  buf = ""
302
  emitted = ""
303
- for token in raw:
304
- if not token:
305
- continue
306
- buf += token
307
- cleaned = _squash_repetition(buf, max_sentences=4, sim_threshold=0.88)
308
- cleaned = _strip_labels(cleaned)
309
- if len(cleaned) < len(emitted):
310
- emitted = cleaned
311
- continue
312
- delta = cleaned[len(emitted):]
313
- if delta:
314
- emitted = cleaned
315
- yield delta
 
 
 
 
 
 
6
  import re
7
  import threading
8
  from pathlib import Path
9
+ from typing import List, Tuple, Dict, Optional, Iterable, Generator
10
 
11
  from ..core.config import Settings
12
+ from ..core.inference.client import ChatClient # ← multi-provider cascade (GROQ→Gemini→HF)
13
  from ..core.rag.retriever import Retriever
14
 
15
  logger = logging.getLogger(__name__)
 
33
  "\nUser:", "User:", "\nAssistant:", "Assistant:"
34
  ]
35
 
36
+ # ----------------------------
37
  # Thread-safe singleton retriever
38
+ # ----------------------------
39
  _retriever_instance: Optional[Retriever] = None
40
  _retriever_lock = threading.Lock()
41
 
42
+
43
  def get_retriever(settings: Settings) -> Optional[Retriever]:
44
+ """
45
+ Initialize and cache the Retriever once (thread-safe).
46
+ If no KB is present, returns None and logs that we run LLM-only.
47
+ """
48
  global _retriever_instance
49
  if _retriever_instance is not None:
50
  return _retriever_instance
 
65
  _retriever_instance = None
66
  return _retriever_instance
67
 
68
+
69
  # ---------- anti-repetition / anti-label helpers ----------
70
  _SENT_SPLIT = re.compile(r'(?<=[\.\!\?])\s+')
71
  _NORM = re.compile(r'[^a-z0-9\s]+')
72
 
73
+
74
  def _norm_sentence(s: str) -> str:
75
  s = s.lower().strip()
76
  s = _NORM.sub(' ', s)
77
  s = re.sub(r'\s+', ' ', s)
78
  return s
79
 
80
+
81
  def _jaccard(a: str, b: str) -> float:
82
  ta = set(a.split())
83
  tb = set(b.split())
 
85
  return 0.0
86
  return len(ta & tb) / max(1, len(ta | tb))
87
 
88
+
89
  def _squash_repetition(text: str, max_sentences: int = 4, sim_threshold: float = 0.88) -> str:
90
  t = re.sub(r'\s+', ' ', text).strip()
91
  if not t:
 
105
  break
106
  return ' '.join(out).strip()
107
 
108
+
109
  # Strip common label patterns
110
  _LABEL_PREFIX = re.compile(r'^\s*(?:Answer:|A:)\s*', re.IGNORECASE)
111
  _LABEL_INLINE_Q = re.compile(r'\s*(?:Question:|Q:)\s*$', re.IGNORECASE)
112
 
113
+
114
  def _strip_labels(text: str) -> str:
115
  s = _LABEL_PREFIX.sub('', text)
116
  # If the model tries to end with "Question:" remove that tail prompt
 
119
  s = re.sub(r'\b(?:Answer:|A:)\s*', '', s, flags=re.IGNORECASE)
120
  return s.strip()
121
 
122
+
123
  # ---------- RAG utilities (ranking & snippets) ----------
124
  _ALIAS_TABLE: Dict[str, List[str]] = {
125
  "matrixhub": ["matrix hub", "hub api", "catalog", "registry", "cas"],
 
128
  }
129
  _WORD_RE = re.compile(r"[A-Za-z0-9_]+")
130
 
131
+
132
  def _normalize(text: str) -> List[str]:
133
  return [t.lower() for t in _WORD_RE.findall(text)]
134
 
135
+
136
  def _expand_query(q: str) -> str:
137
  ql = q.lower()
138
  extras: List[str] = []
 
143
  return q + " | " + " ".join(sorted(set(extras)))
144
  return q
145
 
146
+
147
  def _keyword_overlap_score(query: str, text: str) -> float:
148
  q_tokens = set(_normalize(query))
149
  d_tokens = set(_normalize(text))
 
153
  union = len(q_tokens | d_tokens)
154
  return inter / max(1, union)
155
 
156
+
157
  def _domain_boost(text: str) -> float:
158
  t = text.lower()
159
  boost = 0.0
 
162
  boost += 0.05
163
  return min(boost, 0.25)
164
 
165
+
166
  def _best_paragraphs(text: str, query: str, max_chars: int = 700) -> str:
167
  paras = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()]
168
  if not paras:
 
180
  break
181
  return "\n".join(picked)
182
 
183
+
184
  def _cross_encoder_scores(model: Optional["CrossEncoder"], query: str, docs: List[Dict], max_pairs: int = 50) -> Optional[List[float]]:
185
  if not model:
186
  return None
 
191
  logger.warning("Cross-encoder scoring failed; continuing without it (%s)", e)
192
  return None
193
 
194
+
195
  def _rerank_docs(docs: List[Dict], query: str, k_final: int, reranker: Optional["CrossEncoder"] = None) -> List[Dict]:
196
  if not docs:
197
  return []
 
224
  merged.sort(key=lambda x: x[0], reverse=True)
225
  return [d for _s, d in merged[:k_final]]
226
 
227
+
228
  def _build_context_from_docs(docs: List[Dict], query: str, max_blocks: int = 4) -> Tuple[str, List[str]]:
229
  blocks: List[str] = []
230
  sources: List[str] = []
 
238
  prelude = "CONTEXT (use only these facts; if missing, say you don't know):"
239
  return prelude + "\n\n" + "\n\n".join(blocks), sources
240
 
241
+
242
  # ----------------------------
243
  # Service
244
  # ----------------------------
245
  class ChatService:
246
+ """
247
+ High-level Q&A service with optional RAG. Uses the multi-provider ChatClient,
248
+ honoring provider_order from configs/settings.yaml (e.g., groq → gemini → router).
249
+ """
250
+
251
  def __init__(self, settings: Settings):
252
  self.settings = settings
253
+
254
+ # Log backend + provider order for traceability
255
+ try:
256
+ order = getattr(settings, "provider_order", ["router"])
257
+ logger.info("Chat backend=%s | Provider order=%s", settings.chat_backend, order)
258
+ except Exception:
259
+ logger.info("Chat backend=%s", getattr(settings, "chat_backend", "unknown"))
260
+
261
+ # Use the multi-provider cascade: GROQ → Gemini → HF Router
262
+ self.client = ChatClient(settings)
263
+
264
+ # RAG components
265
  self.retriever = get_retriever(settings)
266
 
267
+ # Optional cross-encoder reranker
268
  self.reranker = None
269
  use_rerank = os.getenv("RAG_RERANK", "true").lower() in ("1", "true", "yes")
270
  if use_rerank and CrossEncoder is not None:
 
305
 
306
  # ---------- Non-stream ----------
307
  def answer_with_sources(self, query: str) -> Tuple[str, List[str]]:
308
+ """
309
+ Returns a concise answer and the list of source identifiers (if any).
310
+ Uses the cascade in non-streaming mode (always returns a string).
311
+ """
312
  user_msg, sources = self._augment(query)
313
+ messages = [
314
+ {"role": "system", "content": SYSTEM_PROMPT},
315
+ {"role": "user", "content": user_msg},
316
+ ]
317
+ text = self.client.chat(
318
+ messages,
319
  temperature=self.settings.model.temperature,
320
+ max_new_tokens=self.settings.model.max_new_tokens,
321
+ stream=False,
 
322
  )
323
+ # Post-process for brevity and cleanliness
324
  text = _strip_labels(_squash_repetition(text, max_sentences=4, sim_threshold=0.88))
325
  return text, sources
326
 
327
  # ---------- Stream ----------
328
+ def stream_answer(self, query: str) -> Iterable[str]:
329
+ """
330
+ Yields chunks of text as they are produced.
331
+ On GROQ, this is true token streaming; on Gemini/HF, it may yield once.
332
+ """
333
  user_msg, _ = self._augment(query)
334
+ messages = [
335
+ {"role": "system", "content": SYSTEM_PROMPT},
336
+ {"role": "user", "content": user_msg},
337
+ ]
338
+ raw = self.client.chat(
339
+ messages,
340
  temperature=self.settings.model.temperature,
341
+ max_new_tokens=self.settings.model.max_new_tokens,
342
+ stream=True,
 
343
  )
344
 
345
+ # Normalize to a generator of strings
346
+ def _iter_chunks(gen_or_text: Generator[str, None, None] | str) -> Generator[str, None, None]:
347
+ if isinstance(gen_or_text, str):
348
+ yield gen_or_text
349
+ else:
350
+ for chunk in gen_or_text:
351
+ if chunk:
352
+ yield chunk
353
+
354
  buf = ""
355
  emitted = ""
356
+ try:
357
+ for token in _iter_chunks(raw):
358
+ buf += token
359
+ cleaned = _squash_repetition(buf, max_sentences=4, sim_threshold=0.88)
360
+ cleaned = _strip_labels(cleaned)
361
+ if len(cleaned) < len(emitted):
362
+ # Cleaning shortened text; wait for more tokens
363
+ continue
364
+ delta = cleaned[len(emitted):]
365
+ if delta:
366
+ emitted = cleaned
367
+ yield delta
368
+ except Exception as e:
369
+ logger.error("Streaming error: %s", e)
370
+ # Best-effort final flush
371
+ final = _strip_labels(_squash_repetition(buf, max_sentences=4, sim_threshold=0.88)).strip()
372
+ if final and final != emitted:
373
+ yield final[len(emitted):]
pyproject.toml CHANGED
@@ -11,7 +11,7 @@ requires-python = ">=3.11"
11
  license = { text = "Apache-2.0" }
12
  dependencies = [
13
  "fastapi==0.111.0",
14
- "groq==0.9.0",
15
  "uvicorn[standard]==0.29.0",
16
  "httpx==0.28.1",
17
  "pydantic==2.7.1",
 
11
  license = { text = "Apache-2.0" }
12
  dependencies = [
13
  "fastapi==0.111.0",
14
+ "groq==0.32.0",
15
  "uvicorn[standard]==0.29.0",
16
  "httpx==0.28.1",
17
  "pydantic==2.7.1",
requirements.txt CHANGED
@@ -20,7 +20,7 @@ mypy
20
  pytest-asyncio
21
 
22
  # Additional libraries for extended functionality
23
- groq==0.9.0
24
  python-dotenv==1.0.1
25
  google-genai==1.39.1
26
 
 
20
  pytest-asyncio
21
 
22
  # Additional libraries for extended functionality
23
+ groq==0.32.0
24
  python-dotenv==1.0.1
25
  google-genai==1.39.1
26