VibecoderMcSwaggins commited on
Commit
f985224
·
1 Parent(s): c2f7da2

fix: address CodeRabbit feedback and P0 blockers

Browse files

Code Fixes (HIGH priority):
- Add API key/provider validation to prevent silent auth failures
- Fix hardcoded manager model in orchestrator_magentic.py (now uses settings.openai_model)
- Add bounds checking to JSON extraction in judges.py (prevents IndexError)
- Fix fragile test assertion in test_judges_hf.py

Code Quality (MEDIUM priority):
- Add explicit type annotation for models_to_try: list[str]
- Fix structured logging (f-string → structured params)
- Align fallback query count (3 queries) between handlers

Test Improvements:
- Add

@pytest
.mark.unit decorator to TestHFInferenceJudgeHandler

Documentation Sync:
- Update Phase 3 docs to match actual implementation:
- __init__ signature (simplified, no inline imports)
- _extract_json (string split with bounds checking)
- _call_with_retry (tenacity decorator, asyncio.get_running_loop())
- assess method (simplified model loop)
- Update Phase 4 docs with ChatInterface additional_inputs for BYOK

All 104 tests pass.

docs/implementation/03_phase_judge.md CHANGED
@@ -374,272 +374,167 @@ class HFInferenceJudgeHandler:
374
  "HuggingFaceH4/zephyr-7b-beta", # Ungated fallback
375
  ]
376
 
377
- def __init__(self, model_id: str | None = None):
378
  """
379
  Initialize with HF Inference client.
380
 
381
  Args:
382
- model_id: HuggingFace model ID. If None, uses fallback chain.
383
- Will automatically use HF_TOKEN from env if available.
384
  """
385
- from huggingface_hub import InferenceClient
386
- import os
387
-
388
- self.model_id = model_id or self.FALLBACK_MODELS[0]
389
- self._fallback_models = self.FALLBACK_MODELS.copy()
390
-
391
- # InferenceClient auto-reads HF_TOKEN from env
392
- self.client = InferenceClient(model=self.model_id)
393
- self._has_token = bool(os.getenv("HF_TOKEN"))
394
-
395
  self.call_count = 0
396
- self.last_question = None
397
- self.last_evidence = None
398
-
399
- logger.info(
400
- "HFInferenceJudgeHandler initialized",
401
- model=self.model_id,
402
- has_token=self._has_token,
403
- )
404
 
405
- def _extract_json(self, response: str) -> dict | None:
406
  """
407
- Robustly extract JSON from LLM response.
408
-
409
- Handles:
410
- - Raw JSON: {"key": "value"}
411
- - Markdown code blocks: ```json\n{"key": "value"}\n```
412
- - Preamble text: "Here is the JSON:\n{"key": "value"}"
413
- - Nested braces: {"outer": {"inner": "value"}}
414
-
415
- Returns:
416
- Parsed dict or None if extraction fails
417
  """
418
- import json
419
- import re
420
-
421
- # Strategy 1: Try markdown code block first
422
- code_block_match = re.search(r"```(?:json)?\s*(\{[\s\S]*?\})\s*```", response)
423
- if code_block_match:
424
- try:
425
- return json.loads(code_block_match.group(1))
426
- except json.JSONDecodeError:
427
- pass
428
-
429
- # Strategy 2: Find outermost JSON object with brace matching
430
- # This handles nested objects correctly
431
- start = response.find("{")
432
- if start == -1:
 
 
 
 
433
  return None
434
 
435
- depth = 0
436
- end = start
437
  in_string = False
438
- escape_next = False
439
-
440
- for i, char in enumerate(response[start:], start):
441
- if escape_next:
442
- escape_next = False
443
- continue
444
-
445
- if char == "\\":
446
- escape_next = True
447
- continue
448
-
449
- if char == '"' and not escape_next:
450
- in_string = not in_string
451
- continue
452
 
 
453
  if in_string:
454
- continue
455
-
456
- if char == "{":
457
- depth += 1
 
 
 
 
 
 
458
  elif char == "}":
