File size: 6,192 Bytes
4071f4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import sagemaker
import streamlit as st
from streamlit_chat import message
from langchain.llms.sagemaker_endpoint import LLMContentHandler, SagemakerEndpoint
from langchain.document_loaders import UnstructuredURLLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain import PromptTemplate
from langchain.chains.question_answering import load_qa_chain
from langchain.memory import ConversationBufferMemory
from typing import Dict
import json
import chromadb
import datetime

endpoint_name = "falcon-40b-instruct-gates3"
aws_region = "us-east-1"
class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"
    len_prompt = 0

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        self.len_prompt = len(prompt)
        input_str = json.dumps(
            {"inputs": prompt,
             "parameters": {
                "do_sample": True,
                "top_p": 0.9,
                "temperature": 0.8,
                "max_new_tokens": 1024,
                "repetition_penalty": 1.03,
                "stop": ["\n\n", "Human:", "<|endoftext|>", "</s>"]
              }})
        return input_str.encode('utf-8')

    def transform_output(self, output: bytes) -> str:
        response_json = output.read()
        res = json.loads(response_json)
        ans = res[0]['generated_text'][self.len_prompt:]
        ans = ans[:ans.rfind("Human")].strip()
        return ans


@st.cache_resource
def getDocsearchOnce():
    print("Getting docsearch...")

    # define URL sources
    urls = [
        'https://www.dssinc.com/blog/2022/6/21/suicide-prevention-manager-enabling-the-veterans-affairs-to-achieve-high-reliability-in-suicide-risk-identification',
        'https://www.dssinc.com/blog/2022/8/9/dss-inc-announces-appointment-of-brion-bailey-as-director-of-federal-business-development', 
        'https://www.dssinc.com/blog/2022/3/21/march-22-is-diabetes-alertness-day-a-helpful-reminder-to-monitor-and-prevent-diabetes',
        'https://www.dssinc.com/blog/2023/5/24/supporting-the-vas-high-reliability-organization-journey-through-suicide-prevention',
        'https://www.dssinc.com/blog/2022/12/19/dss-theradoc-helps-battle-super-bugs-for-better-veteran-health',
        'https://www.dssinc.com/blog/2022/9/21/dss-inc-chosen-for-phase-two-of-mission-daybreak-vas-suicide-prevention-challenge',
        'https://www.dssinc.com/blog/2022/9/19/crescenz-va-medical-center-cmcvamc-deploys-the-dss-iconic-data-patient-case-manager-pcm-solution',
        'https://www.dssinc.com/blog/2022/5/9/federal-news-network-the-importance-of-va-supply-chain-modernization']

    # load and split
    loaders = UnstructuredURLLoader(urls=urls)
    data = loaders.load()
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    texts = text_splitter.split_documents(data) # do not want to reassign (must stay consistent with pinecone embeddings)
    print("Sources split into the following number of \"texts\":", len(texts))

    # load embedding model
    print("Loading embedding model...")
    embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-xl")

    docsearch = Chroma.from_texts([t.page_content for t in texts], embeddings)

    return docsearch


@st.cache_resource
def getChainOnce():
    print("Getting chain...")

    content_handler = ContentHandler()

    # get langchain-wrapped Predictor Endpoint Object
    llm = SagemakerEndpoint(
        endpoint_name=endpoint_name,
        region_name=aws_region,
        content_handler=content_handler
    )

    # configure the prompt object
    prompt = PromptTemplate(
            input_variables=["history", "context", "new_message"],
            template="""
The following is a given context and history of a conversation between a human and a helpful AI assistant. 
The AI assistant answers the human's messages with full, informational responses by inferring on the information given in the context, but does not extrapolate any facts that are not given.
When answering questions with uncertain answers, the AI assistant qualifies responses as such.
The AI assistant pretends it is also human and that the context is objectively true and therefore avoids mentioning the context as a source.

Context: {context}

{history}
Human: {new_message}
AI:"""
        )
    return load_qa_chain(llm=llm, chain_type="stuff", verbose=True, memory=ConversationBufferMemory(memory_key="history", input_key="new_message"), prompt=prompt)


def getAIresponse(chain, docsearch, query):
    print("Getting AI response... @ ", datetime.datetime.now().strftime("%H:%M:%S"))
    return chain({"input_documents": docsearch.similarity_search(query, k=3), "new_message": query}, return_only_outputs=True)['output_text'].strip()


st.title("DSS Prototype LLM 💬")

# THREE VARIABLES NEED TO BE PERSISTENT:

# 1. Conversational Chain
if "chain" not in st.session_state:
    st.session_state["chain"] = getChainOnce()

# 2. docsearch object
if "docsearch" not in st.session_state:
    st.session_state["docsearch"] = getDocsearchOnce()

# 3. messages for UI
if "messages" not in st.session_state:
    st.session_state["messages"] = []


# DRAW THE ACTUAL UI AND IMPLEMENT FUNCTIONALITY
# (some formatting is handled by streamlit chat)

# draw input box
with st.form("chat_input", clear_on_submit=True):
    a, b = st.columns([4, 1])
    user_input = a.text_input(
        label="Your message:",
        placeholder="What would you like to say?",
        label_visibility="collapsed",
    )
    b.form_submit_button("Send", use_container_width=True)

# handle input
if user_input:
    st.session_state.messages.append({"role":"user", "content": user_input})
    respText = getAIresponse(st.session_state.chain, st.session_state.docsearch, user_input)
    print(respText)
    st.session_state.messages.append({"role":"assistant", "content": respText})

# draw messages
for k, msg in enumerate(st.session_state.messages):
    message(msg["content"], is_user=msg["role"] == "user", key=k)