File size: 5,489 Bytes
91f974c 0a1d4cf 91f974c 0a1d4cf 91f974c 8946f02 91f974c 8946f02 91f974c 8946f02 91f974c 0a1d4cf 91f974c 0a1d4cf 91f974c 0a1d4cf 91f974c 0a1d4cf 91f974c 0a1d4cf 91f974c 0a1d4cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from haystack import Document, Pipeline
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
from haystack.components.builders import PromptBuilder
from datasets import load_dataset
from haystack.dataclasses import ChatMessage
from typing import Optional, List, Union, Dict
from .config import DatasetConfig, DATASET_CONFIGS, MODEL_CONFIG
class RAGPipeline:
def __init__(
self,
dataset_config: Union[str, DatasetConfig],
documents: Optional[List[Union[str, Document]]] = None,
embedding_model: Optional[str] = None
):
"""
Initialize the RAG Pipeline.
Args:
dataset_config: Either a string key from DATASET_CONFIGS or a DatasetConfig object
documents: Optional list of documents to use instead of loading from a dataset
embedding_model: Optional override for embedding model
"""
# Load configuration
if isinstance(dataset_config, str):
if dataset_config not in DATASET_CONFIGS:
raise ValueError(f"Dataset config '{dataset_config}' not found. Available configs: {list(DATASET_CONFIGS.keys())}")
self.config = DATASET_CONFIGS[dataset_config]
else:
self.config = dataset_config
# Load documents either from provided list or dataset
if documents is not None:
self.documents = documents
else:
dataset = load_dataset(self.config.name, split=self.config.split)
# Create documents with metadata based on configuration
self.documents = []
for doc in dataset:
# Create metadata dictionary from configured fields
meta = {}
if self.config.fields:
for meta_key, dataset_field in self.config.fields.items():
if dataset_field in doc:
meta[meta_key] = doc[dataset_field]
# Create document with content and metadata
document = Document(
content=doc[self.config.content_field],
meta=meta
)
self.documents.append(document)
# Documents loaded silently - remove verbose output
# Initialize components
self.document_store = InMemoryDocumentStore()
self.doc_embedder = SentenceTransformersDocumentEmbedder(
model=embedding_model or MODEL_CONFIG["embedding_model"],
progress_bar=False
)
self.text_embedder = SentenceTransformersTextEmbedder(
model=embedding_model or MODEL_CONFIG["embedding_model"],
progress_bar=False
)
self.text_embedder = SentenceTransformersTextEmbedder(
model=embedding_model or MODEL_CONFIG["embedding_model"],
progress_bar=False
)
self.retriever = InMemoryEmbeddingRetriever(self.document_store)
# Warm up the embedders
self.doc_embedder.warm_up()
self.text_embedder.warm_up()
# Initialize prompt template
self.prompt_builder = PromptBuilder(template=self.config.prompt_template or """
Given the following context, please answer the question.
Context:
{% for document in documents %}
{{ document.content }}
{% endfor %}
Question: {{question}}
Answer:
""")
# Index documents
self._index_documents(self.documents)
# Build pipeline
self.pipeline = self._build_pipeline()
@classmethod
def from_preset(cls, preset_name: str):
"""
Create a pipeline from a preset configuration.
Args:
preset_name: Name of the preset configuration to use
"""
return cls(dataset_config=preset_name)
def _index_documents(self, documents):
# Embed and index documents
docs_with_embeddings = self.doc_embedder.run(documents)
self.document_store.write_documents(docs_with_embeddings["documents"])
def _build_pipeline(self):
pipeline = Pipeline()
pipeline.add_component("text_embedder", self.text_embedder)
pipeline.add_component("retriever", self.retriever)
pipeline.add_component("prompt_builder", self.prompt_builder)
# Connect components
pipeline.connect("text_embedder.embedding", "retriever.query_embedding")
pipeline.connect("retriever", "prompt_builder")
return pipeline
def answer_question(self, question: str) -> str:
"""Run the RAG pipeline to answer a question"""
# First, embed the question and retrieve relevant documents
embedded_question = self.text_embedder.run(text=question)
retrieved_docs = self.retriever.run(query_embedding=embedded_question["embedding"])
# Then, build the prompt with retrieved documents
prompt_result = self.prompt_builder.run(
question=question,
documents=retrieved_docs["documents"]
)
# Return the formatted prompt (this will be processed by the main AI)
return prompt_result["prompt"] |