459
- depth -= 1
460
- if depth == 0:
461
- end = i + 1
462
- break
463
-
464
- if depth == 0 and end > start:
465
- try:
466
- return json.loads(response[start:end])
467
- except json.JSONDecodeError:
468
- pass
469
 
470
  return None
471
 
472
- async def _call_with_retry(
473
- self,
474
- messages: list[dict],
475
- max_retries: int = 3,
476
- ) -> str:
477
- """
478
- Call HF Inference with exponential backoff retry.
479
-
480
- Args:
481
- messages: Chat messages in OpenAI format
482
- max_retries: Max retry attempts
483
-
484
- Returns:
485
- Response text
486
-
487
- Raises:
488
- Exception if all retries fail
489
- """
490
- import asyncio
491
- import time
492
-
493
- last_error = None
494
 
495
- for attempt in range(max_retries):
496
- try:
497
- loop = asyncio.get_event_loop()
498
- response = await loop.run_in_executor(
499
- None,
500
- lambda: self.client.chat_completion(
501
- messages=messages,
502
- max_tokens=1024,
503
- temperature=0.1,
504
- )
505
- )
506
- return response.choices[0].message.content
507
 
508
- except Exception as e:
509
- last_error = e
510
- error_str = str(e).lower()
511
-
512
- # Check if rate limited or service unavailable
513
- is_rate_limit = "429" in error_str or "rate" in error_str
514
- is_unavailable = "503" in error_str or "unavailable" in error_str
515
- is_auth_error = "401" in error_str or "403" in error_str
516
-
517
- if is_auth_error:
518
- # Gated model without token - try fallback immediately
519
- logger.warning("Auth error, trying fallback model", error=str(e))
520
- if self._try_fallback_model():
521
- continue
522
- raise
523
-
524
- if is_rate_limit or is_unavailable:
525
- # Exponential backoff: 1s, 2s, 4s
526
- wait_time = 2 ** attempt
527
- logger.warning(
528
- "Rate limited, retrying",
529
- attempt=attempt + 1,
530
- wait=wait_time,
531
- error=str(e),
532
- )
533
- await asyncio.sleep(wait_time)
534
- continue
535
-
536
- # Other errors - raise immediately
537
- raise
538
-
539
- # All retries failed - try fallback model
540
- if self._try_fallback_model():
541
- return await self._call_with_retry(messages, max_retries=1)
542
-
543
- raise last_error or Exception("All retries failed")
544
-
545
- def _try_fallback_model(self) -> bool:
546
- """
547
- Try to switch to a fallback model.
548
 
549
- Returns:
550
- True if successfully switched, False if no fallbacks left
551
- """
552
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
553
 
554
- # Remove current model from fallbacks
555
- if self.model_id in self._fallback_models:
556
- self._fallback_models.remove(self.model_id)
 
557
 
558
- if not self._fallback_models:
559
- return False
 
 
560
 
561
- # Switch to next model
562
- self.model_id = self._fallback_models[0]
563
- self.client = InferenceClient(model=self.model_id)
564
- logger.info("Switched to fallback model", model=self.model_id)
565
- return True
566
 
567
  async def assess(
568
  self,
569
  question: str,
570
- evidence: List[Evidence],
571
  ) -> JudgeAssessment:
572
  """
573
  Assess evidence using HuggingFace Inference API.
574
-
575
- Uses chat_completion API for model-agnostic prompts.
576
- Includes retry logic and fallback model chain.
577
-
578
- Args:
579
- question: The user's research question
580
- evidence: List of Evidence objects from search
581
-
582
- Returns:
583
- JudgeAssessment with evaluation results
584
  """
585
  self.call_count += 1
586
  self.last_question = question
587
  self.last_evidence = evidence
588
 
589
- # Format the prompt
590
  if evidence:
591
  user_prompt = format_user_prompt(question, evidence)
592
  else:
593
  user_prompt = format_empty_evidence_prompt(question)
594
 
