Fangyu Liu commited on
Commit
8fc7477
1 Parent(s): 6663c9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -27
app.py CHANGED
@@ -46,32 +46,105 @@ Q: Which party has the second highest favor rates in 2007?
46
  A: Let's find the row of year 2007, that's Row 3. Let's extract the numbers on Row 3: [59.0, 38.0, 45.0]. 45.0 is the second highest. 45.0 is the number of Independents. The answer is Independents.
47
  {_INSTRUCTION}"""
48
 
49
- def text_generate(prompt, table, problem):
50
- p = prompt + "\n" + table + "\n" + "Q: " + problem
51
- # print(f"Final prompt is : {p}")
52
- json_ = {"inputs": p,
53
- "parameters":
54
- {
55
- "top_p": 0.9,
56
- "temperature": 1.1,
57
- "max_new_tokens": 128,
58
- "return_full_text": True
59
- }, "options":
60
- {
61
- "use_cache": True,
62
- "wait_for_model":True
63
- },}
64
- response = requests.post(API_URL, headers=headers, json=json_)
65
- print(f"Response is : {response}")
66
- output = response.json()
67
- print(f"output is : {output}") #{output}")
68
- output_tmp = output['generated_text']
69
- print(f"output_tmp is: {output_tmp}")
70
- #solution = output_tmp.split("\nQ:")[0] #output[0]['generated_text'].split("Q:")[0] # +"."
71
- #print(f"Final response after splits is: {solution}")
72
-
73
- #return solution
74
- return output_tmp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
 
77
 
@@ -86,7 +159,7 @@ def process_document(image, question):
86
  table = processor_deplot.decode(predictions[0], skip_special_tokens=True)
87
 
88
  # send prompt+table to LLM
89
- res = text_generate(_TEMPLATE, table, question)
90
  print (res)
91
 
92
  description = "Demo for deplot+llm for QA or summarisation. To use it, simply upload your image and type a question and click 'submit', or click one of the examples to load them. Read more at the links below."
 
46
  A: Let's find the row of year 2007, that's Row 3. Let's extract the numbers on Row 3: [59.0, 38.0, 45.0]. 45.0 is the second highest. 45.0 is the number of Independents. The answer is Independents.
47
  {_INSTRUCTION}"""
48
 
49
+
50
+ import torch
51
+ from peft import PeftModel
52
+ import transformers
53
+ import gradio as gr
54
+
55
+ assert (
56
+ "LlamaTokenizer" in transformers._import_structure["models.llama"]
57
+ ), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
58
+ from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
59
+
60
+ tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
61
+
62
+ BASE_MODEL = "decapoda-research/llama-7b-hf"
63
+ LORA_WEIGHTS = "tloen/alpaca-lora-7b"
64
+
65
+ if torch.cuda.is_available():
66
+ device = "cuda"
67
+ else:
68
+ device = "cpu"
69
+
70
+ try:
71
+ if torch.backends.mps.is_available():
72
+ device = "mps"
73
+ except:
74
+ pass
75
+
76
+ if device == "cuda":
77
+ model = LlamaForCausalLM.from_pretrained(
78
+ BASE_MODEL,
79
+ load_in_8bit=False,
80
+ torch_dtype=torch.float16,
81
+ device_map="auto",
82
+ )
83
+ model = PeftModel.from_pretrained(
84
+ model, LORA_WEIGHTS, torch_dtype=torch.float16, force_download=True
85
+ )
86
+ elif device == "mps":
87
+ model = LlamaForCausalLM.from_pretrained(
88
+ BASE_MODEL,
89
+ device_map={"": device},
90
+ torch_dtype=torch.float16,
91
+ )
92
+ model = PeftModel.from_pretrained(
93
+ model,
94
+ LORA_WEIGHTS,
95
+ device_map={"": device},
96
+ torch_dtype=torch.float16,
97
+ )
98
+ else:
99
+ model = LlamaForCausalLM.from_pretrained(
100
+ BASE_MODEL, device_map={"": device}, low_cpu_mem_usage=True
101
+ )
102
+ model = PeftModel.from_pretrained(
103
+ model,
104
+ LORA_WEIGHTS,
105
+ device_map={"": device},
106
+ )
107
+
108
+
109
+ if device != "cpu":
110
+ model.half()
111
+ model.eval()
112
+ if torch.__version__ >= "2":
113
+ model = torch.compile(model)
114
+
115
+
116
+ def evaluate(
117
+ table,
118
+ question,
119
+ input=None,
120
+ temperature=0.1,
121
+ top_p=0.75,
122
+ top_k=40,
123
+ num_beams=4,
124
+ max_new_tokens=128,
125
+ **kwargs,
126
+ ):
127
+ prompt = _TEMPLATE + "\n" + table + "\n" + "Q: " + question
128
+ inputs = tokenizer(prompt, return_tensors="pt")
129
+ input_ids = inputs["input_ids"].to(device)
130
+ generation_config = GenerationConfig(
131
+ temperature=temperature,
132
+ top_p=top_p,
133
+ top_k=top_k,
134
+ num_beams=num_beams,
135
+ **kwargs,
136
+ )
137
+ with torch.no_grad():
138
+ generation_output = model.generate(
139
+ input_ids=input_ids,
140
+ generation_config=generation_config,
141
+ return_dict_in_generate=True,
142
+ output_scores=True,
143
+ max_new_tokens=max_new_tokens,
144
+ )
145
+ s = generation_output.sequences[0]
146
+ output = tokenizer.decode(s)
147
+ return output.split("### Response:")[1].strip()
148
 
149
 
150
 
 
159
  table = processor_deplot.decode(predictions[0], skip_special_tokens=True)
160
 
161
  # send prompt+table to LLM
162
+ res = evaluate(table, question)
163
  print (res)
164
 
165
  description = "Demo for deplot+llm for QA or summarisation. To use it, simply upload your image and type a question and click 'submit', or click one of the examples to load them. Read more at the links below."