aivre / app /services /indexes.py
Vedang Barhate
chore: copied from assist repo
cfc8e23
raw
history blame
13.3 kB
import logging
import pickle
import re
from pathlib import Path
import numpy as np
import torch
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
from transformers import AutoModel, AutoTokenizer
from app.models.document import Document
logger = logging.getLogger(__name__)
class BM25Searcher:
"""BM25 keyword-based search index"""
def __init__(self):
logger.debug("Initializing BM25Searcher")
self.corpus = []
self.documents = []
self.bm25 = None
logger.debug("BM25Searcher initialized successfully")
def build_index(self, documents: list[Document]) -> None:
"""Build BM25 index from documents"""
logger.info(f"Building BM25 index with {len(documents)} documents...")
self.documents = documents
self.corpus = []
logger.debug("Starting document tokenization")
for i, doc in enumerate(documents):
if i > 0 and i % 1000 == 0:
logger.debug(f"Tokenized {i}/{len(documents)} documents")
tokens = self._tokenize(doc.content)
self.corpus.append(tokens)
logger.debug("Creating BM25Okapi instance")
self.bm25 = BM25Okapi(self.corpus)
logger.info("BM25 index built successfully")
def search(self, query: str, k: int = 10) -> list[tuple[Document, float]]:
"""Search documents using BM25 scoring"""
logger.debug(f"BM25 search initiated with query: '{query}', k={k}")
if self.bm25 is None:
logger.error("BM25 index not built - cannot perform search")
raise ValueError("Index not built. Call build_index() first.")
logger.debug("Tokenizing search query")
query_tokens = self._tokenize(query)
logger.debug(f"Query tokens: {query_tokens}")
logger.debug("Computing BM25 scores")
scores = self.bm25.get_scores(query_tokens)
logger.debug(f"Computed scores for {len(scores)} documents")
top_indices = np.argsort(scores)[::-1][:k]
logger.debug(f"Top {len(top_indices)} indices: {top_indices.tolist()}")
results = []
for idx in top_indices:
if scores[idx] > 0:
results.append((self.documents[idx], float(scores[idx])))
logger.debug(f"Added result: doc_idx={idx}, score={scores[idx]:.4f}")
logger.info(f"BM25 search completed: {len(results)} results returned")
return results
def _tokenize(self, text: str) -> list[str]:
"""Simple word tokenization"""
tokens = re.findall(r"\b\w+\b", text.lower())
logger.debug(f"Tokenized text into {len(tokens)} tokens")
return tokens
def save(self, path: str) -> None:
"""Save index to disk"""
logger.info(f"Saving BM25 index to {path}")
save_path = Path(path)
save_path.mkdir(parents=True, exist_ok=True)
logger.debug(f"Created directory structure: {save_path}")
logger.debug("Serializing BM25 index data")
with open(save_path / "bm25_index.pkl", "wb") as f:
pickle.dump(
{"corpus": self.corpus, "documents": self.documents, "bm25": self.bm25},
f,
)
logger.info(f"BM25 index saved to {path}")
def load(self, path: str) -> None:
"""Load index from disk"""
logger.info(f"Loading BM25 index from {path}")
load_path = Path(path)
if not (load_path / "bm25_index.pkl").exists():
logger.error(f"BM25 index file not found at {path}")
raise FileNotFoundError(f"Index file not found at {path}")
logger.debug("Deserializing BM25 index data")
with open(load_path / "bm25_index.pkl", "rb") as f:
data = pickle.load(f)
self.corpus = data["corpus"]
self.documents = data["documents"]
self.bm25 = data["bm25"]
logger.debug(
f"Loaded {len(self.documents)} documents and {len(self.corpus)} corpus entries"
)
logger.info(f"BM25 index loaded from {path}")
class ColBERTSearcher:
"""ColBERT-style dense retrieval using sentence transformers"""
def __init__(
self, model_name: str = "colbert-ir/colbertv2.0", device: str | None = None
):
logger.debug(f"Initializing ColBERTSearcher with model: {model_name}")
self.model_name = model_name
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {self.device}")
self.model = None
self.tokenizer = None
self.documents = []
self.document_embeddings = []
self._model_loaded = False
self._is_sentence_transformer = False
logger.debug("ColBERTSearcher initialized successfully")
def _load_model(self):
"""Load the ColBERT model"""
logger.info(f"Loading model: {self.model_name}")
try:
logger.debug("Attempting to load as SentenceTransformer")
self.model = SentenceTransformer(self.model_name)
self._is_sentence_transformer = True
self._model_loaded = True
logger.info(f"Loaded {self.model_name} as SentenceTransformer")
except Exception as e:
logger.warning(f"Failed to load as SentenceTransformer: {e}")
try:
logger.debug("Attempting to load as AutoModel")
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModel.from_pretrained(self.model_name)
self.model.to(self.device)
self.model.eval()
self._is_sentence_transformer = False
self._model_loaded = True
logger.info(f"Loaded {self.model_name} as AutoModel")
except Exception as e:
logger.error(f"Failed to load model {self.model_name}: {e}")
raise RuntimeError(f"Could not load model {self.model_name}") from e
def encode_passages(
self, passages: list[str], batch_size: int = 32
) -> list[np.ndarray]:
"""Encode passages into embeddings"""
logger.info(f"Encoding {len(passages)} passages with batch_size={batch_size}")
if not self._model_loaded:
logger.debug("Model not loaded, loading now")
self._load_model()
if self.model is None:
logger.error("Model failed to load during encoding")
raise RuntimeError("Model failed to load")
all_embeddings = []
if isinstance(self.model, SentenceTransformer):
logger.debug("Using SentenceTransformer for encoding")
embeddings = self.model.encode(
passages,
batch_size=batch_size,
show_progress_bar=True,
convert_to_numpy=True,
)
all_embeddings = list(embeddings)
logger.debug(
f"Generated {len(all_embeddings)} embeddings using SentenceTransformer"
)
else:
logger.debug("Using AutoModel for encoding")
if self.tokenizer is None:
logger.error("Tokenizer not initialized for AutoModel")
raise ValueError("Tokenizer not initialized for AutoModel")
num_batches = (len(passages) + batch_size - 1) // batch_size
logger.debug(f"Processing {num_batches} batches")
for i in range(0, len(passages), batch_size):
batch_num = i // batch_size + 1
logger.debug(f"Processing batch {batch_num}/{num_batches}")
batch = passages[i : i + batch_size]
inputs = self.tokenizer(
batch,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1)
all_embeddings.extend(embeddings.cpu().numpy())
logger.info(f"Successfully encoded {len(all_embeddings)} passages")
return all_embeddings
def build_index(self, documents: list[Document]) -> None:
"""Build ColBERT index from documents"""
logger.info(f"Building ColBERT index with {len(documents)} documents...")
self.documents = documents
logger.debug("Extracting content from documents")
passages = [doc.content for doc in documents]
logger.debug("Starting passage encoding")
self.document_embeddings = self.encode_passages(passages)
logger.info("ColBERT index built successfully")
def search(self, query: str, k: int = 10) -> list[tuple[Document, float]]:
"""Search using cosine similarity"""
logger.debug(f"ColBERT search initiated with query: '{query}', k={k}")
if not self.documents:
logger.error("ColBERT index not built - cannot perform search")
raise ValueError("Index not built. Call build_index() first.")
if not self._model_loaded:
logger.debug("Model not loaded for search, loading now")
self._load_model()
if self.model is None:
logger.error("Model failed to load during search")
raise RuntimeError("Model failed to load")
logger.debug("Encoding search query")
if isinstance(self.model, SentenceTransformer):
logger.debug("Using SentenceTransformer for query encoding")
query_embedding = self.model.encode([query], convert_to_numpy=True)[0]
else:
logger.debug("Using AutoModel for query encoding")
if self.tokenizer is None:
logger.error("Tokenizer not initialized for AutoModel during search")
raise ValueError("Tokenizer not initialized for AutoModel")
inputs = self.tokenizer(
[query],
padding=True,
truncation=True,
max_length=512,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
query_embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()[0]
logger.debug(
f"Computing similarities with {len(self.document_embeddings)} document embeddings"
)
similarities = []
for i, doc_emb in enumerate(self.document_embeddings):
sim = np.dot(query_embedding, doc_emb) / (
np.linalg.norm(query_embedding) * np.linalg.norm(doc_emb)
)
similarities.append(sim)
if i > 0 and i % 1000 == 0:
logger.debug(
f"Computed similarity for {i}/{len(self.document_embeddings)} documents"
)
top_indices = np.argsort(similarities)[::-1][:k]
logger.debug(f"Top {len(top_indices)} indices: {top_indices.tolist()}")
results = []
for idx in top_indices:
results.append((self.documents[idx], float(similarities[idx])))
logger.debug(
f"Added result: doc_idx={idx}, similarity={similarities[idx]:.4f}"
)
logger.info(f"ColBERT search completed: {len(results)} results returned")
return results
def save(self, path: str) -> None:
"""Save index to disk"""
logger.info(f"Saving ColBERT index to {path}")
save_path = Path(path)
save_path.mkdir(parents=True, exist_ok=True)
logger.debug(f"Created directory structure: {save_path}")
logger.debug("Saving documents")
with open(save_path / "documents.pkl", "wb") as f:
pickle.dump(self.documents, f)
logger.debug("Saving embeddings")
with open(save_path / "embeddings.pkl", "wb") as f:
pickle.dump(self.document_embeddings, f)
logger.info(f"ColBERT index saved to {path}")
def load(self, path: str) -> None:
"""Load index from disk"""
logger.info(f"Loading ColBERT index from {path}")
load_path = Path(path)
if not (load_path / "documents.pkl").exists():
logger.error(f"Documents file not found at {path}")
raise FileNotFoundError(f"Documents file not found at {path}")
if not (load_path / "embeddings.pkl").exists():
logger.error(f"Embeddings file not found at {path}")
raise FileNotFoundError(f"Embeddings file not found at {path}")
logger.debug("Loading documents")
with open(load_path / "documents.pkl", "rb") as f:
self.documents = pickle.load(f)
logger.debug("Loading embeddings")
with open(load_path / "embeddings.pkl", "rb") as f:
self.document_embeddings = pickle.load(f)
logger.debug(
f"Loaded {len(self.documents)} documents and {len(self.document_embeddings)} embeddings"
)
logger.debug("Loading model for loaded index")
self._load_model()
logger.info(f"ColBERT index loaded from {path}")