VibecoderMcSwaggins commited on
Commit
499170b
·
1 Parent(s): 32fcf60

feat(phase2): implement search slice (PubMed, Web, Orchestrator) (#4)

Browse files
src/tools/__init__.py CHANGED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """Search tools package."""
2
+
3
+ from src.tools.base import SearchTool
4
+ from src.tools.pubmed import PubMedTool
5
+ from src.tools.search_handler import SearchHandler
6
+ from src.tools.websearch import WebTool
7
+
8
+ # Re-export
9
+ __all__ = ["PubMedTool", "SearchHandler", "SearchTool", "WebTool"]
src/tools/base.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base classes and protocols for search tools."""
2
+
3
+ from typing import Protocol
4
+
5
+ from src.utils.models import Evidence
6
+
7
+
8
+ class SearchTool(Protocol):
9
+ """Protocol defining the interface for all search tools."""
10
+
11
+ @property
12
+ def name(self) -> str:
13
+ """Human-readable name of this tool."""
14
+ ...
15
+
16
+ async def search(self, query: str, max_results: int = 10) -> list[Evidence]:
17
+ """
18
+ Execute a search and return evidence.
19
+
20
+ Args:
21
+ query: The search query string
22
+ max_results: Maximum number of results to return
23
+
24
+ Returns:
25
+ List of Evidence objects
26
+
27
+ Raises:
28
+ SearchError: If the search fails
29
+ RateLimitError: If we hit rate limits
30
+ """
31
+ ...
src/tools/pubmed.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PubMed search tool using NCBI E-utilities."""
2
+
3
+ import asyncio
4
+ from typing import Any
5
+
6
+ import httpx
7
+ import xmltodict
8
+ from tenacity import retry, stop_after_attempt, wait_exponential
9
+
10
+ from src.utils.config import settings
11
+ from src.utils.exceptions import RateLimitError, SearchError
12
+ from src.utils.models import Citation, Evidence
13
+
14
+
15
+ class PubMedTool:
16
+ """Search tool for PubMed/NCBI."""
17
+
18
+ BASE_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
19
+ RATE_LIMIT_DELAY = 0.34 # ~3 requests/sec without API key
20
+ HTTP_TOO_MANY_REQUESTS = 429
21
+
22
+ def __init__(self, api_key: str | None = None) -> None:
23
+ self.api_key = api_key or getattr(settings, "ncbi_api_key", None)
24
+ self._last_request_time = 0.0
25
+
26
+ @property
27
+ def name(self) -> str:
28
+ return "pubmed"
29
+
30
+ async def _rate_limit(self) -> None:
31
+ """Enforce NCBI rate limiting."""
32
+ now = asyncio.get_event_loop().time()
33
+ elapsed = now - self._last_request_time
34
+ if elapsed < self.RATE_LIMIT_DELAY:
35
+ await asyncio.sleep(self.RATE_LIMIT_DELAY - elapsed)
36
+ self._last_request_time = asyncio.get_event_loop().time()
37
+
38
+ def _build_params(self, **kwargs: Any) -> dict[str, Any]:
39
+ """Build request params with optional API key."""
40
+ params = {**kwargs, "retmode": "json"}
41
+ if self.api_key:
42
+ params["api_key"] = self.api_key
43
+ return params
44
+
45
+ @retry( # type: ignore[misc]
46
+ stop=stop_after_attempt(3),
47
+ wait=wait_exponential(multiplier=1, min=1, max=10),
48
+ reraise=True,
49
+ )
50
+ async def search(self, query: str, max_results: int = 10) -> list[Evidence]:
51
+ """
52
+ Search PubMed and return evidence.
53
+
54
+ 1. ESearch: Get PMIDs matching query
55
+ 2. EFetch: Get abstracts for those PMIDs
56
+ 3. Parse and return Evidence objects
57
+ """
58
+ await self._rate_limit()
59
+
60
+ async with httpx.AsyncClient(timeout=30.0) as client:
61
+ # Step 1: Search for PMIDs
62
+ search_params = self._build_params(
63
+ db="pubmed",
64
+ term=query,
65
+ retmax=max_results,
66
+ sort="relevance",
67
+ )
68
+
69
+ try:
70
+ search_resp = await client.get(
71
+ f"{self.BASE_URL}/esearch.fcgi",
72
+ params=search_params,
73
+ )
74
+ search_resp.raise_for_status()
75
+ except httpx.HTTPStatusError as e:
76
+ if e.response.status_code == self.HTTP_TOO_MANY_REQUESTS:
77
+ raise RateLimitError("PubMed rate limit exceeded") from e
78
+ raise SearchError(f"PubMed search failed: {e}") from e
79
+
80
+ search_data = search_resp.json()
81
+ pmids = search_data.get("esearchresult", {}).get("idlist", [])
82
+
83
+ if not pmids:
84
+ return []
85
+
86
+ # Step 2: Fetch abstracts
87
+ await self._rate_limit()
88
+ fetch_params = self._build_params(
89
+ db="pubmed",
90
+ id=",".join(pmids),
91
+ rettype="abstract",
92
+ )
93
+ # Use XML for fetch (more reliable parsing)
94
+ fetch_params["retmode"] = "xml"
95
+
96
+ fetch_resp = await client.get(
97
+ f"{self.BASE_URL}/efetch.fcgi",
98
+ params=fetch_params,
99
+ )
100
+ fetch_resp.raise_for_status()
101
+
102
+ # Step 3: Parse XML to Evidence
103
+ return self._parse_pubmed_xml(fetch_resp.text)
104
+
105
+ def _parse_pubmed_xml(self, xml_text: str) -> list[Evidence]:
106
+ """Parse PubMed XML into Evidence objects."""
107
+ try:
108
+ data = xmltodict.parse(xml_text)
109
+ except Exception as e:
110
+ raise SearchError(f"Failed to parse PubMed XML: {e}") from e
111
+
112
+ articles = data.get("PubmedArticleSet", {}).get("PubmedArticle", [])
113
+
114
+ # Handle single article (xmltodict returns dict instead of list)
115
+ if isinstance(articles, dict):
116
+ articles = [articles]
117
+
118
+ evidence_list = []
119
+ for article in articles:
120
+ try:
121
+ evidence = self._article_to_evidence(article)
122
+ if evidence:
123
+ evidence_list.append(evidence)
124
+ except Exception:
125
+ continue # Skip malformed articles
126
+
127
+ return evidence_list
128
+
129
+ def _article_to_evidence(self, article: dict[str, Any]) -> Evidence | None:
130
+ """Convert a single PubMed article to Evidence."""
131
+ medline = article.get("MedlineCitation", {})
132
+ article_data = medline.get("Article", {})
133
+
134
+ # Extract PMID
135
+ pmid = medline.get("PMID", {})
136
+ if isinstance(pmid, dict):
137
+ pmid = pmid.get("#text", "")
138
+
139
+ # Extract title
140
+ title = article_data.get("ArticleTitle", "")
141
+ if isinstance(title, dict):
142
+ title = title.get("#text", str(title))
143
+
144
+ # Extract abstract
145
+ abstract_data = article_data.get("Abstract", {}).get("AbstractText", "")
146
+ if isinstance(abstract_data, list):
147
+ abstract = " ".join(
148
+ item.get("#text", str(item)) if isinstance(item, dict) else str(item)
149
+ for item in abstract_data
150
+ )
151
+ elif isinstance(abstract_data, dict):
152
+ abstract = abstract_data.get("#text", str(abstract_data))
153
+ else:
154
+ abstract = str(abstract_data)
155
+
156
+ if not abstract or not title:
157
+ return None
158
+
159
+ # Extract date
160
+ pub_date = article_data.get("Journal", {}).get("JournalIssue", {}).get("PubDate", {})
161
+ year = pub_date.get("Year", "Unknown")
162
+ month = pub_date.get("Month", "01")
163
+ day = pub_date.get("Day", "01")
164
+ date_str = f"{year}-{month}-{day}" if year != "Unknown" else "Unknown"
165
+
166
+ # Extract authors
167
+ author_list = article_data.get("AuthorList", {}).get("Author", [])
168
+ if isinstance(author_list, dict):
169
+ author_list = [author_list]
170
+ authors = []
171
+ for author in author_list[:5]: # Limit to 5 authors
172
+ last = author.get("LastName", "")
173
+ first = author.get("ForeName", "")
174
+ if last:
175
+ authors.append(f"{last} {first}".strip())
176
+
177
+ return Evidence(
178
+ content=abstract[:2000], # Truncate long abstracts
179
+ citation=Citation(
180
+ source="pubmed",
181
+ title=title[:500],
182
+ url=f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/",
183
+ date=date_str,
184
+ authors=authors,
185
+ ),
186
+ )
src/tools/search_handler.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Search handler - orchestrates multiple search tools."""
2
+
3
+ import asyncio
4
+ from typing import Literal, cast
5
+
6
+ import structlog
7
+
8
+ from src.tools.base import SearchTool
9
+ from src.utils.exceptions import SearchError
10
+ from src.utils.models import Evidence, SearchResult
11
+
12
+ logger = structlog.get_logger()
13
+
14
+
15
+ def flatten(nested: list[list[Evidence]]) -> list[Evidence]:
16
+ """Flatten a list of lists into a single list."""
17
+ return [item for sublist in nested for item in sublist]
18
+
19
+
20
+ class SearchHandler:
21
+ """Orchestrates parallel searches across multiple tools."""
22
+
23
+ def __init__(self, tools: list[SearchTool], timeout: float = 30.0) -> None:
24
+ """
25
+ Initialize the search handler.
26
+
27
+ Args:
28
+ tools: List of search tools to use
29
+ timeout: Timeout for each search in seconds
30
+ """
31
+ self.tools = tools
32
+ self.timeout = timeout
33
+
34
+ async def execute(self, query: str, max_results_per_tool: int = 10) -> SearchResult:
35
+ """
36
+ Execute search across all tools in parallel.
37
+
38
+ Args:
39
+ query: The search query
40
+ max_results_per_tool: Max results from each tool
41
+
42
+ Returns:
43
+ SearchResult containing all evidence and metadata
44
+ """
45
+ logger.info("Starting search", query=query, tools=[t.name for t in self.tools])
46
+
47
+ # Create tasks for parallel execution
48
+ tasks = [
49
+ self._search_with_timeout(tool, query, max_results_per_tool) for tool in self.tools
50
+ ]
51
+
52
+ # Gather results (don't fail if one tool fails)
53
+ results = await asyncio.gather(*tasks, return_exceptions=True)
54
+
55
+ # Process results
56
+ all_evidence: list[Evidence] = []
57
+ sources_searched: list[Literal["pubmed", "web"]] = []
58
+ errors: list[str] = []
59
+
60
+ for tool, result in zip(self.tools, results, strict=True):
61
+ if isinstance(result, Exception):
62
+ errors.append(f"{tool.name}: {result!s}")
63
+ logger.warning("Search tool failed", tool=tool.name, error=str(result))
64
+ else:
65
+ # Cast result to list[Evidence] as we know it succeeded
66
+ success_result = cast(list[Evidence], result)
67
+ all_evidence.extend(success_result)
68
+
69
+ # Cast tool.name to the expected Literal
70
+ tool_name = cast(Literal["pubmed", "web"], tool.name)
71
+ sources_searched.append(tool_name)
72
+ logger.info("Search tool succeeded", tool=tool.name, count=len(success_result))
73
+
74
+ return SearchResult(
75
+ query=query,
76
+ evidence=all_evidence,
77
+ sources_searched=sources_searched,
78
+ total_found=len(all_evidence),
79
+ errors=errors,
80
+ )
81
+
82
+ async def _search_with_timeout(
83
+ self,
84
+ tool: SearchTool,
85
+ query: str,
86
+ max_results: int,
87
+ ) -> list[Evidence]:
88
+ """Execute a single tool search with timeout."""
89
+ try:
90
+ return await asyncio.wait_for(
91
+ tool.search(query, max_results),
92
+ timeout=self.timeout,
93
+ )
94
+ except TimeoutError as e:
95
+ raise SearchError(f"{tool.name} search timed out after {self.timeout}s") from e
src/tools/websearch.py CHANGED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Web search tool using DuckDuckGo."""
2
+
3
+ import asyncio
4
+ from typing import Any
5
+
6
+ from duckduckgo_search import DDGS
7
+
8
+ from src.utils.exceptions import SearchError
9
+ from src.utils.models import Citation, Evidence
10
+
11
+
12
+ class WebTool:
13
+ """Search tool for general web search via DuckDuckGo."""
14
+
15
+ def __init__(self) -> None:
16
+ pass
17
+
18
+ @property
19
+ def name(self) -> str:
20
+ return "web"
21
+
22
+ async def search(self, query: str, max_results: int = 10) -> list[Evidence]:
23
+ """
24
+ Search DuckDuckGo and return evidence.
25
+
26
+ Note: duckduckgo-search is synchronous, so we run it in executor.
27
+ """
28
+ loop = asyncio.get_event_loop()
29
+ try:
30
+ results = await loop.run_in_executor(
31
+ None,
32
+ lambda: self._sync_search(query, max_results),
33
+ )
34
+ return results
35
+ except Exception as e:
36
+ raise SearchError(f"Web search failed: {e}") from e
37
+
38
+ def _sync_search(self, query: str, max_results: int) -> list[Evidence]:
39
+ """Synchronous search implementation."""
40
+ evidence_list = []
41
+
42
+ with DDGS() as ddgs:
43
+ results: list[dict[str, Any]] = list(ddgs.text(query, max_results=max_results))
44
+
45
+ for result in results:
46
+ evidence_list.append(
47
+ Evidence(
48
+ content=result.get("body", "")[:1000],
49
+ citation=Citation(
50
+ source="web",
51
+ title=result.get("title", "Unknown")[:500],
52
+ url=result.get("href", ""),
53
+ date="Unknown",
54
+ authors=[],
55
+ ),
56
+ )
57
+ )
58
+
59
+ return evidence_list
src/utils/models.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data models for the Search feature."""
2
+
3
+ from typing import ClassVar, Literal
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class Citation(BaseModel):
9
+ """A citation to a source document."""
10
+
11
+ source: Literal["pubmed", "web"] = Field(description="Where this came from")
12
+ title: str = Field(min_length=1, max_length=500)
13
+ url: str = Field(description="URL to the source")
14
+ date: str = Field(description="Publication date (YYYY-MM-DD or 'Unknown')")
15
+ authors: list[str] = Field(default_factory=list)
16
+
17
+ MAX_AUTHORS_IN_CITATION: ClassVar[int] = 3
18
+
19
+ @property
20
+ def formatted(self) -> str:
21
+ """Format as a citation string."""
22
+ author_str = ", ".join(self.authors[: self.MAX_AUTHORS_IN_CITATION])
23
+ if len(self.authors) > self.MAX_AUTHORS_IN_CITATION:
24
+ author_str += " et al."
25
+ return f"{author_str} ({self.date}). {self.title}. {self.source.upper()}"
26
+
27
+
28
+ class Evidence(BaseModel):
29
+ """A piece of evidence retrieved from search."""
30
+
31
+ content: str = Field(min_length=1, description="The actual text content")
32
+ citation: Citation
33
+ relevance: float = Field(default=0.0, ge=0.0, le=1.0, description="Relevance score 0-1")
34
+
35
+ model_config = {"frozen": True}
36
+
37
+
38
+ class SearchResult(BaseModel):
39
+ """Result of a search operation."""
40
+
41
+ query: str
42
+ evidence: list[Evidence]
43
+ sources_searched: list[Literal["pubmed", "web"]]
44
+ total_found: int
45
+ errors: list[str] = Field(default_factory=list)
tests/unit/tools/test_pubmed.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for PubMed tool."""
2
+
3
+ from unittest.mock import AsyncMock, MagicMock
4
+
5
+ import pytest
6
+
7
+ from src.tools.pubmed import PubMedTool
8
+
9
+ # Sample PubMed XML response for mocking
10
+ SAMPLE_PUBMED_XML = """<?xml version="1.0" ?>
11
+ <PubmedArticleSet>
12
+ <PubmedArticle>
13
+ <MedlineCitation>
14
+ <PMID>12345678</PMID>
15
+ <Article>
16
+ <ArticleTitle>Metformin in Alzheimer's Disease: A Systematic Review</ArticleTitle>
17
+ <Abstract>
18
+ <AbstractText>Metformin shows neuroprotective properties...</AbstractText>
19
+ </Abstract>
20
+ <AuthorList>
21
+ <Author>
22
+ <LastName>Smith</LastName>
23
+ <ForeName>John</ForeName>
24
+ </Author>
25
+ </AuthorList>
26
+ <Journal>
27
+ <JournalIssue>
28
+ <PubDate>
29
+ <Year>2024</Year>
30
+ <Month>01</Month>
31
+ </PubDate>
32
+ </JournalIssue>
33
+ </Journal>
34
+ </Article>
35
+ </MedlineCitation>
36
+ </PubmedArticle>
37
+ </PubmedArticleSet>
38
+ """
39
+
40
+
41
+ class TestPubMedTool:
42
+ """Tests for PubMedTool."""
43
+
44
+ @pytest.mark.asyncio
45
+ async def test_search_returns_evidence(self, mocker):
46
+ """PubMedTool should return Evidence objects from search."""
47
+ # Mock the HTTP responses
48
+ mock_search_response = MagicMock()
49
+ mock_search_response.json.return_value = {"esearchresult": {"idlist": ["12345678"]}}
50
+ mock_search_response.raise_for_status = MagicMock()
51
+
52
+ mock_fetch_response = MagicMock()
53
+ mock_fetch_response.text = SAMPLE_PUBMED_XML
54
+ mock_fetch_response.raise_for_status = MagicMock()
55
+
56
+ mock_client = AsyncMock()
57
+ mock_client.get = AsyncMock(side_effect=[mock_search_response, mock_fetch_response])
58
+ mock_client.__aenter__ = AsyncMock(return_value=mock_client)
59
+ mock_client.__aexit__ = AsyncMock(return_value=None)
60
+
61
+ mocker.patch("httpx.AsyncClient", return_value=mock_client)
62
+
63
+ # Act
64
+ tool = PubMedTool()
65
+ results = await tool.search("metformin alzheimer")
66
+
67
+ # Assert
68
+ assert len(results) == 1
69
+ assert results[0].citation.source == "pubmed"
70
+ assert "Metformin" in results[0].citation.title
71
+ assert "12345678" in results[0].citation.url
72
+
73
+ @pytest.mark.asyncio
74
+ async def test_search_empty_results(self, mocker):
75
+ """PubMedTool should return empty list when no results."""
76
+ mock_response = MagicMock()
77
+ mock_response.json.return_value = {"esearchresult": {"idlist": []}}
78
+ mock_response.raise_for_status = MagicMock()
79
+
80
+ mock_client = AsyncMock()
81
+ mock_client.get = AsyncMock(return_value=mock_response)
82
+ mock_client.__aenter__ = AsyncMock(return_value=mock_client)
83
+ mock_client.__aexit__ = AsyncMock(return_value=None)
84
+
85
+ mocker.patch("httpx.AsyncClient", return_value=mock_client)
86
+
87
+ tool = PubMedTool()
88
+ results = await tool.search("xyznonexistentquery123")
89
+
90
+ assert results == []
91
+
92
+ def test_parse_pubmed_xml(self):
93
+ """PubMedTool should correctly parse XML."""
94
+ tool = PubMedTool()
95
+ results = tool._parse_pubmed_xml(SAMPLE_PUBMED_XML)
96
+
97
+ assert len(results) == 1
98
+ assert results[0].citation.source == "pubmed"
99
+ assert "Smith John" in results[0].citation.authors
tests/unit/tools/test_search_handler.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for SearchHandler."""
2
+
3
+ from unittest.mock import AsyncMock
4
+
5
+ import pytest
6
+
7
+ from src.tools.search_handler import SearchHandler
8
+ from src.utils.exceptions import SearchError
9
+ from src.utils.models import Citation, Evidence
10
+
11
+
12
+ class TestSearchHandler:
13
+ """Tests for SearchHandler."""
14
+
15
+ @pytest.mark.asyncio
16
+ async def test_execute_aggregates_results(self):
17
+ """SearchHandler should aggregate results from all tools."""
18
+ # Create mock tools
19
+ mock_tool_1 = AsyncMock()
20
+ mock_tool_1.name = "pubmed"
21
+ mock_tool_1.search = AsyncMock(
22
+ return_value=[
23
+ Evidence(
24
+ content="Result 1",
25
+ citation=Citation(source="pubmed", title="T1", url="u1", date="2024"),
26
+ )
27
+ ]
28
+ )
29
+
30
+ mock_tool_2 = AsyncMock()
31
+ mock_tool_2.name = "web"
32
+ mock_tool_2.search = AsyncMock(
33
+ return_value=[
34
+ Evidence(
35
+ content="Result 2",
36
+ citation=Citation(source="web", title="T2", url="u2", date="2024"),
37
+ )
38
+ ]
39
+ )
40
+
41
+ handler = SearchHandler(tools=[mock_tool_1, mock_tool_2])
42
+ result = await handler.execute("test query")
43
+
44
+ expected_total = 2
45
+ assert result.total_found == expected_total
46
+ assert "pubmed" in result.sources_searched
47
+ assert "web" in result.sources_searched
48
+ assert len(result.errors) == 0
49
+
50
+ @pytest.mark.asyncio
51
+ async def test_execute_handles_tool_failure(self):
52
+ """SearchHandler should continue if one tool fails."""
53
+ mock_tool_ok = AsyncMock()
54
+ mock_tool_ok.name = "pubmed"
55
+ mock_tool_ok.search = AsyncMock(
56
+ return_value=[
57
+ Evidence(
58
+ content="Good result",
59
+ citation=Citation(source="pubmed", title="T", url="u", date="2024"),
60
+ )
61
+ ]
62
+ )
63
+
64
+ mock_tool_fail = AsyncMock()
65
+ mock_tool_fail.name = "web"
66
+ mock_tool_fail.search = AsyncMock(side_effect=SearchError("API down"))
67
+
68
+ handler = SearchHandler(tools=[mock_tool_ok, mock_tool_fail])
69
+ result = await handler.execute("test")
70
+
71
+ assert result.total_found == 1
72
+ assert "pubmed" in result.sources_searched
73
+ assert len(result.errors) == 1
74
+ assert "web" in result.errors[0]
tests/unit/tools/test_websearch.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for WebTool."""
2
+
3
+ from unittest.mock import MagicMock
4
+
5
+ import pytest
6
+
7
+ from src.tools.websearch import WebTool
8
+
9
+
10
+ class TestWebTool:
11
+ """Tests for WebTool."""
12
+
13
+ @pytest.mark.asyncio
14
+ async def test_search_returns_evidence(self, mocker):
15
+ """WebTool should return Evidence objects from search."""
16
+ mock_results = [
17
+ {
18
+ "title": "Drug Repurposing Article",
19
+ "href": "https://example.com/article",
20
+ "body": "Some content about drug repurposing...",
21
+ }
22
+ ]
23
+
24
+ mock_ddgs = MagicMock()
25
+ mock_ddgs.__enter__ = MagicMock(return_value=mock_ddgs)
26
+ mock_ddgs.__exit__ = MagicMock(return_value=None)
27
+ mock_ddgs.text = MagicMock(return_value=mock_results)
28
+
29
+ mocker.patch("src.tools.websearch.DDGS", return_value=mock_ddgs)
30
+
31
+ tool = WebTool()
32
+ results = await tool.search("drug repurposing")
33
+
34
+ assert len(results) == 1
35
+ assert results[0].citation.source == "web"
36
+ assert "Drug Repurposing" in results[0].citation.title