tiedong commited on
Commit
a56348b
1 Parent(s): dad0102

Add application file

Browse files
Files changed (1) hide show
  1. app.py +192 -0
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import fire
5
+ import gradio as gr
6
+ import torch
7
+ import transformers
8
+ from peft import PeftModel
9
+ from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
10
+
11
+ from utils.callbacks import Iteratorize, Stream
12
+ from utils.prompter import Prompter
13
+
14
+ if torch.cuda.is_available():
15
+ device = "cuda"
16
+ else:
17
+ device = "cpu"
18
+
19
+ try:
20
+ if torch.backends.mps.is_available():
21
+ device = "mps"
22
+ except:
23
+ pass
24
+
25
+
26
+ def main(
27
+ load_8bit: bool = False,
28
+ base_model: str = "",
29
+ lora_weights: str = "tiedong/goat-lora-7b",
30
+ prompt_template: str = "goat",
31
+ server_name: str = "0.0.0.0",
32
+ share_gradio: bool = True,
33
+ ):
34
+ base_model = base_model or os.environ.get("BASE_MODEL", "")
35
+ assert (
36
+ base_model
37
+ ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
38
+
39
+ prompter = Prompter(prompt_template)
40
+ tokenizer = LlamaTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer')
41
+ if device == "cuda":
42
+ model = LlamaForCausalLM.from_pretrained(
43
+ base_model,
44
+ load_in_8bit=load_8bit,
45
+ torch_dtype=torch.float16,
46
+ device_map="auto",
47
+ )
48
+ model = PeftModel.from_pretrained(
49
+ model,
50
+ lora_weights,
51
+ torch_dtype=torch.float16,
52
+ )
53
+ elif device == "mps":
54
+ model = LlamaForCausalLM.from_pretrained(
55
+ base_model,
56
+ device_map={"": device},
57
+ torch_dtype=torch.float16,
58
+ )
59
+ model = PeftModel.from_pretrained(
60
+ model,
61
+ lora_weights,
62
+ device_map={"": device},
63
+ torch_dtype=torch.float16,
64
+ )
65
+ else:
66
+ model = LlamaForCausalLM.from_pretrained(
67
+ base_model, device_map={"": device}, low_cpu_mem_usage=True
68
+ )
69
+ model = PeftModel.from_pretrained(
70
+ model,
71
+ lora_weights,
72
+ device_map={"": device},
73
+ )
74
+
75
+ if not load_8bit:
76
+ model.half()
77
+
78
+ model.eval()
79
+ if torch.__version__ >= "2" and sys.platform != "win32":
80
+ model = torch.compile(model)
81
+
82
+ def evaluate(
83
+ instruction,
84
+ temperature=0.1,
85
+ top_p=0.75,
86
+ top_k=40,
87
+ num_beams=4,
88
+ max_new_tokens=512,
89
+ stream_output=True,
90
+ **kwargs,
91
+ ):
92
+ prompt = prompter.generate_prompt_inference(instruction)
93
+ inputs = tokenizer(prompt, return_tensors="pt")
94
+ input_ids = inputs["input_ids"].to(device)
95
+ generation_config = GenerationConfig(
96
+ temperature=temperature,
97
+ top_p=top_p,
98
+ top_k=top_k,
99
+ num_beams=num_beams,
100
+ **kwargs,
101
+ )
102
+
103
+ generate_params = {
104
+ "input_ids": input_ids,
105
+ "generation_config": generation_config,
106
+ "return_dict_in_generate": True,
107
+ "output_scores": True,
108
+ "max_new_tokens": max_new_tokens,
109
+ }
110
+
111
+ if stream_output:
112
+ # Stream the reply 1 token at a time.
113
+ # This is based on the trick of using 'stopping_criteria' to create an iterator,
114
+ # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
115
+
116
+ def generate_with_callback(callback=None, **kwargs):
117
+ kwargs.setdefault(
118
+ "stopping_criteria", transformers.StoppingCriteriaList()
119
+ )
120
+ kwargs["stopping_criteria"].append(
121
+ Stream(callback_func=callback)
122
+ )
123
+ with torch.no_grad():
124
+ model.generate(**kwargs)
125
+
126
+ def generate_with_streaming(**kwargs):
127
+ return Iteratorize(
128
+ generate_with_callback, kwargs, callback=None
129
+ )
130
+
131
+ with generate_with_streaming(**generate_params) as generator:
132
+ for output in generator:
133
+ # new_tokens = len(output) - len(input_ids[0])
134
+ decoded_output = tokenizer.decode(output)
135
+
136
+ if output[-1] in [tokenizer.eos_token_id]:
137
+ break
138
+
139
+ yield prompter.get_response(decoded_output)
140
+ return # early return for stream_output
141
+
142
+ # Without streaming
143
+ with torch.no_grad():
144
+ generation_output = model.generate(
145
+ input_ids=input_ids,
146
+ generation_config=generation_config,
147
+ return_dict_in_generate=True,
148
+ output_scores=True,
149
+ max_new_tokens=max_new_tokens,
150
+ )
151
+ s = generation_output.sequences[0]
152
+ output = tokenizer.decode(s, skip_special_tokens=True).strip()
153
+ yield prompter.get_response(output)
154
+
155
+ gr.Interface(
156
+ fn=evaluate,
157
+ inputs=[
158
+ gr.components.Textbox(
159
+ lines=2,
160
+ label="Arithmetic",
161
+ placeholder="What is 63303235 + 20239503",
162
+ ),
163
+ gr.components.Slider(
164
+ minimum=0, maximum=1, value=0.1, label="Temperature"
165
+ ),
166
+ gr.components.Slider(
167
+ minimum=0, maximum=1, value=0.75, label="Top p"
168
+ ),
169
+ gr.components.Slider(
170
+ minimum=0, maximum=100, step=1, value=40, label="Top k"
171
+ ),
172
+ gr.components.Slider(
173
+ minimum=1, maximum=4, step=1, value=4, label="Beams"
174
+ ),
175
+ gr.components.Slider(
176
+ minimum=1, maximum=1024, step=1, value=512, label="Max tokens"
177
+ ),
178
+ gr.components.Checkbox(label="Stream output"),
179
+ ],
180
+ outputs=[
181
+ gr.inputs.Textbox(
182
+ lines=5,
183
+ label="Output",
184
+ )
185
+ ],
186
+ title="Goat-loRA-7b",
187
+ description="Goat-LoRA-7b is a 7B-parameter LLaMA finetuned to perform arithmetic tasks, including addition, subtraction, multiplication, and division of integers. It is trained on a synthetic dataset (https://github.com/liutiedong/goat) and makes use of the Huggingface LLaMA implementation. For more information, please visit [the project's website](https://github.com/liutiedong/goat).", # noqa: E501
188
+ ).queue().launch(server_name="0.0.0.0", share=share_gradio)
189
+
190
+
191
+ if __name__ == "__main__":
192
+ fire.Fire(main)