fl399 commited on
Commit
a6237eb
1 Parent(s): 62edd91

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -19,22 +19,22 @@ model.to(device)
19
  def filter_output(output):
20
  return output.replace("<0x0A>", "")
21
 
22
- def summarize_chart(image):
23
- inputs = processor(images=image, text="", return_tensors="pt").to(device)
24
  predictions = model.generate(**inputs, max_new_tokens=512)
25
  return filter_output(processor.decode(predictions[0], skip_special_tokens=True))
26
 
27
 
28
  image = gr.inputs.Image(type="pil", label="Chart")
29
- answer = gr.outputs.Textbox(label="Chart Summary")
30
  examples = [["chart_example.png", "Which country has the second highest death rate?"], ]
31
  #["chart_example_2.png"], ["chart_example_3.png"], ["chart_example_4.png"]]
32
 
33
  title = "Interactive demo: chart QA"
34
  description = "Gradio Demo for matcha model, fine-tuned on the ChartQA dataset. To use it, simply upload your image and click 'submit', or click one of the examples to load them."
35
 
36
- interface = gr.Interface(fn=summarize_chart,
37
- inputs=[image],
38
  outputs=answer,
39
  examples=examples,
40
  title=title,
 
19
  def filter_output(output):
20
  return output.replace("<0x0A>", "")
21
 
22
+ def chart_qa(image, question):
23
+ inputs = processor(images=image, text=question, return_tensors="pt").to(device)
24
  predictions = model.generate(**inputs, max_new_tokens=512)
25
  return filter_output(processor.decode(predictions[0], skip_special_tokens=True))
26
 
27
 
28
  image = gr.inputs.Image(type="pil", label="Chart")
29
+ answer = gr.outputs.Textbox(label="Model Output")
30
  examples = [["chart_example.png", "Which country has the second highest death rate?"], ]
31
  #["chart_example_2.png"], ["chart_example_3.png"], ["chart_example_4.png"]]
32
 
33
  title = "Interactive demo: chart QA"
34
  description = "Gradio Demo for matcha model, fine-tuned on the ChartQA dataset. To use it, simply upload your image and click 'submit', or click one of the examples to load them."
35
 
36
+ interface = gr.Interface(fn=chart_qa,
37
+ inputs=[image, question],
38
  outputs=answer,
39
  examples=examples,
40
  title=title,