File size: 13,257 Bytes
cfc8e23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
import logging
import pickle
import re
from pathlib import Path

import numpy as np
import torch
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
from transformers import AutoModel, AutoTokenizer

from app.models.document import Document

logger = logging.getLogger(__name__)


class BM25Searcher:
    """BM25 keyword-based search index"""

    def __init__(self):
        logger.debug("Initializing BM25Searcher")
        self.corpus = []
        self.documents = []
        self.bm25 = None
        logger.debug("BM25Searcher initialized successfully")

    def build_index(self, documents: list[Document]) -> None:
        """Build BM25 index from documents"""
        logger.info(f"Building BM25 index with {len(documents)} documents...")

        self.documents = documents
        self.corpus = []

        logger.debug("Starting document tokenization")
        for i, doc in enumerate(documents):
            if i > 0 and i % 1000 == 0:
                logger.debug(f"Tokenized {i}/{len(documents)} documents")
            tokens = self._tokenize(doc.content)
            self.corpus.append(tokens)

        logger.debug("Creating BM25Okapi instance")
        self.bm25 = BM25Okapi(self.corpus)
        logger.info("BM25 index built successfully")

    def search(self, query: str, k: int = 10) -> list[tuple[Document, float]]:
        """Search documents using BM25 scoring"""
        logger.debug(f"BM25 search initiated with query: '{query}', k={k}")

        if self.bm25 is None:
            logger.error("BM25 index not built - cannot perform search")
            raise ValueError("Index not built. Call build_index() first.")

        logger.debug("Tokenizing search query")
        query_tokens = self._tokenize(query)
        logger.debug(f"Query tokens: {query_tokens}")

        logger.debug("Computing BM25 scores")
        scores = self.bm25.get_scores(query_tokens)
        logger.debug(f"Computed scores for {len(scores)} documents")

        top_indices = np.argsort(scores)[::-1][:k]
        logger.debug(f"Top {len(top_indices)} indices: {top_indices.tolist()}")

        results = []
        for idx in top_indices:
            if scores[idx] > 0:
                results.append((self.documents[idx], float(scores[idx])))
                logger.debug(f"Added result: doc_idx={idx}, score={scores[idx]:.4f}")

        logger.info(f"BM25 search completed: {len(results)} results returned")
        return results

    def _tokenize(self, text: str) -> list[str]:
        """Simple word tokenization"""
        tokens = re.findall(r"\b\w+\b", text.lower())
        logger.debug(f"Tokenized text into {len(tokens)} tokens")
        return tokens

    def save(self, path: str) -> None:
        """Save index to disk"""
        logger.info(f"Saving BM25 index to {path}")
        save_path = Path(path)
        save_path.mkdir(parents=True, exist_ok=True)
        logger.debug(f"Created directory structure: {save_path}")

        logger.debug("Serializing BM25 index data")
        with open(save_path / "bm25_index.pkl", "wb") as f:
            pickle.dump(
                {"corpus": self.corpus, "documents": self.documents, "bm25": self.bm25},
                f,
            )

        logger.info(f"BM25 index saved to {path}")

    def load(self, path: str) -> None:
        """Load index from disk"""
        logger.info(f"Loading BM25 index from {path}")
        load_path = Path(path)

        if not (load_path / "bm25_index.pkl").exists():
            logger.error(f"BM25 index file not found at {path}")
            raise FileNotFoundError(f"Index file not found at {path}")

        logger.debug("Deserializing BM25 index data")
        with open(load_path / "bm25_index.pkl", "rb") as f:
            data = pickle.load(f)
            self.corpus = data["corpus"]
            self.documents = data["documents"]
            self.bm25 = data["bm25"]

        logger.debug(
            f"Loaded {len(self.documents)} documents and {len(self.corpus)} corpus entries"
        )
        logger.info(f"BM25 index loaded from {path}")


