File size: 6,192 Bytes
4071f4f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import sagemaker
import streamlit as st
from streamlit_chat import message
from langchain.llms.sagemaker_endpoint import LLMContentHandler, SagemakerEndpoint
from langchain.document_loaders import UnstructuredURLLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain import PromptTemplate
from langchain.chains.question_answering import load_qa_chain
from langchain.memory import ConversationBufferMemory
from typing import Dict
import json
import chromadb
import datetime
endpoint_name = "falcon-40b-instruct-gates3"
aws_region = "us-east-1"
class ContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"
len_prompt = 0
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
self.len_prompt = len(prompt)
input_str = json.dumps(
{"inputs": prompt,
"parameters": {
"do_sample": True,
"top_p": 0.9,
"temperature": 0.8,
"max_new_tokens": 1024,
"repetition_penalty": 1.03,
"stop": ["\n\n", "Human:", "<|endoftext|>", "</s>"]
}})
return input_str.encode('utf-8')
def transform_output(self, output: bytes) -> str:
response_json = output.read()
res = json.loads(response_json)
ans = res[0]['generated_text'][self.len_prompt:]
ans = ans[:ans.rfind("Human")].strip()
return ans
@st.cache_resource
def getDocsearchOnce():
print("Getting docsearch...")
# define URL sources
urls = [
'https://www.dssinc.com/blog/2022/6/21/suicide-prevention-manager-enabling-the-veterans-affairs-to-achieve-high-reliability-in-suicide-risk-identification',
'https://www.dssinc.com/blog/2022/8/9/dss-inc-announces-appointment-of-brion-bailey-as-director-of-federal-business-development',
'https://www.dssinc.com/blog/2022/3/21/march-22-is-diabetes-alertness-day-a-helpful-reminder-to-monitor-and-prevent-diabetes',
'https://www.dssinc.com/blog/2023/5/24/supporting-the-vas-high-reliability-organization-journey-through-suicide-prevention',
'https://www.dssinc.com/blog/2022/12/19/dss-theradoc-helps-battle-super-bugs-for-better-veteran-health',
'https://www.dssinc.com/blog/2022/9/21/dss-inc-chosen-for-phase-two-of-mission-daybreak-vas-suicide-prevention-challenge',
'https://www.dssinc.com/blog/2022/9/19/crescenz-va-medical-center-cmcvamc-deploys-the-dss-iconic-data-patient-case-manager-pcm-solution',
'https://www.dssinc.com/blog/2022/5/9/federal-news-network-the-importance-of-va-supply-chain-modernization']
# load and split
loaders = UnstructuredURLLoader(urls=urls)
data = loaders.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = text_splitter.split_documents(data) # do not want to reassign (must stay consistent with pinecone embeddings)
print("Sources split into the following number of \"texts\":", len(texts))
# load embedding model
print("Loading embedding model...")
embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-xl")
docsearch = Chroma.from_texts([t.page_content for t in texts], embeddings)
return docsearch
@st.cache_resource
def getChainOnce():
print("Getting chain...")
content_handler = ContentHandler()
# get langchain-wrapped Predictor Endpoint Object
llm = SagemakerEndpoint(
endpoint_name=endpoint_name,
region_name=aws_region,
content_handler=content_handler
)
# configure the prompt object
prompt = PromptTemplate(
input_variables=["history", "context", "new_message"],
template="""
The following is a given context and history of a conversation between a human and a helpful AI assistant.
The AI assistant answers the human's messages with full, informational responses by inferring on the information given in the context, but does not extrapolate any facts that are not given.
When answering questions with uncertain answers, the AI assistant qualifies responses as such.
The AI assistant pretends it is also human and that the context is objectively true and therefore avoids mentioning the context as a source.
Context: {context}
{history}
Human: {new_message}
AI:"""
)
return load_qa_chain(llm=llm, chain_type="stuff", verbose=True, memory=ConversationBufferMemory(memory_key="history", input_key="new_message"), prompt=prompt)
def getAIresponse(chain, docsearch, query):
print("Getting AI response... @ ", datetime.datetime.now().strftime("%H:%M:%S"))
return chain({"input_documents": docsearch.similarity_search(query, k=3), "new_message": query}, return_only_outputs=True)['output_text'].strip()
st.title("DSS Prototype LLM 💬")
# THREE VARIABLES NEED TO BE PERSISTENT:
# 1. Conversational Chain
if "chain" not in st.session_state:
st.session_state["chain"] = getChainOnce()
# 2. docsearch object
if "docsearch" not in st.session_state:
st.session_state["docsearch"] = getDocsearchOnce()
# 3. messages for UI
if "messages" not in st.session_state:
st.session_state["messages"] = []
# DRAW THE ACTUAL UI AND IMPLEMENT FUNCTIONALITY
# (some formatting is handled by streamlit chat)
# draw input box
with st.form("chat_input", clear_on_submit=True):
a, b = st.columns([4, 1])
user_input = a.text_input(
label="Your message:",
placeholder="What would you like to say?",
label_visibility="collapsed",
)
b.form_submit_button("Send", use_container_width=True)
# handle input
if user_input:
st.session_state.messages.append({"role":"user", "content": user_input})
respText = getAIresponse(st.session_state.chain, st.session_state.docsearch, user_input)
print(respText)
st.session_state.messages.append({"role":"assistant", "content": respText})
# draw messages
for k, msg in enumerate(st.session_state.messages):
message(msg["content"], is_user=msg["role"] == "user", key=k)
|