|
import sagemaker |
|
import streamlit as st |
|
from streamlit_chat import message |
|
from langchain.llms.sagemaker_endpoint import LLMContentHandler, SagemakerEndpoint |
|
from langchain.document_loaders import UnstructuredURLLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.vectorstores import Chroma |
|
from langchain.embeddings import HuggingFaceInstructEmbeddings |
|
from langchain import PromptTemplate |
|
from langchain.chains.question_answering import load_qa_chain |
|
from langchain.memory import ConversationBufferMemory |
|
from typing import Dict |
|
import json |
|
import chromadb |
|
import datetime |
|
|
|
endpoint_name = "falcon-40b-instruct-gates3" |
|
aws_region = "us-east-1" |
|
class ContentHandler(LLMContentHandler): |
|
content_type = "application/json" |
|
accepts = "application/json" |
|
len_prompt = 0 |
|
|
|
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: |
|
self.len_prompt = len(prompt) |
|
input_str = json.dumps( |
|
{"inputs": prompt, |
|
"parameters": { |
|
"do_sample": True, |
|
"top_p": 0.9, |
|
"temperature": 0.8, |
|
"max_new_tokens": 1024, |
|
"repetition_penalty": 1.03, |
|
"stop": ["\n\n", "Human:", "<|endoftext|>", "</s>"] |
|
}}) |
|
return input_str.encode('utf-8') |
|
|
|
def transform_output(self, output: bytes) -> str: |
|
response_json = output.read() |
|
res = json.loads(response_json) |
|
ans = res[0]['generated_text'][self.len_prompt:] |
|
ans = ans[:ans.rfind("Human")].strip() |
|
return ans |
|
|
|
|
|
@st.cache_resource |
|
def getDocsearchOnce(): |
|
print("Getting docsearch...") |
|
|
|
|
|
urls = [ |
|
'https://www.dssinc.com/blog/2022/6/21/suicide-prevention-manager-enabling-the-veterans-affairs-to-achieve-high-reliability-in-suicide-risk-identification', |
|
'https://www.dssinc.com/blog/2022/8/9/dss-inc-announces-appointment-of-brion-bailey-as-director-of-federal-business-development', |
|
'https://www.dssinc.com/blog/2022/3/21/march-22-is-diabetes-alertness-day-a-helpful-reminder-to-monitor-and-prevent-diabetes', |
|
'https://www.dssinc.com/blog/2023/5/24/supporting-the-vas-high-reliability-organization-journey-through-suicide-prevention', |
|
'https://www.dssinc.com/blog/2022/12/19/dss-theradoc-helps-battle-super-bugs-for-better-veteran-health', |
|
'https://www.dssinc.com/blog/2022/9/21/dss-inc-chosen-for-phase-two-of-mission-daybreak-vas-suicide-prevention-challenge', |
|
'https://www.dssinc.com/blog/2022/9/19/crescenz-va-medical-center-cmcvamc-deploys-the-dss-iconic-data-patient-case-manager-pcm-solution', |
|
'https://www.dssinc.com/blog/2022/5/9/federal-news-network-the-importance-of-va-supply-chain-modernization'] |
|
|
|
|
|
loaders = UnstructuredURLLoader(urls=urls) |
|
data = loaders.load() |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) |
|
texts = text_splitter.split_documents(data) |
|
print("Sources split into the following number of \"texts\":", len(texts)) |
|
|
|
|
|
print("Loading embedding model...") |
|
embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-xl") |
|
|
|
docsearch = Chroma.from_texts([t.page_content for t in texts], embeddings) |
|
|
|
return docsearch |
|
|
|
|
|
@st.cache_resource |
|
def getChainOnce(): |
|
print("Getting chain...") |
|
|
|
content_handler = ContentHandler() |
|
|
|
|
|
llm = SagemakerEndpoint( |
|
endpoint_name=endpoint_name, |
|
region_name=aws_region, |
|
content_handler=content_handler |
|
) |
|
|
|
|
|
prompt = PromptTemplate( |
|
input_variables=["history", "context", "new_message"], |
|
template=""" |
|
The following is a given context and history of a conversation between a human and a helpful AI assistant. |
|
The AI assistant answers the human's messages with full, informational responses by inferring on the information given in the context, but does not extrapolate any facts that are not given. |
|
When answering questions with uncertain answers, the AI assistant qualifies responses as such. |
|
The AI assistant pretends it is also human and that the context is objectively true and therefore avoids mentioning the context as a source. |
|
|
|
Context: {context} |
|
|
|
{history} |
|
Human: {new_message} |
|
AI:""" |
|
) |
|
return load_qa_chain(llm=llm, chain_type="stuff", verbose=True, memory=ConversationBufferMemory(memory_key="history", input_key="new_message"), prompt=prompt) |
|
|
|
|
|
def getAIresponse(chain, docsearch, query): |
|
print("Getting AI response... @ ", datetime.datetime.now().strftime("%H:%M:%S")) |
|
return chain({"input_documents": docsearch.similarity_search(query, k=3), "new_message": query}, return_only_outputs=True)['output_text'].strip() |
|
|
|
|
|
st.title("DSS Prototype LLM 💬") |
|
|
|
|
|
|
|
|
|
if "chain" not in st.session_state: |
|
st.session_state["chain"] = getChainOnce() |
|
|
|
|
|
if "docsearch" not in st.session_state: |
|
st.session_state["docsearch"] = getDocsearchOnce() |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state["messages"] = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
with st.form("chat_input", clear_on_submit=True): |
|
a, b = st.columns([4, 1]) |
|
user_input = a.text_input( |
|
label="Your message:", |
|
placeholder="What would you like to say?", |
|
label_visibility="collapsed", |
|
) |
|
b.form_submit_button("Send", use_container_width=True) |
|
|
|
|
|
if user_input: |
|
st.session_state.messages.append({"role":"user", "content": user_input}) |
|
respText = getAIresponse(st.session_state.chain, st.session_state.docsearch, user_input) |
|
print(respText) |
|
st.session_state.messages.append({"role":"assistant", "content": respText}) |
|
|
|
|
|
for k, msg in enumerate(st.session_state.messages): |
|
message(msg["content"], is_user=msg["role"] == "user", key=k) |
|
|
|
|
|
|
|
|
|
|
|
|