595
- # Build messages in OpenAI-compatible format (works with chat_completion)
596
- json_schema = """{
597
- "details": {
598
- "mechanism_score": <int 0-10>,
599
- "mechanism_reasoning": "<string>",
600
- "clinical_evidence_score": <int 0-10>,
601
- "clinical_reasoning": "<string>",
602
- "drug_candidates": ["<string>", ...],
603
- "key_findings": ["<string>", ...]
604
- },
605
- "sufficient": <bool>,
606
- "confidence": <float 0-1>,
607
- "recommendation": "continue" | "synthesize",
608
- "next_search_queries": ["<string>", ...],
609
- "reasoning": "<string>"
610
- }"""
611
 
612
- messages = [
613
- {
614
- "role": "system",
615
- "content": f"{SYSTEM_PROMPT}\n\nIMPORTANT: Respond with ONLY valid JSON matching this schema:\n{json_schema}",
616
- },
617
- {
618
- "role": "user",
619
- "content": user_prompt,
620
- },
621
- ]
622
-
623
- try:
624
- # Call with retry and fallback
625
- response = await self._call_with_retry(messages)
626
-
627
- # Robust JSON extraction
628
- data = self._extract_json(response)
629
- if data:
630
- return JudgeAssessment(**data)
631
-
632
- # If no valid JSON, return fallback
633
- logger.warning(
634
- "HF Inference returned invalid JSON",
635
- response=response[:200],
636
- model=self.model_id,
637
- )
638
- return self._create_fallback_assessment(question, "Invalid JSON response")
639
 
640
- except Exception as e:
641
- logger.error("HF Inference failed", error=str(e), model=self.model_id)
642
- return self._create_fallback_assessment(question, str(e))
643
 
644
  def _create_fallback_assessment(
645
  self,
 
374
  "HuggingFaceH4/zephyr-7b-beta", # Ungated fallback
375
  ]
376
 
377
+ def __init__(self, model_id: str | None = None) -> None:
378
  """
379
  Initialize with HF Inference client.
380
 
381
  Args:
382
+ model_id: Optional specific model ID. If None, uses FALLBACK_MODELS chain.
 
383
  """
384
+ self.model_id = model_id
385
+ # Will automatically use HF_TOKEN from env if available
386
+ self.client = InferenceClient()
 
 
 
 
 
 
 
387
  self.call_count = 0
388
+ self.last_question: str | None = None
389
+ self.last_evidence: list[Evidence] | None = None
 
 
 
 
 
 
390
 
391
+ def _extract_json(self, text: str) -> dict[str, Any] | None:
392
  """
393
+ Robust JSON extraction that handles markdown blocks and nested braces.
 
 
 
 
 
 
 
 
 
394
  """
395
+ text = text.strip()
396
+
397
+ # Remove markdown code blocks if present (with bounds checking)
398
+ if "```json" in text:
399
+ parts = text.split("```json", 1)
400
+ if len(parts) > 1:
401
+ inner_parts = parts[1].split("```", 1)
402
+ text = inner_parts[0]
403
+ elif "```" in text:
404
+ parts = text.split("```", 1)
405
+ if len(parts) > 1:
406
+ inner_parts = parts[1].split("```", 1)
407
+ text = inner_parts[0]
408
+
409
+ text = text.strip()
410
+
411
+ # Find first '{'
412
+ start_idx = text.find("{")
413
+ if start_idx == -1:
414
  return None
415
 
416
+ # Stack-based parsing ignoring chars in strings
417
+ count = 0
418
  in_string = False
419
+ escape = False
 
 
 
 
 
 
 
 
 
 
 
 
 
420
 
421
+ for i, char in enumerate(text[start_idx:], start=start_idx):
422
  if in_string:
423
+ if escape:
424
+ escape = False
425
+ elif char == "\\":
426
+ escape = True
427
+ elif char == '"':
428
+ in_string = False
429
+ elif char == '"':
430
+ in_string = True
431
+ elif char == "{":
432
+ count += 1
433
  elif char == "}":
