File size: 2,170 Bytes
cc32819
 
1995a08
 
 
cc32819
0cd1a30
1995a08
ca1abf7
 
 
 
 
 
1995a08
 
 
0c96c7c
1995a08
 
 
 
 
 
 
7f5b663
1995a08
 
 
afb5ebd
 
1995a08
 
 
 
 
 
 
 
 
 
 
7f5b663
cc32819
8a2282e
 
afb5ebd
8a2282e
 
7f5b663
1995a08
682ac8f
1995a08
 
 
 
3064f8a
1995a08
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from haystack import Document
from haystack.utils import Secret
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
from haystack.components.builders import PromptBuilder
from haystack.components.generators import HuggingFaceTGIGenerator
from haystack import Pipeline

import sys
import subprocess

def install(name):
    subprocess.call([sys.executable, '-m', 'pip', 'install', name])

def init_doc_store(path, files):
    docs = []
    for file in files:
        with open(path + file, 'r') as f:
            content = f.read()
            docs.append(Document(content=content, meta={'name':file}))

    document_store = InMemoryDocumentStore()
    document_store.write_documents(docs)
    return document_store

def define_components(document_store, api_key):
    retriever = InMemoryBM25Retriever(document_store, top_k=3)
    
    template = """
    You are a Chatbot designed to spread Awareness about Alzheimer's Disease. You are AI Chaperone.
    You will be provided information about Alzheimer's Disease as context for each question. Given the following information, answer the question.
    
    Context:
    {% for document in documents %}
        {{ document.content }}
    {% endfor %}
    
    Question: {{question}}
    Answer:
    """
    prompt_builder = PromptBuilder(template=template)
    
    generator = HuggingFaceTGIGenerator(
        model="mistralai/Mistral-7B-Instruct-v0.1", 
        token=Secret.from_token(api_key),
        generation_kwargs = {
            'max_new_tokens':50,
            'temperature':0.7
        }
    )
    generator.warm_up()
    return retriever, prompt_builder, generator

def define_pipeline(retreiver, prompt_builder, generator):
    basic_rag_pipeline = Pipeline()
    
    basic_rag_pipeline.add_component("retriever", retreiver)
    basic_rag_pipeline.add_component("prompt_builder", prompt_builder)
    basic_rag_pipeline.add_component("llm", generator)
    
    basic_rag_pipeline.connect("retriever", "prompt_builder.documents")
    basic_rag_pipeline.connect("prompt_builder", "llm")

    return basic_rag_pipeline