fl399 commited on
Commit
4cfb376
1 Parent(s): fd2e81d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -23
app.py CHANGED
@@ -112,9 +112,18 @@ if torch.__version__ >= "2":
112
  model = torch.compile(model)
113
 
114
 
 
 
 
 
 
 
 
 
115
  def evaluate(
116
  table,
117
  question,
 
118
  input=None,
119
  temperature=0.1,
120
  top_p=0.75,
@@ -124,26 +133,34 @@ def evaluate(
124
  **kwargs,
125
  ):
126
  prompt = _TEMPLATE + "\n" + _add_markup(table) + "\n" + "Q: " + question + "\n" + "A:"
127
- inputs = tokenizer(prompt, return_tensors="pt")
128
- input_ids = inputs["input_ids"].to(device)
129
- generation_config = GenerationConfig(
130
- temperature=temperature,
131
- top_p=top_p,
132
- top_k=top_k,
133
- num_beams=num_beams,
134
- **kwargs,
135
- )
136
- with torch.no_grad():
137
- generation_output = model.generate(
138
- input_ids=input_ids,
139
- generation_config=generation_config,
140
- return_dict_in_generate=True,
141
- output_scores=True,
142
- max_new_tokens=max_new_tokens,
143
  )
144
- s = generation_output.sequences[0]
145
- output = tokenizer.decode(s)
146
- #return output.split("A:")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  return output
148
 
149
 
@@ -151,23 +168,31 @@ def evaluate(
151
  model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot", torch_dtype=torch.bfloat16).to(0)
152
  processor_deplot = Pix2StructProcessor.from_pretrained("google/deplot")
153
 
154
- def process_document(image, question):
155
  # image = Image.open(image)
156
  inputs = processor_deplot(images=image, text="Generate the underlying data table for the figure below:", return_tensors="pt").to(0, torch.bfloat16)
157
  predictions = model_deplot.generate(**inputs, max_new_tokens=512)
158
  table = processor_deplot.decode(predictions[0], skip_special_tokens=True).replace("<0x0A>", "\n")
159
 
160
  # send prompt+table to LLM
161
- res = evaluate(table, question)
162
  #return res + "\n\n" + res.split("A:")[-1]
163
- return [table, res.split("A:")[-1]]
 
 
 
164
 
165
  description = "Demo for DePlot+LLM for QA and summarisation. [DePlot](https://arxiv.org/abs/2212.10505) is an image-to-text model that converts plots and charts into a textual sequence. The sequence then is used to prompt LLM for chain-of-thought reasoning. The current underlying LLM is [alpaca-lora](https://huggingface.co/spaces/tloen/alpaca-lora). To use it, simply upload your image and type a question or instruction and click 'submit', or click one of the examples to load them. Read more at the links below."
166
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2212.10505' target='_blank'>DePlot: One-shot visual language reasoning by plot-to-table translation</a></p>"
167
 
168
  demo = gr.Interface(
169
  fn=process_document,
170
- inputs=["image", "text"],
 
 
 
 
 
171
  outputs=[
172
  gr.inputs.Textbox(
173
  lines=8,
 
112
  model = torch.compile(model)
113
 
114
 
115
+ ## FLAN-UL2
116
+ TOKEN = os.environ.get("API_TOKEN", None)
117
+ API_URL = "https://api-inference.huggingface.co/models/google/flan-ul2"
118
+ headers = {"Authorization": f"Bearer {TOKEN}"}
119
+ def query(payload):
120
+ response = requests.post(API_URL, headers=headers, json=payload)
121
+ return response.json()
122
+
123
  def evaluate(
124
  table,
125
  question,
126
+ llm="alpaca-lora",
127
  input=None,
128
  temperature=0.1,
129
  top_p=0.75,
 
133
  **kwargs,
134
  ):
135
  prompt = _TEMPLATE + "\n" + _add_markup(table) + "\n" + "Q: " + question + "\n" + "A:"
136
+ if llm == "alpaca-lora":
137
+ inputs = tokenizer(prompt, return_tensors="pt")
138
+ input_ids = inputs["input_ids"].to(device)
139
+ generation_config = GenerationConfig(
140
+ temperature=temperature,
141
+ top_p=top_p,
142
+ top_k=top_k,
143
+ num_beams=num_beams,
144
+ **kwargs,
 
 
 
 
 
 
 
145
  )
146
+ with torch.no_grad():
147
+ generation_output = model.generate(
148
+ input_ids=input_ids,
149
+ generation_config=generation_config,
150
+ return_dict_in_generate=True,
151
+ output_scores=True,
152
+ max_new_tokens=max_new_tokens,
153
+ )
154
+ s = generation_output.sequences[0]
155
+ output = tokenizer.decode(s)
156
+ elif llm == "flan-ul2":
157
+ output = query({
158
+ "inputs": prompt
159
+ })[0]["generated_text"]
160
+
161
+ else:
162
+ RuntimeError(f"No such LLM: {llm}")
163
+
164
  return output
165
 
166
 
 
168
  model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot", torch_dtype=torch.bfloat16).to(0)
169
  processor_deplot = Pix2StructProcessor.from_pretrained("google/deplot")
170
 
171
+ def process_document(llm, image, question):
172
  # image = Image.open(image)
173
  inputs = processor_deplot(images=image, text="Generate the underlying data table for the figure below:", return_tensors="pt").to(0, torch.bfloat16)
174
  predictions = model_deplot.generate(**inputs, max_new_tokens=512)
175
  table = processor_deplot.decode(predictions[0], skip_special_tokens=True).replace("<0x0A>", "\n")
176
 
177
  # send prompt+table to LLM
178
+ res = evaluate(table, question, llm=llm)
179
  #return res + "\n\n" + res.split("A:")[-1]
180
+ if llm == "alpaca-lora":
181
+ return [table, res.split("A:")[-1]]
182
+ else:
183
+ return [table, res]
184
 
185
  description = "Demo for DePlot+LLM for QA and summarisation. [DePlot](https://arxiv.org/abs/2212.10505) is an image-to-text model that converts plots and charts into a textual sequence. The sequence then is used to prompt LLM for chain-of-thought reasoning. The current underlying LLM is [alpaca-lora](https://huggingface.co/spaces/tloen/alpaca-lora). To use it, simply upload your image and type a question or instruction and click 'submit', or click one of the examples to load them. Read more at the links below."
186
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2212.10505' target='_blank'>DePlot: One-shot visual language reasoning by plot-to-table translation</a></p>"
187
 
188
  demo = gr.Interface(
189
  fn=process_document,
190
+ inputs=[
191
+ gr.Dropdown(
192
+ ["alpaca-lora", "flan-ul2"], label="LLM", info="Will add more LLMs later!"
193
+ ),
194
+ "image",
195
+ "text"],
196
  outputs=[
197
  gr.inputs.Textbox(
198
  lines=8,