ilia_khristoforov commited on
Commit
23d5842
1 Parent(s): ad079c2

Charts tab

Browse files
Files changed (1) hide show
  1. app.py +36 -8
app.py CHANGED
@@ -9,6 +9,11 @@ from langchain.embeddings import OpenAIEmbeddings
9
  from langchain.vectorstores import Chroma
10
  from langchain.chains import ConversationalRetrievalChain
11
  from langchain import PromptTemplate
 
 
 
 
 
12
 
13
 
14
  # _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
@@ -28,6 +33,26 @@ from langchain import PromptTemplate
28
  # =========
29
  # Answer in Markdown:"""
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def loading_pdf():
32
  return "Loading..."
33
 
@@ -169,11 +194,10 @@ def respond(message, chat_history):
169
  return "", chat_history
170
 
171
 
172
-
173
  with gr.Blocks() as demo:
174
  with gr.Column(elem_id="col-container"):
175
  gr.HTML(title)
176
- openai_key = gr.Textbox(
177
  show_label=False,
178
  placeholder="Your OpenAI key",
179
  type = 'password',
@@ -209,7 +233,7 @@ with gr.Blocks() as demo:
209
  clr_btn = gr.Button("Clear!")
210
 
211
  load_pdf.click(loading_pdf, None, langchain_status, queue=False)
212
- load_pdf.click(pdf_changes, inputs=[pdf_doc, openai_key], outputs=[langchain_status], queue=True)
213
  question.submit(add_text, [chatbot, question], [chatbot, question]).then(
214
  bot, chatbot, chatbot
215
  )
@@ -244,16 +268,20 @@ with gr.Blocks() as demo:
244
  clr_btn = gr.Button("Clear!")
245
 
246
  load_table.click(load_file, None, status_sh, queue=False)
247
- load_table.click(table_loader, inputs=[raw_table, openai_key], outputs=[status_sh], queue=False)
248
 
249
  question_sh.submit(respond, [question_sh, chatbot_sh], [question_sh, chatbot_sh])
250
  clr_btn.click(lambda: None, None, chatbot_sh, queue=False)
251
 
252
-
253
  with gr.Tab("Charts"):
254
- gr.Text('Soon!')
255
-
256
-
 
 
 
 
257
 
258
  demo.queue(concurrency_count=3)
259
  demo.launch()
 
9
  from langchain.vectorstores import Chroma
10
  from langchain.chains import ConversationalRetrievalChain
11
  from langchain import PromptTemplate
12
+ from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
13
+ import requests
14
+ from PIL import Image
15
+ import torch
16
+
17
 
18
 
19
  # _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
 
33
  # =========
34
  # Answer in Markdown:"""
35
 
36
+ torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/20294671002019.png', 'chart_example.png')
37
+ torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/test/png/multi_col_1081.png', 'chart_example_2.png')
38
+ torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/test/png/18143564004789.png', 'chart_example_3.png')
39
+ torch.hub.download_url_to_file('https://sharkcoder.com/files/article/matplotlib-bar-plot.png', 'chart_example_4.png')
40
+
41
+
42
+ model_name = "google/matcha-chartqa"
43
+ model = Pix2StructForConditionalGeneration.from_pretrained(model_name)
44
+ processor = Pix2StructProcessor.from_pretrained(model_name)
45
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
+ model.to(device)
47
+
48
+ def filter_output(output):
49
+ return output.replace("<0x0A>", "")
50
+
51
+ def chart_qa(image, question):
52
+ inputs = processor(images=image, text=question, return_tensors="pt").to(device)
53
+ predictions = model.generate(**inputs, max_new_tokens=512)
54
+ return filter_output(processor.decode(predictions[0], skip_special_tokens=True))
55
+
56
  def loading_pdf():
57
  return "Loading..."
58
 
 
194
  return "", chat_history
195
 
196
 
 
197
  with gr.Blocks() as demo:
198
  with gr.Column(elem_id="col-container"):
199
  gr.HTML(title)
200
+ key = gr.Textbox(
201
  show_label=False,
202
  placeholder="Your OpenAI key",
203
  type = 'password',
 
233
  clr_btn = gr.Button("Clear!")
234
 
235
  load_pdf.click(loading_pdf, None, langchain_status, queue=False)
236
+ load_pdf.click(pdf_changes, inputs=[pdf_doc, key], outputs=[langchain_status], queue=True)
237
  question.submit(add_text, [chatbot, question], [chatbot, question]).then(
238
  bot, chatbot, chatbot
239
  )
 
268
  clr_btn = gr.Button("Clear!")
269
 
270
  load_table.click(load_file, None, status_sh, queue=False)
271
+ load_table.click(table_loader, inputs=[raw_table, key], outputs=[status_sh], queue=False)
272
 
273
  question_sh.submit(respond, [question_sh, chatbot_sh], [question_sh, chatbot_sh])
274
  clr_btn.click(lambda: None, None, chatbot_sh, queue=False)
275
 
276
+
277
  with gr.Tab("Charts"):
278
+ image = gr.Image(type="pil", label="Chart")
279
+ question = gr.Textbox(label="Question")
280
+ load_chart = gr.Button("Load chart and question!")
281
+ answer = gr.Textbox(label="Model Output")
282
+
283
+ load_chart.click(chart_qa, [image, question], answer)
284
+
285
 
286
  demo.queue(concurrency_count=3)
287
  demo.launch()