|
|
from haystack import Document, Pipeline |
|
|
from haystack.document_stores.in_memory import InMemoryDocumentStore |
|
|
from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder |
|
|
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever |
|
|
from haystack.components.builders import PromptBuilder |
|
|
from datasets import load_dataset |
|
|
from haystack.dataclasses import ChatMessage |
|
|
from typing import Optional, List, Union, Dict |
|
|
from .config import DatasetConfig, DATASET_CONFIGS, MODEL_CONFIG |
|
|
|
|
|
class RAGPipeline: |
|
|
def __init__( |
|
|
self, |
|
|
dataset_config: Union[str, DatasetConfig], |
|
|
documents: Optional[List[Union[str, Document]]] = None, |
|
|
embedding_model: Optional[str] = None |
|
|
): |
|
|
""" |
|
|
Initialize the RAG Pipeline. |
|
|
|
|
|
Args: |
|
|
dataset_config: Either a string key from DATASET_CONFIGS or a DatasetConfig object |
|
|
documents: Optional list of documents to use instead of loading from a dataset |
|
|
embedding_model: Optional override for embedding model |
|
|
""" |
|
|
|
|
|
if isinstance(dataset_config, str): |
|
|
if dataset_config not in DATASET_CONFIGS: |
|
|
raise ValueError(f"Dataset config '{dataset_config}' not found. Available configs: {list(DATASET_CONFIGS.keys())}") |
|
|
self.config = DATASET_CONFIGS[dataset_config] |
|
|
else: |
|
|
self.config = dataset_config |
|
|
|
|
|
|
|
|
if documents is not None: |
|
|
self.documents = documents |
|
|
else: |
|
|
dataset = load_dataset(self.config.name, split=self.config.split) |
|
|
|
|
|
self.documents = [] |
|
|
for doc in dataset: |
|
|
|
|
|
meta = {} |
|
|
if self.config.fields: |
|
|
for meta_key, dataset_field in self.config.fields.items(): |
|
|
if dataset_field in doc: |
|
|
meta[meta_key] = doc[dataset_field] |
|
|
|
|
|
|
|
|
document = Document( |
|
|
content=doc[self.config.content_field], |
|
|
meta=meta |
|
|
) |
|
|
self.documents.append(document) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.document_store = InMemoryDocumentStore() |
|
|
self.doc_embedder = SentenceTransformersDocumentEmbedder( |
|
|
model=embedding_model or MODEL_CONFIG["embedding_model"], |
|
|
progress_bar=False |
|
|
) |
|
|
self.text_embedder = SentenceTransformersTextEmbedder( |
|
|
model=embedding_model or MODEL_CONFIG["embedding_model"], |
|
|
progress_bar=False |
|
|
) |
|
|
self.text_embedder = SentenceTransformersTextEmbedder( |
|
|
model=embedding_model or MODEL_CONFIG["embedding_model"], |
|
|
progress_bar=False |
|
|
) |
|
|
self.retriever = InMemoryEmbeddingRetriever(self.document_store) |
|
|
|
|
|
|
|
|
self.doc_embedder.warm_up() |
|
|
self.text_embedder.warm_up() |
|
|
|
|
|
|
|
|
self.prompt_builder = PromptBuilder(template=self.config.prompt_template or """ |
|
|
Given the following context, please answer the question. |
|
|
|
|
|
Context: |
|
|
{% for document in documents %} |
|
|
{{ document.content }} |
|
|
{% endfor %} |
|
|
|
|
|
Question: {{question}} |
|
|
Answer: |
|
|
""") |
|
|
|
|
|
|
|
|
self._index_documents(self.documents) |
|
|
|
|
|
|
|
|
self.pipeline = self._build_pipeline() |
|
|
|
|
|
@classmethod |
|
|
def from_preset(cls, preset_name: str): |
|
|
""" |
|
|
Create a pipeline from a preset configuration. |
|
|
|
|
|
Args: |
|
|
preset_name: Name of the preset configuration to use |
|
|
""" |
|
|
return cls(dataset_config=preset_name) |
|
|
|
|
|
def _index_documents(self, documents): |
|
|
|
|
|
docs_with_embeddings = self.doc_embedder.run(documents) |
|
|
self.document_store.write_documents(docs_with_embeddings["documents"]) |
|
|
|
|
|
def _build_pipeline(self): |
|
|
pipeline = Pipeline() |
|
|
pipeline.add_component("text_embedder", self.text_embedder) |
|
|
pipeline.add_component("retriever", self.retriever) |
|
|
pipeline.add_component("prompt_builder", self.prompt_builder) |
|
|
|
|
|
|
|
|
pipeline.connect("text_embedder.embedding", "retriever.query_embedding") |
|
|
pipeline.connect("retriever", "prompt_builder") |
|
|
|
|
|
return pipeline |
|
|
|
|
|
def answer_question(self, question: str) -> str: |
|
|
"""Run the RAG pipeline to answer a question""" |
|
|
|
|
|
embedded_question = self.text_embedder.run(text=question) |
|
|
retrieved_docs = self.retriever.run(query_embedding=embedded_question["embedding"]) |
|
|
|
|
|
|
|
|
prompt_result = self.prompt_builder.run( |
|
|
question=question, |
|
|
documents=retrieved_docs["documents"] |
|
|
) |
|
|
|
|
|
|
|
|
return prompt_result["prompt"] |