lakshmivairamani commited on
Commit
b7f5bd0
1 Parent(s): e4950ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -53
app.py CHANGED
@@ -9,9 +9,26 @@ from PIL import Image
9
  import base64
10
  from io import BytesIO
11
  import os
 
 
 
12
  import requests
13
  import gradio as gr
14
- #import nltk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  from langchain_core.prompts import ChatPromptTemplate
16
  from langchain_core.output_parsers import StrOutputParser
17
  from langchain_core.runnables import RunnableSequence, RunnableLambda
@@ -27,12 +44,23 @@ from PyPDF2 import PdfReader
27
  from nltk.tokenize import sent_tokenize
28
  from sqlalchemy import create_engine
29
  from sqlalchemy.sql import text
30
- import json
31
- import nltk
32
- nltk.download('punkt')
 
 
 
 
33
 
34
- open_api_key_token = os.environ['OPEN_AI_API']
35
 
 
 
 
 
 
 
 
 
36
  os.environ['OPENAI_API_KEY'] = open_api_key_token
37
  db_uri = 'postgresql+psycopg2://postgres:postpass@193.203.162.39:5432/warehouseAi'
38
  # Database setup
@@ -40,8 +68,8 @@ db_uri = 'postgresql+psycopg2://postgres:postpass@193.203.162.39:5432/warehouseA
40
  db = SQLDatabase.from_uri(db_uri)
41
 
42
  # LLM setup
43
- #llm = ChatOpenAI(model="gpt-3.5-turbo-0125",max_tokens=150,temperature=0.1)
44
- llm = ChatOpenAI(model="gpt-4o-mini",max_tokens=250,temperature=0.1)
45
 
46
  def get_schema(_):
47
  schema_info = db.get_table_info() # This should be a string of your SQL schema
@@ -69,7 +97,7 @@ def generate_sql_query(question):
69
  def run_query(query):
70
  # Clean the query by removing markdown symbols and trimming whitespace
71
  clean_query = query.replace("```sql", "").replace("```", "").strip()
72
- #print(f"Executing SQL Query: {clean_query}")
73
  try:
74
  result = db.run(clean_query)
75
  return result
@@ -83,7 +111,7 @@ def run_query(query):
83
  def database_tool(question):
84
  # print(question)
85
  sql_query = generate_sql_query(question)
86
- ##print(sql_query)
87
  return run_query(sql_query)
88
 
89
  def get_ASN_data(question):
@@ -149,7 +177,7 @@ def create_vector_store(texts):
149
 
150
  def query_vector_store(vector_store, query):
151
  docs = vector_store.similarity_search(query, k=5)
152
- #print(f"Vector store return: {docs}")
153
  return docs
154
 
155
  def summarize_document(docs):
@@ -167,13 +195,12 @@ def summarize_document(docs):
167
  summarized_content = doc_content
168
  summarized_docs.append(summarized_content)
169
  return '\n\n'.join(summarized_docs)
170
- pdf_path = "Inbound.pdf"
171
- #pdf_path = r"D:\rajesh\python\chat_agent\Inbound.pdf"
172
  texts = load_and_split_pdf(pdf_path)
173
  vector_store = create_vector_store(texts)
174
 
175
  def document_data_tool(question):
176
- #print(f"Document data tool enter: {question}")
177
  # query_string = question['tags'][0] if 'tags' in question and question['tags'] else ""
178
  query_response = query_vector_store(vector_store, question)
179
  print("query****")
@@ -182,6 +209,45 @@ def document_data_tool(question):
182
  #print("summary***")
183
  #print(summarized_response)
184
  return query_response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  def make_api_request(url, params):
187
  import requests
@@ -203,11 +269,7 @@ apis = [
203
  "url": "http://193.203.162.39:9090/nxt-wms/userWarehouse/fetchWarehouseForUserId?",
204
  "params": {"query": name, "userId": "164"}
205
  },
206
- #fetch customer id
207
- {
208
- "url": "http://193.203.162.39:9090/nxt-wms/userCustomer/fetchCustomerForUserId?",
209
- "params": {"query": "TESTING 123", "userId": "164", "status": "Active"}
210
- },
211
  #Stock summary based on warehouse id
212
  {
213
  "url": "http://193.203.162.39:9090/nxt-wms/transactionHistory/stockSummary?",
@@ -217,12 +279,13 @@ apis = [
217
 
218
  def inventory_report(question):
219
 
 
 
 
 
 
 
220
 
221
- name = question.split(":")[0]
222
- #print(question)
223
- question = question.split(":")[1]
224
- #print(name)
225
- import requests
226
 
227
  data = make_api_request(apis[0]["url"], apis[0]["params"])
228
  if data:
@@ -236,11 +299,8 @@ def inventory_report(question):
236
  if "warehouseId" in api["params"]:
237
  api["params"]["warehouseId"] = warehouse_id
238
 
239
- #print(apis[2]["url"])
240
- #print(apis[2]["params"])
241
- data1 = make_api_request(apis[2]["url"], apis[2]["params"])
242
- #if data1:
243
- #print(data1)
244
 
245
  from tabulate import tabulate
246
 
@@ -268,22 +328,25 @@ def inventory_report(question):
268
  table_data.append(row)
269
 
270
 
271
- #if table_data:
272
- #print(tabulate(table_data, headers=headers, tablefmt="grid"))
273
-
274
- # Convert to pandas DataFrame
275
- import pandas as pd
276
  df = pd.DataFrame(table_data, columns=headers)
277
- import pandas as pd
278
- from pandasai.llm.openai import OpenAI
279
- from pandasai import SmartDataframe
280
- #open api key
281
- import openai
282
-
283
- llm = OpenAI()
284
- sdf = SmartDataframe(df, config={"llm": llm})
285
  #chart = sdf.chat("Can you draw a bar chart with all avaialble item name and quantity.")
286
  chart = sdf.chat(question)
 
 
 
 
 
 
 
 
 
 
 
 
287
  return chart
288
  #inventory_report("WH:can you give me a bar chart with item name and quantity for the warehouse WH")
289
 
@@ -325,7 +388,19 @@ tools = [
325
  name="dataVisualization",
326
  args_schema=QueryInput,
327
  output_schema=QueryOutput,
328
- description="Tool to generate visual output for a particular warehouse. Invoke this tool if the user wants to create charts. Process the user question and send two inputs to the tool. One input will be the warehouse name and another input to the tool will be the entire user_question itself. "
 
 
 
 
 
 
 
 
 
 
 
 
329
  )
330
  ]
331
 
@@ -342,25 +417,71 @@ llm = llm.bind()
342
  agent = create_tool_calling_agent(llm, tools, ChatPromptTemplate.from_template(prompt_template))
343
  agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  # Define the interface function
346
  max_iterations = 5
347
  iterations = 0
348
 
349
- def answer_question(user_question,chatbot):
350
  global iterations
351
  iterations = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
 
353
  while iterations < max_iterations:
354
- #print(user_question)
355
- response = agent_executor.invoke({"input": user_question})
356
- #print(response)
357
- if isinstance(response, dict):
358
- response_text = response.get("output", "")
359
- else:
360
- response_text = response
361
- if "invalid" not in response_text.lower():
362
- break
363
- iterations += 1
 
 
 
 
 
 
 
364
 
365
  if iterations == max_iterations:
366
  return "The agent could not generate a valid response within the iteration limit."
@@ -378,7 +499,18 @@ def answer_question(user_question,chatbot):
378
  #image = gr.Image(value=img_str)
379
  chatbot.append((user_question,img))
380
  #print(chatbot)
 
 
 
 
 
 
 
 
 
 
381
  return gr.update(value=chatbot)
 
382
 
383
  #return [(user_question,gr.Image("/home/user/app/exports/charts/temp_chart.png"))]
384
  # return "/home/user/app/exports/charts/temp_chart.png"
@@ -432,6 +564,7 @@ with gr.Blocks(css=css) as demo:
432
  with gr.Row():
433
  with gr.Column(scale=1):
434
  message = gr.Textbox(show_label=False)
 
435
  with gr.Column(scale=1):
436
  with gr.Row():
437
  button = gr.Button("Submit", elem_classes="gr-button")
 
9
  import base64
10
  from io import BytesIO
11
  import os
12
+ import re
13
+ import tempfile
14
+ import wave
15
  import requests
16
  import gradio as gr
17
+ import time
18
+ import shutil
19
+ import json
20
+ import nltk
21
+ #audio package
22
+ import speech_recognition as sr
23
+ from pydub import AudioSegment
24
+ from pydub.playback import play
25
+ #email library
26
+ import smtplib
27
+ from email.mime.multipart import MIMEMultipart
28
+ from email.mime.text import MIMEText
29
+ from email.mime.base import MIMEBase
30
+ from email import encoders
31
+ #langchain
32
  from langchain_core.prompts import ChatPromptTemplate
33
  from langchain_core.output_parsers import StrOutputParser
34
  from langchain_core.runnables import RunnableSequence, RunnableLambda
 
44
  from nltk.tokenize import sent_tokenize
45
  from sqlalchemy import create_engine
46
  from sqlalchemy.sql import text
47
+ #google
48
+ from google.colab import userdata
49
+ from google.colab import drive
50
+ #pandas
51
+ import pandas as pd
52
+ from pandasai.llm.openai import OpenAI
53
+ from pandasai import SmartDataframe
54
 
 
55
 
56
+
57
+
58
+ nltk.download('punkt')
59
+
60
+ drive.mount('/content/drive', force_remount=True)
61
+ open_api_key_token = userdata.get('OPENAI_API_KEY')
62
+ postgresql_connection = userdata.get('POSTGRESQL_CONNECTION')
63
+ pdf_path="Inbound.pdf"
64
  os.environ['OPENAI_API_KEY'] = open_api_key_token
65
  db_uri = 'postgresql+psycopg2://postgres:postpass@193.203.162.39:5432/warehouseAi'
66
  # Database setup
 
68
  db = SQLDatabase.from_uri(db_uri)
69
 
70
  # LLM setup
71
+ llm = ChatOpenAI(model="gpt-4o-mini",max_tokens=300,temperature=0.1)
72
+ llm_chart = OpenAI()
73
 
74
  def get_schema(_):
75
  schema_info = db.get_table_info() # This should be a string of your SQL schema
 
97
  def run_query(query):
98
  # Clean the query by removing markdown symbols and trimming whitespace
99
  clean_query = query.replace("```sql", "").replace("```", "").strip()
100
+ print(f"Executing SQL Query: {clean_query}")
101
  try:
102
  result = db.run(clean_query)
103
  return result
 
111
  def database_tool(question):
112
  # print(question)
113
  sql_query = generate_sql_query(question)
114
+ print(sql_query)
115
  return run_query(sql_query)
116
 
117
  def get_ASN_data(question):
 
177
 
178
  def query_vector_store(vector_store, query):
179
  docs = vector_store.similarity_search(query, k=5)
180
+ print(f"Vector store return: {docs}")
181
  return docs
182
 
183
  def summarize_document(docs):
 
195
  summarized_content = doc_content
196
  summarized_docs.append(summarized_content)
197
  return '\n\n'.join(summarized_docs)
198
+
 
199
  texts = load_and_split_pdf(pdf_path)
200
  vector_store = create_vector_store(texts)
201
 
202
  def document_data_tool(question):
203
+ print(f"Document data tool enter: {question}")
204
  # query_string = question['tags'][0] if 'tags' in question and question['tags'] else ""
205
  query_response = query_vector_store(vector_store, question)
206
  print("query****")
 
209
  #print("summary***")
210
  #print(summarized_response)
211
  return query_response
212
+
213
+ def send_email_with_attachment(recipient_email, subject, body, attachment_path):
214
+ sender_email = "learning.rajeshthangaraj1@gmail.com"
215
+ sender_password = "mkeogppbcjgrdfpg"
216
+
217
+ # Create a multipart message
218
+ msg = MIMEMultipart()
219
+ msg['From'] = sender_email
220
+ msg['To'] = recipient_email
221
+ msg['Subject'] = subject
222
+
223
+ # Attach the body with the msg instance
224
+ msg.attach(MIMEText(body, 'plain'))
225
+
226
+ # Open the file to be sent
227
+ attachment = open(attachment_path, "rb")
228
+
229
+ # Instance of MIMEBase and named as p
230
+ part = MIMEBase('application', 'octet-stream')
231
+
232
+ # To change the payload into encoded form
233
+ part.set_payload((attachment).read())
234
+
235
+ # Encode into base64
236
+ encoders.encode_base64(part)
237
+
238
+ part.add_header('Content-Disposition', f"attachment; filename= {attachment_path}")
239
+
240
+ # Attach the instance 'part' to instance 'msg'
241
+ msg.attach(part)
242
+
243
+ # Create SMTP session for sending the mail
244
+ server = smtplib.SMTP('smtp.gmail.com', 587)
245
+ server.starttls()
246
+ server.login(sender_email, sender_password)
247
+ text = msg.as_string()
248
+ server.sendmail(sender_email, recipient_email, text)
249
+ server.quit()
250
+ #return 1
251
 
252
  def make_api_request(url, params):
253
  import requests
 
269
  "url": "http://193.203.162.39:9090/nxt-wms/userWarehouse/fetchWarehouseForUserId?",
270
  "params": {"query": name, "userId": "164"}
271
  },
272
+
 
 
 
 
273
  #Stock summary based on warehouse id
274
  {
275
  "url": "http://193.203.162.39:9090/nxt-wms/transactionHistory/stockSummary?",
 
279
 
280
  def inventory_report(question):
281
 
282
+ # Split the question to extract warehouse name, user question, and optional email
283
+ parts = question.split(":", 2)
284
+ name = parts[0].strip()
285
+ user_question = parts[1].strip()
286
+ user_email = parts[2].strip() if len(parts) > 2 else None
287
+ print(f"Warehouse: {name}, Email: {user_email}, Question: {user_question}")
288
 
 
 
 
 
 
289
 
290
  data = make_api_request(apis[0]["url"], apis[0]["params"])
291
  if data:
 
299
  if "warehouseId" in api["params"]:
300
  api["params"]["warehouseId"] = warehouse_id
301
 
302
+
303
+ data1 = make_api_request(apis[1]["url"], apis[1]["params"])
 
 
 
304
 
305
  from tabulate import tabulate
306
 
 
328
  table_data.append(row)
329
 
330
 
331
+ # Convert to pandas DataFrame
 
 
 
 
332
  df = pd.DataFrame(table_data, columns=headers)
333
+
334
+ sdf = SmartDataframe(df, config={"llm": llm_chart})
335
+
 
 
 
 
 
336
  #chart = sdf.chat("Can you draw a bar chart with all avaialble item name and quantity.")
337
  chart = sdf.chat(question)
338
+
339
+ #email send
340
+ if user_email:
341
+ # Send email with the chart image attached
342
+ send_email_with_attachment(
343
+ recipient_email=user_email,
344
+ subject="Warehouse Inventory Report",
345
+ body="Please find the attached bar chart report for the warehouse inventory analysis.",
346
+ #attachment_path=chart_path
347
+ attachment_path="/content/exports/charts/temp_chart.png"
348
+ )
349
+
350
  return chart
351
  #inventory_report("WH:can you give me a bar chart with item name and quantity for the warehouse WH")
352
 
 
388
  name="dataVisualization",
389
  args_schema=QueryInput,
390
  output_schema=QueryOutput,
391
+ description = """
392
+ Tool to generate a visual output (such as a bar chart) for a particular warehouse based on the provided question.
393
+ This tool processes the user question to identify the warehouse name and the specific request. If the user specifies
394
+ an email, include the email in the input. The input format should be: 'warehouse name: user question: email (if any)'.
395
+ The tool generates the requested chart and sends it to the provided email if specified.
396
+
397
+ Examples:
398
+ 1. Question without email: "Analyze item name and quantity in a bar chart in warehouse Allcargo Logistics"
399
+ Input to tool: "Allcargo Logistics: I want to analyze item name and quantity in a bar chart"
400
+
401
+ 2. Question with email: "Analyze item name and quantity in a bar chart in warehouse Allcargo Logistics report to send email to example@example.com"
402
+ Input to tool: "Allcargo Logistics: I want to analyze item name and quantity in a bar chart: example@example.com"
403
+ """
404
  )
405
  ]
406
 
 
417
  agent = create_tool_calling_agent(llm, tools, ChatPromptTemplate.from_template(prompt_template))
418
  agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
419
 
420
+ def ensure_temp_chart_dir():
421
+ temp_chart_dir = "/content/exports/charts/"
422
+ if not os.path.exists(temp_chart_dir):
423
+ os.makedirs(temp_chart_dir)
424
+
425
+ def clean_gradio_tmp_dir():
426
+ tmp_dir = "/tmp/gradio/"
427
+ if os.path.exists(tmp_dir):
428
+ try:
429
+ shutil.rmtree(tmp_dir)
430
+ except Exception as e:
431
+ print(f"Error cleaning up /tmp/gradio/ directory: {e}")
432
+
433
  # Define the interface function
434
  max_iterations = 5
435
  iterations = 0
436
 
437
+ def answer_question(user_question, chatbot, audio=None):
438
  global iterations
439
  iterations = 0
440
+ # Ensure the temporary chart directory exists
441
+ #ensure_temp_chart_dir()
442
+ # Clean the /tmp/gradio/ directory
443
+ #clean_gradio_tmp_dir()
444
+ # Handle audio input if provided
445
+ if audio is not None:
446
+ sample_rate, audio_data = audio
447
+ audio_segment = AudioSegment(
448
+ audio_data.tobytes(),
449
+ frame_rate=sample_rate,
450
+ sample_width=audio_data.dtype.itemsize,
451
+ channels=1
452
+ )
453
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file:
454
+ audio_segment.export(temp_audio_file.name, format="wav")
455
+ temp_audio_file_path = temp_audio_file.name
456
+
457
+ recognizer = sr.Recognizer()
458
+ with sr.AudioFile(temp_audio_file_path) as source:
459
+ audio_content = recognizer.record(source)
460
+ try:
461
+ user_question = recognizer.recognize_google(audio_content)
462
+ except sr.UnknownValueError:
463
+ user_question = "Sorry, I could not understand the audio."
464
+ except sr.RequestError:
465
+ user_question = "Could not request results from Google Speech Recognition service."
466
 
