Fangyu Liu commited on
Commit
f67adcd
1 Parent(s): ae2c89d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -6
app.py CHANGED
@@ -2,14 +2,86 @@ import gradio as gr
2
  # from PIL import Image
3
  from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
4
 
5
- model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-docvqa-large")
6
- processor = Pix2StructProcessor.from_pretrained("google/pix2struct-docvqa-large")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  def process_document(image, question):
9
  # image = Image.open(image)
10
- inputs = processor(images=image, text=question, return_tensors="pt")
11
- predictions = model.generate(**inputs)
12
- return processor.decode(predictions[0], skip_special_tokens=True)
 
 
 
 
13
 
14
  description = "Demo for pix2struct fine-tuned on DocVQA (document visual question answering). To use it, simply upload your image and type a question and click 'submit', or click one of the examples to load them. Read more at the links below."
15
  article = "<p style='text-align: center'><a href='https://arxiv.org/pdf/2210.03347.pdf' target='_blank'>PIX2STRUCT: SCREENSHOT PARSING AS PRETRAINING FOR VISUAL LANGUAGE UNDERSTANDING</a></p>"
@@ -18,7 +90,7 @@ demo = gr.Interface(
18
  fn=process_document,
19
  inputs=["image", "text"],
20
  outputs="text",
21
- title="Demo: pix2struct for DocVQA",
22
  description=description,
23
  article=article,
24
  enable_queue=True,
 
2
  # from PIL import Image
3
  from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
4
 
5
+ def _add_markup(table):
6
+ parts = [p.strip() for p in table.splitlines(keepends=False)]
7
+ if parts[0].startswith('TITLE'):
8
+ result = f"Title: {parts[0].split(' | ')[1].strip()}\n"
9
+ rows = parts[1:]
10
+ else:
11
+ result = ''
12
+ rows = parts
13
+ prefixes = ['Header: '] + [f'Row {i+1}: ' for i in range(len(rows) - 1)]
14
+ return result + '\n'.join(prefix + row for prefix, row in zip(prefixes, rows))
15
+
16
+
17
+ _TABLE = """Year | Democrats | Republicans | Independents
18
+ 2004 | 68.1% | 45.0% | 53.0%
19
+ 2006 | 58.0% | 42.0% | 53.0%
20
+ 2007 | 59.0% | 38.0% | 45.0%
21
+ 2009 | 72.0% | 49.0% | 60.0%
22
+ 2011 | 71.0% | 51.2% | 58.0%
23
+ 2012 | 70.0% | 48.0% | 53.0%
24
+ 2013 | 72.0% | 41.0% | 60.0%"""
25
+
26
+ _INSTRUCTION = 'Read the table below to answer the following questions.'
27
+
28
+
29
+ _TEMPLATE = f"""{_INSTRUCTION}
30
+ {_add_markup(_TABLE)}
31
+ Q: In which year republicans have the lowest favor rate?
32
+ A: Let's find the column of republicans. Then let's extract the favor rates, they [45.0, 42.0, 38.0, 49.0, 51.2, 48.0, 41.0]. The smallest number is 38.0, that's Row 3. Row 3 is year 2007. The answer is 2007.
33
+ Q: What is the sum of Democrats' favor rates of 2004, 2012, and 2013?
34
+ A: Let's find the rows of years 2004, 2012, and 2013. We find Row 1, 6, 7. The favor dates of Demoncrats on that 3 rows are 68.1, 70.0, and 72.0. 68.1+70.0+72=210.1. The answer is 210.1.
35
+ Q: By how many points do Independents surpass Republicans in the year of 2011?
36
+ A: Let's find the row with year = 2011. We find Row 5. We extract Independents and Republicans' numbers. They are 58.0 and 51.2. 58.0-51.2=6.8. The answer is 6.8.
37
+ Q: Which group has the overall worst performance?
38
+ A: Let's sample a couple of years. In Row 1, year 2004, we find Republicans having the lowest favor rate 45.0 (since 45.0<68.1, 45.0<53.0). In year 2006, Row 2, we find Republicans having the lowest favor rate 42.0 (42.0<58.0, 42.0<53.0). The trend continues to other years. The answer is Republicans.
39
+ Q: Which party has the second highest favor rates in 2007?
40
+ A: Let's find the row of year 2007, that's Row 3. Let's extract the numbers on Row 3: [59.0, 38.0, 45.0]. 45.0 is the second highest. 45.0 is the number of Independents. The answer is Independents.
41
+ {_INSTRUCTION}"""
42
+
43
+ def text_generate(prompt, table, problem):
44
+ p = prompt + "\n" + _INSTRUCTION + "\n" + table + "\n" + "Q: " + problem
45
+ # print(f"Final prompt is : {p}")
46
+ json_ = {"inputs": p,
47
+ "parameters":
48
+ {
49
+ "top_p": 0.9,
50
+ "temperature": 1.1,
51
+ "max_new_tokens": 64,
52
+ "return_full_text": True
53
+ }, "options":
54
+ {
55
+ "use_cache": True,
56
+ "wait_for_model":True
57
+ },}
58
+ response = requests.post(API_URL, headers=headers, json=json_)
59
+ print(f"Response is : {response}")
60
+ output = response.json()
61
+ print(f"output is : {output}") #{output}")
62
+ output_tmp = output[0]['generated_text']
63
+ print(f"output_tmp is: {output_tmp}")
64
+ #solution = output_tmp.split("\nQ:")[0] #output[0]['generated_text'].split("Q:")[0] # +"."
65
+ #print(f"Final response after splits is: {solution}")
66
+
67
+ #return solution
68
+ return output_tmp
69
+
70
+
71
+
72
+
73
+ model_deplot = Pix2StructForConditionalGeneration.from_pretrained("belkada/deplot")
74
+ processor_deplot = Pix2StructProcessor.from_pretrained("belkada/deplot")
75
 
76
  def process_document(image, question):
77
  # image = Image.open(image)
78
+ inputs = processor_deplot(images=image, text="Generate the underlying data table for the figure below:", return_tensors="pt")
79
+ predictions = model_deplot.generate(**inputs)
80
+ table = processor_deplot.decode(predictions[0], skip_special_tokens=True)
81
+
82
+ # send prompt+table to LLM
83
+ res = text_generate(_TEMPLATE, table, question)
84
+ print (res)
85
 
86
  description = "Demo for pix2struct fine-tuned on DocVQA (document visual question answering). To use it, simply upload your image and type a question and click 'submit', or click one of the examples to load them. Read more at the links below."
87
  article = "<p style='text-align: center'><a href='https://arxiv.org/pdf/2210.03347.pdf' target='_blank'>PIX2STRUCT: SCREENSHOT PARSING AS PRETRAINING FOR VISUAL LANGUAGE UNDERSTANDING</a></p>"
 
90
  fn=process_document,
91
  inputs=["image", "text"],
92
  outputs="text",
93
+ title="Demo: deplot+llm test",
94
  description=description,
95
  article=article,
96
  enable_queue=True,