File size: 7,204 Bytes
3bacbf8
 
 
 
 
 
 
 
 
 
506a9c0
3bacbf8
 
20c7bad
3bacbf8
 
 
 
20c7bad
3bacbf8
 
cd004e1
20c7bad
 
 
 
3bacbf8
 
20c7bad
 
 
3bacbf8
 
 
 
 
 
 
 
 
 
 
59afc84
 
3bacbf8
 
 
59afc84
 
3bacbf8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59afc84
 
 
 
 
 
 
 
 
 
 
3bacbf8
 
59afc84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bacbf8
59afc84
 
 
3bacbf8
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""Embedding service for semantic search.

IMPORTANT: All public methods are async to avoid blocking the event loop.
The sentence-transformers model is CPU-bound, so we use run_in_executor().
"""

import asyncio
from typing import Any

import chromadb
import structlog
from sentence_transformers import SentenceTransformer

from src.utils.config import settings
from src.utils.models import Evidence


class EmbeddingService:
    """Handles text embedding and vector storage using local sentence-transformers.

    All embedding operations run in a thread pool to avoid blocking
    the async event loop.

    Note:
        Uses local sentence-transformers models (no API key required).
        Model is configured via settings.local_embedding_model.
    """

    def __init__(self, model_name: str | None = None):
        self._model_name = model_name or settings.local_embedding_model
        self._model = SentenceTransformer(self._model_name)
        self._client = chromadb.Client()  # In-memory for hackathon
        self._collection = self._client.create_collection(
            name="evidence", metadata={"hnsw:space": "cosine"}
        )

    # ─────────────────────────────────────────────────────────────────
    # Sync internal methods (run in thread pool)
    # ─────────────────────────────────────────────────────────────────

    def _sync_embed(self, text: str) -> list[float]:
        """Synchronous embedding - DO NOT call directly from async code."""
        result: list[float] = self._model.encode(text).tolist()
        return result

    def _sync_batch_embed(self, texts: list[str]) -> list[list[float]]:
        """Batch embedding for efficiency - DO NOT call directly from async code."""
        embeddings = self._model.encode(texts)
        return [e.tolist() for e in embeddings]

    # ─────────────────────────────────────────────────────────────────
    # Async public methods (safe for event loop)
    # ─────────────────────────────────────────────────────────────────

    async def embed(self, text: str) -> list[float]:
        """Embed a single text (async-safe).

        Uses run_in_executor to avoid blocking the event loop.
        """
        loop = asyncio.get_running_loop()
        return await loop.run_in_executor(None, self._sync_embed, text)

    async def embed_batch(self, texts: list[str]) -> list[list[float]]:
        """Batch embed multiple texts (async-safe, more efficient)."""
        loop = asyncio.get_running_loop()
        return await loop.run_in_executor(None, self._sync_batch_embed, texts)

    async def add_evidence(self, evidence_id: str, content: str, metadata: dict[str, Any]) -> None:
        """Add evidence to vector store (async-safe)."""
        embedding = await self.embed(content)
        # ChromaDB operations are fast, but wrap for consistency
        loop = asyncio.get_running_loop()
        await loop.run_in_executor(
            None,
            lambda: self._collection.add(
                ids=[evidence_id],
                embeddings=[embedding],  # type: ignore[arg-type]
                metadatas=[metadata],
                documents=[content],
            ),
        )

    async def search_similar(self, query: str, n_results: int = 5) -> list[dict[str, Any]]:
        """Find semantically similar evidence (async-safe)."""
        query_embedding = await self.embed(query)

        loop = asyncio.get_running_loop()
        results = await loop.run_in_executor(
            None,
            lambda: self._collection.query(
                query_embeddings=[query_embedding],  # type: ignore[arg-type]
                n_results=n_results,
            ),
        )

        # Handle empty results gracefully
        ids = results.get("ids")
        docs = results.get("documents")
        metas = results.get("metadatas")
        dists = results.get("distances")

        if not ids or not ids[0] or not docs or not metas or not dists:
            return []

        return [
            {"id": id, "content": doc, "metadata": meta, "distance": dist}
            for id, doc, meta, dist in zip(
                ids[0],
                docs[0],
                metas[0],
                dists[0],
                strict=False,
            )
        ]

    async def deduplicate(
        self, new_evidence: list[Evidence], threshold: float = 0.9
    ) -> list[Evidence]:
        """Remove semantically duplicate evidence (async-safe).

        Args:
            new_evidence: List of evidence items to deduplicate
            threshold: Similarity threshold (0.9 = 90% similar is duplicate).
                      ChromaDB cosine distance: 0=identical, 2=opposite.
                      We consider duplicate if distance < (1 - threshold).

        Returns:
            List of unique evidence items (not already in vector store).
        """
        unique = []
        for evidence in new_evidence:
            try:
                similar = await self.search_similar(evidence.content, n_results=1)
                # ChromaDB cosine distance: 0 = identical, 2 = opposite
                # threshold=0.9 means distance < 0.1 is considered duplicate
                is_duplicate = similar and similar[0]["distance"] < (1 - threshold)

                if not is_duplicate:
                    unique.append(evidence)
                    # Store FULL citation metadata for reconstruction later
                    await self.add_evidence(
                        evidence_id=evidence.citation.url,
                        content=evidence.content,
                        metadata={
                            "source": evidence.citation.source,
                            "title": evidence.citation.title,
                            "date": evidence.citation.date,
                            "authors": ",".join(evidence.citation.authors or []),
                        },
                    )
            except Exception as e:
                # Log but don't fail entire deduplication for one bad item
                structlog.get_logger().warning(
                    "Failed to process evidence in deduplicate",
                    url=evidence.citation.url,
                    error=str(e),
                )
                # Still add to unique list - better to have duplicates than lose data
                unique.append(evidence)

        return unique


_embedding_service: EmbeddingService | None = None


def get_embedding_service() -> EmbeddingService:
    """Get singleton instance of EmbeddingService."""
    global _embedding_service  # noqa: PLW0603
    if _embedding_service is None:
        _embedding_service = EmbeddingService()
    return _embedding_service