fl399 commited on
Commit
9308288
1 Parent(s): c225b4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -148,7 +148,7 @@ def get_response_from_openai(prompt, model="gpt-3.5-turbo", max_output_tokens=25
148
  if device == "cuda":
149
  model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot", torch_dtype=torch.bfloat16).to(0)
150
  else:
151
- model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot")
152
  processor_deplot = Pix2StructProcessor.from_pretrained("google/deplot")
153
 
154
  def evaluate(
@@ -200,9 +200,9 @@ def evaluate(
200
 
201
  def process_document(image, question, llm):
202
  # image = Image.open(image)
203
- inputs = processor_deplot(images=image, text="Generate the underlying data table for the figure below:", return_tensors="pt")
204
  if device == "cuda":
205
- inputs = inputs.to(0, torch.bfloat16)
206
  predictions = model_deplot.generate(**inputs, max_new_tokens=512)
207
  table = processor_deplot.decode(predictions[0], skip_special_tokens=True).replace("<0x0A>", "\n")
208
 
 
148
  if device == "cuda":
149
  model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot", torch_dtype=torch.bfloat16).to(0)
150
  else:
151
+ model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot", torch_dtype=torch.bfloat16)
152
  processor_deplot = Pix2StructProcessor.from_pretrained("google/deplot")
153
 
154
  def evaluate(
 
200
 
201
  def process_document(image, question, llm):
202
  # image = Image.open(image)
203
+ inputs = processor_deplot(images=image, text="Generate the underlying data table for the figure below:", return_tensors="pt").to(torch.bfloat16)
204
  if device == "cuda":
205
+ inputs = inputs.to(0)
206
  predictions = model_deplot.generate(**inputs, max_new_tokens=512)
207
  table = processor_deplot.decode(predictions[0], skip_special_tokens=True).replace("<0x0A>", "\n")
208