fl399 commited on
Commit
70fa6be
·
1 Parent(s): 40d28d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -58
app.py CHANGED
@@ -114,19 +114,40 @@ if torch.__version__ >= "2":
114
 
115
 
116
  ## FLAN-UL2
117
- # in dev...
118
  TOKEN = os.environ.get("API_TOKEN", None)
119
  API_URL = "https://api-inference.huggingface.co/models/google/flan-ul2"
120
  headers = {"Authorization": f"Bearer {TOKEN}"}
121
  def query(payload):
122
  response = requests.post(API_URL, headers=headers, json=payload)
123
  return response.json()
124
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  def evaluate(
126
  table,
127
  question,
128
  llm="alpaca-lora",
129
- num_shot="1-shot",
130
  input=None,
131
  temperature=0.1,
132
  top_p=0.75,
@@ -138,10 +159,7 @@ def evaluate(
138
  prompt_0shot = _INSTRUCTION + "\n" + _add_markup(table) + "\n" + "Q: " + question + "\n" + "A:"
139
  prompt = _TEMPLATE + "\n" + _add_markup(table) + "\n" + "Q: " + question + "\n" + "A:"
140
  if llm == "alpaca-lora":
141
- if num_shot == "1-shot":
142
- inputs = tokenizer(prompt, return_tensors="pt")
143
- else:
144
- inputs = tokenizer(prompt_0shot, return_tensors="pt")
145
  input_ids = inputs["input_ids"].to(device)
146
  generation_config = GenerationConfig(
147
  temperature=temperature,
@@ -161,24 +179,15 @@ def evaluate(
161
  s = generation_output.sequences[0]
162
  output = tokenizer.decode(s)
163
  elif llm == "flan-ul2":
164
- if num_shot == "1-shot":
165
- output = query({
166
- "inputs": prompt
167
- })[0]["generated_text"]
168
- else:
169
- output = query({
170
- "inputs": prompt_0shot
171
- })[0]["generated_text"]
172
  else:
173
  RuntimeError(f"No such LLM: {llm}")
174
 
175
  return output
176
 
177
 
178
- ## deplot models
179
- model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot", torch_dtype=torch.bfloat16).to(0)
180
- processor_deplot = Pix2StructProcessor.from_pretrained("google/deplot")
181
-
182
  def process_document(image, question, llm, num_shot):
183
  # image = Image.open(image)
184
  inputs = processor_deplot(images=image, text="Generate the underlying data table for the figure below:", return_tensors="pt").to(0, torch.bfloat16)
@@ -191,45 +200,75 @@ def process_document(image, question, llm, num_shot):
191
  return [table, res.split("A:")[-1]]
192
  else:
193
  return [table, res]
194
-
195
- description = "Demo for DePlot+LLM for QA and summarisation. [DePlot](https://arxiv.org/abs/2212.10505) is an image-to-text model that converts plots and charts into a textual sequence. The sequence then is used to prompt LLM for chain-of-thought reasoning. The current underlying LLMs are [alpaca-lora](https://huggingface.co/spaces/tloen/alpaca-lora) and [flan-ul2](https://huggingface.co/google/flan-ul2). To use it, simply upload your image and type a question or instruction and click 'submit', or click one of the examples to load them. Read more at the links below."
196
- article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2212.10505' target='_blank'>DePlot: One-shot visual language reasoning by plot-to-table translation</a></p>"
197
-
198
- demo = gr.Interface(
199
- fn=process_document,
200
- inputs=[
201
- "image",
202
- "text",
203
- gr.Dropdown(
204
- ["alpaca-lora", "flan-ul2"], label="LLM", info="Will add more LLMs later!"
205
- ),
206
- gr.Dropdown(
207
- ["0-shot", "1-shot"], label="#shots", info="How many example tables in the prompt?"
208
- ),
209
- ],
210
- outputs=[
211
- gr.inputs.Textbox(
212
- lines=8,
213
- label="Intermediate Table",
214
- ),
215
- gr.inputs.Textbox(
216
- lines=5,
217
- label="Output",
218
- )
219
- ],
220
- title="DePlot+LLM (Multimodal chain-of-thought reasoning on plots)",
221
- description=description,
222
- article=article,
223
- enable_queue=True,
224
- examples=[["deplot_case_study_m1.png", "What is the sum of numbers of Indonesia and Ireland? Remember to think step by step.", "alpaca-lora", "1-shot"],
225
- ["deplot_case_study_m1.png", "Summarise the chart for me please.", "alpaca-lora", "0-shot"],
226
- ["deplot_case_study_3.png", "By how much did China's growth rate drop? Think step by step.", "alpaca-lora", "1-shot"],
227
- ["deplot_case_study_4.png", "How many papers are submitted in 2020?", "alpaca-lora", "1-shot"],
228
- ["deplot_case_study_x2.png", "Summarise the chart for me please.", "alpaca-lora", "0-shot"],
229
- ["deplot_case_study_4.png", "How many papers are submitted in 2020?", "flan-ul2", "0-shot"],
230
- ["deplot_case_study_4.png", "acceptance rate = # accepted / #submitted . What is the acceptance rate of 2010?", "flan-ul2", "0-shot"],
231
- ["deplot_case_study_m1.png", "Summarise the chart for me please.", "flan-ul2", "0-shot"],
 
 
 
 
 
 
 
 
 
 
 
232
  ],
233
- cache_examples=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
  demo.launch(debug=True)
 
114
 
115
 
116
  ## FLAN-UL2
 
117
  TOKEN = os.environ.get("API_TOKEN", None)
118
  API_URL = "https://api-inference.huggingface.co/models/google/flan-ul2"
119
  headers = {"Authorization": f"Bearer {TOKEN}"}
120
  def query(payload):
121
  response = requests.post(API_URL, headers=headers, json=payload)
122
  return response.json()
123
+
124
+ ## OpenAI models
125
+ def set_openai_api_key(api_key):
126
+ if api_key and api_key.startswith("sk-") and len(api_key) > 50:
127
+ openai.api_key = api_key
128
+
129
+ def get_response_from_openai(prompt, model="gpt-3.5-turbo", max_output_tokens=128):
130
+ messages = [{"role": "assistant", "content": prompt}]
131
+ response = openai.ChatCompletion.create(
132
+ model=model,
133
+ messages=messages,
134
+ temperature=0.7,
135
+ max_tokens=max_output_tokens,
136
+ top_p=1,
137
+ frequency_penalty=0,
138
+ presence_penalty=0,
139
+ )
140
+ ret = response.choices[0].message['content']
141
+ return ret
142
+
143
+ ## deplot models
144
+ model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot", torch_dtype=torch.bfloat16).to(0)
145
+ processor_deplot = Pix2StructProcessor.from_pretrained("google/deplot")
146
+
147
  def evaluate(
148
  table,
149
  question,
150
  llm="alpaca-lora",
 
151
  input=None,
152
  temperature=0.1,
153
  top_p=0.75,
 
159
  prompt_0shot = _INSTRUCTION + "\n" + _add_markup(table) + "\n" + "Q: " + question + "\n" + "A:"
160
  prompt = _TEMPLATE + "\n" + _add_markup(table) + "\n" + "Q: " + question + "\n" + "A:"
161
  if llm == "alpaca-lora":
162
+ inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
163
  input_ids = inputs["input_ids"].to(device)
164
  generation_config = GenerationConfig(
165
  temperature=temperature,
 
179
  s = generation_output.sequences[0]
180
  output = tokenizer.decode(s)
181
  elif llm == "flan-ul2":
182
+ output = query({"inputs": prompt_0shot})[0]["generated_text"]
183
+ elif llm == "gpt-3.5-turbo":
184
+ output = get_response_from_openai(prompt_0shot)
 
 
 
 
 
185
  else:
186
  RuntimeError(f"No such LLM: {llm}")
187
 
188
  return output
189
 
190
 
 
 
 
 
191
  def process_document(image, question, llm, num_shot):
192
  # image = Image.open(image)
193
  inputs = processor_deplot(images=image, text="Generate the underlying data table for the figure below:", return_tensors="pt").to(0, torch.bfloat16)
 
200
  return [table, res.split("A:")[-1]]
201
  else:
202
  return [table, res]
203
+
204
+ theme = gr.themes.Monochrome(
205
+ primary_hue="indigo",
206
+ secondary_hue="blue",
207
+ neutral_hue="slate",
208
+ radius_size=gr.themes.sizes.radius_sm,
209
+ font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
210
+ )
211
+
212
+ with gr.Blocks(theme=theme) as demo:
213
+ with gr.Column():
214
+ gr.Markdown(
215
+ """<h1><center>DePlot+LLM: Multimodal chain-of-thought reasoning on plots</center></h1>
216
+ <p>
217
+ "This is a demo for DePlot+LLM for QA and summarisation. <a href='https://arxiv.org/abs/2212.10505' target='_blank'>DePlot</a> is an image-to-text model that converts plots and charts into a textual sequence. The sequence then is used to prompt LLM for chain-of-thought reasoning. The current underlying LLMs are <a href='https://huggingface.co/spaces/tloen/alpaca-lora' target='_blank'>alpaca-lora</a> and <a href='https://huggingface.co/google/flan-ul2' target='_blank'>flan-ul2</a>. To use it, simply upload your image and type a question or instruction and click 'submit', or click one of the examples to load them. Read more at the links below."
218
+ </p>
219
+ """
220
+ )
221
+ # #with gr.Row():
222
+ # llm = gr.Dropdown(
223
+ # ["alpaca-lora", "flan-ul2"], label="LLM", info="We will add more LLMs.")
224
+ # num_shot = gr.Dropdown(
225
+ # ["0-shot", "1-shot"], label="shots", info="How many example tables in the prompt?")
226
+ # openai_api = gr.Textbox(label="openai api (if using OpenAI models, otherwise leave empty)")
227
+
228
+ with gr.Row():
229
+ with gr.Column(scale=2):
230
+ input_image = gr.Image(label="Input Image", type="pil", interactive=True)
231
+ #input_image.style(height=512, width=512)
232
+ instruction = gr.Textbox(placeholder="Enter your instruction/question...", label="Question/Instruction")
233
+ llm = gr.Dropdown(["alpaca-lora", "flan-ul2", "gpt-3.5-turbo"], label="LLM")
234
+ openai_api_key_textbox = gr.Textbox(placeholder="Paste your OpenAI API key (sk-...) and hit Enter (if using OpenAI models, otherwise leave empty)",
235
+ show_label=False, lines=1, type='password')
236
+ submit = gr.Button("Submit", variant="primary")
237
+
238
+ with gr.Column(scale=2):
239
+ with gr.Accordion("Show intermediate table", open=False):
240
+ output_table = gr.Textbox(lines=8)
241
+ output_text = gr.Textbox(lines=8,label="Output")
242
+
243
+ gr.Examples(
244
+ examples=[["deplot_case_study_m1.png", "What is the sum of numbers of Indonesia and Ireland? Remember to think step by step.", "alpaca-lora"],
245
+ ["deplot_case_study_m1.png", "Summarise the chart for me please.", "alpaca-lora"],
246
+ ["deplot_case_study_3.png", "By how much did China's growth rate drop? Think step by step.", "alpaca-lora"],
247
+ ["deplot_case_study_4.png", "How many papers are submitted in 2020?", "alpaca-lora"],
248
+ ["deplot_case_study_x2.png", "Summarise the chart for me please.", "alpaca-lora"],
249
+ ["deplot_case_study_4.png", "How many papers are submitted in 2020?", "flan-ul2"],
250
+ ["deplot_case_study_4.png", "acceptance rate = # accepted / #submitted . What is the acceptance rate of 2010?", "flan-ul2"],
251
+ ["deplot_case_study_m1.png", "Summarise the chart for me please.", "flan-ul2"],
252
  ],
253
+ cache_examples=True,
254
+ inputs=[input_image, instruction, llm],
255
+ outputs=[output_table, output_text],
256
+ fn=process_document
257
+ )
258
+
259
+ gr.Markdown(
260
+ """<p style='text-align: center'><a href='https://arxiv.org/abs/2212.10505' target='_blank'>DePlot: One-shot visual language reasoning by plot-to-table translation</a></p>"""
261
+ )
262
+ openai_api_key_textbox.change(set_openai_api_key,
263
+ inputs=[openai_api_key_textbox],
264
+ outputs=[])
265
+ openai_api_key_textbox.submit(set_openai_api_key,
266
+ inputs=[openai_api_key_textbox],
267
+ outputs=[])
268
+ submit.click(process_document, inputs=[input_image, instruction, llm], outputs=[output_table, output_text])
269
+ instruction.submit(
270
+ process_document, inputs=[input_image, instruction, llm], outputs=[output_table, output_text]
271
+ )
272
+
273
 
274
  demo.launch(debug=True)