fl399 commited on
Commit
8507790
1 Parent(s): 15c5816

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -16
app.py CHANGED
@@ -7,6 +7,18 @@ import transformers
7
  from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
8
  from peft import PeftModel
9
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  ## CoT prompts
11
 
12
  def _add_markup(table):
@@ -24,7 +36,6 @@ def _add_markup(table):
24
  # just use the raw table if parsing fails
25
  return table
26
 
27
-
28
  _TABLE = """Year | Democrats | Republicans | Independents
29
  2004 | 68.1% | 45.0% | 53.0%
30
  2006 | 58.0% | 42.0% | 53.0%
@@ -36,7 +47,6 @@ _TABLE = """Year | Democrats | Republicans | Independents
36
 
37
  _INSTRUCTION = 'Read the table below to answer the following questions.'
38
 
39
-
40
  _TEMPLATE = f"""First read an example then the complete question for the second table.
41
  ------------
42
  {_INSTRUCTION}
@@ -56,7 +66,6 @@ A: Let's find the row of year 2007, that's Row 3. Let's extract the numbers on R
56
 
57
  ## alpaca-lora
58
 
59
- # debugging...
60
  assert (
61
  "LlamaTokenizer" in transformers._import_structure["models.llama"]
62
  ), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
@@ -67,17 +76,6 @@ tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
67
  BASE_MODEL = "decapoda-research/llama-7b-hf"
68
  LORA_WEIGHTS = "tloen/alpaca-lora-7b"
69
 
70
- if torch.cuda.is_available():
71
- device = "cuda"
72
- else:
73
- device = "cpu"
74
-
75
- try:
76
- if torch.backends.mps.is_available():
77
- device = "mps"
78
- except:
79
- pass
80
-
81
  if device == "cuda":
82
  model = LlamaForCausalLM.from_pretrained(
83
  BASE_MODEL,
@@ -147,7 +145,10 @@ def get_response_from_openai(prompt, model="gpt-3.5-turbo", max_output_tokens=25
147
  return ret
148
 
149
  ## deplot models
150
- model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot", torch_dtype=torch.bfloat16).to(0)
 
 
 
151
  processor_deplot = Pix2StructProcessor.from_pretrained("google/deplot")
152
 
153
  def evaluate(
@@ -199,7 +200,9 @@ def evaluate(
199
 
200
  def process_document(image, question, llm):
201
  # image = Image.open(image)
202
- inputs = processor_deplot(images=image, text="Generate the underlying data table for the figure below:", return_tensors="pt").to(0, torch.bfloat16)
 
 
203
  predictions = model_deplot.generate(**inputs, max_new_tokens=512)
204
  table = processor_deplot.decode(predictions[0], skip_special_tokens=True).replace("<0x0A>", "\n")
205
 
 
7
  from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
8
  from peft import PeftModel
9
 
10
+
11
+ if torch.cuda.is_available():
12
+ device = "cuda"
13
+ else:
14
+ device = "cpu"
15
+
16
+ try:
17
+ if torch.backends.mps.is_available():
18
+ device = "mps"
19
+ except:
20
+ pass
21
+
22
  ## CoT prompts
23
 
24
  def _add_markup(table):
 
36
  # just use the raw table if parsing fails
37
  return table
38
 
 
39
  _TABLE = """Year | Democrats | Republicans | Independents
40
  2004 | 68.1% | 45.0% | 53.0%
41
  2006 | 58.0% | 42.0% | 53.0%
 
47
 
48
  _INSTRUCTION = 'Read the table below to answer the following questions.'
49
 
 
50
  _TEMPLATE = f"""First read an example then the complete question for the second table.
51
  ------------
52
  {_INSTRUCTION}
 
66
 
67
  ## alpaca-lora
68
 
 
69
  assert (
70
  "LlamaTokenizer" in transformers._import_structure["models.llama"]
71
  ), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
 
76
  BASE_MODEL = "decapoda-research/llama-7b-hf"
77
  LORA_WEIGHTS = "tloen/alpaca-lora-7b"
78
 
 
 
 
 
 
 
 
 
 
 
 
79
  if device == "cuda":
80
  model = LlamaForCausalLM.from_pretrained(
81
  BASE_MODEL,
 
145
  return ret
146
 
147
  ## deplot models
148
+ if device == "cuda":
149
+ model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot", torch_dtype=torch.bfloat16).to(0)
150
+ else:
151
+ model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot")
152
  processor_deplot = Pix2StructProcessor.from_pretrained("google/deplot")
153
 
154
  def evaluate(
 
200
 
201
  def process_document(image, question, llm):
202
  # image = Image.open(image)
203
+ inputs = processor_deplot(images=image, text="Generate the underlying data table for the figure below:", return_tensors="pt")
204
+ if device == "cuda":
205
+ inputs = inputs.to(0, torch.bfloat16)
206
  predictions = model_deplot.generate(**inputs, max_new_tokens=512)
207
  table = processor_deplot.decode(predictions[0], skip_special_tokens=True).replace("<0x0A>", "\n")
208