from haystack import Document, Pipeline from haystack.document_stores.in_memory import InMemoryDocumentStore from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever from haystack.components.builders import PromptBuilder from datasets import load_dataset from haystack.dataclasses import ChatMessage from typing import Optional, List, Union, Dict from .config import DatasetConfig, DATASET_CONFIGS, MODEL_CONFIG class RAGPipeline: def __init__( self, dataset_config: Union[str, DatasetConfig], documents: Optional[List[Union[str, Document]]] = None, embedding_model: Optional[str] = None ): """ Initialize the RAG Pipeline. Args: dataset_config: Either a string key from DATASET_CONFIGS or a DatasetConfig object documents: Optional list of documents to use instead of loading from a dataset embedding_model: Optional override for embedding model """ # Load configuration if isinstance(dataset_config, str): if dataset_config not in DATASET_CONFIGS: raise ValueError(f"Dataset config '{dataset_config}' not found. Available configs: {list(DATASET_CONFIGS.keys())}") self.config = DATASET_CONFIGS[dataset_config] else: self.config = dataset_config # Load documents either from provided list or dataset if documents is not None: self.documents = documents else: dataset = load_dataset(self.config.name, split=self.config.split) # Create documents with metadata based on configuration self.documents = [] for doc in dataset: # Create metadata dictionary from configured fields meta = {} if self.config.fields: for meta_key, dataset_field in self.config.fields.items(): if dataset_field in doc: meta[meta_key] = doc[dataset_field] # Create document with content and metadata document = Document( content=doc[self.config.content_field], meta=meta ) self.documents.append(document) # print 10 documents for doc in self.documents[:10]: print(f"Content: {doc.content}") print(f"Metadata: {doc.meta}") print("-"*100) # Initialize components self.document_store = InMemoryDocumentStore() self.doc_embedder = SentenceTransformersDocumentEmbedder( model=embedding_model or MODEL_CONFIG["embedding_model"] ) self.text_embedder = SentenceTransformersTextEmbedder( model=embedding_model or MODEL_CONFIG["embedding_model"] ) self.retriever = InMemoryEmbeddingRetriever(self.document_store) # Warm up the embedders self.doc_embedder.warm_up() self.text_embedder.warm_up() # Initialize prompt template self.prompt_builder = PromptBuilder(template=self.config.prompt_template or """ Given the following context, please answer the question. Context: {% for document in documents %} {{ document.content }} {% endfor %} Question: {{question}} Answer: """) # Index documents self._index_documents(self.documents) # Build pipeline self.pipeline = self._build_pipeline() @classmethod def from_preset(cls, preset_name: str): """ Create a pipeline from a preset configuration. Args: preset_name: Name of the preset configuration to use """ return cls(dataset_config=preset_name) def _index_documents(self, documents): # Embed and index documents docs_with_embeddings = self.doc_embedder.run(documents) self.document_store.write_documents(docs_with_embeddings["documents"]) def _build_pipeline(self): pipeline = Pipeline() pipeline.add_component("text_embedder", self.text_embedder) pipeline.add_component("retriever", self.retriever) pipeline.add_component("prompt_builder", self.prompt_builder) # Connect components pipeline.connect("text_embedder.embedding", "retriever.query_embedding") pipeline.connect("retriever", "prompt_builder") return pipeline def answer_question(self, question: str) -> str: """Run the RAG pipeline to answer a question""" # First, embed the question and retrieve relevant documents embedded_question = self.text_embedder.run(text=question) retrieved_docs = self.retriever.run(query_embedding=embedded_question["embedding"]) # Then, build the prompt with retrieved documents prompt_result = self.prompt_builder.run( question=question, documents=retrieved_docs["documents"] ) # Return the formatted prompt (this will be processed by the main AI) return prompt_result["prompt"]