{ "cells": [ { "cell_type": "markdown", "id": "3ecb7e4b-c220-438f-9afd-5d959f3235de", "metadata": { "tags": [] }, "source": [ "# Install packages, need to write a requirement later\n", "!pip install instructorembedding sentence-transformers gradio langchain unstructured chromadb pdf2image pdfminer pdfminer.six" ] }, { "cell_type": "code", "execution_count": 1, "id": "4d086ed6-eb66-4ded-b701-dac062e19521", "metadata": { "tags": [] }, "outputs": [], "source": [ "import boto3\n", "import sagemaker\n", "from sagemaker.predictor import Predictor\n", "from sagemaker.serializers import JSONSerializer\n", "from sagemaker.deserializers import JSONDeserializer\n", "from langchain.embeddings import HuggingFaceInstructEmbeddings\n", "from langchain.document_loaders import UnstructuredURLLoader, UnstructuredPDFLoader\n", "from langchain.document_loaders.csv_loader import CSVLoader\n", "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", "from langchain.vectorstores import Chroma\n", "import json\n", "import gradio as gr" ] }, { "cell_type": "code", "execution_count": 2, "id": "a2d81908-6026-47b6-bd0d-ade2771eacdd", "metadata": { "tags": [] }, "outputs": [], "source": [ "def loadCleanDocsearch(embeddings):\n", " print(\"Getting fresh docsearch...\")\n", "\n", " # define URL sources\n", " urls = [\n", " 'https://www.dssinc.com/blog/2022/8/9/dss-inc-announces-appointment-of-brion-bailey-as-director-of-federal-business-development',\n", " 'https://www.dssinc.com/blog/2022/3/21/march-22-is-diabetes-alertness-day-a-helpful-reminder-to-monitor-and-prevent-diabetes',\n", " 'https://www.dssinc.com/blog/2022/12/19/dss-theradoc-helps-battle-super-bugs-for-better-veteran-health',\n", " 'https://www.dssinc.com/blog/2022/5/9/federal-news-network-the-importance-of-va-supply-chain-modernization'\n", " ]\n", "\n", " # load and split\n", " loaders = UnstructuredURLLoader(urls=urls)\n", " data = loaders.load()\n", " text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=50)\n", " texts = text_splitter.split_documents(data)\n", " print(\"Sources split into the following number of \\\"texts\\\":\", len(texts))\n", "\n", " # get object\n", " docsearch = Chroma.from_texts([t.page_content for t in texts],\n", " metadatas=[{\"src\": \"DSS\"} for t in texts],\n", " embedding=embeddings)\n", " print(\"Done getting fresh docsearch.\")\n", "\n", " return docsearch\n", "\n", "def resetDocsearch():\n", " global docsearch\n", "\n", " foreignIDs = docsearch.get(where= {\"src\":\"foreign\"})['ids']\n", "\n", " if foreignIDs != []:\n", " docsearch.delete(ids=foreignIDs)\n", " \n", " clearStuff()\n", " \n", "def addURLsource(url):\n", " print(\"Adding new source...\")\n", " \n", " global docsearch\n", "\n", " # load and split\n", " loaders = UnstructuredURLLoader(urls=[url])\n", " data = loaders.load()\n", " text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\n", " texts = text_splitter.split_documents(data)\n", " print(\"New source split into the following number of \\\"texts\\\":\", len(texts))\n", "\n", " # add new sources\n", " docsearch.add_texts([t.page_content for t in texts], metadatas=[{\"src\": \"foreign\"} for t in texts])\n", " \n", " # restart convo, as the old messages confuse the AI\n", " clearStuff()\n", "\n", " print(\"Done adding new source.\")\n", " \n", " return None, None\n", "\n", "def addCSVsource(url):\n", " print(\"Adding new source...\")\n", " \n", " global docsearch\n", "\n", " # load and split\n", " loaders = CSVLoader(urls=[url])\n", " data = loaders.load()\n", " text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\n", " texts = text_splitter.split_documents(data)\n", " print(\"New source split into the following number of \\\"texts\\\":\", len(texts))\n", "\n", " # add new sources\n", " docsearch.add_texts([t.page_content for t in texts], metadatas=[{\"src\": \"foreign\"} for t in texts])\n", " \n", " # restart convo, as the old messages confuse the AI\n", " clearStuff()\n", "\n", " print(\"Done adding new source.\")\n", " \n", " return None, None\n", "\n", "def addPDFsource(url):\n", " print(\"Adding new source...\")\n", "\n", " global docsearch\n", " \n", " # load and split\n", " loaders = UnstructuredPDFLoader(url)\n", " data = loaders.load()\n", " text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\n", " texts = text_splitter.split_documents(data)\n", " print(\"New source split into the following number of \\\"texts\\\":\", len(texts))\n", "\n", " # add new sources\n", " docsearch.add_texts([t.page_content for t in texts], metadatas=[{\"src\": \"foreign\"} for t in texts])\n", " \n", " # restart convo, as the old messages confuse the AI\n", " clearStuff()\n", "\n", " print(\"Done adding new source.\")\n", " \n", " return None, None\n", "\n", "def msgs2chatbot(msgs):\n", " # the gradio chatbot object is used to display the conversation\n", " # it needs the msgs to be in List[List] format where the inner list is 2 elements: user message, chatbot response message\n", " chatbot = []\n", " \n", " for msg in msgs:\n", " if msg['role'] == 'user':\n", " chatbot.append([msg['content'], \"\"])\n", " elif msg['role'] == 'assistant':\n", " chatbot[-1][1] = msg['content']\n", "\n", " return chatbot\n", "\n", "def getPrediction(newMsg):\n", " global msgs\n", " global docsearch\n", " global predictor\n", " \n", " # add new message to msgs object\n", " msgs.append({\"role\":\"user\", \"content\": newMsg})\n", "\n", " # edit system message to include the correct context\n", " msgs[0] = {\"role\": \"system\",\n", " \"content\": f\"\"\"\n", " You are a helpful AI assistant.\n", " Use your knowledge to answer the user's question if they asked a question.\n", " If the answer to a question is not in your knowledge, just admit you do not know the answer and do not fabricate information.\n", " DO NOT use phrases like \"Based on the information provided\" or other similar phrases. \n", " Refer to the information provided below as \"your knowledge\". \n", " State all answers as if they are ground truth, DO NOT mention where you got the information.\n", " \n", " YOUR KNOWLEDGE: {\" \".join([tup[0].page_content for tup in docsearch.similarity_search_with_score(newMsg, k=5) if tup[1]<=.85])}\"\"\"}\n", "\n", " # get response from endpoint\n", "\n", " responseObject = predictor.predict({\"inputs\": [msgs],\n", " \"parameters\": {\"max_new_tokens\": 750, \"top_p\": 0.9, \"temperature\": 0.5}},\n", " initial_args={'CustomAttributes': \"accept_eula=true\"})\n", "# responseObject = predictor.predict(payload, custom_attributes=\"accept_eula=true\")\n", "\n", " \n", " responseMsg = responseObject[0]['generation']['content'].strip()\n", "\n", " # add response to msgs object\n", " msgs.append({\"role\":\"assistant\", \"content\": responseMsg})\n", " \n", " # print msgs object for debugging\n", " print(msgs)\n", " \n", " # convert msgs to chatbot object to be displayed\n", " chatbot = msgs2chatbot(msgs)\n", "\n", " return chatbot, \"\"\n", "\n", "def clearStuff():\n", " global msgs\n", " msgs = [{}]\n", " return None" ] }, { "cell_type": "code", "execution_count": 3, "id": "c56c588d-1bca-448f-bba9-96f61e5bab33", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "load INSTRUCTOR_Transformer\n", "max_seq_length 512\n", "Getting fresh docsearch...\n", "Sources split into the following number of \"texts\": 19\n", "Done getting fresh docsearch.\n" ] } ], "source": [ "# Create a SageMaker client\n", "sagemaker_client = boto3.client('sagemaker')\n", "sagemaker_session = sagemaker.Session()\n", "\n", "# Create a predictor object\n", "predictor = Predictor(endpoint_name='meta-textgeneration-llama-2-13b-f-2023-08-08-23-37-15-947',\n", " sagemaker_session=sagemaker_session,\n", " serializer=JSONSerializer(),\n", " deserializer=JSONDeserializer())\n", "\n", "embeddings = HuggingFaceInstructEmbeddings(model_name=\"hkunlp/instructor-xl\")\n", "\n", "# Create a docsearch object\n", "docsearch = loadCleanDocsearch(embeddings)\n", "\n", "# Create messages list with system message\n", "msgs = [{}]" ] }, { "cell_type": "code", "execution_count": 4, "id": "0572e3b3-2805-4db5-9d23-dac40842c58c", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7865\n", "Running on public URL: https://7a86829a731298a3ba.gradio.live\n", "\n", "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "