Redmind commited on
Commit
f08b1e6
·
verified ·
1 Parent(s): e4950ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +184 -52
app.py CHANGED
@@ -9,9 +9,26 @@ from PIL import Image
9
  import base64
10
  from io import BytesIO
11
  import os
 
 
 
12
  import requests
13
  import gradio as gr
14
- #import nltk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  from langchain_core.prompts import ChatPromptTemplate
16
  from langchain_core.output_parsers import StrOutputParser
17
  from langchain_core.runnables import RunnableSequence, RunnableLambda
@@ -27,21 +44,30 @@ from PyPDF2 import PdfReader
27
  from nltk.tokenize import sent_tokenize
28
  from sqlalchemy import create_engine
29
  from sqlalchemy.sql import text
30
- import json
31
- import nltk
 
 
 
 
 
 
 
32
  nltk.download('punkt')
33
 
34
  open_api_key_token = os.environ['OPEN_AI_API']
35
 
36
  os.environ['OPENAI_API_KEY'] = open_api_key_token
 
 
37
  db_uri = 'postgresql+psycopg2://postgres:postpass@193.203.162.39:5432/warehouseAi'
38
  # Database setup
39
 
40
  db = SQLDatabase.from_uri(db_uri)
41
 
42
  # LLM setup
43
- #llm = ChatOpenAI(model="gpt-3.5-turbo-0125",max_tokens=150,temperature=0.1)
44
- llm = ChatOpenAI(model="gpt-4o-mini",max_tokens=250,temperature=0.1)
45
 
46
  def get_schema(_):
47
  schema_info = db.get_table_info() # This should be a string of your SQL schema
@@ -69,7 +95,7 @@ def generate_sql_query(question):
69
  def run_query(query):
70
  # Clean the query by removing markdown symbols and trimming whitespace
71
  clean_query = query.replace("```sql", "").replace("```", "").strip()
72
- #print(f"Executing SQL Query: {clean_query}")
73
  try:
74
  result = db.run(clean_query)
75
  return result
@@ -83,7 +109,7 @@ def run_query(query):
83
  def database_tool(question):
84
  # print(question)
85
  sql_query = generate_sql_query(question)
86
- ##print(sql_query)
87
  return run_query(sql_query)
88
 
89
  def get_ASN_data(question):
@@ -149,7 +175,7 @@ def create_vector_store(texts):
149
 
150
  def query_vector_store(vector_store, query):
151
  docs = vector_store.similarity_search(query, k=5)
152
- #print(f"Vector store return: {docs}")
153
  return docs
154
 
155
  def summarize_document(docs):
@@ -167,13 +193,12 @@ def summarize_document(docs):
167
  summarized_content = doc_content
168
  summarized_docs.append(summarized_content)
169
  return '\n\n'.join(summarized_docs)
170
- pdf_path = "Inbound.pdf"
171
- #pdf_path = r"D:\rajesh\python\chat_agent\Inbound.pdf"
172
  texts = load_and_split_pdf(pdf_path)
173
  vector_store = create_vector_store(texts)
174
 
175
  def document_data_tool(question):
176
- #print(f"Document data tool enter: {question}")
177
  # query_string = question['tags'][0] if 'tags' in question and question['tags'] else ""
178
  query_response = query_vector_store(vector_store, question)
179
  print("query****")
@@ -182,6 +207,45 @@ def document_data_tool(question):
182
  #print("summary***")
183
  #print(summarized_response)
184
  return query_response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  def make_api_request(url, params):
187
  import requests
@@ -203,26 +267,23 @@ apis = [
203
  "url": "http://193.203.162.39:9090/nxt-wms/userWarehouse/fetchWarehouseForUserId?",
204
  "params": {"query": name, "userId": "164"}
205
  },
206
- #fetch customer id
207
- {
208
- "url": "http://193.203.162.39:9090/nxt-wms/userCustomer/fetchCustomerForUserId?",
209
- "params": {"query": "TESTING 123", "userId": "164", "status": "Active"}
210
- },
211
  #Stock summary based on warehouse id
212
  {
213
  "url": "http://193.203.162.39:9090/nxt-wms/transactionHistory/stockSummary?",
214
- "params": {"branchId": "343", "onDate": "2024-08-06", "warehouseId" : warehouse_id }
215
  }
216
  ]
217
 
218
  def inventory_report(question):
219
 
 
 
 
 
 
 
220
 
221
- name = question.split(":")[0]
222
- #print(question)
223
- question = question.split(":")[1]
224
- #print(name)
225
- import requests
226
 
227
  data = make_api_request(apis[0]["url"], apis[0]["params"])
228
  if data:
@@ -236,11 +297,8 @@ def inventory_report(question):
236
  if "warehouseId" in api["params"]:
237
  api["params"]["warehouseId"] = warehouse_id
238
 
239
- #print(apis[2]["url"])
240
- #print(apis[2]["params"])
241
- data1 = make_api_request(apis[2]["url"], apis[2]["params"])
242
- #if data1:
243
- #print(data1)
244
 
245
  from tabulate import tabulate
246
 
@@ -268,22 +326,25 @@ def inventory_report(question):
268
  table_data.append(row)
269
 
270
 
271
- #if table_data:
272
- #print(tabulate(table_data, headers=headers, tablefmt="grid"))
273
-
274
- # Convert to pandas DataFrame
275
- import pandas as pd
276
  df = pd.DataFrame(table_data, columns=headers)
277
- import pandas as pd
278
- from pandasai.llm.openai import OpenAI
279
- from pandasai import SmartDataframe
280
- #open api key
281
- import openai
282
-
283
- llm = OpenAI()
284
- sdf = SmartDataframe(df, config={"llm": llm})
285
  #chart = sdf.chat("Can you draw a bar chart with all avaialble item name and quantity.")
286
  chart = sdf.chat(question)
 
 
 
 
 
 
 
 
 
 
 
 
287
  return chart
288
  #inventory_report("WH:can you give me a bar chart with item name and quantity for the warehouse WH")
289
 
@@ -325,7 +386,19 @@ tools = [
325
  name="dataVisualization",
326
  args_schema=QueryInput,
327
  output_schema=QueryOutput,
328
- description="Tool to generate visual output for a particular warehouse. Invoke this tool if the user wants to create charts. Process the user question and send two inputs to the tool. One input will be the warehouse name and another input to the tool will be the entire user_question itself. "
 
 
 
 
 
 
 
 
 
 
 
 
329
  )
330
  ]
331
 
@@ -342,25 +415,71 @@ llm = llm.bind()
342
  agent = create_tool_calling_agent(llm, tools, ChatPromptTemplate.from_template(prompt_template))
343
  agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  # Define the interface function
346
  max_iterations = 5
347
  iterations = 0
348
 
349
- def answer_question(user_question,chatbot):
350
  global iterations
351
  iterations = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
 
353
  while iterations < max_iterations:
354
- #print(user_question)
355
- response = agent_executor.invoke({"input": user_question})
356
- #print(response)
357
- if isinstance(response, dict):
358
- response_text = response.get("output", "")
359
- else:
360
- response_text = response
361
- if "invalid" not in response_text.lower():
362
- break
363
- iterations += 1
 
 
 
 
 
 
 
364
 
365
  if iterations == max_iterations:
366
  return "The agent could not generate a valid response within the iteration limit."
@@ -378,7 +497,18 @@ def answer_question(user_question,chatbot):
378
  #image = gr.Image(value=img_str)
379
  chatbot.append((user_question,img))
380
  #print(chatbot)
 
 
 
 
 
 
 
 
 
 
381
  return gr.update(value=chatbot)
 
382
 
383
  #return [(user_question,gr.Image("/home/user/app/exports/charts/temp_chart.png"))]
384
  # return "/home/user/app/exports/charts/temp_chart.png"
@@ -432,6 +562,7 @@ with gr.Blocks(css=css) as demo:
432
  with gr.Row():
433
  with gr.Column(scale=1):
434
  message = gr.Textbox(show_label=False)
 
435
  with gr.Column(scale=1):
436
  with gr.Row():
437
  button = gr.Button("Submit", elem_classes="gr-button")
@@ -458,3 +589,4 @@ with gr.Blocks(css=css) as demo:
458
 
459
 
460
  demo.launch()
 
 
9
  import base64
10
  from io import BytesIO
11
  import os
12
+ import re
13
+ import tempfile
14
+ import wave
15
  import requests
16
  import gradio as gr
17
+ import time
18
+ import shutil
19
+ import json
20
+ import nltk
21
+ #audio package
22
+ import speech_recognition as sr
23
+ from pydub import AudioSegment
24
+ from pydub.playback import play
25
+ #email library
26
+ import smtplib
27
+ from email.mime.multipart import MIMEMultipart
28
+ from email.mime.text import MIMEText
29
+ from email.mime.base import MIMEBase
30
+ from email import encoders
31
+ #langchain
32
  from langchain_core.prompts import ChatPromptTemplate
33
  from langchain_core.output_parsers import StrOutputParser
34
  from langchain_core.runnables import RunnableSequence, RunnableLambda
 
44
  from nltk.tokenize import sent_tokenize
45
  from sqlalchemy import create_engine
46
  from sqlalchemy.sql import text
47
+
48
+ #pandas
49
+ import pandas as pd
50
+ from pandasai.llm.openai import OpenAI
51
+ from pandasai import SmartDataframe
52
+
53
+
54
+
55
+
56
  nltk.download('punkt')
57
 
58
  open_api_key_token = os.environ['OPEN_AI_API']
59
 
60
  os.environ['OPENAI_API_KEY'] = open_api_key_token
61
+ pdf_path="Inbound.pdf"
62
+
63
  db_uri = 'postgresql+psycopg2://postgres:postpass@193.203.162.39:5432/warehouseAi'
64
  # Database setup
65
 
66
  db = SQLDatabase.from_uri(db_uri)
67
 
68
  # LLM setup
69
+ llm = ChatOpenAI(model="gpt-4o-mini",max_tokens=300,temperature=0.1)
70
+ llm_chart = OpenAI()
71
 
72
  def get_schema(_):
73
  schema_info = db.get_table_info() # This should be a string of your SQL schema
 
95
  def run_query(query):
96
  # Clean the query by removing markdown symbols and trimming whitespace
97
  clean_query = query.replace("```sql", "").replace("```", "").strip()
98
+ print(f"Executing SQL Query: {clean_query}")
99
  try:
100
  result = db.run(clean_query)
101
  return result
 
109
  def database_tool(question):
110
  # print(question)
111
  sql_query = generate_sql_query(question)
112
+ print(sql_query)
113
  return run_query(sql_query)
114
 
115
  def get_ASN_data(question):
 
175
 
176
  def query_vector_store(vector_store, query):
177
  docs = vector_store.similarity_search(query, k=5)
178
+ print(f"Vector store return: {docs}")
179
  return docs
180
 
181
  def summarize_document(docs):
 
193
  summarized_content = doc_content
194
  summarized_docs.append(summarized_content)
195
  return '\n\n'.join(summarized_docs)
196
+
 
197
  texts = load_and_split_pdf(pdf_path)
198
  vector_store = create_vector_store(texts)
199
 
200
  def document_data_tool(question):
201
+ print(f"Document data tool enter: {question}")
202
  # query_string = question['tags'][0] if 'tags' in question and question['tags'] else ""
203
  query_response = query_vector_store(vector_store, question)
204
  print("query****")
 
207
  #print("summary***")
208
  #print(summarized_response)
209
  return query_response
210
+
211
+ def send_email_with_attachment(recipient_email, subject, body, attachment_path):
212
+ sender_email = "redmind.uiautomation@gmail.com"
213
+ sender_password = "jymz apyc raih eubg"
214
+
215
+ # Create a multipart message
216
+ msg = MIMEMultipart()
217
+ msg['From'] = sender_email
218
+ msg['To'] = recipient_email
219
+ msg['Subject'] = subject
220
+
221
+ # Attach the body with the msg instance
222
+ msg.attach(MIMEText(body, 'plain'))
223
+
224
+ # Open the file to be sent
225
+ attachment = open(attachment_path, "rb")
226
+
227
+ # Instance of MIMEBase and named as p
228
+ part = MIMEBase('application', 'octet-stream')
229
+
230
+ # To change the payload into encoded form
231
+ part.set_payload((attachment).read())
232
+
233
+ # Encode into base64
234
+ encoders.encode_base64(part)
235
+
236
+ part.add_header('Content-Disposition', f"attachment; filename= {attachment_path}")
237
+
238
+ # Attach the instance 'part' to instance 'msg'
239
+ msg.attach(part)
240
+
241
+ # Create SMTP session for sending the mail
242
+ server = smtplib.SMTP('smtp.gmail.com', 587)
243
+ server.starttls()
244
+ server.login(sender_email, sender_password)
245
+ text = msg.as_string()
246
+ server.sendmail(sender_email, recipient_email, text)
247
+ server.quit()
248
+ #return 1
249
 
250
  def make_api_request(url, params):
251
  import requests
 
267
  "url": "http://193.203.162.39:9090/nxt-wms/userWarehouse/fetchWarehouseForUserId?",
268
  "params": {"query": name, "userId": "164"}
269
  },
270
+
 
 
 
 
271
  #Stock summary based on warehouse id
272
  {
273
  "url": "http://193.203.162.39:9090/nxt-wms/transactionHistory/stockSummary?",
274
+ "params": {"branchId": "343", "onDate": "2024-08-09", "warehouseId" : warehouse_id }
275
  }
276
  ]
277
 
278
  def inventory_report(question):
279
 
280
+ # Split the question to extract warehouse name, user question, and optional email
281
+ parts = question.split(":", 2)
282
+ name = parts[0].strip()
283
+ user_question = parts[1].strip()
284
+ user_email = parts[2].strip() if len(parts) > 2 else None
285
+ print(f"Warehouse: {name}, Email: {user_email}, Question: {user_question}")
286
 
 
 
 
 
 
287
 
288
  data = make_api_request(apis[0]["url"], apis[0]["params"])
289
  if data:
 
297
  if "warehouseId" in api["params"]:
298
  api["params"]["warehouseId"] = warehouse_id
299
 
300
+
301
+ data1 = make_api_request(apis[1]["url"], apis[1]["params"])
 
 
 
302
 
303
  from tabulate import tabulate
304
 
 
326
  table_data.append(row)
327
 
328
 
329
+ # Convert to pandas DataFrame
 
 
 
 
330
  df = pd.DataFrame(table_data, columns=headers)
331
+
332
+ sdf = SmartDataframe(df, config={"llm": llm_chart})
333
+
 
 
 
 
 
334
  #chart = sdf.chat("Can you draw a bar chart with all avaialble item name and quantity.")
335
  chart = sdf.chat(question)
336
+
337
+ #email send
338
+ if user_email:
339
+ # Send email with the chart image attached
340
+ send_email_with_attachment(
341
+ recipient_email=user_email,
342
+ subject="Warehouse Inventory Report",
343
+ body="Please find the attached bar chart report for the warehouse inventory analysis.",
344
+ #attachment_path=chart_path
345
+ attachment_path="/home/user/app/exports/charts/temp_chart.png"
346
+ )
347
+
348
  return chart
349
  #inventory_report("WH:can you give me a bar chart with item name and quantity for the warehouse WH")
350
 
 
386
  name="dataVisualization",
387
  args_schema=QueryInput,
388
  output_schema=QueryOutput,
389
+ description = """
390
+ Tool to generate a visual output (such as a bar chart) for a particular warehouse based on the provided question.
391
+ This tool processes the user question to identify the warehouse name and the specific request. If the user specifies
392
+ an email, include the email in the input. The input format should be: 'warehouse name: user question: email (if any)'.
393
+ The tool generates the requested chart and sends it to the provided email if specified.
394
+
395
+ Examples:
396
+ 1. Question without email: "Analyze item name and quantity in a bar chart in warehouse Allcargo Logistics"
397
+ Input to tool: "Allcargo Logistics: I want to analyze item name and quantity in a bar chart"
398
+
399
+ 2. Question with email: "Analyze item name and quantity in a bar chart in warehouse Allcargo Logistics report to send email to example@example.com"
400
+ Input to tool: "Allcargo Logistics: I want to analyze item name and quantity in a bar chart: example@example.com"
401
+ """
402
  )
