File size: 14,943 Bytes
4071f4f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 |
{
"cells": [
{
"cell_type": "markdown",
"id": "3ecb7e4b-c220-438f-9afd-5d959f3235de",
"metadata": {
"tags": []
},
"source": [
"# Install packages, need to write a requirement later\n",
"!pip install instructorembedding sentence-transformers gradio langchain unstructured chromadb pdf2image pdfminer pdfminer.six"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "75cbef7c-974e-4438-addc-9c0b70be4d71",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import boto3\n",
"import sagemaker\n",
"from sagemaker.predictor import Predictor\n",
"from sagemaker.serializers import JSONSerializer\n",
"from sagemaker.deserializers import JSONDeserializer\n",
"from langchain.embeddings import HuggingFaceInstructEmbeddings\n",
"from langchain.document_loaders import UnstructuredURLLoader, UnstructuredPDFLoader, S3FileLoader\n",
"from langchain.docstore.document import Document\n",
"from langchain.document_loaders.csv_loader import CSVLoader\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"from langchain.vectorstores import Chroma\n",
"import json\n",
"import gradio as gr\n",
"\n",
"def loadCleanDocsearch(embeddings):\n",
" print(\"Getting fresh docsearch...\")\n",
"\n",
" # define URL sources with some stock articles from public DSS website\n",
" urls = [\n",
" 'https://www.dssinc.com/blog/2022/8/9/dss-inc-announces-appointment-of-brion-bailey-as-director-of-federal-business-development',\n",
" 'https://www.dssinc.com/blog/2022/3/21/march-22-is-diabetes-alertness-day-a-helpful-reminder-to-monitor-and-prevent-diabetes',\n",
" 'https://www.dssinc.com/blog/2022/12/19/dss-theradoc-helps-battle-super-bugs-for-better-veteran-health',\n",
" 'https://www.dssinc.com/blog/2022/5/9/federal-news-network-the-importance-of-va-supply-chain-modernization'\n",
" ]\n",
"\n",
" # load and split\n",
" loaders = UnstructuredURLLoader(urls=urls)\n",
" data = loaders.load()\n",
" text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=50)\n",
" texts = text_splitter.split_documents(data)\n",
" print(\"Sources split into the following number of \\\"texts\\\":\", len(texts))\n",
"\n",
" # get object\n",
" docsearch = Chroma.from_texts([t.page_content for t in texts],\n",
" metadatas=[{\"src\": \"DSS\"} for t in texts],\n",
" embedding=embeddings)\n",
" print(\"Done getting fresh docsearch.\")\n",
"\n",
" return docsearch\n",
"\n",
"def resetDocsearch():\n",
" global docsearch\n",
"\n",
" foreignIDs = docsearch.get(where= {\"src\":\"foreign\"})['ids']\n",
"\n",
" if foreignIDs != []:\n",
" docsearch.delete(ids=foreignIDs)\n",
" \n",
" clearStuff()\n",
"\n",
"\n",
"def addURLsource(url):\n",
" print(\"Adding new source...\")\n",
" \n",
" global docsearch\n",
"\n",
" # load and split\n",
" loaders = UnstructuredURLLoader(urls=[url])\n",
" data = loaders.load()\n",
" text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\n",
" texts = text_splitter.split_documents(data)\n",
" print(\"New source split into the following number of \\\"texts\\\":\", len(texts))\n",
"\n",
" # add new sources\n",
" docsearch.add_texts([t.page_content for t in texts], metadatas=[{\"src\": \"foreign\"} for t in texts])\n",
" \n",
" # restart convo, as the old messages confuse the AI\n",
" clearStuff()\n",
"\n",
" print(\"Done adding new source.\")\n",
" \n",
" return None, None\n",
"\n",
"# def addCSVsource(url):\n",
"# print(\"Adding new source...\")\n",
" \n",
"# global docsearch\n",
"\n",
"# # load and split\n",
"# loaders = CSVLoader(urls=[url])\n",
"# data = loaders.load()\n",
"# text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\n",
"# texts = text_splitter.split_documents(data)\n",
"# print(\"New source split into the following number of \\\"texts\\\":\", len(texts))\n",
"\n",
"# # add new sources\n",
"# docsearch.add_texts([t.page_content for t in texts], metadatas=[{\"src\": \"foreign\"} for t in texts])\n",
" \n",
"# # restart convo, as the old messages confuse the AI\n",
"# clearStuff()\n",
"\n",
"# print(\"Done adding new source.\")\n",
" \n",
"# return None, None\n",
"\n",
"def addPDFsource(url):\n",
" print(\"Adding new source...\")\n",
"\n",
" global docsearch\n",
" \n",
" # load and split\n",
" try: # assuming it is local\n",
" data = UnstructuredPDFLoader(url).load()\n",
" except: # not local, try S3\n",
" if '://' in url:\n",
" scheme, path = url.split('://', 1)\n",
" bucket, key = path.split('/', 1)\n",
"\n",
" else:\n",
" raise ValueError('Invalid S3 URI')\n",
" \n",
" data = S3FileLoader(\"strategicinnovation\", \"testingPDFload/bitcoin.pdf\").load()\n",
" \n",
" \n",
" text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\n",
" texts = text_splitter.split_documents(data)\n",
" print(\"New source split into the following number of \\\"texts\\\":\", len(texts))\n",
"\n",
" # add new sources\n",
" docsearch.add_texts([t.page_content for t in texts], metadatas=[{\"src\": \"foreign\"} for t in texts])\n",
" \n",
" # restart convo, as the old messages confuse the AI\n",
" clearStuff()\n",
"\n",
" print(\"Done adding new source.\")\n",
" \n",
" return None, None\n",
"\n",
"def msgs2chatbot(msgs):\n",
" # the gradio chatbot object is used to display the conversation\n",
" # it needs the msgs to be in List[List] format where the inner list is 2 elements: user message, chatbot response message\n",
" chatbot = []\n",
" \n",
" for msg in msgs:\n",
" if msg['role'] == 'user':\n",
" chatbot.append([msg['content'], \"\"])\n",
" elif msg['role'] == 'assistant':\n",
" chatbot[-1][1] = msg['content']\n",
"\n",
" return chatbot\n",
"\n",
"def getPrediction(newMsg):\n",
" global msgs\n",
" global docsearch\n",
" global predictor\n",
" \n",
" # add new message to msgs object\n",
" msgs.append({\"role\":\"user\", \"content\": newMsg})\n",
"\n",
" # edit system message to include the correct context\n",
" msgs[0] = {\"role\": \"system\",\n",
" \"content\": f\"\"\"\n",
" You are a helpful AI assistant.\n",
" Use your knowledge to answer the user's question if they asked a question.\n",
" If the answer to a question is not in your knowledge, just admit you do not know the answer and do not fabricate information.\n",
" DO NOT use phrases like \"Based on the information provided\" or other similar phrases. \n",
" Refer to the information provided below as \"your knowledge\". \n",
" State all answers as if they are ground truth, DO NOT mention where you got the information.\n",
" \n",
" YOUR KNOWLEDGE: {\" \".join([tup[0].page_content for tup in docsearch.similarity_search_with_score(newMsg, k=5) if tup[1]<=.85])}\"\"\"}\n",
"\n",
" # get response from endpoint\n",
"\n",
" responseObject = predictor.predict({\"inputs\": [msgs],\n",
" \"parameters\": {\"max_new_tokens\": 750, \"top_p\": 0.9, \"temperature\": 0.5}},\n",
" initial_args={'CustomAttributes': \"accept_eula=true\"})\n",
"# responseObject = predictor.predict(payload, custom_attributes=\"accept_eula=true\")\n",
"\n",
" \n",
" responseMsg = responseObject[0]['generation']['content'].strip()\n",
"\n",
" # add response to msgs object\n",
" msgs.append({\"role\":\"assistant\", \"content\": responseMsg})\n",
" \n",
" # print msgs object for debugging\n",
" print(msgs)\n",
" \n",
" # convert msgs to chatbot object to be displayed\n",
" chatbot = msgs2chatbot(msgs)\n",
"\n",
" return chatbot, \"\"\n",
"\n",
"def clearStuff():\n",
" global msgs\n",
" msgs = [{}]\n",
" return None\n",
"\n",
"# Create a SageMaker client\n",
"sagemaker_client = boto3.client('sagemaker')\n",
"sagemaker_session = sagemaker.Session()\n",
"\n",
"# Create a predictor object\n",
"predictor = Predictor(endpoint_name='meta-textgeneration-llama-2-13b-f-2023-08-08-23-37-15-947',\n",
" sagemaker_session=sagemaker_session,\n",
" serializer=JSONSerializer(),\n",
" deserializer=JSONDeserializer())\n",
"\n",
"embeddings = HuggingFaceInstructEmbeddings(model_name=\"hkunlp/instructor-xl\")\n",
"\n",
"# Create a docsearch object\n",
"docsearch = loadCleanDocsearch(embeddings)\n",
"\n",
"# Create messages list with system message\n",
"msgs = [{}]\n",
"\n",
"with gr.Blocks() as demo:\n",
" gr.HTML(\"<img src='https://images.squarespace-cdn.com/content/v1/5bab98d9f4e53108da59ae49/1537972707182-B5VGFGO3IDMB6HHSJY9H/dss_sp_logo.png?format=1500w' />\")\n",
" gr.Markdown(\"## DSS LLM Demo: Chat with Llama 2\")\n",
"\n",
" with gr.Column():\n",
" chatbot = gr.Chatbot()\n",
" \n",
" with gr.Row():\n",
" with gr.Column():\n",
" newMsg = gr.Textbox(label=\"New Message Box\", placeholder=\"New Message\", show_label=False)\n",
" with gr.Column():\n",
" with gr.Row():\n",
" submit = gr.Button(\"Submit\")\n",
" clear = gr.Button(\"Clear\")\n",
" with gr.Row():\n",
" with gr.Column():\n",
" newSRC = gr.Textbox(label=\"New source link/path Box\", placeholder=\"New source link/path\", show_label=False)\n",
" with gr.Column():\n",
" with gr.Row():\n",
" addURL = gr.Button(\"Add URL Source\")\n",
" addPDF = gr.Button(\"Add PDF Source\")\n",
" #uploadFile = gr.UploadButton(file_types=[\".pdf\",\".csv\",\".doc\"])\n",
" reset = gr.Button(\"Reset Sources\")\n",
"\n",
" submit.click(getPrediction, [newMsg], [chatbot, newMsg])\n",
" clear.click(clearStuff, None, chatbot, queue=False)\n",
" \n",
" addURL.click(addURLsource, newSRC, [newSRC, chatbot])\n",
" addPDF.click(addPDFsource, newSRC, [newSRC, chatbot])\n",
" #uploadFile.click(getOut, uploadFile, None)\n",
" reset.click(resetDocsearch, None, chatbot)\n",
"\n",
" gr.Markdown(\"\"\"*Note:*\n",
" \n",
" To add a URL source, place a full hyperlink in the bottom textbox and click the 'Add URL Source' button.\n",
" \n",
" To add a PDF source, place either (1) the relative filepath to the current directory or (2) the full S3 URI in the bottom textbox and click the 'Add PDF Source' button.\n",
" \n",
" The database for contextualization includes 8 public DSS website articles upon initialization.\n",
" \n",
" When the 'Reset Sources' button is clicked, the database is completely wiped. (Some knowledge may be preserved through the conversation history if left uncleared.)\"\"\")\n",
"\n",
"\n",
"demo.queue()\n",
"demo.launch(share=True)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e200839d-9f90-4651-8212-decc75d1e3e3",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"print() # execute this after gradio cell to make this cell the std_out for console logging"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "027e17a8-a2cf-4dcf-9edc-22d9900ff7b8",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"19"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[{'role': 'system', 'content': '\\n You are a helpful AI assistant.\\n Use your knowledge to answer the user\\'s question if they asked a question.\\n If the answer to a question is not in your knowledge, just admit you do not know the answer and do not fabricate information.\\n DO NOT use phrases like \"Based on the information provided\" or other similar phrases. \\n Refer to the information provided below as \"your knowledge\". \\n State all answers as if they are ground truth, DO NOT mention where you got the information.\\n \\n YOUR KNOWLEDGE: '}, {'role': 'user', 'content': 'hi, who is Brion bailey?'}, {'role': 'assistant', 'content': 'Based on your knowledge, Brion Bailey is the Director of Federal Business Development at DSS, Inc. He has over 25 years of sales and business development experience, specifically in the healthcare market, and has held leadership positions at various companies, including Chief Commercial Officer at Syft Corporation. He has a Master of Science in Marketing from St. Thomas University and a Bachelor of Business Administration from Florida International University.'}, {'role': 'user', 'content': 'hi'}, {'role': 'assistant', 'content': 'Hello! How can I assist you today?'}]\n"
]
}
],
"source": [
"len(docsearch.get()['ids'])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "conda_pytorch_p310",
"language": "python",
"name": "conda_pytorch_p310"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|