434
+ count -= 1
435
+ if count == 0:
436
+ try:
437
+ result = json.loads(text[start_idx : i + 1])
438
+ if isinstance(result, dict):
439
+ return result
440
+ return None
441
+ except json.JSONDecodeError:
442
+ return None
 
443
 
444
  return None
445
 
446
+ @retry(
447
+ stop=stop_after_attempt(3),
448
+ wait=wait_exponential(multiplier=1, min=1, max=4),
449
+ retry=retry_if_exception_type(Exception),
450
+ reraise=True,
451
+ )
452
+ async def _call_with_retry(self, model: str, prompt: str, question: str) -> JudgeAssessment:
453
+ """Make API call with retry logic using chat_completion."""
454
+ loop = asyncio.get_running_loop()
 
 
 
 
 
 
 
 
 
 
 
 
 
455
 
456
+ # Build messages for chat_completion (model-agnostic)
457
+ messages = [
458
+ {
459
+ "role": "system",
460
+ "content": f"""{SYSTEM_PROMPT}
 
 
 
 
 
 
 
461
 
462
+ IMPORTANT: Respond with ONLY valid JSON matching this schema:
463
+ {{
464
+ "details": {{
465
+ "mechanism_score": <int 0-10>,
466
+ "mechanism_reasoning": "<string>",
467
+ "clinical_evidence_score": <int 0-10>,
468
+ "clinical_reasoning": "<string>",
469
+ "drug_candidates": ["<string>", ...],
470
+ "key_findings": ["<string>", ...]
471
+ }},
472
+ "sufficient": <bool>,
473
+ "confidence": <float 0-1>,
474
+ "recommendation": "continue" | "synthesize",
475
+ "next_search_queries": ["<string>", ...],
476
+ "reasoning": "<string>"
477
+ }}""",
478
+ },
479
+ {"role": "user", "content": prompt},
480
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
 
482
+ # Use chat_completion (conversational task - supported by all models)
483
+ response = await loop.run_in_executor(
484
+ None,
485
+ lambda: self.client.chat_completion(
486
+ messages=messages,
487
+ model=model,
488
+ max_tokens=1024,
489
+ temperature=0.1,
490
+ ),
491
+ )
492
 
493
+ # Extract content from response
494
+ content = response.choices[0].message.content
495
+ if not content:
496
+ raise ValueError("Empty response from model")
497
 
498
+ # Extract and parse JSON
499
+ json_data = self._extract_json(content)
500
+ if not json_data:
501
+ raise ValueError("No valid JSON found in response")
502
 
503
+ return JudgeAssessment(**json_data)
 
 
 
 
504
 
505
  async def assess(
506
  self,
507
  question: str,
508
+ evidence: list[Evidence],
509
  ) -> JudgeAssessment:
510
  """
511
  Assess evidence using HuggingFace Inference API.
512
+ Attempts models in order until one succeeds.
 
 
 
 
 
 
 
 
 
513
  """
514
  self.call_count += 1
515
  self.last_question = question
516
  self.last_evidence = evidence
517
 
518
+ # Format the user prompt
519
  if evidence:
520
  user_prompt = format_user_prompt(question, evidence)
521
  else:
522
  user_prompt = format_empty_evidence_prompt(question)
523
 
524
+ models_to_try: list[str] = [self.model_id] if self.model_id else self.FALLBACK_MODELS
525
+ last_error: Exception | None = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
 
527
+ for model in models_to_try:
528
+ try:
529
+ return await self._call_with_retry(model, user_prompt, question)
530
+ except Exception as e:
531
+ logger.warning("Model failed", model=model, error=str(e))
532
+ last_error = e
533
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
534
 
535
+ # All models failed
536
+ logger.error("All HF models failed", error=str(last_error))
537
+ return self._create_fallback_assessment(question, str(last_error))
538
 
539
  def _create_fallback_assessment(
540
  self,
docs/implementation/04_phase_ui.md CHANGED
@@ -573,19 +573,43 @@ def create_demo() -> gr.Blocks:
573
  - "What existing medications show promise for Long COVID?"
574
  """)
575
 
576
- chatbot = gr.ChatInterface(
 
577
  fn=research_agent,
578
- type="messages",
579
- title="",
580
  examples=[
581
- "What drugs could be repurposed for Alzheimer's disease?",
582
- "Is metformin effective for treating cancer?",
583
- "What medications show promise for Long COVID treatment?",
584
- "Can statins be repurposed for neurological conditions?",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585
  ],
586
- retry_btn="🔄 Retry",
587
- undo_btn="↩️ Undo",
588
- clear_btn="🗑️ Clear",
589
  )
590
 
591
  gr.Markdown("""
 
573
  - "What existing medications show promise for Long COVID?"
574
  """)
575
 
576
+ # Note: additional_inputs render in an accordion below the chat input
577
+ gr.ChatInterface(
578
  fn=research_agent,
 
 
579
  examples=[
580
+ [
581
+ "What drugs could be repurposed for Alzheimer's disease?",
582
+ "simple",
583
+ "",
584
+ "openai",
585
+ ],
586
+ [
587
+ "Is metformin effective for treating cancer?",
588
+ "simple",
589
+ "",
590
+ "openai",
591
+ ],
592
+ ],
593
+ additional_inputs=[
594
+ gr.Radio(
595
+ choices=["simple", "magentic"],
596
+ value="simple",
597
+ label="Orchestrator Mode",
598
+ info="Simple: Linear | Magentic: Multi-Agent (OpenAI)",
599
+ ),
600
+ gr.Textbox(
601
+ label="API Key (Optional - Bring Your Own Key)",
602
+ placeholder="sk-... or sk-ant-...",
603
+ type="password",
604
+ info="Enter your own API key for full AI analysis. Never stored.",
605
+ ),
606
+ gr.Radio(
607
+ choices=["openai", "anthropic"],
608
+ value="openai",
609
+ label="API Provider",
610
+ info="Select the provider for your API key",
611
+ ),
612
  ],
 
 
 
613
  )
614
 
615
  gr.Markdown("""
src/agent_factory/judges.py CHANGED
@@ -195,14 +195,14 @@ class HFInferenceJudgeHandler:
195
  else:
196
  user_prompt = format_empty_evidence_prompt(question)
197
 
198
- models_to_try = [self.model_id] if self.model_id else self.FALLBACK_MODELS
199
- last_error = None
200
 
201
  for model in models_to_try:
202
  try:
203
  return await self._call_with_retry(model, user_prompt, question)
204
  except Exception as e:
205
- logger.warning(f"Model {model} failed", error=str(e))
206
  last_error = e
207
  continue
208
 
@@ -275,11 +275,17 @@ IMPORTANT: Respond with ONLY valid JSON matching this schema:
275
  """
276
  text = text.strip()
277
 
278
- # Remove markdown code blocks if present
279
  if "```json" in text:
280
- text = text.split("```json")[1].split("```")[0]
 
 
 
281
  elif "```" in text:
282
- text = text.split("```")[1].split("```")[0]
 
 
 
283
 
284
  text = text.strip()
285
 
@@ -339,6 +345,7 @@ IMPORTANT: Respond with ONLY valid JSON matching this schema:
339
  next_search_queries=[
340
  f"{question} mechanism",
341
  f"{question} clinical trials",
 
342
  ],
343
  reasoning=f"HF Inference failed: {error}. Recommend configuring OpenAI/Anthropic key.",
344
  )
 
195
  else:
196
  user_prompt = format_empty_evidence_prompt(question)
197
 
198
+ models_to_try: list[str] = [self.model_id] if self.model_id else self.FALLBACK_MODELS
199
+ last_error: Exception | None = None
200
 
201
  for model in models_to_try:
202
  try:
203
  return await self._call_with_retry(model, user_prompt, question)
204
  except Exception as e:
205
+ logger.warning("Model failed", model=model, error=str(e))
206
  last_error = e
207
  continue
208
 
 
275
  """
276
  text = text.strip()
277
 
278
+ # Remove markdown code blocks if present (with bounds checking)
279
  if "```json" in text:
280
+ parts = text.split("```json", 1)
281
+ if len(parts) > 1:
282
+ inner_parts = parts[1].split("```", 1)
283
+ text = inner_parts[0]
284
  elif "```" in text:
285
+ parts = text.split("```", 1)
286
+ if len(parts) > 1:
287
+ inner_parts = parts[1].split("```", 1)
288
+ text = inner_parts[0]
289
 
290
  text = text.strip()
291
 
 
345
  next_search_queries=[
346
  f"{question} mechanism",
347
  f"{question} clinical trials",
348
+ f"{question} drug candidates",
349
  ],
350
  reasoning=f"HF Inference failed: {error}. Recommend configuring OpenAI/Anthropic key.",
351
  )
src/app.py CHANGED
@@ -74,6 +74,14 @@ def configure_orchestrator(
74
  ):
75
  model: AnthropicModel | OpenAIModel | None = None
76
  if user_api_key:
 
 
 
 
 
 
 
 
77
  if api_provider == "anthropic":
78
  anthropic_provider = AnthropicProvider(api_key=user_api_key)
79
  model = AnthropicModel(settings.anthropic_model, provider=anthropic_provider)
 
74
  ):
