csuhan commited on
Commit
59ffc68
1 Parent(s): a4d3a1d
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ example*
2
+ *bak
3
+ flagged
4
+ *.sh
5
+ __pycache__/
6
+ *.pth
README.md CHANGED
@@ -9,4 +9,7 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
9
  pinned: false
10
  ---
11
 
12
+ ### LLaMA-Adapter
13
+ The official demo for **LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention**.
14
+ Please refer to our [arXiv paper](https://arxiv.org/abs/2303.16199) and [github](https://github.com/ZrrSkywalker/LLaMA-Adapter) for more details.
15
+
app.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import glob
4
+ import sys
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Tuple
8
+
9
+ from PIL import Image
10
+ import gradio as gr
11
+ import torch
12
+ from fairscale.nn.model_parallel.initialize import initialize_model_parallel
13
+
14
+ from llama import LLaMA, ModelArgs, Tokenizer, Transformer, VisionModel
15
+
16
+ PROMPT_DICT = {
17
+ "prompt_input": (
18
+ "Below is an instruction that describes a task, paired with an input that provides further context. "
19
+ "Write a response that appropriately completes the request.\n\n"
20
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
21
+ ),
22
+ "prompt_no_input": (
23
+ "Below is an instruction that describes a task. "
24
+ "Write a response that appropriately completes the request.\n\n"
25
+ "### Instruction:\n{instruction}\n\n### Response:"
26
+ ),
27
+ }
28
+
29
+
30
+ def setup_model_parallel() -> Tuple[int, int]:
31
+ os.environ['RANK'] = '0'
32
+ os.environ['WORLD_SIZE'] = '1'
33
+ os.environ['MP'] = '1'
34
+ os.environ['MASTER_ADDR'] = '127.0.0.1'
35
+ os.environ['MASTER_PORT'] = '2223'
36
+ local_rank = int(os.environ.get("LOCAL_RANK", -1))
37
+ world_size = int(os.environ.get("WORLD_SIZE", -1))
38
+
39
+ torch.distributed.init_process_group("nccl")
40
+ initialize_model_parallel(world_size)
41
+ torch.cuda.set_device(local_rank)
42
+
43
+ # seed must be the same in all processes
44
+ torch.manual_seed(1)
45
+ return local_rank, world_size
46
+
47
+
48
+ def load(
49
+ ckpt_dir: str,
50
+ tokenizer_path: str,
51
+ instruct_adapter_path: str,
52
+ caption_adapter_path: str,
53
+ local_rank: int,
54
+ world_size: int,
55
+ max_seq_len: int,
56
+ max_batch_size: int,
57
+ ) -> LLaMA:
58
+ start_time = time.time()
59
+ checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
60
+ assert world_size == len(
61
+ checkpoints
62
+ ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
63
+ ckpt_path = checkpoints[local_rank]
64
+ print("Loading")
65
+ checkpoint = torch.load(ckpt_path, map_location="cuda")
66
+ instruct_adapter_checkpoint = torch.load(
67
+ instruct_adapter_path, map_location="cuda")
68
+ caption_adapter_checkpoint = torch.load(
69
+ caption_adapter_path, map_location="cuda")
70
+ with open(Path(ckpt_dir) / "params.json", "r") as f:
71
+ params = json.loads(f.read())
72
+
73
+ model_args: ModelArgs = ModelArgs(
74
+ max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
75
+ )
76
+ model_args.adapter_layer = int(
77
+ instruct_adapter_checkpoint['adapter_query.weight'].shape[0] / model_args.adapter_len)
78
+ model_args.cap_adapter_layer = int(
79
+ caption_adapter_checkpoint['cap_adapter_query.weight'].shape[0] / model_args.cap_adapter_len)
80
+
81
+ tokenizer = Tokenizer(model_path=tokenizer_path)
82
+ model_args.vocab_size = tokenizer.n_words
83
+ torch.set_default_tensor_type(torch.cuda.HalfTensor)
84
+ model = Transformer(model_args)
85
+ vision_model = VisionModel(model_args)
86
+
87
+ torch.set_default_tensor_type(torch.FloatTensor)
88
+ model.load_state_dict(checkpoint, strict=False)
89
+ model.load_state_dict(instruct_adapter_checkpoint, strict=False)
90
+ model.load_state_dict(caption_adapter_checkpoint, strict=False)
91
+ vision_model.load_state_dict(caption_adapter_checkpoint, strict=False)
92
+
93
+ generator = LLaMA(model, tokenizer, vision_model)
94
+ print(f"Loaded in {time.time() - start_time:.2f} seconds")
95
+ return generator
96
+
97
+
98
+ def instruct_generate(
99
+ instruct: str,
100
+ input: str = 'none',
101
+ max_gen_len=512,
102
+ temperature: float = 0.1,
103
+ top_p: float = 0.75,
104
+ ):
105
+ if input == 'none':
106
+ prompt = PROMPT_DICT['prompt_no_input'].format_map(
107
+ {'instruction': instruct, 'input': ''})
108
+ else:
109
+ prompt = PROMPT_DICT['prompt_input'].format_map(
110
+ {'instruction': instruct, 'input': input})
111
+
112
+ results = generator.generate(
113
+ [prompt], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p
114
+ )
115
+ result = results[0].strip()
116
+ print(result)
117
+ return result
118
+
119
+
120
+ def caption_generate(
121
+ img: str,
122
+ max_gen_len=512,
123
+ temperature: float = 0.1,
124
+ top_p: float = 0.75,
125
+ ):
126
+ imgs = [Image.open(img).convert('RGB')]
127
+ prompts = ["Generate caption of this image :",] * len(imgs)
128
+
129
+ results = generator.generate(
130
+ prompts, imgs=imgs, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p
131
+ )
132
+ result = results[0].strip()
133
+ print(result)
134
+ return result
135
+
136
+
137
+ def download_llama_7b(ckpt_dir, tokenizer_path):
138
+ print("LLaMA-7B downloading")
139
+ os.makedirs(ckpt_dir, exist_ok=True)
140
+ ckpt_path = os.path.join(ckpt_dir, "consolidated.00.pth")
141
+ param_path = os.path.join(ckpt_dir, "params.json")
142
+ if not os.path.exists(ckpt_path):
143
+ os.system(
144
+ f"wget -O {ckpt_path} https://huggingface.co/nyanko7/LLaMA-7B/resolve/main/consolidated.00.pth")
145
+ if not os.path.exists(param_path):
146
+ os.system(
147
+ f"wget -O {param_path} https://huggingface.co/nyanko7/LLaMA-7B/raw/main/params.json")
148
+ if not os.path.exists(tokenizer_path):
149
+ os.system(
150
+ f"wget -O {tokenizer_path} https://huggingface.co/nyanko7/LLaMA-7B/resolve/main/tokenizer.model")
151
+ print("LLaMA-7B downloaded")
152
+
153
+ def download_llama_adapter(instruct_adapter_path, caption_adapter_path):
154
+ if not os.path.exists(instruct_adapter_path):
155
+ os.system(f"wget -O {instruct_adapter_path} https://github.com/ZrrSkywalker/LLaMA-Adapter/releases/download/v.1.0.0/llama_adapter_len10_layer30_release.pth")
156
+
157
+ if not os.path.exists(caption_adapter_path):
158
+ os.system(f"wget -O {caption_adapter_path} https://github.com/ZrrSkywalker/LLaMA-Adapter/releases/download/v.1.0.0/llama_adapter_len10_layer30_caption_vit_l.pth")
159
+
160
+
161
+ # ckpt_dir = "/data1/llma/7B"
162
+ # tokenizer_path = "/data1/llma/tokenizer.model"
163
+ ckpt_dir = "llama_7B/"
164
+ tokenizer_path = "tokenizer.model"
165
+ instruct_adapter_path = "llama_adapter_len10_layer30_release.pth"
166
+ caption_adapter_path = "llama_adapter_len10_layer30_caption_vit_l.pth"
167
+ max_seq_len = 512
168
+ max_batch_size = 1
169
+
170
+ # download models
171
+ download_llama_7b(ckpt_dir, tokenizer_path)
172
+ download_llama_adapter(instruct_adapter_path, caption_adapter_path)
173
+
174
+ local_rank, world_size = setup_model_parallel()
175
+ if local_rank > 0:
176
+ sys.stdout = open(os.devnull, "w")
177
+
178
+ generator = load(
179
+ ckpt_dir, tokenizer_path, instruct_adapter_path, caption_adapter_path, local_rank, world_size, max_seq_len, max_batch_size
180
+ )
181
+
182
+
183
+ def create_instruct_demo():
184
+ with gr.Blocks() as instruct_demo:
185
+ with gr.Row():
186
+ with gr.Column():
187
+ instruction = gr.Textbox(lines=2, label="Instruction")
188
+ input = gr.Textbox(
189
+ lines=2, label="Context input", placeholder='none')
190
+ max_len = gr.Slider(minimum=1, maximum=512,
191
+ value=128, label="Max length")
192
+ with gr.Accordion(label='Advanced options', open=False):
193
+ temp = gr.Slider(minimum=0, maximum=1,
194
+ value=0.1, label="Temperature")
195
+ top_p = gr.Slider(minimum=0, maximum=1,
196
+ value=0.75, label="Top p")
197
+
198
+ run_botton = gr.Button("Run")
199
+
200
+ with gr.Column():
201
+ outputs = gr.Textbox(lines=10, label="Output")
202
+
203
+ inputs = [instruction, input, max_len, temp, top_p]
204
+
205
+ examples = [
206
+ "Tell me about alpacas.",
207
+ "Write a Python program that prints the first 10 Fibonacci numbers.",
208
+ "Write a conversation between the sun and pluto.",
209
+ "Write a theory to explain why cat never existed",
210
+ ]
211
+ examples = [
212
+ [x, "none", 128, 0.1, 0.75]
213
+ for x in examples]
214
+
215
+ gr.Examples(
216
+ examples=examples,
217
+ inputs=inputs,
218
+ outputs=outputs,
219
+ fn=instruct_generate,
220
+ cache_examples=os.getenv('SYSTEM') == 'spaces'
221
+ )
222
+ run_botton.click(fn=instruct_generate, inputs=inputs, outputs=outputs)
223
+ return instruct_demo
224
+
225
+
226
+ def create_caption_demo():
227
+ with gr.Blocks() as instruct_demo:
228
+ with gr.Row():
229
+ with gr.Column():
230
+ img = gr.Image(label='Input', type='filepath')
231
+ max_len = gr.Slider(minimum=1, maximum=512,
232
+ value=64, label="Max length")
233
+ with gr.Accordion(label='Advanced options', open=False):
234
+ temp = gr.Slider(minimum=0, maximum=1,
235
+ value=0.1, label="Temperature")
236
+ top_p = gr.Slider(minimum=0, maximum=1,
237
+ value=0.75, label="Top p")
238
+
239
+ run_botton = gr.Button("Run")
240
+
241
+ with gr.Column():
242
+ outputs = gr.Textbox(lines=10, label="Output")
243
+
244
+ inputs = [img, max_len, temp, top_p]
245
+
246
+ examples = glob.glob("caption_demo/*.jpg")
247
+ examples = [
248
+ [x, 64, 0.1, 0.75]
249
+ for x in examples]
250
+
251
+ gr.Examples(
252
+ examples=examples,
253
+ inputs=inputs,
254
+ outputs=outputs,
255
+ fn=caption_generate,
256
+ cache_examples=os.getenv('SYSTEM') == 'spaces'
257
+ )
258
+ run_botton.click(fn=caption_generate, inputs=inputs, outputs=outputs)
259
+ return instruct_demo
260
+
261
+ description = """
262
+ # LLaMA-Adapter
263
+ The official demo for **LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention**.
264
+ Please refer to our [arXiv paper](https://arxiv.org/abs/2303.16199) and [github](https://github.com/ZrrSkywalker/LLaMA-Adapter) for more details.
265
+ """
266
+
267
+ with gr.Blocks(css='style.css') as demo:
268
+ gr.Markdown(description)
269
+ with gr.TabItem("Instruction-Following"):
270
+ create_instruct_demo()
271
+ with gr.TabItem("Image Captioning"):
272
+ create_caption_demo()
273
+
274
+ demo.queue(api_open=True, concurrency_count=1).launch()
caption_demo/COCO_val2014_000000111104.jpg ADDED
caption_demo/COCO_val2014_000000111165.jpg ADDED
caption_demo/COCO_val2014_000000111179.jpg ADDED
caption_demo/COCO_val2014_000000111180.jpg ADDED
caption_demo/COCO_val2014_000000111194.jpg ADDED
caption_demo/base_logo.jpg ADDED
llama/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3
+
4
+ from .generation import LLaMA
5
+ from .model import ModelArgs, Transformer, VisionModel
6
+ from .tokenizer import Tokenizer
llama/generation.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3
+
4
+ from typing import List
5
+
6
+ import torch
7
+
8
+ from llama.tokenizer import Tokenizer
9
+ from llama.model import Transformer
10
+
11
+
12
+ class LLaMA:
13
+ def __init__(self, model: Transformer, tokenizer: Tokenizer, vision_model = None):
14
+ self.model = model
15
+ self.tokenizer = tokenizer
16
+ self.vision_model = vision_model
17
+
18
+ def generate(
19
+ self,
20
+ prompts: List[str],
21
+ imgs = None,
22
+ max_gen_len: int = 512,
23
+ temperature: float = 0.8,
24
+ top_p: float = 0.95,
25
+ ) -> List[str]:
26
+ bsz = len(prompts)
27
+ params = self.model.params
28
+ assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
29
+
30
+ mode = 'instruct'
31
+ vision_tokens = None
32
+ if imgs is not None and self.vision_model is not None:
33
+ vision_tokens = self.vision_model(imgs)
34
+ mode = 'caption'
35
+
36
+ prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
37
+
38
+ min_prompt_size = min([len(t) for t in prompt_tokens])
39
+ max_prompt_size = max([len(t) for t in prompt_tokens])
40
+
41
+ total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
42
+
43
+ tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long()
44
+ for k, t in enumerate(prompt_tokens):
45
+ tokens[k, : len(t)] = torch.tensor(t).long()
46
+ input_text_mask = tokens != self.tokenizer.pad_id
47
+ start_pos = min_prompt_size
48
+ prev_pos = 0
49
+ for cur_pos in range(start_pos, total_len):
50
+ logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos, vision_tokens, mode)
51
+ if temperature > 0:
52
+ probs = torch.softmax(logits / temperature, dim=-1)
53
+ next_token = sample_top_p(probs, top_p)
54
+ else:
55
+ next_token = torch.argmax(logits, dim=-1)
56
+ next_token = next_token.reshape(-1)
57
+ # only replace token if prompt has already been generated
58
+ next_token = torch.where(
59
+ input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
60
+ )
61
+ tokens[:, cur_pos] = next_token
62
+ prev_pos = cur_pos
63
+
64
+ decoded = []
65
+ for i, t in enumerate(tokens.tolist()):
66
+ # cut to max gen len
67
+ t = t[len(prompt_tokens[i]) : len(prompt_tokens[i]) + max_gen_len]
68
+ # cut to eos tok if any
69
+ try:
70
+ t = t[: t.index(self.tokenizer.eos_id)]
71
+ except ValueError:
72
+ pass
73
+ decoded.append(self.tokenizer.decode(t))
74
+ return decoded
75
+
76
+
77
+ def sample_top_p(probs, p):
78
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
79
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
80
+ mask = probs_sum - probs_sort > p
81
+ probs_sort[mask] = 0.0
82
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
83
+ next_token = torch.multinomial(probs_sort, num_samples=1)
84
+ next_token = torch.gather(probs_idx, -1, next_token)
85
+ return next_token
llama/model.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3
+
4
+ from typing import Optional, Tuple
5
+ from dataclasses import dataclass
6
+ import math
7
+
8
+ import torch
9
+ from torch import nn
10
+ import torch.nn.functional as F
11
+
12
+ import clip
13
+ from timm.models.vision_transformer import Block
14
+
15
+ import fairscale.nn.model_parallel.initialize as fs_init
16
+ from fairscale.nn.model_parallel.layers import (
17
+ ParallelEmbedding,
18
+ RowParallelLinear,
19
+ ColumnParallelLinear,
20
+ )
21
+
22
+ @dataclass
23
+ class ModelArgs:
24
+ dim: int = 512
25
+ n_layers: int = 8
26
+ n_heads: int = 8
27
+ vocab_size: int = -1 # defined later by tokenizer
28
+ multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
29
+ norm_eps: float = 1e-5
30
+
31
+ max_batch_size: int = 32
32
+ max_seq_len: int = 2048
33
+
34
+ adapter_len: int = 10
35
+ adapter_layer: int = 30
36
+
37
+ cap_adapter_len: int = 10
38
+ cap_adapter_layer: int = 30
39
+ cap_vision_model: str = "ViT-L/14"
40
+ cap_vision_dim: int = 512
41
+ cap_vision_block: int = 2
42
+
43
+
44
+ class RMSNorm(torch.nn.Module):
45
+ def __init__(self, dim: int, eps: float = 1e-6):
46
+ super().__init__()
47
+ self.eps = eps
48
+ self.weight = nn.Parameter(torch.ones(dim))
49
+
50
+ def _norm(self, x):
51
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
52
+
53
+ def forward(self, x):
54
+ output = self._norm(x.float()).type_as(x)
55
+ return output * self.weight
56
+
57
+
58
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
59
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
60
+ t = torch.arange(end, device=freqs.device) # type: ignore
61
+ freqs = torch.outer(t, freqs).float() # type: ignore
62
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
63
+ return freqs_cis
64
+
65
+
66
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
67
+ ndim = x.ndim
68
+ assert 0 <= 1 < ndim
69
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
70
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
71
+ return freqs_cis.view(*shape)
72
+
73
+
74
+ def apply_rotary_emb(
75
+ xq: torch.Tensor,
76
+ xk: torch.Tensor,
77
+ freqs_cis: torch.Tensor,
78
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
79
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
80
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
81
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
82
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
83
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
84
+ return xq_out.type_as(xq), xk_out.type_as(xk)
85
+
86
+
87
+ class Attention(nn.Module):
88
+ def __init__(self, args: ModelArgs):
89
+ super().__init__()
90
+
91
+ self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()
92
+ self.head_dim = args.dim // args.n_heads
93
+
94
+ self.wq = ColumnParallelLinear(
95
+ args.dim,
96
+ args.n_heads * self.head_dim,
97
+ bias=False,
98
+ gather_output=False,
99
+ init_method=lambda x: x,
100
+ )
101
+ self.wk = ColumnParallelLinear(
102
+ args.dim,
103
+ args.n_heads * self.head_dim,
104
+ bias=False,
105
+ gather_output=False,
106
+ init_method=lambda x: x,
107
+ )
108
+ self.wv = ColumnParallelLinear(
109
+ args.dim,
110
+ args.n_heads * self.head_dim,
111
+ bias=False,
112
+ gather_output=False,
113
+ init_method=lambda x: x,
114
+ )
115
+ self.wo = RowParallelLinear(
116
+ args.n_heads * self.head_dim,
117
+ args.dim,
118
+ bias=False,
119
+ input_is_parallel=True,
120
+ init_method=lambda x: x,
121
+ )
122
+
123
+ self.cache_k = torch.zeros(
124
+ (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
125
+ ).cuda()
126
+ self.cache_v = torch.zeros(
127
+ (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
128
+ ).cuda()
129
+ self.gate = torch.nn.Parameter(torch.zeros(1))
130
+
131
+ self.cap_gate = torch.nn.Parameter(torch.zeros(1, self.n_local_heads, 1, 1))
132
+
133
+
134
+ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None, mode='instruct'):
135
+ if mode == 'instruct':
136
+ return self.forward_instruct(x, start_pos, freqs_cis, mask, adapter)
137
+ elif mode == 'caption':
138
+ return self.forward_caption(x, start_pos, freqs_cis, mask, adapter)
139
+
140
+
141
+ def forward_instruct(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None):
142
+ bsz, seqlen, _ = x.shape
143
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
144
+
145
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
146
+ xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
147
+ xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
148
+
149
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
150
+
151
+ self.cache_k = self.cache_k.to(xq)
152
+ self.cache_v = self.cache_v.to(xq)
153
+
154
+ self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
155
+ self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
156
+
157
+ keys = self.cache_k[:bsz, : start_pos + seqlen]
158
+ values = self.cache_v[:bsz, : start_pos + seqlen]
159
+
160
+ if adapter is not None:
161
+ adapter_len = adapter.shape[1]
162
+ adapter_k = self.wk(adapter).view(1, adapter_len, self.n_local_heads, self.head_dim).repeat(bsz, 1, 1, 1)
163
+ adapter_v = self.wv(adapter).view(1, adapter_len, self.n_local_heads, self.head_dim).repeat(bsz, 1, 1, 1)
164
+ adapter_k = adapter_k.transpose(1, 2)
165
+ adapter_v = adapter_v.transpose(1, 2)
166
+ xq = xq.transpose(1, 2)
167
+ keys = keys.transpose(1, 2)
168
+ values = values.transpose(1, 2)
169
+ scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
170
+ if mask is not None:
171
+ scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen)
172
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq)
173
+ output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
174
+ if adapter is not None:
175
+ adapter_scores = torch.matmul(xq, adapter_k.transpose(2, 3)) / math.sqrt(self.head_dim)
176
+ adapter_scores = self.gate * F.softmax(adapter_scores.float(), dim=-1).type_as(xq)
177
+ output = output + torch.matmul(adapter_scores, adapter_v)
178
+ output = output.transpose(
179
+ 1, 2
180
+ ).contiguous().view(bsz, seqlen, -1)
181
+
182
+ return self.wo(output)
183
+
184
+
185
+ def forward_caption(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None):
186
+ bsz, seqlen, _ = x.shape
187
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
188
+
189
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
190
+ xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
191
+ xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
192
+
193
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
194
+
195
+ self.cache_k = self.cache_k.to(xq)
196
+ self.cache_v = self.cache_v.to(xq)
197
+
198
+ self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
199
+ self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
200
+
201
+ keys = self.cache_k[:bsz, : start_pos + seqlen]
202
+ values = self.cache_v[:bsz, : start_pos + seqlen]
203
+
204
+ if adapter is not None:
205
+ adapter_len = adapter.shape[1]
206
+ adapter_k = self.wk(adapter).view(bsz, adapter_len, self.n_local_heads, self.head_dim)
207
+ adapter_v = self.wv(adapter).view(bsz, adapter_len, self.n_local_heads, self.head_dim)
208
+ adapter_k = adapter_k.transpose(1, 2)
209
+ adapter_v = adapter_v.transpose(1, 2)
210
+ xq = xq.transpose(1, 2)
211
+ keys = keys.transpose(1, 2)
212
+ values = values.transpose(1, 2)
213
+ scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
214
+ if mask is not None:
215
+ scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen)
216
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq)
217
+ output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
218
+ if adapter is not None:
219
+ adapter_scores = torch.matmul(xq, adapter_k.transpose(2, 3)) / math.sqrt(self.head_dim)
220
+ adapter_scores = self.cap_gate.tanh() * F.softmax(adapter_scores.float(), dim=-1).type_as(xq)
221
+
222
+ output = output + torch.matmul(adapter_scores, adapter_v)
223
+ output = output.transpose(
224
+ 1, 2
225
+ ).contiguous().view(bsz, seqlen, -1)
226
+
227
+ return self.wo(output)
228
+
229
+
230
+
231
+ class FeedForward(nn.Module):
232
+ def __init__(
233
+ self,
234
+ dim: int,
235
+ hidden_dim: int,
236
+ multiple_of: int,
237
+ ):
238
+ super().__init__()
239
+ hidden_dim = int(2 * hidden_dim / 3)
240
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
241
+
242
+ self.w1 = ColumnParallelLinear(
243
+ dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
244
+ )
245
+ self.w2 = RowParallelLinear(
246
+ hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
247
+ )
248
+ self.w3 = ColumnParallelLinear(
249
+ dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
250
+ )
251
+
252
+ def forward(self, x):
253
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
254
+
255
+
256
+ class TransformerBlock(nn.Module):
257
+ def __init__(self, layer_id: int, args: ModelArgs):
258
+ super().__init__()
259
+ self.n_heads = args.n_heads
260
+ self.dim = args.dim
261
+ self.head_dim = args.dim // args.n_heads
262
+ self.attention = Attention(args)
263
+ self.feed_forward = FeedForward(
264
+ dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of
265
+ )
266
+ self.layer_id = layer_id
267
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
268
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
269
+
270
+ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None, mode='instruct'):
271
+ h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask, adapter, mode=mode)
272
+ out = h + self.feed_forward.forward(self.ffn_norm(h))
273
+ return out
274
+
275
+
276
+ class Transformer(nn.Module):
277
+ def __init__(self, params: ModelArgs):
278
+ super().__init__()
279
+ self.params = params
280
+ self.vocab_size = params.vocab_size
281
+ self.n_layers = params.n_layers
282
+
283
+ self.tok_embeddings = ParallelEmbedding(
284
+ params.vocab_size, params.dim, init_method=lambda x: x
285
+ )
286
+
287
+ self.layers = torch.nn.ModuleList()
288
+ for layer_id in range(params.n_layers):
289
+ self.layers.append(TransformerBlock(layer_id, params))
290
+
291
+ self.norm = RMSNorm(params.dim, eps=params.norm_eps)
292
+ self.output = ColumnParallelLinear(
293
+ params.dim, params.vocab_size, bias=False, init_method=lambda x: x
294
+ )
295
+
296
+ self.freqs_cis = precompute_freqs_cis(
297
+ self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
298
+ )
299
+
300
+ # Note: this is only a preview of multimodal LLaMA-Adapter
301
+ # and requires more efforts to decouple LLaMA-Adapter from LLaMA.
302
+ # instruct model
303
+ self.adapter_query = nn.Embedding(params.adapter_len * params.adapter_layer, params.dim)
304
+ self.adapter_len = params.adapter_len
305
+ self.adapter_layer = params.adapter_layer
306
+
307
+ # caption model
308
+ self.cap_adapter_query = nn.Embedding(params.cap_adapter_len * params.cap_adapter_layer, params.dim)
309
+ self.cap_adapter_len = params.cap_adapter_len
310
+ self.cap_adapter_layer = params.cap_adapter_layer
311
+
312
+ @torch.inference_mode()
313
+ def forward(self, tokens: torch.Tensor, start_pos: int, visual_tokens: torch.Tensor = None, mode: str = 'instruct'):
314
+ if mode == 'instruct':
315
+ return self.forward_instruct(tokens, start_pos, mode)
316
+ elif mode == 'caption':
317
+ return self.forward_caption(tokens, start_pos, visual_tokens, mode)
318
+
319
+ def forward_instruct(self, tokens: torch.Tensor, start_pos: int, mode=None):
320
+ _bsz, seqlen = tokens.shape
321
+ h = self.tok_embeddings(tokens)
322
+ self.freqs_cis = self.freqs_cis.to(h.device)
323
+ freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
324
+ adapter = self.adapter_query.weight.reshape(self.params.adapter_layer, self.params.adapter_len, self.params.dim).unsqueeze(1)
325
+ mask = None
326
+ if seqlen > 1:
327
+ mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
328
+ mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
329
+
330
+ for layer in self.layers[: -1 * self.params.adapter_layer]:
331
+ h = layer(h, start_pos, freqs_cis, mask)
332
+ layer_index = 0
333
+ for layer in self.layers[-1 * self.params.adapter_layer:]:
334
+ h = layer(h, start_pos, freqs_cis, mask, adapter[layer_index], mode=mode)
335
+ layer_index = layer_index + 1
336
+ h = self.norm(h)
337
+ output = self.output(h[:, -1, :]) # only compute last logits
338
+ return output.float()
339
+
340
+ def forward_caption(self, tokens: torch.Tensor, start_pos: int, visual_tokens: torch.Tensor = None, mode=None):
341
+ _bsz, seqlen = tokens.shape
342
+ h = self.tok_embeddings(tokens)
343
+ self.freqs_cis = self.freqs_cis.to(h.device)
344
+ freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
345
+ adapter = self.cap_adapter_query.weight.reshape(self.params.cap_adapter_layer, self.params.cap_adapter_len, self.params.dim).unsqueeze(1)
346
+ mask = None
347
+ if seqlen > 1:
348
+ mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
349
+ mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
350
+
351
+ for layer in self.layers[: -1 * self.params.cap_adapter_layer]:
352
+ h = layer(h, start_pos, freqs_cis, mask)
353
+ layer_index = 0
354
+ for layer in self.layers[-1 * self.params.cap_adapter_layer:]:
355
+ adapter_per_layer = adapter[layer_index]
356
+ if visual_tokens is not None:
357
+ adapter_per_layer = adapter_per_layer + visual_tokens
358
+ h = layer(h, start_pos, freqs_cis, mask, adapter_per_layer, mode=mode)
359
+ layer_index = layer_index + 1
360
+ h = self.norm(h)
361
+ output = self.output(h[:, -1, :]) # only compute last logits
362
+ return output.float()
363
+
364
+
365
+
366
+ class VisionModel(nn.Module):
367
+ def __init__(self, params: ModelArgs):
368
+ super().__init__()
369
+
370
+ self.params = params
371
+
372
+ self.clip, self.clip_transform = clip.load(params.cap_vision_model)
373
+ self.clip.float()
374
+ for param in self.clip.parameters():
375
+ param.requires_grad = False
376
+
377
+ self.clip_proj = nn.Linear(self.clip.visual.output_dim, params.cap_vision_dim)
378
+ self.clip_proj_norm = nn.LayerNorm(params.cap_vision_dim)
379
+
380
+ self.visual_query = nn.Embedding(params.cap_adapter_len, params.cap_vision_dim)
381
+
382
+ self.visual_blocks = nn.ModuleList([
383
+ Block(params.cap_vision_dim, 16, 4, qkv_bias=True, qk_scale=None, norm_layer=nn.LayerNorm)
384
+ for i in range(params.cap_vision_block)])
385
+
386
+ self.visual_proj = nn.Linear(params.cap_vision_dim, params.dim)
387
+ self.visual_proj_norm = nn.LayerNorm(params.dim)
388
+
389
+ def clip_encode_image(self, x):
390
+ x = self.clip.visual.conv1(x) # shape = [*, width, grid, grid]
391
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
392
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
393
+ x = torch.cat([self.clip.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
394
+ x = x + self.clip.visual.positional_embedding.to(x.dtype)
395
+ x = self.clip.visual.ln_pre(x)
396
+
397
+ x = x.permute(1, 0, 2) # NLD -> LND
398
+ x = self.clip.visual.transformer(x)
399
+ x = x.permute(1, 0, 2) # LND -> NLD
400
+
401
+ x = self.clip.visual.ln_post(x[:, :, :])
402
+
403
+ if self.clip.visual.proj is not None:
404
+ x = x @ self.clip.visual.proj
405
+
406
+ return x
407
+
408
+ def forward(self, imgs):
409
+ x = [self.clip_transform(img) for img in imgs]
410
+ x = torch.stack(x, dim=0).to(self.visual_query.weight.device)
411
+ _bsz = x.shape[0]
412
+
413
+ visual_feats = self.clip_encode_image(x).half()
414
+ visual_feats = self.clip_proj_norm(self.clip_proj(visual_feats))
415
+ visual_query = self.visual_query.weight.unsqueeze(0).repeat(_bsz, 1, 1)
416
+ visual_query = torch.cat([visual_query, visual_feats], dim=1)
417
+ for block in self.visual_blocks:
418
+ visual_query = block(visual_query)
419
+ visual_query = visual_query[:, :self.params.cap_adapter_len, :]
420
+ visual_query = self.visual_proj(visual_query)
421
+ visual_query = self.visual_proj_norm(visual_query)
422
+
423
+ return visual_query
llama/tokenizer.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3
+
4
+ from sentencepiece import SentencePieceProcessor
5
+ from logging import getLogger
6
+ from typing import List
7
+ import os
8
+
9
+
10
+ logger = getLogger()
11
+
12
+
13
+ class Tokenizer:
14
+ def __init__(self, model_path: str):
15
+ # reload tokenizer
16
+ assert os.path.isfile(model_path), model_path
17
+ self.sp_model = SentencePieceProcessor(model_file=model_path)
18
+ logger.info(f"Reloaded SentencePiece model from {model_path}")
19
+
20
+ # BOS / EOS token IDs
21
+ self.n_words: int = self.sp_model.vocab_size()
22
+ self.bos_id: int = self.sp_model.bos_id()
23
+ self.eos_id: int = self.sp_model.eos_id()
24
+ self.pad_id: int = self.sp_model.pad_id()
25
+ logger.info(
26
+ f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
27
+ )
28
+ assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
29
+
30
+ def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
31
+ assert type(s) is str
32
+ t = self.sp_model.encode(s)
33
+ if bos:
34
+ t = [self.bos_id] + t
35
+ if eos:
36
+ t = t + [self.eos_id]
37
+ return t
38
+
39
+ def decode(self, t: List[int]) -> str:
40
+ return self.sp_model.decode(t)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ fairscale
3
+ sentencepiece
4
+ Pillow
5
+ timm==0.3.2
6
+ git+https://github.com/openai/CLIP.git
style.css ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ h1,p {
2
+ text-align: center;
3
+ }
4
+