File size: 4,179 Bytes
8df121c
a831acd
70b2fc9
05feb2b
 
70b2fc9
05feb2b
f98593f
620a0ce
f98593f
617bb16
70b2fc9
 
 
 
 
 
 
 
617bb16
fecdd86
8df121c
617bb16
 
 
 
 
 
 
 
f98593f
617bb16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fecdd86
 
0b1473d
617bb16
 
 
 
 
 
 
 
e162fdc
 
 
 
 
 
 
 
 
 
fecdd86
 
e162fdc
 
 
 
 
 
 
617bb16
 
 
 
fecdd86
 
05feb2b
 
 
479646e
05feb2b
062c3f5
0b1473d
 
e6afd23
0b1473d
 
 
 
d727ca7
 
0b1473d
 
 
479646e
 
 
 
d727ca7
8df121c
05feb2b
 
8df121c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import glob
import os
import logging
import sys

import streamlit as st
from haystack import Pipeline
from haystack.document_stores import FAISSDocumentStore
from haystack.nodes import Shaper, PromptNode, PromptTemplate, PromptModel, EmbeddingRetriever
from haystack.nodes.retriever.web import WebRetriever
from haystack.schema import Document

logging.basicConfig(
    level=logging.DEBUG,
    format="%(levelname)s %(asctime)s %(name)s:%(message)s",
    handlers=[logging.StreamHandler(sys.stdout)],
    force=True,
)

def get_plain_pipeline():
    prompt_open_ai = PromptModel(model_name_or_path="text-davinci-003", api_key=st.secrets["OPENAI_API_KEY"])

    # Now let make one PromptNode use the default model and the other one the OpenAI model:
    plain_llm_template = PromptTemplate(name="plain_llm", prompt_text="Answer the following question: $query")
    node_openai = PromptNode(prompt_open_ai, default_prompt_template=plain_llm_template, max_length=300)

    pipeline = Pipeline()
    pipeline.add_node(component=node_openai, name="prompt_node", inputs=["Query"])
    return pipeline


def get_ret_aug_pipeline():
    ds = FAISSDocumentStore(faiss_index_path="my_faiss_index.faiss",
                            faiss_config_path="my_faiss_index.json")

    retriever = EmbeddingRetriever(
        document_store=ds,
        embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
        model_format="sentence_transformers",
        top_k=2
    )
    shaper = Shaper(func="join_documents", inputs={"documents": "documents"}, outputs=["documents"])

    default_template= PromptTemplate(
        name="question-answering",
        prompt_text="Given the context please answer the question. Context: $documents; Question: "
                    "$query; Answer:",
    )
    # Let's initiate the PromptNode
    node = PromptNode("text-davinci-003", default_prompt_template=default_template,
                      api_key=st.secrets["OPENAI_API_KEY"], max_length=500)

    # Let's create a pipeline with Shaper and PromptNode
    pipe = Pipeline()
    pipe.add_node(component=retriever, name='retriever', inputs=['Query'])
    pipe.add_node(component=shaper, name="shaper", inputs=["retriever"])
    pipe.add_node(component=node, name="prompt_node", inputs=["shaper"])
    return pipe


def get_web_ret_pipeline():
    search_key = st.secrets["WEBRET_API_KEY"]
    web_retriever = WebRetriever(api_key=search_key, search_engine_provider="SerperDev")
    shaper = Shaper(func="join_documents", inputs={"documents": "documents"}, outputs=["documents"])
    default_template = PromptTemplate(
        name="question-answering",
        prompt_text="Given the context please answer the question. Context: $documents; Question: "
                    "$query; Answer:",
    )
    # Let's initiate the PromptNode
    node = PromptNode("text-davinci-003", default_prompt_template=default_template,
                      api_key=st.secrets["OPENAI_API_KEY"], max_length=500)
    # Let's create a pipeline with Shaper and PromptNode
    pipe = Pipeline()
    pipe.add_node(component=web_retriever, name='retriever', inputs=['Query'])
    pipe.add_node(component=shaper, name="shaper", inputs=["retriever"])
    pipe.add_node(component=node, name="prompt_node", inputs=["shaper"])
    return pipe

def app_init():
    os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
    p1 = get_plain_pipeline()
    p2 = get_ret_aug_pipeline()
    p3 = get_web_ret_pipeline()
    return p1, p2, p3


def main():
    p1, p2, p3 = app_init()
    st.title("Haystack Demo")
    input = st.text_input("Query ...", "Did SVB collapse?")

    query_type = st.radio("Type",
                          ("Retrieval Augmented", "Retrieval Augmented with Web Search"))
    col_1, col_2 = st.columns(2)

    with col_1:
        st.text("PLAIN")
        answers = p1.run(input)
        st.text(answers['results'][0])

    with col_2:
        st.write(query_type.upper())
        if query_type == "Retrieval Augmented":
            answers_2 = p2.run(input)
        else:
            answers_2 = p3.run(input)
        st.text(answers_2['results'][0])


if __name__ == "__main__":
    main()