403
  ]
404
 
 
415
  agent = create_tool_calling_agent(llm, tools, ChatPromptTemplate.from_template(prompt_template))
416
  agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
417
 
418
+ def ensure_temp_chart_dir():
419
+ temp_chart_dir = "/home/user/app/exports/charts/"
420
+ if not os.path.exists(temp_chart_dir):
421
+ os.makedirs(temp_chart_dir)
422
+
423
+ def clean_gradio_tmp_dir():
424
+ tmp_dir = "/tmp/gradio/"
425
+ if os.path.exists(tmp_dir):
426
+ try:
427
+ shutil.rmtree(tmp_dir)
428
+ except Exception as e:
429
+ print(f"Error cleaning up /tmp/gradio/ directory: {e}")
430
+
431
  # Define the interface function
432
  max_iterations = 5
433
  iterations = 0
434
 
435
+ def answer_question(user_question, chatbot, audio=None):
436
  global iterations
437
  iterations = 0
438
+ # Ensure the temporary chart directory exists
439
+ #ensure_temp_chart_dir()
440
+ # Clean the /tmp/gradio/ directory
441
+ #clean_gradio_tmp_dir()
442
+ # Handle audio input if provided
443
+ if audio is not None:
444
+ sample_rate, audio_data = audio
445
+ audio_segment = AudioSegment(
446
+ audio_data.tobytes(),
447
+ frame_rate=sample_rate,
448
+ sample_width=audio_data.dtype.itemsize,
449
+ channels=1
450
+ )
451
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file:
452
+ audio_segment.export(temp_audio_file.name, format="wav")
453
+ temp_audio_file_path = temp_audio_file.name
454
+
455
+ recognizer = sr.Recognizer()
456
+ with sr.AudioFile(temp_audio_file_path) as source:
457
+ audio_content = recognizer.record(source)
458
+ try:
459
+ user_question = recognizer.recognize_google(audio_content)
460
+ except sr.UnknownValueError:
461
+ user_question = "Sorry, I could not understand the audio."
462
+ except sr.RequestError:
463
+ user_question = "Could not request results from Google Speech Recognition service."
464
 
465
  while iterations < max_iterations:
466
+ print(user_question)
467
+ if "send email to" in user_question:
468
+ email_match = re.search(r"send email to ([\w\.-]+@[\w\.-]+)", user_question)
469
+ if email_match:
470
+ user_email = email_match.group(1).strip()
471
+ user_question = user_question.replace(f"send email to {user_email}", "").strip()
472
+ user_question = f"{user_question}:{user_email}"
473
+
474
+ response = agent_executor.invoke({"input": user_question})
475
+
476
+ if isinstance(response, dict):
477
+ response_text = response.get("output", "")
478
+ else:
479
+ response_text = response
480
+ if "invalid" not in response_text.lower():
481
+ break
482
+ iterations += 1
483
 
484
  if iterations == max_iterations:
485
  return "The agent could not generate a valid response within the iteration limit."
 
497
  #image = gr.Image(value=img_str)
498
  chatbot.append((user_question,img))
499
  #print(chatbot)
500
+ if "send email to" in user_question:
501
+ try:
502
+ os.remove(image_path) # Clean up the temporary image file
503
+ except Exception as e:
504
+ print(f"Error cleaning up image file: {e}")
505
+ except Exception as e:
506
+ print(f"Error loading image file: {e}")
507
+ chatbot.append((user_question, "Chart generation failed. Please try again."))
508
+ else:
509
+ chatbot.append((user_question, "Chart generation failed. Please try again."))
510
  return gr.update(value=chatbot)
511
+
512
 
513
  #return [(user_question,gr.Image("/home/user/app/exports/charts/temp_chart.png"))]
514
  # return "/home/user/app/exports/charts/temp_chart.png"
 
562
  with gr.Row():
563
  with gr.Column(scale=1):
564
  message = gr.Textbox(show_label=False)
565
+ audio_input = gr.Audio(label="Record your question")
566
  with gr.Column(scale=1):
567
  with gr.Row():
568
  button = gr.Button("Submit", elem_classes="gr-button")
 
589
 
590
 
591
  demo.launch()
592
+