File size: 4,462 Bytes
4a448eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os

import streamlit as st
from haystack import Pipeline
from haystack.document_stores import FAISSDocumentStore
from haystack.nodes import Shaper, PromptNode, PromptTemplate, PromptModel, EmbeddingRetriever
from haystack.nodes.retriever.web import WebRetriever


QUERIES = [
    "Did SVB collapse?",
    "Why did SVB collapse?",
    "What does SVB failure mean for our economy?",
    "Who is responsible for SVC collapse?",
    "When did SVB collapse?"
]

def ChangeWidgetFontSize(wgt_txt, wch_font_size = '12px'):
    htmlstr = """<script>var elements = window.parent.document.querySelectorAll('*'), i;
                    for (i = 0; i < elements.length; ++i) { if (elements[i].innerText == |wgt_txt|) 
                        { elements[i].style.fontSize='""" + wch_font_size + """';} } </script>  """

    htmlstr = htmlstr.replace('|wgt_txt|', "'" + wgt_txt + "'")


def get_plain_pipeline():
    prompt_open_ai = PromptModel(model_name_or_path="text-davinci-003", api_key=st.secrets["OPENAI_API_KEY"])
    # Now let make one PromptNode use the default model and the other one the OpenAI model:
    plain_llm_template = PromptTemplate(name="plain_llm", prompt_text="Answer the following question: $query")
    node_openai = PromptNode(prompt_open_ai, default_prompt_template=plain_llm_template, max_length=300)
    pipeline = Pipeline()
    pipeline.add_node(component=node_openai, name="prompt_node", inputs=["Query"])
    return pipeline


def get_retrieval_augmented_pipeline():
    ds = FAISSDocumentStore(faiss_index_path="data/my_faiss_index.faiss",
                            faiss_config_path="data/my_faiss_index.json")

    retriever = EmbeddingRetriever(
        document_store=ds,
        embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
        model_format="sentence_transformers",
        top_k=2
    )
    shaper = Shaper(func="join_documents", inputs={"documents": "documents"}, outputs=["documents"])

    default_template = PromptTemplate(
        name="question-answering",
        prompt_text="Given the context please answer the question. Context: $documents; Question: "
                    "$query; Answer:",
    )
    # Let's initiate the PromptNode
    node = PromptNode("text-davinci-003", default_prompt_template=default_template,
                      api_key=st.secrets["OPENAI_API_KEY"], max_length=500)

    # Let's create a pipeline with Shaper and PromptNode
    pipeline = Pipeline()
    pipeline.add_node(component=retriever, name='retriever', inputs=['Query'])
    pipeline.add_node(component=shaper, name="shaper", inputs=["retriever"])
    pipeline.add_node(component=node, name="prompt_node", inputs=["shaper"])
    return pipeline


def get_web_retrieval_augmented_pipeline():
    search_key = st.secrets["WEBRET_API_KEY"]
    web_retriever = WebRetriever(api_key=search_key, search_engine_provider="SerperDev")
    shaper = Shaper(func="join_documents", inputs={"documents": "documents"}, outputs=["documents"])
    default_template = PromptTemplate(
        name="question-answering",
        prompt_text="Given the context please answer the question. Context: $documents; Question: "
                    "$query; Answer:",
    )
    # Let's initiate the PromptNode
    node = PromptNode("text-davinci-003", default_prompt_template=default_template,
                      api_key=st.secrets["OPENAI_API_KEY"], max_length=500)
    # Let's create a pipeline with Shaper and PromptNode
    pipeline = Pipeline()
    pipeline.add_node(component=web_retriever, name='retriever', inputs=['Query'])
    pipeline.add_node(component=shaper, name="shaper", inputs=["retriever"])
    pipeline.add_node(component=node, name="prompt_node", inputs=["shaper"])
    return pipeline


@st.cache_resource(show_spinner=False)
def app_init():
    os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
    p1 = get_plain_pipeline()
    p2 = get_retrieval_augmented_pipeline()
    p3 = get_web_retrieval_augmented_pipeline()
    return p1, p2, p3


if 'query' not in st.session_state:
    st.session_state['query'] = ""


def set_question():
    st.session_state['query'] = st.session_state['q_drop_down']


def set_q1():
    st.session_state['query'] = QUERIES[0]


def set_q2():
    st.session_state['query'] = QUERIES[1]


def set_q3():
    st.session_state['query'] = QUERIES[2]


def set_q4():
    st.session_state['query'] = QUERIES[3]


def set_q5():
    st.session_state['query'] = QUERIES[4]