xuxw98 commited on
Commit
b0a3abb
1 Parent(s): 1126743

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +277 -0
  2. requirements.txt +8 -0
  3. style.css +4 -0
app.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import glob
4
+ import sys
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Tuple
8
+
9
+ from huggingface_hub import hf_hub_download
10
+ from PIL import Image
11
+ import gradio as gr
12
+ import torch
13
+ from fairscale.nn.model_parallel.initialize import initialize_model_parallel
14
+
15
+ from llama import LLaMA, ModelArgs, Tokenizer, Transformer, VisionModel
16
+
17
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
18
+
19
+ PROMPT_DICT = {
20
+ "prompt_input": (
21
+ "Below is an instruction that describes a task, paired with an input that provides further context. "
22
+ "Write a response that appropriately completes the request.\n\n"
23
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
24
+ ),
25
+ "prompt_no_input": (
26
+ "Below is an instruction that describes a task. "
27
+ "Write a response that appropriately completes the request.\n\n"
28
+ "### Instruction:\n{instruction}\n\n### Response:"
29
+ ),
30
+ }
31
+
32
+
33
+ def setup_model_parallel() -> Tuple[int, int]:
34
+ os.environ['RANK'] = '0'
35
+ os.environ['WORLD_SIZE'] = '1'
36
+ os.environ['MP'] = '1'
37
+ os.environ['MASTER_ADDR'] = '127.0.0.1'
38
+ os.environ['MASTER_PORT'] = '2223'
39
+ local_rank = int(os.environ.get("LOCAL_RANK", -1))
40
+ world_size = int(os.environ.get("WORLD_SIZE", -1))
41
+
42
+ torch.distributed.init_process_group("nccl")
43
+ initialize_model_parallel(world_size)
44
+ torch.cuda.set_device(local_rank)
45
+
46
+ # seed must be the same in all processes
47
+ torch.manual_seed(1)
48
+ return local_rank, world_size
49
+
50
+
51
+ def load(
52
+ ckpt0_path: str,
53
+ ckpt1_path: str,
54
+ param_path: str,
55
+ tokenizer_path: str,
56
+ instruct_adapter_path: str,
57
+ caption_adapter_path: str,
58
+ local_rank: int,
59
+ world_size: int,
60
+ max_seq_len: int,
61
+ max_batch_size: int,
62
+ ) -> LLaMA:
63
+ start_time = time.time()
64
+ print("Loading")
65
+ instruct_adapter_checkpoint = torch.load(
66
+ instruct_adapter_path, map_location="cpu")
67
+ caption_adapter_checkpoint = torch.load(
68
+ caption_adapter_path, map_location="cpu")
69
+ with open(param_path, "r") as f:
70
+ params = json.loads(f.read())
71
+
72
+ model_args: ModelArgs = ModelArgs(
73
+ max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
74
+ )
75
+ model_args.adapter_layer = int(
76
+ instruct_adapter_checkpoint['adapter_query.weight'].shape[0] / model_args.adapter_len)
77
+ model_args.cap_adapter_layer = int(
78
+ caption_adapter_checkpoint['cap_adapter_query.weight'].shape[0] / model_args.cap_adapter_len)
79
+
80
+ tokenizer = Tokenizer(model_path=tokenizer_path)
81
+ model_args.vocab_size = tokenizer.n_words
82
+ torch.set_default_tensor_type(torch.cuda.HalfTensor)
83
+ model = Transformer(model_args)
84
+
85
+ # To reduce memory usuage
86
+ ckpt0 = torch.load(ckpt0_path, map_location='cuda')
87
+ model.load_state_dict(ckpt0, strict=False)
88
+ del ckpt0
89
+ torch.cuda.empty_cache()
90
+
91
+ ckpt1 = torch.load(ckpt1_path, map_location='cuda')
92
+ model.load_state_dict(ckpt1, strict=False)
93
+ del ckpt1
94
+ torch.cuda.empty_cache()
95
+
96
+ vision_model = VisionModel(model_args)
97
+
98
+ torch.set_default_tensor_type(torch.FloatTensor)
99
+ model.load_state_dict(instruct_adapter_checkpoint, strict=False)
100
+ model.load_state_dict(caption_adapter_checkpoint, strict=False)
101
+ vision_model.load_state_dict(caption_adapter_checkpoint, strict=False)
102
+
103
+ generator = LLaMA(model, tokenizer, vision_model)
104
+ print(f"Loaded in {time.time() - start_time:.2f} seconds")
105
+ return generator
106
+
107
+
108
+ def instruct_generate(
109
+ instruct: str,
110
+ input: str = 'none',
111
+ max_gen_len=512,
112
+ temperature: float = 0.1,
113
+ top_p: float = 0.75,
114
+ ):
115
+ if input == 'none':
116
+ prompt = PROMPT_DICT['prompt_no_input'].format_map(
117
+ {'instruction': instruct, 'input': ''})
118
+ else:
119
+ prompt = PROMPT_DICT['prompt_input'].format_map(
120
+ {'instruction': instruct, 'input': input})
121
+
122
+ results = generator.generate(
123
+ [prompt], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p
124
+ )
125
+ result = results[0].strip()
126
+ print(result)
127
+ return result
128
+
129
+
130
+ def caption_generate(
131
+ img: str,
132
+ max_gen_len=512,
133
+ temperature: float = 0.1,
134
+ top_p: float = 0.75,
135
+ ):
136
+ imgs = [Image.open(img).convert('RGB')]
137
+ prompts = ["Generate caption of this image :",] * len(imgs)
138
+
139
+ results = generator.generate(
140
+ prompts, imgs=imgs, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p
141
+ )
142
+ result = results[0].strip()
143
+ print(result)
144
+ return result
145
+
146
+
147
+ def download_llama_adapter(instruct_adapter_path, caption_adapter_path):
148
+ if not os.path.exists(instruct_adapter_path):
149
+ os.system(
150
+ f"wget -q -O {instruct_adapter_path} https://github.com/ZrrSkywalker/LLaMA-Adapter/releases/download/v.1.0.0/llama_adapter_len10_layer30_release.pth")
151
+
152
+ if not os.path.exists(caption_adapter_path):
153
+ os.system(
154
+ f"wget -q -O {caption_adapter_path} https://github.com/ZrrSkywalker/LLaMA-Adapter/releases/download/v.1.0.0/llama_adapter_len10_layer30_caption_vit_l.pth")
155
+
156
+
157
+ # ckpt_path = "/data1/llma/7B/consolidated.00.pth"
158
+ # param_path = "/data1/llma/7B/params.json"
159
+ # tokenizer_path = "/data1/llma/tokenizer.model"
160
+ ckpt0_path = hf_hub_download(
161
+ repo_id="csuhan/llama_storage", filename="consolidated.00_part0.pth")
162
+ ckpt1_path = hf_hub_download(
163
+ repo_id="csuhan/llama_storage", filename="consolidated.00_part1.pth")
164
+ param_path = hf_hub_download(
165
+ repo_id="nyanko7/LLaMA-7B", filename="params.json")
166
+ tokenizer_path = hf_hub_download(
167
+ repo_id="nyanko7/LLaMA-7B", filename="tokenizer.model")
168
+ instruct_adapter_path = "llama_adapter_len10_layer30_release.pth"
169
+ caption_adapter_path = "llama_adapter_len10_layer30_caption_vit_l.pth"
170
+ max_seq_len = 512
171
+ max_batch_size = 1
172
+
173
+ # download models
174
+ # download_llama_adapter(instruct_adapter_path, caption_adapter_path)
175
+
176
+ local_rank, world_size = setup_model_parallel()
177
+ if local_rank > 0:
178
+ sys.stdout = open(os.devnull, "w")
179
+
180
+ generator = load(
181
+ ckpt0_path, ckpt1_path, param_path, tokenizer_path, instruct_adapter_path, caption_adapter_path, local_rank, world_size, max_seq_len, max_batch_size
182
+ )
183
+
184
+
185
+ def create_instruct_demo():
186
+ with gr.Blocks() as instruct_demo:
187
+ with gr.Row():
188
+ with gr.Column():
189
+ instruction = gr.Textbox(lines=2, label="Instruction")
190
+ input = gr.Textbox(
191
+ lines=2, label="Context input", placeholder='none')
192
+ max_len = gr.Slider(minimum=1, maximum=512,
193
+ value=128, label="Max length")
194
+ with gr.Accordion(label='Advanced options', open=False):
195
+ temp = gr.Slider(minimum=0, maximum=1,
196
+ value=0.1, label="Temperature")
197
+ top_p = gr.Slider(minimum=0, maximum=1,
198
+ value=0.75, label="Top p")
199
+
200
+ run_botton = gr.Button("Run")
201
+
202
+ with gr.Column():
203
+ outputs = gr.Textbox(lines=10, label="Output")
204
+
205
+ inputs = [instruction, input, max_len, temp, top_p]
206
+
207
+ examples = [
208
+ "Tell me about alpacas.",
209
+ "Write a Python program that prints the first 10 Fibonacci numbers.",
210
+ "Write a conversation between the sun and pluto.",
211
+ "Write a theory to explain why cat never existed",
212
+ ]
213
+ examples = [
214
+ [x, "none", 128, 0.1, 0.75]
215
+ for x in examples]
216
+
217
+ gr.Examples(
218
+ examples=examples,
219
+ inputs=inputs,
220
+ outputs=outputs,
221
+ fn=instruct_generate,
222
+ cache_examples=os.getenv('SYSTEM') == 'spaces'
223
+ )
224
+ run_botton.click(fn=instruct_generate, inputs=inputs, outputs=outputs)
225
+ return instruct_demo
226
+
227
+
228
+ def create_caption_demo():
229
+ with gr.Blocks() as instruct_demo:
230
+ with gr.Row():
231
+ with gr.Column():
232
+ img = gr.Image(label='Input', type='filepath')
233
+ max_len = gr.Slider(minimum=1, maximum=512,
234
+ value=64, label="Max length")
235
+ with gr.Accordion(label='Advanced options', open=False):
236
+ temp = gr.Slider(minimum=0, maximum=1,
237
+ value=0.1, label="Temperature")
238
+ top_p = gr.Slider(minimum=0, maximum=1,
239
+ value=0.75, label="Top p")
240
+
241
+ run_botton = gr.Button("Run")
242
+
243
+ with gr.Column():
244
+ outputs = gr.Textbox(lines=10, label="Output")
245
+
246
+ inputs = [img, max_len, temp, top_p]
247
+
248
+ examples = glob.glob("caption_demo/*.jpg")
249
+ examples = [
250
+ [x, 64, 0.1, 0.75]
251
+ for x in examples]
252
+
253
+ gr.Examples(
254
+ examples=examples,
255
+ inputs=inputs,
256
+ outputs=outputs,
257
+ fn=caption_generate,
258
+ cache_examples=os.getenv('SYSTEM') == 'spaces'
259
+ )
260
+ run_botton.click(fn=caption_generate, inputs=inputs, outputs=outputs)
261
+ return instruct_demo
262
+
263
+
264
+ description = """
265
+ # LLaMA-Adapter🚀
266
+ The official demo for **LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention**.
267
+ Please refer to our [arXiv paper](https://arxiv.org/abs/2303.16199) and [github](https://github.com/ZrrSkywalker/LLaMA-Adapter) for more details.
268
+ """
269
+
270
+ with gr.Blocks(css='style.css') as demo:
271
+ gr.Markdown(description)
272
+ with gr.TabItem("Instruction-Following"):
273
+ create_instruct_demo()
274
+ with gr.TabItem("Image Captioning"):
275
+ create_caption_demo()
276
+
277
+ demo.queue(api_open=True, concurrency_count=1).launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu113
2
+ torch==1.12.0+cu113
3
+ fairscale
4
+ sentencepiece
5
+ Pillow
6
+ huggingface_hub
7
+ git+https://github.com/csuhan/timm_0_3_2.git
8
+ git+https://github.com/openai/CLIP.git
style.css ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ h1,p {
2
+ text-align: center;
3
+ }
4
+