75
  model: AnthropicModel | OpenAIModel | None = None
76
  if user_api_key:
77
+ # Validate key/provider match to prevent silent auth failures
78
+ if api_provider == "openai" and user_api_key.startswith("sk-ant-"):
79
+ raise ValueError("Anthropic key provided but OpenAI provider selected")
80
+ is_openai_key = user_api_key.startswith("sk-") and not user_api_key.startswith(
81
+ "sk-ant-"
82
+ )
83
+ if api_provider == "anthropic" and is_openai_key:
84
+ raise ValueError("OpenAI key provided but Anthropic provider selected")
85
  if api_provider == "anthropic":
86
  anthropic_provider = AnthropicProvider(api_key=user_api_key)
87
  model = AnthropicModel(settings.anthropic_model, provider=anthropic_provider)
src/orchestrator_magentic.py CHANGED
@@ -82,7 +82,7 @@ class MagenticOrchestrator:
82
 
83
  # Manager chat client (orchestrates the agents)
84
  manager_client = OpenAIChatClient(
85
- model_id="gpt-4o", # Good model for planning/coordination
86
  api_key=settings.openai_api_key,
87
  )
88
 
 
82
 
83
  # Manager chat client (orchestrates the agents)
84
  manager_client = OpenAIChatClient(
85
+ model_id=settings.openai_model, # Use configured model
86
  api_key=settings.openai_api_key,
87
  )
