import logging import pickle import re from pathlib import Path import numpy as np import torch from rank_bm25 import BM25Okapi from sentence_transformers import SentenceTransformer from transformers import AutoModel, AutoTokenizer from app.models.document import Document logger = logging.getLogger(__name__) class BM25Searcher: """BM25 keyword-based search index""" def __init__(self): logger.debug("Initializing BM25Searcher") self.corpus = [] self.documents = [] self.bm25 = None logger.debug("BM25Searcher initialized successfully") def build_index(self, documents: list[Document]) -> None: """Build BM25 index from documents""" logger.info(f"Building BM25 index with {len(documents)} documents...") self.documents = documents self.corpus = [] logger.debug("Starting document tokenization") for i, doc in enumerate(documents): if i > 0 and i % 1000 == 0: logger.debug(f"Tokenized {i}/{len(documents)} documents") tokens = self._tokenize(doc.content) self.corpus.append(tokens) logger.debug("Creating BM25Okapi instance") self.bm25 = BM25Okapi(self.corpus) logger.info("BM25 index built successfully") def search(self, query: str, k: int = 10) -> list[tuple[Document, float]]: """Search documents using BM25 scoring""" logger.debug(f"BM25 search initiated with query: '{query}', k={k}") if self.bm25 is None: logger.error("BM25 index not built - cannot perform search") raise ValueError("Index not built. Call build_index() first.") logger.debug("Tokenizing search query") query_tokens = self._tokenize(query) logger.debug(f"Query tokens: {query_tokens}") logger.debug("Computing BM25 scores") scores = self.bm25.get_scores(query_tokens) logger.debug(f"Computed scores for {len(scores)} documents") top_indices = np.argsort(scores)[::-1][:k] logger.debug(f"Top {len(top_indices)} indices: {top_indices.tolist()}") results = [] for idx in top_indices: if scores[idx] > 0: results.append((self.documents[idx], float(scores[idx]))) logger.debug(f"Added result: doc_idx={idx}, score={scores[idx]:.4f}") logger.info(f"BM25 search completed: {len(results)} results returned") return results def _tokenize(self, text: str) -> list[str]: """Simple word tokenization""" tokens = re.findall(r"\b\w+\b", text.lower()) logger.debug(f"Tokenized text into {len(tokens)} tokens") return tokens def save(self, path: str) -> None: """Save index to disk""" logger.info(f"Saving BM25 index to {path}") save_path = Path(path) save_path.mkdir(parents=True, exist_ok=True) logger.debug(f"Created directory structure: {save_path}") logger.debug("Serializing BM25 index data") with open(save_path / "bm25_index.pkl", "wb") as f: pickle.dump( {"corpus": self.corpus, "documents": self.documents, "bm25": self.bm25}, f, ) logger.info(f"BM25 index saved to {path}") def load(self, path: str) -> None: """Load index from disk""" logger.info(f"Loading BM25 index from {path}") load_path = Path(path) if not (load_path / "bm25_index.pkl").exists(): logger.error(f"BM25 index file not found at {path}") raise FileNotFoundError(f"Index file not found at {path}") logger.debug("Deserializing BM25 index data") with open(load_path / "bm25_index.pkl", "rb") as f: data = pickle.load(f) self.corpus = data["corpus"] self.documents = data["documents"] self.bm25 = data["bm25"] logger.debug( f"Loaded {len(self.documents)} documents and {len(self.corpus)} corpus entries" ) logger.info(f"BM25 index loaded from {path}") class ColBERTSearcher: """ColBERT-style dense retrieval using sentence transformers""" def __init__( self, model_name: str = "colbert-ir/colbertv2.0", device: str | None = None ): logger.debug(f"Initializing ColBERTSearcher with model: {model_name}") self.model_name = model_name self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {self.device}") self.model = None self.tokenizer = None self.documents = [] self.document_embeddings = [] self._model_loaded = False self._is_sentence_transformer = False logger.debug("ColBERTSearcher initialized successfully") def _load_model(self): """Load the ColBERT model""" logger.info(f"Loading model: {self.model_name}") try: logger.debug("Attempting to load as SentenceTransformer") self.model = SentenceTransformer(self.model_name) self._is_sentence_transformer = True self._model_loaded = True logger.info(f"Loaded {self.model_name} as SentenceTransformer") except Exception as e: logger.warning(f"Failed to load as SentenceTransformer: {e}") try: logger.debug("Attempting to load as AutoModel") self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = AutoModel.from_pretrained(self.model_name) self.model.to(self.device) self.model.eval() self._is_sentence_transformer = False self._model_loaded = True logger.info(f"Loaded {self.model_name} as AutoModel") except Exception as e: logger.error(f"Failed to load model {self.model_name}: {e}") raise RuntimeError(f"Could not load model {self.model_name}") from e def encode_passages( self, passages: list[str], batch_size: int = 32 ) -> list[np.ndarray]: """Encode passages into embeddings""" logger.info(f"Encoding {len(passages)} passages with batch_size={batch_size}") if not self._model_loaded: logger.debug("Model not loaded, loading now") self._load_model() if self.model is None: logger.error("Model failed to load during encoding") raise RuntimeError("Model failed to load") all_embeddings = [] if isinstance(self.model, SentenceTransformer): logger.debug("Using SentenceTransformer for encoding") embeddings = self.model.encode( passages, batch_size=batch_size, show_progress_bar=True, convert_to_numpy=True, ) all_embeddings = list(embeddings) logger.debug( f"Generated {len(all_embeddings)} embeddings using SentenceTransformer" ) else: logger.debug("Using AutoModel for encoding") if self.tokenizer is None: logger.error("Tokenizer not initialized for AutoModel") raise ValueError("Tokenizer not initialized for AutoModel") num_batches = (len(passages) + batch_size - 1) // batch_size logger.debug(f"Processing {num_batches} batches") for i in range(0, len(passages), batch_size): batch_num = i // batch_size + 1 logger.debug(f"Processing batch {batch_num}/{num_batches}") batch = passages[i : i + batch_size] inputs = self.tokenizer( batch, padding=True, truncation=True, max_length=512, return_tensors="pt", ).to(self.device) with torch.no_grad(): outputs = self.model(**inputs) embeddings = outputs.last_hidden_state.mean(dim=1) all_embeddings.extend(embeddings.cpu().numpy()) logger.info(f"Successfully encoded {len(all_embeddings)} passages") return all_embeddings def build_index(self, documents: list[Document]) -> None: """Build ColBERT index from documents""" logger.info(f"Building ColBERT index with {len(documents)} documents...") self.documents = documents logger.debug("Extracting content from documents") passages = [doc.content for doc in documents] logger.debug("Starting passage encoding") self.document_embeddings = self.encode_passages(passages) logger.info("ColBERT index built successfully") def search(self, query: str, k: int = 10) -> list[tuple[Document, float]]: """Search using cosine similarity""" logger.debug(f"ColBERT search initiated with query: '{query}', k={k}") if not self.documents: logger.error("ColBERT index not built - cannot perform search") raise ValueError("Index not built. Call build_index() first.") if not self._model_loaded: logger.debug("Model not loaded for search, loading now") self._load_model() if self.model is None: logger.error("Model failed to load during search") raise RuntimeError("Model failed to load") logger.debug("Encoding search query") if isinstance(self.model, SentenceTransformer): logger.debug("Using SentenceTransformer for query encoding") query_embedding = self.model.encode([query], convert_to_numpy=True)[0] else: logger.debug("Using AutoModel for query encoding") if self.tokenizer is None: logger.error("Tokenizer not initialized for AutoModel during search") raise ValueError("Tokenizer not initialized for AutoModel") inputs = self.tokenizer( [query], padding=True, truncation=True, max_length=512, return_tensors="pt", ).to(self.device) with torch.no_grad(): outputs = self.model(**inputs) query_embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()[0] logger.debug( f"Computing similarities with {len(self.document_embeddings)} document embeddings" ) similarities = [] for i, doc_emb in enumerate(self.document_embeddings): sim = np.dot(query_embedding, doc_emb) / ( np.linalg.norm(query_embedding) * np.linalg.norm(doc_emb) ) similarities.append(sim) if i > 0 and i % 1000 == 0: logger.debug( f"Computed similarity for {i}/{len(self.document_embeddings)} documents" ) top_indices = np.argsort(similarities)[::-1][:k] logger.debug(f"Top {len(top_indices)} indices: {top_indices.tolist()}") results = [] for idx in top_indices: results.append((self.documents[idx], float(similarities[idx]))) logger.debug( f"Added result: doc_idx={idx}, similarity={similarities[idx]:.4f}" ) logger.info(f"ColBERT search completed: {len(results)} results returned") return results def save(self, path: str) -> None: """Save index to disk""" logger.info(f"Saving ColBERT index to {path}") save_path = Path(path) save_path.mkdir(parents=True, exist_ok=True) logger.debug(f"Created directory structure: {save_path}") logger.debug("Saving documents") with open(save_path / "documents.pkl", "wb") as f: pickle.dump(self.documents, f) logger.debug("Saving embeddings") with open(save_path / "embeddings.pkl", "wb") as f: pickle.dump(self.document_embeddings, f) logger.info(f"ColBERT index saved to {path}") def load(self, path: str) -> None: """Load index from disk""" logger.info(f"Loading ColBERT index from {path}") load_path = Path(path) if not (load_path / "documents.pkl").exists(): logger.error(f"Documents file not found at {path}") raise FileNotFoundError(f"Documents file not found at {path}") if not (load_path / "embeddings.pkl").exists(): logger.error(f"Embeddings file not found at {path}") raise FileNotFoundError(f"Embeddings file not found at {path}") logger.debug("Loading documents") with open(load_path / "documents.pkl", "rb") as f: self.documents = pickle.load(f) logger.debug("Loading embeddings") with open(load_path / "embeddings.pkl", "rb") as f: self.document_embeddings = pickle.load(f) logger.debug( f"Loaded {len(self.documents)} documents and {len(self.document_embeddings)} embeddings" ) logger.debug("Loading model for loaded index") self._load_model() logger.info(f"ColBERT index loaded from {path}")