fl399 commited on
Commit
08fb2c9
1 Parent(s): 6e32116

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -14
app.py CHANGED
@@ -126,6 +126,7 @@ def evaluate(
126
  table,
127
  question,
128
  llm="alpaca-lora",
 
129
  input=None,
130
  temperature=0.1,
131
  top_p=0.75,
@@ -134,9 +135,13 @@ def evaluate(
134
  max_new_tokens=128,
135
  **kwargs,
136
  ):
 
137
  prompt = _TEMPLATE + "\n" + _add_markup(table) + "\n" + "Q: " + question + "\n" + "A:"
138
  if llm == "alpaca-lora":
139
- inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
140
  input_ids = inputs["input_ids"].to(device)
141
  generation_config = GenerationConfig(
142
  temperature=temperature,
@@ -155,15 +160,15 @@ def evaluate(
155
  )
156
  s = generation_output.sequences[0]
157
  output = tokenizer.decode(s)
158
- # output = query({
159
- # "inputs": prompt
160
- # })
161
  elif llm == "flan-ul2":
162
- # in development...
163
- output = query({
164
- "inputs": prompt
165
- }) #[0]["generated_text"]
166
- # output = "in dev..."
 
 
 
167
  else:
168
  RuntimeError(f"No such LLM: {llm}")
169
 
@@ -200,6 +205,11 @@ demo = gr.Interface(
200
  ),
201
  "image",
202
  "text"],
 
 
 
 
 
203
  outputs=[
204
  gr.inputs.Textbox(
205
  lines=8,
@@ -214,11 +224,11 @@ demo = gr.Interface(
214
  description=description,
215
  article=article,
216
  enable_queue=True,
217
- examples=[["alpaca-lora", "deplot_case_study_m1.png", "What is the sum of numbers of Indonesia and Ireland? Remember to think step by step."],
218
- ["alpaca-lora", "deplot_case_study_m1.png", "Summarise the chart for me please."],
219
- ["alpaca-lora", "deplot_case_study_3.png", "By how much did China's growth rate drop? Think step by step."],
220
- ["alpaca-lora", "deplot_case_study_4.png", "How many papers are submitted in 2020?"],
221
- ["alpaca-lora", "deplot_case_study_x2.png", "Summarise the chart for me please."]],
222
  cache_examples=True)
223
 
224
  demo.launch(debug=True)
 
126
  table,
127
  question,
128
  llm="alpaca-lora",
129
+ shot="1-shot",
130
  input=None,
131
  temperature=0.1,
132
  top_p=0.75,
 
135
  max_new_tokens=128,
136
  **kwargs,
137
  ):
138
+ prompt_0shot = _INSTRUCTION + "\n" + _add_markup(table) + "\n" + "Q: " + question + "\n" + "A:"
139
  prompt = _TEMPLATE + "\n" + _add_markup(table) + "\n" + "Q: " + question + "\n" + "A:"
140
  if llm == "alpaca-lora":
141
+ if shot == "1-shot":
142
+ inputs = tokenizer(prompt, return_tensors="pt")
143
+ else:
144
+ inputs = tokenizer(prompt_0shot, return_tensors="pt")
145
  input_ids = inputs["input_ids"].to(device)
146
  generation_config = GenerationConfig(
147
  temperature=temperature,
 
160
  )
161
  s = generation_output.sequences[0]
162
  output = tokenizer.decode(s)
 
 
 
163
  elif llm == "flan-ul2":
164
+ if shot == "1-shot":
165
+ output = query({
166
+ "inputs": prompt
167
+ })[0]["generated_text"]
168
+ else:
169
+ output = query({
170
+ "inputs": prompt_0shot
171
+ })[0]["generated_text"]
172
  else:
173
  RuntimeError(f"No such LLM: {llm}")
174
 
 
205
  ),
206
  "image",
207
  "text"],
208
+ gr.Dropdown(
209
+ ["0-shot", "1-shot"], label="#shots", info="How many example tables in the prompt?"
210
+ ),
211
+ "image",
212
+ "text"],
213
  outputs=[
214
  gr.inputs.Textbox(
215
  lines=8,
 
224
  description=description,
225
  article=article,
226
  enable_queue=True,
227
+ examples=[["alpaca-lora", "1-shot", "deplot_case_study_m1.png", "What is the sum of numbers of Indonesia and Ireland? Remember to think step by step."],
228
+ ["alpaca-lora", "1-shot", "deplot_case_study_m1.png", "Summarise the chart for me please."],
229
+ ["alpaca-lora", "1-shot", "deplot_case_study_3.png", "By how much did China's growth rate drop? Think step by step."],
230
+ ["alpaca-lora", "1-shot", "deplot_case_study_4.png", "How many papers are submitted in 2020?"],
231
+ ["alpaca-lora", "1-shot", "deplot_case_study_x2.png", "Summarise the chart for me please."]],
232
  cache_examples=True)
233
 
234
  demo.launch(debug=True)