import boto3 import sagemaker from sagemaker.predictor import Predictor from sagemaker.serializers import JSONSerializer from sagemaker.deserializers import JSONDeserializer from langchain.embeddings import HuggingFaceInstructEmbeddings from langchain.document_loaders import UnstructuredURLLoader, UnstructuredPDFLoader, S3FileLoader from langchain.docstore.document import Document from langchain.document_loaders.csv_loader import CSVLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import Chroma import json import gradio as gr def loadCleanDocsearch(embeddings): print("Getting fresh docsearch...") # define URL sources with some stock articles from public DSS website urls = [ 'https://www.dssinc.com/blog/2022/8/9/dss-inc-announces-appointment-of-brion-bailey-as-director-of-federal-business-development', 'https://www.dssinc.com/blog/2022/3/21/march-22-is-diabetes-alertness-day-a-helpful-reminder-to-monitor-and-prevent-diabetes', 'https://www.dssinc.com/blog/2022/12/19/dss-theradoc-helps-battle-super-bugs-for-better-veteran-health', 'https://www.dssinc.com/blog/2022/5/9/federal-news-network-the-importance-of-va-supply-chain-modernization' ] # load and split loaders = UnstructuredURLLoader(urls=urls) data = loaders.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=50) texts = text_splitter.split_documents(data) print("Sources split into the following number of \"texts\":", len(texts)) # get object docsearch = Chroma.from_texts([t.page_content for t in texts], metadatas=[{"src": "DSS"} for t in texts], embedding=embeddings) print("Done getting fresh docsearch.") return docsearch def resetDocsearch(): global docsearch foreignIDs = docsearch.get(where= {"src":"foreign"})['ids'] if foreignIDs != []: docsearch.delete(ids=foreignIDs) clearStuff() def addURLsource(url): print("Adding new source...") global docsearch # load and split loaders = UnstructuredURLLoader(urls=[url]) data = loaders.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) texts = text_splitter.split_documents(data) print("New source split into the following number of \"texts\":", len(texts)) # add new sources docsearch.add_texts([t.page_content for t in texts], metadatas=[{"src": "foreign"} for t in texts]) # restart convo, as the old messages confuse the AI clearStuff() print("Done adding new source.") return None, None # def addCSVsource(url): # print("Adding new source...") # global docsearch # # load and split # loaders = CSVLoader(urls=[url]) # data = loaders.load() # text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) # texts = text_splitter.split_documents(data) # print("New source split into the following number of \"texts\":", len(texts)) # # add new sources # docsearch.add_texts([t.page_content for t in texts], metadatas=[{"src": "foreign"} for t in texts]) # # restart convo, as the old messages confuse the AI # clearStuff() # print("Done adding new source.") # return None, None def addPDFsource(url): print("Adding new source...") global docsearch # load and split try: # assuming it is local data = UnstructuredPDFLoader(url).load() except: # not local, try S3 if '://' in url: scheme, path = url.split('://', 1) bucket, key = path.split('/', 1) else: raise ValueError('Invalid S3 URI') data = S3FileLoader("strategicinnovation", "testingPDFload/bitcoin.pdf").load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) texts = text_splitter.split_documents(data) print("New source split into the following number of \"texts\":", len(texts)) # add new sources docsearch.add_texts([t.page_content for t in texts], metadatas=[{"src": "foreign"} for t in texts]) # restart convo, as the old messages confuse the AI clearStuff() print("Done adding new source.") return None, None def msgs2chatbot(msgs): # the gradio chatbot object is used to display the conversation # it needs the msgs to be in List[List] format where the inner list is 2 elements: user message, chatbot response message chatbot = [] for msg in msgs: if msg['role'] == 'user': chatbot.append([msg['content'], ""]) elif msg['role'] == 'assistant': chatbot[-1][1] = msg['content'] return chatbot def getPrediction(newMsg): global msgs global docsearch global predictor # add new message to msgs object msgs.append({"role":"user", "content": newMsg}) # edit system message to include the correct context msgs[0] = {"role": "system", "content": f""" You are a helpful AI assistant. Use your knowledge to answer the user's question if they asked a question. If the answer to a question is not in your knowledge, just admit you do not know the answer and do not fabricate information. DO NOT use phrases like "Based on the information provided" or other similar phrases. Refer to the information provided below as "your knowledge". State all answers as if they are ground truth, DO NOT mention where you got the information. YOUR KNOWLEDGE: {" ".join([tup[0].page_content for tup in docsearch.similarity_search_with_score(newMsg, k=5) if tup[1]<=.85])}"""} # get response from endpoint responseObject = predictor.predict({"inputs": [msgs], "parameters": {"max_new_tokens": 750, "top_p": 0.9, "temperature": 0.5}}, initial_args={'CustomAttributes': "accept_eula=true"}) # responseObject = predictor.predict(payload, custom_attributes="accept_eula=true") responseMsg = responseObject[0]['generation']['content'].strip() # add response to msgs object msgs.append({"role":"assistant", "content": responseMsg}) # print msgs object for debugging print(msgs) # convert msgs to chatbot object to be displayed chatbot = msgs2chatbot(msgs) return chatbot, "" def clearStuff(): global msgs msgs = [{}] return None # Create a SageMaker client sagemaker_client = boto3.client('sagemaker') sagemaker_session = sagemaker.Session() # Create a predictor object predictor = Predictor(endpoint_name='meta-textgeneration-llama-2-13b-f-2023-08-08-23-37-15-947', sagemaker_session=sagemaker_session, serializer=JSONSerializer(), deserializer=JSONDeserializer()) embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-xl") # Create a docsearch object docsearch = loadCleanDocsearch(embeddings) # Create messages list with system message msgs = [{}] with gr.Blocks() as demo: gr.HTML("") gr.Markdown("## DSS LLM Demo: Chat with Llama 2") with gr.Column(): chatbot = gr.Chatbot() with gr.Row(): with gr.Column(): newMsg = gr.Textbox(label="New Message Box", placeholder="New Message", show_label=False) with gr.Column(): with gr.Row(): submit = gr.Button("Submit") clear = gr.Button("Clear") with gr.Row(): with gr.Column(): newSRC = gr.Textbox(label="New source link/path Box", placeholder="New source link/path", show_label=False) with gr.Column(): with gr.Row(): addURL = gr.Button("Add URL Source") addPDF = gr.Button("Add PDF Source") #uploadFile = gr.UploadButton(file_types=[".pdf",".csv",".doc"]) reset = gr.Button("Reset Sources") submit.click(getPrediction, [newMsg], [chatbot, newMsg]) clear.click(clearStuff, None, chatbot, queue=False) addURL.click(addURLsource, newSRC, [newSRC, chatbot]) addPDF.click(addPDFsource, newSRC, [newSRC, chatbot]) #uploadFile.click(getOut, uploadFile, None) reset.click(resetDocsearch, None, chatbot) gr.Markdown("""*Note:* To add a URL source, place a full hyperlink in the bottom textbox and click the 'Add URL Source' button. To add a PDF source, place either (1) the relative filepath to the current directory or (2) the full S3 URI in the bottom textbox and click the 'Add PDF Source' button. The database for contextualization includes 8 public DSS website articles upon initialization. When the 'Reset Sources' button is clicked, the database is completely wiped. (Some knowledge may be preserved through the conversation history if left uncleared.)""") demo.queue() demo.launch(share=True)