class ColBERTSearcher:
    """ColBERT-style dense retrieval using sentence transformers"""

    def __init__(
        self, model_name: str = "colbert-ir/colbertv2.0", device: str | None = None
    ):
        logger.debug(f"Initializing ColBERTSearcher with model: {model_name}")
        self.model_name = model_name
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        logger.info(f"Using device: {self.device}")
        self.model = None
        self.tokenizer = None
        self.documents = []
        self.document_embeddings = []
        self._model_loaded = False
        self._is_sentence_transformer = False
        logger.debug("ColBERTSearcher initialized successfully")

    def _load_model(self):
        """Load the ColBERT model"""
        logger.info(f"Loading model: {self.model_name}")

        try:
            logger.debug("Attempting to load as SentenceTransformer")
            self.model = SentenceTransformer(self.model_name)
            self._is_sentence_transformer = True
            self._model_loaded = True
            logger.info(f"Loaded {self.model_name} as SentenceTransformer")
        except Exception as e:
            logger.warning(f"Failed to load as SentenceTransformer: {e}")
            try:
                logger.debug("Attempting to load as AutoModel")
                self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
                self.model = AutoModel.from_pretrained(self.model_name)
                self.model.to(self.device)
                self.model.eval()
                self._is_sentence_transformer = False
                self._model_loaded = True
                logger.info(f"Loaded {self.model_name} as AutoModel")
            except Exception as e:
                logger.error(f"Failed to load model {self.model_name}: {e}")
                raise RuntimeError(f"Could not load model {self.model_name}") from e

    def encode_passages(
        self, passages: list[str], batch_size: int = 32
    ) -> list[np.ndarray]:
        """Encode passages into embeddings"""
        logger.info(f"Encoding {len(passages)} passages with batch_size={batch_size}")

        if not self._model_loaded:
            logger.debug("Model not loaded, loading now")
            self._load_model()

        if self.model is None:
            logger.error("Model failed to load during encoding")
            raise RuntimeError("Model failed to load")

        all_embeddings = []

        if isinstance(self.model, SentenceTransformer):
            logger.debug("Using SentenceTransformer for encoding")
            embeddings = self.model.encode(
                passages,
                batch_size=batch_size,
                show_progress_bar=True,
                convert_to_numpy=True,
            )
            all_embeddings = list(embeddings)
            logger.debug(
                f"Generated {len(all_embeddings)} embeddings using SentenceTransformer"
            )
        else:
            logger.debug("Using AutoModel for encoding")
            if self.tokenizer is None:
                logger.error("Tokenizer not initialized for AutoModel")
                raise ValueError("Tokenizer not initialized for AutoModel")

            num_batches = (len(passages) + batch_size - 1) // batch_size
            logger.debug(f"Processing {num_batches} batches")

            for i in range(0, len(passages), batch_size):
                batch_num = i // batch_size + 1
                logger.debug(f"Processing batch {batch_num}/{num_batches}")

                batch = passages[i : i + batch_size]
                inputs = self.tokenizer(
                    batch,
                    padding=True,
                    truncation=True,
                    max_length=512,
                    return_tensors="pt",
                ).to(self.device)

                with torch.no_grad():
                    outputs = self.model(**inputs)
                    embeddings = outputs.last_hidden_state.mean(dim=1)
                    all_embeddings.extend(embeddings.cpu().numpy())

        logger.info(f"Successfully encoded {len(all_embeddings)} passages")
        return all_embeddings

    def build_index(self, documents: list[Document]) -> None:
        """Build ColBERT index from documents"""
        logger.info(f"Building ColBERT index with {len(documents)} documents...")

        self.documents = documents
        logger.debug("Extracting content from documents")
        passages = [doc.content for doc in documents]

        logger.debug("Starting passage encoding")
        self.document_embeddings = self.encode_passages(passages)

        logger.info("ColBERT index built successfully")

    def search(self, query: str, k: int = 10) -> list[tuple[Document, float]]:
        """Search using cosine similarity"""
        logger.debug(f"ColBERT search initiated with query: '{query}', k={k}")

        if not self.documents:
            logger.error("ColBERT index not built - cannot perform search")
            raise ValueError("Index not built. Call build_index() first.")

        if not self._model_loaded:
            logger.debug("Model not loaded for search, loading now")
            self._load_model()

        if self.model is None:
            logger.error("Model failed to load during search")
            raise RuntimeError("Model failed to load")

        logger.debug("Encoding search query")
        if isinstance(self.model, SentenceTransformer):
            logger.debug("Using SentenceTransformer for query encoding")
            query_embedding = self.model.encode([query], convert_to_numpy=True)[0]
        else:
            logger.debug("Using AutoModel for query encoding")
            if self.tokenizer is None:
                logger.error("Tokenizer not initialized for AutoModel during search")
                raise ValueError("Tokenizer not initialized for AutoModel")
            inputs = self.tokenizer(
                [query],
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors="pt",
            ).to(self.device)

            with torch.no_grad():
                outputs = self.model(**inputs)
                query_embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()[0]

        logger.debug(
            f"Computing similarities with {len(self.document_embeddings)} document embeddings"
        )
        similarities = []
        for i, doc_emb in enumerate(self.document_embeddings):
            sim = np.dot(query_embedding, doc_emb) / (
                np.linalg.norm(query_embedding) * np.linalg.norm(doc_emb)
            )
            similarities.append(sim)
            if i > 0 and i % 1000 == 0:
                logger.debug(
                    f"Computed similarity for {i}/{len(self.document_embeddings)} documents"
                )

        top_indices = np.argsort(similarities)[::-1][:k]
        logger.debug(f"Top {len(top_indices)} indices: {top_indices.tolist()}")

        results = []
        for idx in top_indices:
            results.append((self.documents[idx], float(similarities[idx])))
            logger.debug(
                f"Added result: doc_idx={idx}, similarity={similarities[idx]:.4f}"
            )

        logger.info(f"ColBERT search completed: {len(results)} results returned")
        return results

    def save(self, path: str) -> None:
        """Save index to disk"""
        logger.info(f"Saving ColBERT index to {path}")
        save_path = Path(path)
        save_path.mkdir(parents=True, exist_ok=True)
        logger.debug(f"Created directory structure: {save_path}")

        logger.debug("Saving documents")
        with open(save_path / "documents.pkl", "wb") as f:
            pickle.dump(self.documents, f)

        logger.debug("Saving embeddings")
        with open(save_path / "embeddings.pkl", "wb") as f:
            pickle.dump(self.document_embeddings, f)

        logger.info(f"ColBERT index saved to {path}")

    def load(self, path: str) -> None:
        """Load index from disk"""
        logger.info(f"Loading ColBERT index from {path}")
        load_path = Path(path)

        if not (load_path / "documents.pkl").exists():
            logger.error(f"Documents file not found at {path}")
            raise FileNotFoundError(f"Documents file not found at {path}")

        if not (load_path / "embeddings.pkl").exists():
            logger.error(f"Embeddings file not found at {path}")
            raise FileNotFoundError(f"Embeddings file not found at {path}")

        logger.debug("Loading documents")
        with open(load_path / "documents.pkl", "rb") as f:
            self.documents = pickle.load(f)

        logger.debug("Loading embeddings")
        with open(load_path / "embeddings.pkl", "rb") as f:
            self.document_embeddings = pickle.load(f)

        logger.debug(
            f"Loaded {len(self.documents)} documents and {len(self.document_embeddings)} embeddings"
        )

        logger.debug("Loading model for loaded index")
        self._load_model()
        logger.info(f"ColBERT index loaded from {path}")