import sagemaker import streamlit as st from streamlit_chat import message from langchain.llms.sagemaker_endpoint import LLMContentHandler, SagemakerEndpoint from langchain.document_loaders import UnstructuredURLLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import Chroma from langchain.embeddings import HuggingFaceInstructEmbeddings from langchain import PromptTemplate from langchain.chains.question_answering import load_qa_chain from langchain.memory import ConversationBufferMemory from typing import Dict import json import chromadb import datetime endpoint_name = "falcon-40b-instruct-gates3" aws_region = "us-east-1" class ContentHandler(LLMContentHandler): content_type = "application/json" accepts = "application/json" len_prompt = 0 def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: self.len_prompt = len(prompt) input_str = json.dumps( {"inputs": prompt, "parameters": { "do_sample": True, "top_p": 0.9, "temperature": 0.8, "max_new_tokens": 1024, "repetition_penalty": 1.03, "stop": ["\n\n", "Human:", "<|endoftext|>", ""] }}) return input_str.encode('utf-8') def transform_output(self, output: bytes) -> str: response_json = output.read() res = json.loads(response_json) ans = res[0]['generated_text'][self.len_prompt:] ans = ans[:ans.rfind("Human")].strip() return ans @st.cache_resource def getDocsearchOnce(): print("Getting docsearch...") # define URL sources urls = [ 'https://www.dssinc.com/blog/2022/6/21/suicide-prevention-manager-enabling-the-veterans-affairs-to-achieve-high-reliability-in-suicide-risk-identification', 'https://www.dssinc.com/blog/2022/8/9/dss-inc-announces-appointment-of-brion-bailey-as-director-of-federal-business-development', 'https://www.dssinc.com/blog/2022/3/21/march-22-is-diabetes-alertness-day-a-helpful-reminder-to-monitor-and-prevent-diabetes', 'https://www.dssinc.com/blog/2023/5/24/supporting-the-vas-high-reliability-organization-journey-through-suicide-prevention', 'https://www.dssinc.com/blog/2022/12/19/dss-theradoc-helps-battle-super-bugs-for-better-veteran-health', 'https://www.dssinc.com/blog/2022/9/21/dss-inc-chosen-for-phase-two-of-mission-daybreak-vas-suicide-prevention-challenge', 'https://www.dssinc.com/blog/2022/9/19/crescenz-va-medical-center-cmcvamc-deploys-the-dss-iconic-data-patient-case-manager-pcm-solution', 'https://www.dssinc.com/blog/2022/5/9/federal-news-network-the-importance-of-va-supply-chain-modernization'] # load and split loaders = UnstructuredURLLoader(urls=urls) data = loaders.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) texts = text_splitter.split_documents(data) # do not want to reassign (must stay consistent with pinecone embeddings) print("Sources split into the following number of \"texts\":", len(texts)) # load embedding model print("Loading embedding model...") embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-xl") docsearch = Chroma.from_texts([t.page_content for t in texts], embeddings) return docsearch @st.cache_resource def getChainOnce(): print("Getting chain...") content_handler = ContentHandler() # get langchain-wrapped Predictor Endpoint Object llm = SagemakerEndpoint( endpoint_name=endpoint_name, region_name=aws_region, content_handler=content_handler ) # configure the prompt object prompt = PromptTemplate( input_variables=["history", "context", "new_message"], template=""" The following is a given context and history of a conversation between a human and a helpful AI assistant. The AI assistant answers the human's messages with full, informational responses by inferring on the information given in the context, but does not extrapolate any facts that are not given. When answering questions with uncertain answers, the AI assistant qualifies responses as such. The AI assistant pretends it is also human and that the context is objectively true and therefore avoids mentioning the context as a source. Context: {context} {history} Human: {new_message} AI:""" ) return load_qa_chain(llm=llm, chain_type="stuff", verbose=True, memory=ConversationBufferMemory(memory_key="history", input_key="new_message"), prompt=prompt) def getAIresponse(chain, docsearch, query): print("Getting AI response... @ ", datetime.datetime.now().strftime("%H:%M:%S")) return chain({"input_documents": docsearch.similarity_search(query, k=3), "new_message": query}, return_only_outputs=True)['output_text'].strip() st.title("DSS Prototype LLM 💬") # THREE VARIABLES NEED TO BE PERSISTENT: # 1. Conversational Chain if "chain" not in st.session_state: st.session_state["chain"] = getChainOnce() # 2. docsearch object if "docsearch" not in st.session_state: st.session_state["docsearch"] = getDocsearchOnce() # 3. messages for UI if "messages" not in st.session_state: st.session_state["messages"] = [] # DRAW THE ACTUAL UI AND IMPLEMENT FUNCTIONALITY # (some formatting is handled by streamlit chat) # draw input box with st.form("chat_input", clear_on_submit=True): a, b = st.columns([4, 1]) user_input = a.text_input( label="Your message:", placeholder="What would you like to say?", label_visibility="collapsed", ) b.form_submit_button("Send", use_container_width=True) # handle input if user_input: st.session_state.messages.append({"role":"user", "content": user_input}) respText = getAIresponse(st.session_state.chain, st.session_state.docsearch, user_input) print(respText) st.session_state.messages.append({"role":"assistant", "content": respText}) # draw messages for k, msg in enumerate(st.session_state.messages): message(msg["content"], is_user=msg["role"] == "user", key=k)