File size: 14,943 Bytes
4071f4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3ecb7e4b-c220-438f-9afd-5d959f3235de",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Install packages, need to write a requirement later\n",
    "!pip install instructorembedding sentence-transformers gradio langchain unstructured chromadb pdf2image pdfminer pdfminer.six"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "75cbef7c-974e-4438-addc-9c0b70be4d71",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import boto3\n",
    "import sagemaker\n",
    "from sagemaker.predictor import Predictor\n",
    "from sagemaker.serializers import JSONSerializer\n",
    "from sagemaker.deserializers import JSONDeserializer\n",
    "from langchain.embeddings import HuggingFaceInstructEmbeddings\n",
    "from langchain.document_loaders import UnstructuredURLLoader, UnstructuredPDFLoader, S3FileLoader\n",
    "from langchain.docstore.document import Document\n",
    "from langchain.document_loaders.csv_loader import CSVLoader\n",
    "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
    "from langchain.vectorstores import Chroma\n",
    "import json\n",
    "import gradio as gr\n",
    "\n",
    "def loadCleanDocsearch(embeddings):\n",
    "    print(\"Getting fresh docsearch...\")\n",
    "\n",
    "    # define URL sources with some stock articles from public DSS website\n",
    "    urls = [\n",
    "        'https://www.dssinc.com/blog/2022/8/9/dss-inc-announces-appointment-of-brion-bailey-as-director-of-federal-business-development',\n",
    "        'https://www.dssinc.com/blog/2022/3/21/march-22-is-diabetes-alertness-day-a-helpful-reminder-to-monitor-and-prevent-diabetes',\n",
    "        'https://www.dssinc.com/blog/2022/12/19/dss-theradoc-helps-battle-super-bugs-for-better-veteran-health',\n",
    "        'https://www.dssinc.com/blog/2022/5/9/federal-news-network-the-importance-of-va-supply-chain-modernization'\n",
    "        ]\n",
    "\n",
    "    # load and split\n",
    "    loaders = UnstructuredURLLoader(urls=urls)\n",
    "    data = loaders.load()\n",
    "    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=50)\n",
    "    texts = text_splitter.split_documents(data)\n",
    "    print(\"Sources split into the following number of \\\"texts\\\":\", len(texts))\n",
    "\n",
    "    # get object\n",
    "    docsearch = Chroma.from_texts([t.page_content for t in texts],\n",
    "                                  metadatas=[{\"src\": \"DSS\"} for t in texts],\n",
    "                                  embedding=embeddings)\n",
    "    print(\"Done getting fresh docsearch.\")\n",
    "\n",
    "    return docsearch\n",
    "\n",
    "def resetDocsearch():\n",
    "    global docsearch\n",
    "\n",
    "    foreignIDs = docsearch.get(where= {\"src\":\"foreign\"})['ids']\n",
    "\n",
    "    if foreignIDs != []:\n",
    "        docsearch.delete(ids=foreignIDs)\n",
    "        \n",
    "    clearStuff()\n",
    "\n",
    "\n",
    "def addURLsource(url):\n",
    "    print(\"Adding new source...\")\n",
    "    \n",
    "    global docsearch\n",
    "\n",
    "    # load and split\n",
    "    loaders = UnstructuredURLLoader(urls=[url])\n",
    "    data = loaders.load()\n",
    "    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\n",
    "    texts = text_splitter.split_documents(data)\n",
    "    print(\"New source split into the following number of \\\"texts\\\":\", len(texts))\n",
    "\n",
    "    # add new sources\n",
    "    docsearch.add_texts([t.page_content for t in texts], metadatas=[{\"src\": \"foreign\"} for t in texts])\n",
    "    \n",
    "    # restart convo, as the old messages confuse the AI\n",
    "    clearStuff()\n",
    "\n",
    "    print(\"Done adding new source.\")\n",
    "    \n",
    "    return None, None\n",
    "\n",
    "# def addCSVsource(url):\n",
    "#     print(\"Adding new source...\")\n",
    "    \n",
    "#     global docsearch\n",
    "\n",
    "#     # load and split\n",
    "#     loaders = CSVLoader(urls=[url])\n",
    "#     data = loaders.load()\n",
    "#     text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\n",
    "#     texts = text_splitter.split_documents(data)\n",
    "#     print(\"New source split into the following number of \\\"texts\\\":\", len(texts))\n",
    "\n",
    "#     # add new sources\n",
    "#     docsearch.add_texts([t.page_content for t in texts], metadatas=[{\"src\": \"foreign\"} for t in texts])\n",
    "    \n",
    "#     # restart convo, as the old messages confuse the AI\n",
    "#     clearStuff()\n",
    "\n",
    "#     print(\"Done adding new source.\")\n",
    "    \n",
    "#     return None, None\n",
    "\n",
    "def addPDFsource(url):\n",
    "    print(\"Adding new source...\")\n",
    "\n",
    "    global docsearch\n",
    "    \n",
    "    # load and split\n",
    "    try: # assuming it is local\n",
    "        data = UnstructuredPDFLoader(url).load()\n",
    "    except: # not local, try S3\n",
    "        if '://' in url:\n",
    "            scheme, path = url.split('://', 1)\n",
    "            bucket, key = path.split('/', 1)\n",
    "\n",
    "        else:\n",
    "            raise ValueError('Invalid S3 URI')\n",
    "            \n",
    "        data = S3FileLoader(\"strategicinnovation\", \"testingPDFload/bitcoin.pdf\").load()\n",
    "        \n",
    "    \n",
    "    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\n",
    "    texts = text_splitter.split_documents(data)\n",
    "    print(\"New source split into the following number of \\\"texts\\\":\", len(texts))\n",
    "\n",
    "    # add new sources\n",
    "    docsearch.add_texts([t.page_content for t in texts], metadatas=[{\"src\": \"foreign\"} for t in texts])\n",
    "    \n",
    "    # restart convo, as the old messages confuse the AI\n",
    "    clearStuff()\n",
    "\n",
    "    print(\"Done adding new source.\")\n",
    "    \n",
    "    return None, None\n",
    "\n",
    "def msgs2chatbot(msgs):\n",
    "    # the gradio chatbot object is used to display the conversation\n",
    "    # it needs the msgs to be in List[List] format where the inner list is 2 elements: user message, chatbot response message\n",
    "    chatbot = []\n",
    "    \n",
    "    for msg in msgs:\n",
    "        if msg['role'] == 'user':\n",
    "            chatbot.append([msg['content'], \"\"])\n",
    "        elif msg['role'] == 'assistant':\n",
    "            chatbot[-1][1] = msg['content']\n",
    "\n",
    "    return chatbot\n",
    "\n",
    "def getPrediction(newMsg):\n",
    "    global msgs\n",
    "    global docsearch\n",
    "    global predictor\n",
    "    \n",
    "    # add new message to msgs object\n",
    "    msgs.append({\"role\":\"user\", \"content\": newMsg})\n",
    "\n",
    "    # edit system message to include the correct context\n",
    "    msgs[0] = {\"role\": \"system\",\n",
    "               \"content\": f\"\"\"\n",
    "               You are a helpful AI assistant.\n",
    "               Use your knowledge to answer the user's question if they asked a question.\n",
    "               If the answer to a question is not in your knowledge, just admit you do not know the answer and do not fabricate information.\n",
    "               DO NOT use phrases like \"Based on the information provided\" or other similar phrases. \n",
    "               Refer to the information provided below as \"your knowledge\". \n",
    "               State all answers as if they are ground truth, DO NOT mention where you got the information.\n",
    "               \n",
    "               YOUR KNOWLEDGE: {\" \".join([tup[0].page_content for tup in docsearch.similarity_search_with_score(newMsg, k=5) if tup[1]<=.85])}\"\"\"}\n",
    "\n",
    "    # get response from endpoint\n",
    "\n",
    "    responseObject = predictor.predict({\"inputs\": [msgs],\n",
    "                       \"parameters\": {\"max_new_tokens\": 750, \"top_p\": 0.9, \"temperature\": 0.5}},\n",
    "                        initial_args={'CustomAttributes': \"accept_eula=true\"})\n",
    "#    responseObject = predictor.predict(payload, custom_attributes=\"accept_eula=true\")\n",
    "\n",
    "    \n",
    "    responseMsg = responseObject[0]['generation']['content'].strip()\n",
    "\n",
    "    # add response to msgs object\n",
    "    msgs.append({\"role\":\"assistant\", \"content\": responseMsg})\n",
    "    \n",
    "    # print msgs object for debugging\n",
    "    print(msgs)\n",
    "    \n",
    "    # convert msgs to chatbot object to be displayed\n",
    "    chatbot = msgs2chatbot(msgs)\n",
    "\n",
    "    return chatbot, \"\"\n",
    "\n",
    "def clearStuff():\n",
    "    global msgs\n",
    "    msgs = [{}]\n",
    "    return None\n",
    "\n",
    "# Create a SageMaker client\n",
    "sagemaker_client = boto3.client('sagemaker')\n",
    "sagemaker_session = sagemaker.Session()\n",
    "\n",
    "# Create a predictor object\n",
    "predictor = Predictor(endpoint_name='meta-textgeneration-llama-2-13b-f-2023-08-08-23-37-15-947',\n",
    "                      sagemaker_session=sagemaker_session,\n",
    "                      serializer=JSONSerializer(),\n",
    "                      deserializer=JSONDeserializer())\n",
    "\n",
    "embeddings = HuggingFaceInstructEmbeddings(model_name=\"hkunlp/instructor-xl\")\n",
    "\n",
    "# Create a docsearch object\n",
    "docsearch = loadCleanDocsearch(embeddings)\n",
    "\n",
    "# Create messages list with system message\n",
    "msgs = [{}]\n",
    "\n",
    "with gr.Blocks() as demo:\n",
    "    gr.HTML(\"<img src='https://images.squarespace-cdn.com/content/v1/5bab98d9f4e53108da59ae49/1537972707182-B5VGFGO3IDMB6HHSJY9H/dss_sp_logo.png?format=1500w' />\")\n",
    "    gr.Markdown(\"## DSS LLM Demo: Chat with Llama 2\")\n",
    "\n",
    "    with gr.Column():\n",
    "        chatbot = gr.Chatbot()\n",
    "        \n",
    "        with gr.Row():\n",
    "            with gr.Column():\n",
    "                newMsg = gr.Textbox(label=\"New Message Box\", placeholder=\"New Message\", show_label=False)\n",
    "            with gr.Column():\n",
    "                with gr.Row():\n",
    "                    submit = gr.Button(\"Submit\")\n",
    "                    clear = gr.Button(\"Clear\")\n",
    "        with gr.Row():\n",
    "            with gr.Column():\n",
    "                newSRC = gr.Textbox(label=\"New source link/path Box\", placeholder=\"New source link/path\", show_label=False)\n",
    "            with gr.Column():\n",
    "                with gr.Row():\n",
    "                    addURL = gr.Button(\"Add URL Source\")\n",
    "                    addPDF = gr.Button(\"Add PDF Source\")\n",
    "                    #uploadFile = gr.UploadButton(file_types=[\".pdf\",\".csv\",\".doc\"])\n",
    "                    reset = gr.Button(\"Reset Sources\")\n",
    "\n",
    "    submit.click(getPrediction, [newMsg], [chatbot, newMsg])\n",
    "    clear.click(clearStuff, None, chatbot, queue=False)\n",
    "    \n",
    "    addURL.click(addURLsource, newSRC, [newSRC, chatbot])\n",
    "    addPDF.click(addPDFsource, newSRC, [newSRC, chatbot])\n",
    "    #uploadFile.click(getOut, uploadFile, None)\n",
    "    reset.click(resetDocsearch, None, chatbot)\n",
    "\n",
    "    gr.Markdown(\"\"\"*Note:*\n",
    "    \n",
    "    To add a URL source, place a full hyperlink in the bottom textbox and click the 'Add URL Source' button.\n",
    "    \n",
    "    To add a PDF source, place either (1) the relative filepath to the current directory or (2) the full S3 URI in the bottom textbox and click the 'Add PDF Source' button.\n",
    "    \n",
    "    The database for contextualization includes 8 public DSS website articles upon initialization.\n",
    "    \n",
    "    When the 'Reset Sources' button is clicked, the database is completely wiped. (Some knowledge may be preserved through the conversation history if left uncleared.)\"\"\")\n",
    "\n",
    "\n",
    "demo.queue()\n",
    "demo.launch(share=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e200839d-9f90-4651-8212-decc75d1e3e3",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "print() # execute this after gradio cell to make this cell the std_out for console logging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "027e17a8-a2cf-4dcf-9edc-22d9900ff7b8",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "19"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'role': 'system', 'content': '\\n               You are a helpful AI assistant.\\n               Use your knowledge to answer the user\\'s question if they asked a question.\\n               If the answer to a question is not in your knowledge, just admit you do not know the answer and do not fabricate information.\\n               DO NOT use phrases like \"Based on the information provided\" or other similar phrases. \\n               Refer to the information provided below as \"your knowledge\". \\n               State all answers as if they are ground truth, DO NOT mention where you got the information.\\n               \\n               YOUR KNOWLEDGE: '}, {'role': 'user', 'content': 'hi, who is Brion bailey?'}, {'role': 'assistant', 'content': 'Based on your knowledge, Brion Bailey is the Director of Federal Business Development at DSS, Inc. He has over 25 years of sales and business development experience, specifically in the healthcare market, and has held leadership positions at various companies, including Chief Commercial Officer at Syft Corporation. He has a Master of Science in Marketing from St. Thomas University and a Bachelor of Business Administration from Florida International University.'}, {'role': 'user', 'content': 'hi'}, {'role': 'assistant', 'content': 'Hello! How can I assist you today?'}]\n"
     ]
    }
   ],
   "source": [
    "len(docsearch.get()['ids'])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_pytorch_p310",
   "language": "python",
   "name": "conda_pytorch_p310"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}