Spaces:
Sleeping
Sleeping
from haystack import Document | |
from haystack.utils import Secret | |
from haystack.document_stores.in_memory import InMemoryDocumentStore | |
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever | |
from haystack.components.builders import PromptBuilder | |
from haystack.components.generators import HuggingFaceTGIGenerator | |
from haystack import Pipeline | |
import sys | |
import subprocess | |
def install(name): | |
subprocess.call([sys.executable, '-m', 'pip', 'install', name]) | |
def init_doc_store(path, files): | |
docs = [] | |
for file in files: | |
with open(path + file, 'r') as f: | |
content = f.read() | |
docs.append(Document(content=content, meta={'name':file})) | |
document_store = InMemoryDocumentStore() | |
document_store.write_documents(docs) | |
return document_store | |
def define_components(document_store, api_key): | |
retriever = InMemoryBM25Retriever(document_store, top_k=3) | |
template = """ | |
You are a Chatbot designed to spread Awareness about Alzheimer's Disease. You are AI Chaperone. | |
You will be provided information about Alzheimer's Disease as context for each question. Given the following information, answer the question. | |
Context: | |
{% for document in documents %} | |
{{ document.content }} | |
{% endfor %} | |
Question: {{question}} | |
Answer: | |
""" | |
prompt_builder = PromptBuilder(template=template) | |
generator = HuggingFaceTGIGenerator( | |
model="mistralai/Mistral-7B-Instruct-v0.1", | |
token=Secret.from_token(api_key), | |
generation_kwargs = { | |
'max_new_tokens':50, | |
'temperature':0.7 | |
} | |
) | |
generator.warm_up() | |
return retriever, prompt_builder, generator | |
def define_pipeline(retreiver, prompt_builder, generator): | |
basic_rag_pipeline = Pipeline() | |
basic_rag_pipeline.add_component("retriever", retreiver) | |
basic_rag_pipeline.add_component("prompt_builder", prompt_builder) | |
basic_rag_pipeline.add_component("llm", generator) | |
basic_rag_pipeline.connect("retriever", "prompt_builder.documents") | |
basic_rag_pipeline.connect("prompt_builder", "llm") | |
return basic_rag_pipeline |