llm / main.py
gjohnsdss's picture
Upload folder using huggingface_hub
4071f4f
raw
history blame contribute delete
No virus
6.19 kB
import sagemaker
import streamlit as st
from streamlit_chat import message
from langchain.llms.sagemaker_endpoint import LLMContentHandler, SagemakerEndpoint
from langchain.document_loaders import UnstructuredURLLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain import PromptTemplate
from langchain.chains.question_answering import load_qa_chain
from langchain.memory import ConversationBufferMemory
from typing import Dict
import json
import chromadb
import datetime
endpoint_name = "falcon-40b-instruct-gates3"
aws_region = "us-east-1"
class ContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"
len_prompt = 0
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
self.len_prompt = len(prompt)
input_str = json.dumps(
{"inputs": prompt,
"parameters": {
"do_sample": True,
"top_p": 0.9,
"temperature": 0.8,
"max_new_tokens": 1024,
"repetition_penalty": 1.03,
"stop": ["\n\n", "Human:", "<|endoftext|>", "</s>"]
}})
return input_str.encode('utf-8')
def transform_output(self, output: bytes) -> str:
response_json = output.read()
res = json.loads(response_json)
ans = res[0]['generated_text'][self.len_prompt:]
ans = ans[:ans.rfind("Human")].strip()
return ans
@st.cache_resource
def getDocsearchOnce():
print("Getting docsearch...")
# define URL sources
urls = [
'https://www.dssinc.com/blog/2022/6/21/suicide-prevention-manager-enabling-the-veterans-affairs-to-achieve-high-reliability-in-suicide-risk-identification',
'https://www.dssinc.com/blog/2022/8/9/dss-inc-announces-appointment-of-brion-bailey-as-director-of-federal-business-development',
'https://www.dssinc.com/blog/2022/3/21/march-22-is-diabetes-alertness-day-a-helpful-reminder-to-monitor-and-prevent-diabetes',
'https://www.dssinc.com/blog/2023/5/24/supporting-the-vas-high-reliability-organization-journey-through-suicide-prevention',
'https://www.dssinc.com/blog/2022/12/19/dss-theradoc-helps-battle-super-bugs-for-better-veteran-health',
'https://www.dssinc.com/blog/2022/9/21/dss-inc-chosen-for-phase-two-of-mission-daybreak-vas-suicide-prevention-challenge',
'https://www.dssinc.com/blog/2022/9/19/crescenz-va-medical-center-cmcvamc-deploys-the-dss-iconic-data-patient-case-manager-pcm-solution',
'https://www.dssinc.com/blog/2022/5/9/federal-news-network-the-importance-of-va-supply-chain-modernization']
# load and split
loaders = UnstructuredURLLoader(urls=urls)
data = loaders.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = text_splitter.split_documents(data) # do not want to reassign (must stay consistent with pinecone embeddings)
print("Sources split into the following number of \"texts\":", len(texts))
# load embedding model
print("Loading embedding model...")
embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-xl")
docsearch = Chroma.from_texts([t.page_content for t in texts], embeddings)
return docsearch
@st.cache_resource
def getChainOnce():
print("Getting chain...")
content_handler = ContentHandler()
# get langchain-wrapped Predictor Endpoint Object
llm = SagemakerEndpoint(
endpoint_name=endpoint_name,
region_name=aws_region,
content_handler=content_handler
)
# configure the prompt object
prompt = PromptTemplate(
input_variables=["history", "context", "new_message"],
template="""
The following is a given context and history of a conversation between a human and a helpful AI assistant.
The AI assistant answers the human's messages with full, informational responses by inferring on the information given in the context, but does not extrapolate any facts that are not given.
When answering questions with uncertain answers, the AI assistant qualifies responses as such.
The AI assistant pretends it is also human and that the context is objectively true and therefore avoids mentioning the context as a source.
Context: {context}
{history}
Human: {new_message}
AI:"""
)
return load_qa_chain(llm=llm, chain_type="stuff", verbose=True, memory=ConversationBufferMemory(memory_key="history", input_key="new_message"), prompt=prompt)
def getAIresponse(chain, docsearch, query):
print("Getting AI response... @ ", datetime.datetime.now().strftime("%H:%M:%S"))
return chain({"input_documents": docsearch.similarity_search(query, k=3), "new_message": query}, return_only_outputs=True)['output_text'].strip()
st.title("DSS Prototype LLM 💬")
# THREE VARIABLES NEED TO BE PERSISTENT:
# 1. Conversational Chain
if "chain" not in st.session_state:
st.session_state["chain"] = getChainOnce()
# 2. docsearch object
if "docsearch" not in st.session_state:
st.session_state["docsearch"] = getDocsearchOnce()
# 3. messages for UI
if "messages" not in st.session_state:
st.session_state["messages"] = []
# DRAW THE ACTUAL UI AND IMPLEMENT FUNCTIONALITY
# (some formatting is handled by streamlit chat)
# draw input box
with st.form("chat_input", clear_on_submit=True):
a, b = st.columns([4, 1])
user_input = a.text_input(
label="Your message:",
placeholder="What would you like to say?",
label_visibility="collapsed",
)
b.form_submit_button("Send", use_container_width=True)
# handle input
if user_input:
st.session_state.messages.append({"role":"user", "content": user_input})
respText = getAIresponse(st.session_state.chain, st.session_state.docsearch, user_input)
print(respText)
st.session_state.messages.append({"role":"assistant", "content": respText})
# draw messages
for k, msg in enumerate(st.session_state.messages):
message(msg["content"], is_user=msg["role"] == "user", key=k)