88
 
tests/unit/agent_factory/test_judges_hf.py CHANGED
@@ -8,6 +8,7 @@ from src.agent_factory.judges import HFInferenceJudgeHandler
8
  from src.utils.models import Citation, Evidence
9
 
10
 
 
11
  class TestHFInferenceJudgeHandler:
12
  """Tests for HFInferenceJudgeHandler."""
13
 
@@ -102,9 +103,9 @@ class TestHFInferenceJudgeHandler:
102
 
103
  # Should have tried all 3 fallback models
104
  assert mock_call.call_count == 3
105
- assert result.sufficient is False # Fallback assessment
106
- error_msg = "All HF models failed"
107
- assert error_msg in str(mock_call.side_effect) or "failed" in result.reasoning
108
 
109
  def test_extract_json_robustness(self, handler):
110
  """Test JSON extraction with various inputs."""
 
8
  from src.utils.models import Citation, Evidence
9
 
10
 
11
+ @pytest.mark.unit
12
  class TestHFInferenceJudgeHandler:
13
  """Tests for HFInferenceJudgeHandler."""
14
 
 
103
 
104
  # Should have tried all 3 fallback models
105
  assert mock_call.call_count == 3
106
+ # Fallback assessment should indicate failure
107
+ assert result.sufficient is False
108
+ assert "failed" in result.reasoning.lower() or "error" in result.reasoning.lower()
109
 
110
  def test_extract_json_robustness(self, handler):
111
  """Test JSON extraction with various inputs."""