File size: 5,382 Bytes
91f974c
 
 
 
0a1d4cf
91f974c
 
 
 
 
 
 
 
 
 
0a1d4cf
91f974c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a1d4cf
91f974c
0a1d4cf
91f974c
 
0a1d4cf
 
 
 
 
 
 
 
 
 
 
91f974c
 
 
 
 
 
 
 
0a1d4cf
91f974c
 
 
 
 
 
0a1d4cf
91f974c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a1d4cf
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from haystack import Document, Pipeline
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
from haystack.components.builders import PromptBuilder
from datasets import load_dataset
from haystack.dataclasses import ChatMessage
from typing import Optional, List, Union, Dict
from .config import DatasetConfig, DATASET_CONFIGS, MODEL_CONFIG

class RAGPipeline:
    def __init__(
        self,
        dataset_config: Union[str, DatasetConfig],
        documents: Optional[List[Union[str, Document]]] = None,
        embedding_model: Optional[str] = None
    ):
        """
        Initialize the RAG Pipeline.
        
        Args:
            dataset_config: Either a string key from DATASET_CONFIGS or a DatasetConfig object
            documents: Optional list of documents to use instead of loading from a dataset
            embedding_model: Optional override for embedding model
        """
        # Load configuration
        if isinstance(dataset_config, str):
            if dataset_config not in DATASET_CONFIGS:
                raise ValueError(f"Dataset config '{dataset_config}' not found. Available configs: {list(DATASET_CONFIGS.keys())}")
            self.config = DATASET_CONFIGS[dataset_config]
        else:
            self.config = dataset_config

        # Load documents either from provided list or dataset
        if documents is not None:
            self.documents = documents
        else:
            dataset = load_dataset(self.config.name, split=self.config.split)
            # Create documents with metadata based on configuration
            self.documents = []
            for doc in dataset:
                # Create metadata dictionary from configured fields
                meta = {}
                if self.config.fields:
                    for meta_key, dataset_field in self.config.fields.items():
                        if dataset_field in doc:
                            meta[meta_key] = doc[dataset_field]
                
                # Create document with content and metadata
                document = Document(
                    content=doc[self.config.content_field],
                    meta=meta
                )
                self.documents.append(document)

        # print 10 documents
        for doc in self.documents[:10]:
            print(f"Content: {doc.content}")
            print(f"Metadata: {doc.meta}")
            print("-"*100)
        
        # Initialize components
        self.document_store = InMemoryDocumentStore()
        self.doc_embedder = SentenceTransformersDocumentEmbedder(
            model=embedding_model or MODEL_CONFIG["embedding_model"]
        )
        self.text_embedder = SentenceTransformersTextEmbedder(
            model=embedding_model or MODEL_CONFIG["embedding_model"]
        )
        self.retriever = InMemoryEmbeddingRetriever(self.document_store)
        
        # Warm up the embedders
        self.doc_embedder.warm_up()
        self.text_embedder.warm_up()
        
        # Initialize prompt template
        self.prompt_builder = PromptBuilder(template=self.config.prompt_template or """
        Given the following context, please answer the question.
        
        Context:
        {% for document in documents %}
            {{ document.content }}
        {% endfor %}
        
        Question: {{question}}
        Answer:
        """)
        
        # Index documents
        self._index_documents(self.documents)
        
        # Build pipeline
        self.pipeline = self._build_pipeline()

    @classmethod
    def from_preset(cls, preset_name: str):
        """
        Create a pipeline from a preset configuration.
        
        Args:
            preset_name: Name of the preset configuration to use
        """
        return cls(dataset_config=preset_name)

    def _index_documents(self, documents):
        # Embed and index documents
        docs_with_embeddings = self.doc_embedder.run(documents)
        self.document_store.write_documents(docs_with_embeddings["documents"])
    
    def _build_pipeline(self):
        pipeline = Pipeline()
        pipeline.add_component("text_embedder", self.text_embedder)
        pipeline.add_component("retriever", self.retriever)
        pipeline.add_component("prompt_builder", self.prompt_builder)
        
        # Connect components
        pipeline.connect("text_embedder.embedding", "retriever.query_embedding")
        pipeline.connect("retriever", "prompt_builder")
        
        return pipeline
    
    def answer_question(self, question: str) -> str:
        """Run the RAG pipeline to answer a question"""
        # First, embed the question and retrieve relevant documents
        embedded_question = self.text_embedder.run(text=question)
        retrieved_docs = self.retriever.run(query_embedding=embedded_question["embedding"])
        
        # Then, build the prompt with retrieved documents
        prompt_result = self.prompt_builder.run(
            question=question, 
            documents=retrieved_docs["documents"]
        )
        
        # Return the formatted prompt (this will be processed by the main AI)
        return prompt_result["prompt"]