fl399 commited on
Commit
58d8b0b
1 Parent(s): 1381c4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -28
app.py CHANGED
@@ -135,33 +135,34 @@ def evaluate(
135
  ):
136
  prompt = _TEMPLATE + "\n" + _add_markup(table) + "\n" + "Q: " + question + "\n" + "A:"
137
  if llm == "alpaca-lora":
138
- # inputs = tokenizer(prompt, return_tensors="pt")
139
- # input_ids = inputs["input_ids"].to(device)
140
- # generation_config = GenerationConfig(
141
- # temperature=temperature,
142
- # top_p=top_p,
143
- # top_k=top_k,
144
- # num_beams=num_beams,
145
- # **kwargs,
146
- # )
147
- # with torch.no_grad():
148
- # generation_output = model.generate(
149
- # input_ids=input_ids,
150
- # generation_config=generation_config,
151
- # return_dict_in_generate=True,
152
- # output_scores=True,
153
- # max_new_tokens=max_new_tokens,
154
- # )
155
- # s = generation_output.sequences[0]
156
- # output = tokenizer.decode(s)
157
- output = query({
158
- "inputs": prompt
159
- })
160
  elif llm == "flan-ul2":
161
- output = query({
162
- "inputs": prompt
163
- })
164
-
 
165
  else:
166
  RuntimeError(f"No such LLM: {llm}")
167
 
@@ -182,8 +183,8 @@ def process_document(llm, image, question):
182
  res = evaluate(table, question, llm=llm)
183
  #return res + "\n\n" + res.split("A:")[-1]
184
  if llm == "alpaca-lora":
185
- #return [table, res.split("A:")[-1]]
186
- return [table, res]
187
  else:
188
  return [table, res]
189
 
 
135
  ):
136
  prompt = _TEMPLATE + "\n" + _add_markup(table) + "\n" + "Q: " + question + "\n" + "A:"
137
  if llm == "alpaca-lora":
138
+ inputs = tokenizer(prompt, return_tensors="pt")
139
+ input_ids = inputs["input_ids"].to(device)
140
+ generation_config = GenerationConfig(
141
+ temperature=temperature,
142
+ top_p=top_p,
143
+ top_k=top_k,
144
+ num_beams=num_beams,
145
+ **kwargs,
146
+ )
147
+ with torch.no_grad():
148
+ generation_output = model.generate(
149
+ input_ids=input_ids,
150
+ generation_config=generation_config,
151
+ return_dict_in_generate=True,
152
+ output_scores=True,
153
+ max_new_tokens=max_new_tokens,
154
+ )
155
+ s = generation_output.sequences[0]
156
+ output = tokenizer.decode(s)
157
+ # output = query({
158
+ # "inputs": prompt
159
+ # })
160
  elif llm == "flan-ul2":
161
+ # in development...
162
+ # output = query({
163
+ # "inputs": prompt
164
+ # })
165
+ output = "in dev..."
166
  else:
167
  RuntimeError(f"No such LLM: {llm}")
168
 
 
183
  res = evaluate(table, question, llm=llm)
184
  #return res + "\n\n" + res.split("A:")[-1]
185
  if llm == "alpaca-lora":
186
+ return [table, res.split("A:")[-1]]
187
+ # return [table, res]
188
  else:
189
  return [table, res]
190