467
  while iterations < max_iterations:
468
+ print(user_question)
469
+ if "send email to" in user_question:
470
+ email_match = re.search(r"send email to ([\w\.-]+@[\w\.-]+)", user_question)
471
+ if email_match:
472
+ user_email = email_match.group(1).strip()
473
+ user_question = user_question.replace(f"send email to {user_email}", "").strip()
474
+ user_question = f"{user_question}:{user_email}"
475
+
476
+ response = agent_executor.invoke({"input": user_question})
477
+
478
+ if isinstance(response, dict):
479
+ response_text = response.get("output", "")
480
+ else:
481
+ response_text = response
482
+ if "invalid" not in response_text.lower():
483
+ break
484
+ iterations += 1
485
 
486
  if iterations == max_iterations:
487
  return "The agent could not generate a valid response within the iteration limit."
 
499
  #image = gr.Image(value=img_str)
500
  chatbot.append((user_question,img))
501
  #print(chatbot)
502
+ if "send email to" in user_question:
503
+ try:
504
+ os.remove(image_path) # Clean up the temporary image file
505
+ except Exception as e:
506
+ print(f"Error cleaning up image file: {e}")
507
+ except Exception as e:
508
+ print(f"Error loading image file: {e}")
509
+ chatbot.append((user_question, "Chart generation failed. Please try again."))
510
+ else:
511
+ chatbot.append((user_question, "Chart generation failed. Please try again."))
512
  return gr.update(value=chatbot)
513
+
514
 
515
  #return [(user_question,gr.Image("/home/user/app/exports/charts/temp_chart.png"))]
516
  # return "/home/user/app/exports/charts/temp_chart.png"
 
564
  with gr.Row():
565
  with gr.Column(scale=1):
566
  message = gr.Textbox(show_label=False)
567
+ audio_input = gr.Audio(label="Record your question")
568
  with gr.Column(scale=1):
569
  with gr.Row():
570
  button = gr.Button("Submit", elem_classes="gr-button")