Spaces:
Running
Running
Commit
·
499170b
1
Parent(s):
32fcf60
feat(phase2): implement search slice (PubMed, Web, Orchestrator) (#4)
Browse files- src/tools/__init__.py +9 -0
- src/tools/base.py +31 -0
- src/tools/pubmed.py +186 -0
- src/tools/search_handler.py +95 -0
- src/tools/websearch.py +59 -0
- src/utils/models.py +45 -0
- tests/unit/tools/test_pubmed.py +99 -0
- tests/unit/tools/test_search_handler.py +74 -0
- tests/unit/tools/test_websearch.py +36 -0
src/tools/__init__.py
CHANGED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Search tools package."""
|
| 2 |
+
|
| 3 |
+
from src.tools.base import SearchTool
|
| 4 |
+
from src.tools.pubmed import PubMedTool
|
| 5 |
+
from src.tools.search_handler import SearchHandler
|
| 6 |
+
from src.tools.websearch import WebTool
|
| 7 |
+
|
| 8 |
+
# Re-export
|
| 9 |
+
__all__ = ["PubMedTool", "SearchHandler", "SearchTool", "WebTool"]
|
src/tools/base.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Base classes and protocols for search tools."""
|
| 2 |
+
|
| 3 |
+
from typing import Protocol
|
| 4 |
+
|
| 5 |
+
from src.utils.models import Evidence
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SearchTool(Protocol):
|
| 9 |
+
"""Protocol defining the interface for all search tools."""
|
| 10 |
+
|
| 11 |
+
@property
|
| 12 |
+
def name(self) -> str:
|
| 13 |
+
"""Human-readable name of this tool."""
|
| 14 |
+
...
|
| 15 |
+
|
| 16 |
+
async def search(self, query: str, max_results: int = 10) -> list[Evidence]:
|
| 17 |
+
"""
|
| 18 |
+
Execute a search and return evidence.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
query: The search query string
|
| 22 |
+
max_results: Maximum number of results to return
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
List of Evidence objects
|
| 26 |
+
|
| 27 |
+
Raises:
|
| 28 |
+
SearchError: If the search fails
|
| 29 |
+
RateLimitError: If we hit rate limits
|
| 30 |
+
"""
|
| 31 |
+
...
|
src/tools/pubmed.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PubMed search tool using NCBI E-utilities."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
import httpx
|
| 7 |
+
import xmltodict
|
| 8 |
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
| 9 |
+
|
| 10 |
+
from src.utils.config import settings
|
| 11 |
+
from src.utils.exceptions import RateLimitError, SearchError
|
| 12 |
+
from src.utils.models import Citation, Evidence
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class PubMedTool:
|
| 16 |
+
"""Search tool for PubMed/NCBI."""
|
| 17 |
+
|
| 18 |
+
BASE_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
|
| 19 |
+
RATE_LIMIT_DELAY = 0.34 # ~3 requests/sec without API key
|
| 20 |
+
HTTP_TOO_MANY_REQUESTS = 429
|
| 21 |
+
|
| 22 |
+
def __init__(self, api_key: str | None = None) -> None:
|
| 23 |
+
self.api_key = api_key or getattr(settings, "ncbi_api_key", None)
|
| 24 |
+
self._last_request_time = 0.0
|
| 25 |
+
|
| 26 |
+
@property
|
| 27 |
+
def name(self) -> str:
|
| 28 |
+
return "pubmed"
|
| 29 |
+
|
| 30 |
+
async def _rate_limit(self) -> None:
|
| 31 |
+
"""Enforce NCBI rate limiting."""
|
| 32 |
+
now = asyncio.get_event_loop().time()
|
| 33 |
+
elapsed = now - self._last_request_time
|
| 34 |
+
if elapsed < self.RATE_LIMIT_DELAY:
|
| 35 |
+
await asyncio.sleep(self.RATE_LIMIT_DELAY - elapsed)
|
| 36 |
+
self._last_request_time = asyncio.get_event_loop().time()
|
| 37 |
+
|
| 38 |
+
def _build_params(self, **kwargs: Any) -> dict[str, Any]:
|
| 39 |
+
"""Build request params with optional API key."""
|
| 40 |
+
params = {**kwargs, "retmode": "json"}
|
| 41 |
+
if self.api_key:
|
| 42 |
+
params["api_key"] = self.api_key
|
| 43 |
+
return params
|
| 44 |
+
|
| 45 |
+
@retry( # type: ignore[misc]
|
| 46 |
+
stop=stop_after_attempt(3),
|
| 47 |
+
wait=wait_exponential(multiplier=1, min=1, max=10),
|
| 48 |
+
reraise=True,
|
| 49 |
+
)
|
| 50 |
+
async def search(self, query: str, max_results: int = 10) -> list[Evidence]:
|
| 51 |
+
"""
|
| 52 |
+
Search PubMed and return evidence.
|
| 53 |
+
|
| 54 |
+
1. ESearch: Get PMIDs matching query
|
| 55 |
+
2. EFetch: Get abstracts for those PMIDs
|
| 56 |
+
3. Parse and return Evidence objects
|
| 57 |
+
"""
|
| 58 |
+
await self._rate_limit()
|
| 59 |
+
|
| 60 |
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
| 61 |
+
# Step 1: Search for PMIDs
|
| 62 |
+
search_params = self._build_params(
|
| 63 |
+
db="pubmed",
|
| 64 |
+
term=query,
|
| 65 |
+
retmax=max_results,
|
| 66 |
+
sort="relevance",
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
search_resp = await client.get(
|
| 71 |
+
f"{self.BASE_URL}/esearch.fcgi",
|
| 72 |
+
params=search_params,
|
| 73 |
+
)
|
| 74 |
+
search_resp.raise_for_status()
|
| 75 |
+
except httpx.HTTPStatusError as e:
|
| 76 |
+
if e.response.status_code == self.HTTP_TOO_MANY_REQUESTS:
|
| 77 |
+
raise RateLimitError("PubMed rate limit exceeded") from e
|
| 78 |
+
raise SearchError(f"PubMed search failed: {e}") from e
|
| 79 |
+
|
| 80 |
+
search_data = search_resp.json()
|
| 81 |
+
pmids = search_data.get("esearchresult", {}).get("idlist", [])
|
| 82 |
+
|
| 83 |
+
if not pmids:
|
| 84 |
+
return []
|
| 85 |
+
|
| 86 |
+
# Step 2: Fetch abstracts
|
| 87 |
+
await self._rate_limit()
|
| 88 |
+
fetch_params = self._build_params(
|
| 89 |
+
db="pubmed",
|
| 90 |
+
id=",".join(pmids),
|
| 91 |
+
rettype="abstract",
|
| 92 |
+
)
|
| 93 |
+
# Use XML for fetch (more reliable parsing)
|
| 94 |
+
fetch_params["retmode"] = "xml"
|
| 95 |
+
|
| 96 |
+
fetch_resp = await client.get(
|
| 97 |
+
f"{self.BASE_URL}/efetch.fcgi",
|
| 98 |
+
params=fetch_params,
|
| 99 |
+
)
|
| 100 |
+
fetch_resp.raise_for_status()
|
| 101 |
+
|
| 102 |
+
# Step 3: Parse XML to Evidence
|
| 103 |
+
return self._parse_pubmed_xml(fetch_resp.text)
|
| 104 |
+
|
| 105 |
+
def _parse_pubmed_xml(self, xml_text: str) -> list[Evidence]:
|
| 106 |
+
"""Parse PubMed XML into Evidence objects."""
|
| 107 |
+
try:
|
| 108 |
+
data = xmltodict.parse(xml_text)
|
| 109 |
+
except Exception as e:
|
| 110 |
+
raise SearchError(f"Failed to parse PubMed XML: {e}") from e
|
| 111 |
+
|
| 112 |
+
articles = data.get("PubmedArticleSet", {}).get("PubmedArticle", [])
|
| 113 |
+
|
| 114 |
+
# Handle single article (xmltodict returns dict instead of list)
|
| 115 |
+
if isinstance(articles, dict):
|
| 116 |
+
articles = [articles]
|
| 117 |
+
|
| 118 |
+
evidence_list = []
|
| 119 |
+
for article in articles:
|
| 120 |
+
try:
|
| 121 |
+
evidence = self._article_to_evidence(article)
|
| 122 |
+
if evidence:
|
| 123 |
+
evidence_list.append(evidence)
|
| 124 |
+
except Exception:
|
| 125 |
+
continue # Skip malformed articles
|
| 126 |
+
|
| 127 |
+
return evidence_list
|
| 128 |
+
|
| 129 |
+
def _article_to_evidence(self, article: dict[str, Any]) -> Evidence | None:
|
| 130 |
+
"""Convert a single PubMed article to Evidence."""
|
| 131 |
+
medline = article.get("MedlineCitation", {})
|
| 132 |
+
article_data = medline.get("Article", {})
|
| 133 |
+
|
| 134 |
+
# Extract PMID
|
| 135 |
+
pmid = medline.get("PMID", {})
|
| 136 |
+
if isinstance(pmid, dict):
|
| 137 |
+
pmid = pmid.get("#text", "")
|
| 138 |
+
|
| 139 |
+
# Extract title
|
| 140 |
+
title = article_data.get("ArticleTitle", "")
|
| 141 |
+
if isinstance(title, dict):
|
| 142 |
+
title = title.get("#text", str(title))
|
| 143 |
+
|
| 144 |
+
# Extract abstract
|
| 145 |
+
abstract_data = article_data.get("Abstract", {}).get("AbstractText", "")
|
| 146 |
+
if isinstance(abstract_data, list):
|
| 147 |
+
abstract = " ".join(
|
| 148 |
+
item.get("#text", str(item)) if isinstance(item, dict) else str(item)
|
| 149 |
+
for item in abstract_data
|
| 150 |
+
)
|
| 151 |
+
elif isinstance(abstract_data, dict):
|
| 152 |
+
abstract = abstract_data.get("#text", str(abstract_data))
|
| 153 |
+
else:
|
| 154 |
+
abstract = str(abstract_data)
|
| 155 |
+
|
| 156 |
+
if not abstract or not title:
|
| 157 |
+
return None
|
| 158 |
+
|
| 159 |
+
# Extract date
|
| 160 |
+
pub_date = article_data.get("Journal", {}).get("JournalIssue", {}).get("PubDate", {})
|
| 161 |
+
year = pub_date.get("Year", "Unknown")
|
| 162 |
+
month = pub_date.get("Month", "01")
|
| 163 |
+
day = pub_date.get("Day", "01")
|
| 164 |
+
date_str = f"{year}-{month}-{day}" if year != "Unknown" else "Unknown"
|
| 165 |
+
|
| 166 |
+
# Extract authors
|
| 167 |
+
author_list = article_data.get("AuthorList", {}).get("Author", [])
|
| 168 |
+
if isinstance(author_list, dict):
|
| 169 |
+
author_list = [author_list]
|
| 170 |
+
authors = []
|
| 171 |
+
for author in author_list[:5]: # Limit to 5 authors
|
| 172 |
+
last = author.get("LastName", "")
|
| 173 |
+
first = author.get("ForeName", "")
|
| 174 |
+
if last:
|
| 175 |
+
authors.append(f"{last} {first}".strip())
|
| 176 |
+
|
| 177 |
+
return Evidence(
|
| 178 |
+
content=abstract[:2000], # Truncate long abstracts
|
| 179 |
+
citation=Citation(
|
| 180 |
+
source="pubmed",
|
| 181 |
+
title=title[:500],
|
| 182 |
+
url=f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/",
|
| 183 |
+
date=date_str,
|
| 184 |
+
authors=authors,
|
| 185 |
+
),
|
| 186 |
+
)
|
src/tools/search_handler.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Search handler - orchestrates multiple search tools."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
from typing import Literal, cast
|
| 5 |
+
|
| 6 |
+
import structlog
|
| 7 |
+
|
| 8 |
+
from src.tools.base import SearchTool
|
| 9 |
+
from src.utils.exceptions import SearchError
|
| 10 |
+
from src.utils.models import Evidence, SearchResult
|
| 11 |
+
|
| 12 |
+
logger = structlog.get_logger()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def flatten(nested: list[list[Evidence]]) -> list[Evidence]:
|
| 16 |
+
"""Flatten a list of lists into a single list."""
|
| 17 |
+
return [item for sublist in nested for item in sublist]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SearchHandler:
|
| 21 |
+
"""Orchestrates parallel searches across multiple tools."""
|
| 22 |
+
|
| 23 |
+
def __init__(self, tools: list[SearchTool], timeout: float = 30.0) -> None:
|
| 24 |
+
"""
|
| 25 |
+
Initialize the search handler.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
tools: List of search tools to use
|
| 29 |
+
timeout: Timeout for each search in seconds
|
| 30 |
+
"""
|
| 31 |
+
self.tools = tools
|
| 32 |
+
self.timeout = timeout
|
| 33 |
+
|
| 34 |
+
async def execute(self, query: str, max_results_per_tool: int = 10) -> SearchResult:
|
| 35 |
+
"""
|
| 36 |
+
Execute search across all tools in parallel.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
query: The search query
|
| 40 |
+
max_results_per_tool: Max results from each tool
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
SearchResult containing all evidence and metadata
|
| 44 |
+
"""
|
| 45 |
+
logger.info("Starting search", query=query, tools=[t.name for t in self.tools])
|
| 46 |
+
|
| 47 |
+
# Create tasks for parallel execution
|
| 48 |
+
tasks = [
|
| 49 |
+
self._search_with_timeout(tool, query, max_results_per_tool) for tool in self.tools
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
# Gather results (don't fail if one tool fails)
|
| 53 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 54 |
+
|
| 55 |
+
# Process results
|
| 56 |
+
all_evidence: list[Evidence] = []
|
| 57 |
+
sources_searched: list[Literal["pubmed", "web"]] = []
|
| 58 |
+
errors: list[str] = []
|
| 59 |
+
|
| 60 |
+
for tool, result in zip(self.tools, results, strict=True):
|
| 61 |
+
if isinstance(result, Exception):
|
| 62 |
+
errors.append(f"{tool.name}: {result!s}")
|
| 63 |
+
logger.warning("Search tool failed", tool=tool.name, error=str(result))
|
| 64 |
+
else:
|
| 65 |
+
# Cast result to list[Evidence] as we know it succeeded
|
| 66 |
+
success_result = cast(list[Evidence], result)
|
| 67 |
+
all_evidence.extend(success_result)
|
| 68 |
+
|
| 69 |
+
# Cast tool.name to the expected Literal
|
| 70 |
+
tool_name = cast(Literal["pubmed", "web"], tool.name)
|
| 71 |
+
sources_searched.append(tool_name)
|
| 72 |
+
logger.info("Search tool succeeded", tool=tool.name, count=len(success_result))
|
| 73 |
+
|
| 74 |
+
return SearchResult(
|
| 75 |
+
query=query,
|
| 76 |
+
evidence=all_evidence,
|
| 77 |
+
sources_searched=sources_searched,
|
| 78 |
+
total_found=len(all_evidence),
|
| 79 |
+
errors=errors,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
async def _search_with_timeout(
|
| 83 |
+
self,
|
| 84 |
+
tool: SearchTool,
|
| 85 |
+
query: str,
|
| 86 |
+
max_results: int,
|
| 87 |
+
) -> list[Evidence]:
|
| 88 |
+
"""Execute a single tool search with timeout."""
|
| 89 |
+
try:
|
| 90 |
+
return await asyncio.wait_for(
|
| 91 |
+
tool.search(query, max_results),
|
| 92 |
+
timeout=self.timeout,
|
| 93 |
+
)
|
| 94 |
+
except TimeoutError as e:
|
| 95 |
+
raise SearchError(f"{tool.name} search timed out after {self.timeout}s") from e
|
src/tools/websearch.py
CHANGED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Web search tool using DuckDuckGo."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
from duckduckgo_search import DDGS
|
| 7 |
+
|
| 8 |
+
from src.utils.exceptions import SearchError
|
| 9 |
+
from src.utils.models import Citation, Evidence
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class WebTool:
|
| 13 |
+
"""Search tool for general web search via DuckDuckGo."""
|
| 14 |
+
|
| 15 |
+
def __init__(self) -> None:
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
@property
|
| 19 |
+
def name(self) -> str:
|
| 20 |
+
return "web"
|
| 21 |
+
|
| 22 |
+
async def search(self, query: str, max_results: int = 10) -> list[Evidence]:
|
| 23 |
+
"""
|
| 24 |
+
Search DuckDuckGo and return evidence.
|
| 25 |
+
|
| 26 |
+
Note: duckduckgo-search is synchronous, so we run it in executor.
|
| 27 |
+
"""
|
| 28 |
+
loop = asyncio.get_event_loop()
|
| 29 |
+
try:
|
| 30 |
+
results = await loop.run_in_executor(
|
| 31 |
+
None,
|
| 32 |
+
lambda: self._sync_search(query, max_results),
|
| 33 |
+
)
|
| 34 |
+
return results
|
| 35 |
+
except Exception as e:
|
| 36 |
+
raise SearchError(f"Web search failed: {e}") from e
|
| 37 |
+
|
| 38 |
+
def _sync_search(self, query: str, max_results: int) -> list[Evidence]:
|
| 39 |
+
"""Synchronous search implementation."""
|
| 40 |
+
evidence_list = []
|
| 41 |
+
|
| 42 |
+
with DDGS() as ddgs:
|
| 43 |
+
results: list[dict[str, Any]] = list(ddgs.text(query, max_results=max_results))
|
| 44 |
+
|
| 45 |
+
for result in results:
|
| 46 |
+
evidence_list.append(
|
| 47 |
+
Evidence(
|
| 48 |
+
content=result.get("body", "")[:1000],
|
| 49 |
+
citation=Citation(
|
| 50 |
+
source="web",
|
| 51 |
+
title=result.get("title", "Unknown")[:500],
|
| 52 |
+
url=result.get("href", ""),
|
| 53 |
+
date="Unknown",
|
| 54 |
+
authors=[],
|
| 55 |
+
),
|
| 56 |
+
)
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
return evidence_list
|
src/utils/models.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data models for the Search feature."""
|
| 2 |
+
|
| 3 |
+
from typing import ClassVar, Literal
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel, Field
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Citation(BaseModel):
|
| 9 |
+
"""A citation to a source document."""
|
| 10 |
+
|
| 11 |
+
source: Literal["pubmed", "web"] = Field(description="Where this came from")
|
| 12 |
+
title: str = Field(min_length=1, max_length=500)
|
| 13 |
+
url: str = Field(description="URL to the source")
|
| 14 |
+
date: str = Field(description="Publication date (YYYY-MM-DD or 'Unknown')")
|
| 15 |
+
authors: list[str] = Field(default_factory=list)
|
| 16 |
+
|
| 17 |
+
MAX_AUTHORS_IN_CITATION: ClassVar[int] = 3
|
| 18 |
+
|
| 19 |
+
@property
|
| 20 |
+
def formatted(self) -> str:
|
| 21 |
+
"""Format as a citation string."""
|
| 22 |
+
author_str = ", ".join(self.authors[: self.MAX_AUTHORS_IN_CITATION])
|
| 23 |
+
if len(self.authors) > self.MAX_AUTHORS_IN_CITATION:
|
| 24 |
+
author_str += " et al."
|
| 25 |
+
return f"{author_str} ({self.date}). {self.title}. {self.source.upper()}"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Evidence(BaseModel):
|
| 29 |
+
"""A piece of evidence retrieved from search."""
|
| 30 |
+
|
| 31 |
+
content: str = Field(min_length=1, description="The actual text content")
|
| 32 |
+
citation: Citation
|
| 33 |
+
relevance: float = Field(default=0.0, ge=0.0, le=1.0, description="Relevance score 0-1")
|
| 34 |
+
|
| 35 |
+
model_config = {"frozen": True}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class SearchResult(BaseModel):
|
| 39 |
+
"""Result of a search operation."""
|
| 40 |
+
|
| 41 |
+
query: str
|
| 42 |
+
evidence: list[Evidence]
|
| 43 |
+
sources_searched: list[Literal["pubmed", "web"]]
|
| 44 |
+
total_found: int
|
| 45 |
+
errors: list[str] = Field(default_factory=list)
|
tests/unit/tools/test_pubmed.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for PubMed tool."""
|
| 2 |
+
|
| 3 |
+
from unittest.mock import AsyncMock, MagicMock
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from src.tools.pubmed import PubMedTool
|
| 8 |
+
|
| 9 |
+
# Sample PubMed XML response for mocking
|
| 10 |
+
SAMPLE_PUBMED_XML = """<?xml version="1.0" ?>
|
| 11 |
+
<PubmedArticleSet>
|
| 12 |
+
<PubmedArticle>
|
| 13 |
+
<MedlineCitation>
|
| 14 |
+
<PMID>12345678</PMID>
|
| 15 |
+
<Article>
|
| 16 |
+
<ArticleTitle>Metformin in Alzheimer's Disease: A Systematic Review</ArticleTitle>
|
| 17 |
+
<Abstract>
|
| 18 |
+
<AbstractText>Metformin shows neuroprotective properties...</AbstractText>
|
| 19 |
+
</Abstract>
|
| 20 |
+
<AuthorList>
|
| 21 |
+
<Author>
|
| 22 |
+
<LastName>Smith</LastName>
|
| 23 |
+
<ForeName>John</ForeName>
|
| 24 |
+
</Author>
|
| 25 |
+
</AuthorList>
|
| 26 |
+
<Journal>
|
| 27 |
+
<JournalIssue>
|
| 28 |
+
<PubDate>
|
| 29 |
+
<Year>2024</Year>
|
| 30 |
+
<Month>01</Month>
|
| 31 |
+
</PubDate>
|
| 32 |
+
</JournalIssue>
|
| 33 |
+
</Journal>
|
| 34 |
+
</Article>
|
| 35 |
+
</MedlineCitation>
|
| 36 |
+
</PubmedArticle>
|
| 37 |
+
</PubmedArticleSet>
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class TestPubMedTool:
|
| 42 |
+
"""Tests for PubMedTool."""
|
| 43 |
+
|
| 44 |
+
@pytest.mark.asyncio
|
| 45 |
+
async def test_search_returns_evidence(self, mocker):
|
| 46 |
+
"""PubMedTool should return Evidence objects from search."""
|
| 47 |
+
# Mock the HTTP responses
|
| 48 |
+
mock_search_response = MagicMock()
|
| 49 |
+
mock_search_response.json.return_value = {"esearchresult": {"idlist": ["12345678"]}}
|
| 50 |
+
mock_search_response.raise_for_status = MagicMock()
|
| 51 |
+
|
| 52 |
+
mock_fetch_response = MagicMock()
|
| 53 |
+
mock_fetch_response.text = SAMPLE_PUBMED_XML
|
| 54 |
+
mock_fetch_response.raise_for_status = MagicMock()
|
| 55 |
+
|
| 56 |
+
mock_client = AsyncMock()
|
| 57 |
+
mock_client.get = AsyncMock(side_effect=[mock_search_response, mock_fetch_response])
|
| 58 |
+
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
| 59 |
+
mock_client.__aexit__ = AsyncMock(return_value=None)
|
| 60 |
+
|
| 61 |
+
mocker.patch("httpx.AsyncClient", return_value=mock_client)
|
| 62 |
+
|
| 63 |
+
# Act
|
| 64 |
+
tool = PubMedTool()
|
| 65 |
+
results = await tool.search("metformin alzheimer")
|
| 66 |
+
|
| 67 |
+
# Assert
|
| 68 |
+
assert len(results) == 1
|
| 69 |
+
assert results[0].citation.source == "pubmed"
|
| 70 |
+
assert "Metformin" in results[0].citation.title
|
| 71 |
+
assert "12345678" in results[0].citation.url
|
| 72 |
+
|
| 73 |
+
@pytest.mark.asyncio
|
| 74 |
+
async def test_search_empty_results(self, mocker):
|
| 75 |
+
"""PubMedTool should return empty list when no results."""
|
| 76 |
+
mock_response = MagicMock()
|
| 77 |
+
mock_response.json.return_value = {"esearchresult": {"idlist": []}}
|
| 78 |
+
mock_response.raise_for_status = MagicMock()
|
| 79 |
+
|
| 80 |
+
mock_client = AsyncMock()
|
| 81 |
+
mock_client.get = AsyncMock(return_value=mock_response)
|
| 82 |
+
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
| 83 |
+
mock_client.__aexit__ = AsyncMock(return_value=None)
|
| 84 |
+
|
| 85 |
+
mocker.patch("httpx.AsyncClient", return_value=mock_client)
|
| 86 |
+
|
| 87 |
+
tool = PubMedTool()
|
| 88 |
+
results = await tool.search("xyznonexistentquery123")
|
| 89 |
+
|
| 90 |
+
assert results == []
|
| 91 |
+
|
| 92 |
+
def test_parse_pubmed_xml(self):
|
| 93 |
+
"""PubMedTool should correctly parse XML."""
|
| 94 |
+
tool = PubMedTool()
|
| 95 |
+
results = tool._parse_pubmed_xml(SAMPLE_PUBMED_XML)
|
| 96 |
+
|
| 97 |
+
assert len(results) == 1
|
| 98 |
+
assert results[0].citation.source == "pubmed"
|
| 99 |
+
assert "Smith John" in results[0].citation.authors
|
tests/unit/tools/test_search_handler.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for SearchHandler."""
|
| 2 |
+
|
| 3 |
+
from unittest.mock import AsyncMock
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from src.tools.search_handler import SearchHandler
|
| 8 |
+
from src.utils.exceptions import SearchError
|
| 9 |
+
from src.utils.models import Citation, Evidence
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TestSearchHandler:
|
| 13 |
+
"""Tests for SearchHandler."""
|
| 14 |
+
|
| 15 |
+
@pytest.mark.asyncio
|
| 16 |
+
async def test_execute_aggregates_results(self):
|
| 17 |
+
"""SearchHandler should aggregate results from all tools."""
|
| 18 |
+
# Create mock tools
|
| 19 |
+
mock_tool_1 = AsyncMock()
|
| 20 |
+
mock_tool_1.name = "pubmed"
|
| 21 |
+
mock_tool_1.search = AsyncMock(
|
| 22 |
+
return_value=[
|
| 23 |
+
Evidence(
|
| 24 |
+
content="Result 1",
|
| 25 |
+
citation=Citation(source="pubmed", title="T1", url="u1", date="2024"),
|
| 26 |
+
)
|
| 27 |
+
]
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
mock_tool_2 = AsyncMock()
|
| 31 |
+
mock_tool_2.name = "web"
|
| 32 |
+
mock_tool_2.search = AsyncMock(
|
| 33 |
+
return_value=[
|
| 34 |
+
Evidence(
|
| 35 |
+
content="Result 2",
|
| 36 |
+
citation=Citation(source="web", title="T2", url="u2", date="2024"),
|
| 37 |
+
)
|
| 38 |
+
]
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
handler = SearchHandler(tools=[mock_tool_1, mock_tool_2])
|
| 42 |
+
result = await handler.execute("test query")
|
| 43 |
+
|
| 44 |
+
expected_total = 2
|
| 45 |
+
assert result.total_found == expected_total
|
| 46 |
+
assert "pubmed" in result.sources_searched
|
| 47 |
+
assert "web" in result.sources_searched
|
| 48 |
+
assert len(result.errors) == 0
|
| 49 |
+
|
| 50 |
+
@pytest.mark.asyncio
|
| 51 |
+
async def test_execute_handles_tool_failure(self):
|
| 52 |
+
"""SearchHandler should continue if one tool fails."""
|
| 53 |
+
mock_tool_ok = AsyncMock()
|
| 54 |
+
mock_tool_ok.name = "pubmed"
|
| 55 |
+
mock_tool_ok.search = AsyncMock(
|
| 56 |
+
return_value=[
|
| 57 |
+
Evidence(
|
| 58 |
+
content="Good result",
|
| 59 |
+
citation=Citation(source="pubmed", title="T", url="u", date="2024"),
|
| 60 |
+
)
|
| 61 |
+
]
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
mock_tool_fail = AsyncMock()
|
| 65 |
+
mock_tool_fail.name = "web"
|
| 66 |
+
mock_tool_fail.search = AsyncMock(side_effect=SearchError("API down"))
|
| 67 |
+
|
| 68 |
+
handler = SearchHandler(tools=[mock_tool_ok, mock_tool_fail])
|
| 69 |
+
result = await handler.execute("test")
|
| 70 |
+
|
| 71 |
+
assert result.total_found == 1
|
| 72 |
+
assert "pubmed" in result.sources_searched
|
| 73 |
+
assert len(result.errors) == 1
|
| 74 |
+
assert "web" in result.errors[0]
|
tests/unit/tools/test_websearch.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for WebTool."""
|
| 2 |
+
|
| 3 |
+
from unittest.mock import MagicMock
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from src.tools.websearch import WebTool
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TestWebTool:
|
| 11 |
+
"""Tests for WebTool."""
|
| 12 |
+
|
| 13 |
+
@pytest.mark.asyncio
|
| 14 |
+
async def test_search_returns_evidence(self, mocker):
|
| 15 |
+
"""WebTool should return Evidence objects from search."""
|
| 16 |
+
mock_results = [
|
| 17 |
+
{
|
| 18 |
+
"title": "Drug Repurposing Article",
|
| 19 |
+
"href": "https://example.com/article",
|
| 20 |
+
"body": "Some content about drug repurposing...",
|
| 21 |
+
}
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
mock_ddgs = MagicMock()
|
| 25 |
+
mock_ddgs.__enter__ = MagicMock(return_value=mock_ddgs)
|
| 26 |
+
mock_ddgs.__exit__ = MagicMock(return_value=None)
|
| 27 |
+
mock_ddgs.text = MagicMock(return_value=mock_results)
|
| 28 |
+
|
| 29 |
+
mocker.patch("src.tools.websearch.DDGS", return_value=mock_ddgs)
|
| 30 |
+
|
| 31 |
+
tool = WebTool()
|
| 32 |
+
results = await tool.search("drug repurposing")
|
| 33 |
+
|
| 34 |
+
assert len(results) == 1
|
| 35 |
+
assert results[0].citation.source == "web"
|
| 36 |
+
assert "Drug Repurposing" in results[0].citation.title
|