chansung commited on
Commit
9cb6afd
0 Parent(s):

Duplicate from chansung/LLaMA-7B

Browse files
Files changed (9) hide show
  1. .gitattributes +34 -0
  2. README.md +14 -0
  3. app.py +76 -0
  4. gen.py +97 -0
  5. llama/generation.py +77 -0
  6. llama/model.py +238 -0
  7. llama/tokenizer.py +40 -0
  8. requirements.txt +4 -0
  9. strings.py +7 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: LLaMA 7B
3
+ emoji: 👀
4
+ colorFrom: indigo
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.19.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: chansung/LLaMA-7B
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch
4
+ import gradio as gr
5
+
6
+ from strings import TITLE, ABSTRACT
7
+ from gen import get_pretrained_models, get_output, setup_model_parallel
8
+
9
+ os.environ["RANK"] = "0"
10
+ os.environ["WORLD_SIZE"] = "1"
11
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
12
+ os.environ["MASTER_PORT"] = "50505"
13
+
14
+ local_rank, world_size = setup_model_parallel()
15
+ generator = get_pretrained_models("7B", "tokenizer", local_rank, world_size)
16
+
17
+ history = []
18
+
19
+ def chat(user_input, top_p, temperature, max_gen_len, state_chatbot):
20
+ bot_response = get_output(
21
+ generator=generator,
22
+ prompt=user_input,
23
+ max_gen_len=max_gen_len,
24
+ temperature=temperature,
25
+ top_p=top_p)
26
+
27
+ # remove the first phrase identical to user prompt
28
+ bot_response = bot_response[0][len(user_input):]
29
+ bot_response = bot_response.replace("\n", "<br><br>")
30
+ # trip the last phrase
31
+ try:
32
+ bot_response = bot_response[:bot_response.rfind(".")]
33
+ except:
34
+ pass
35
+
36
+ history.append({
37
+ "role": "user",
38
+ "content": user_input
39
+ })
40
+ history.append({
41
+ "role": "system",
42
+ "content": bot_response
43
+ })
44
+
45
+ state_chatbot = state_chatbot + [(user_input, None)]
46
+
47
+ response = ""
48
+ for word in bot_response.split(" "):
49
+ time.sleep(0.1)
50
+ response += word + " "
51
+ current_pair = (user_input, response)
52
+ state_chatbot[-1] = current_pair
53
+ yield state_chatbot, state_chatbot
54
+
55
+ def reset_textbox():
56
+ return gr.update(value='')
57
+
58
+ with gr.Blocks(css = """#col_container {width: 95%; margin-left: auto; margin-right: auto;}
59
+ #chatbot {height: 400px; overflow: auto;}""") as demo:
60
+
61
+ state_chatbot = gr.State([])
62
+
63
+ with gr.Column(elem_id='col_container'):
64
+ gr.Markdown(f"## {TITLE}\n\n\n\n{ABSTRACT}")
65
+ chatbot = gr.Chatbot(elem_id='chatbot')
66
+ textbox = gr.Textbox(placeholder="Enter a prompt")
67
+
68
+ with gr.Accordion("Parameters", open=False):
69
+ max_gen_len = gr.Slider(minimum=20, maximum=512, value=256, step=1, interactive=True, label="Max Genenration Length",)
70
+ top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05, interactive=True, label="Top-p (nucleus sampling)",)
71
+ temperature = gr.Slider(minimum=-0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Temperature",)
72
+
73
+ textbox.submit(chat, [textbox, top_p, temperature, max_gen_len, state_chatbot], [state_chatbot, chatbot])
74
+ textbox.submit(reset_textbox, [], [textbox])
75
+
76
+ demo.queue(api_open=False).launch()
gen.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import os
4
+ import time
5
+ import json
6
+ from pathlib import Path
7
+
8
+ import torch
9
+ from fairscale.nn.model_parallel.initialize import initialize_model_parallel
10
+ from llama.generation import LLaMA
11
+ from llama.model import ModelArgs, Transformer
12
+ from llama.tokenizer import Tokenizer
13
+
14
+ from google.cloud import storage
15
+
16
+ bucket_name = os.environ.get("GCS_BUCKET")
17
+
18
+ llama_weight_path = "weights/llama"
19
+ tokenizer_weight_path = "weights/tokenizer"
20
+
21
+ def setup_model_parallel() -> Tuple[int, int]:
22
+ local_rank = int(os.environ.get("LOCAL_RANK", -1))
23
+ world_size = int(os.environ.get("WORLD_SIZE", -1))
24
+
25
+ torch.distributed.init_process_group("nccl")
26
+ initialize_model_parallel(world_size)
27
+ torch.cuda.set_device(local_rank)
28
+
29
+ # seed must be the same in all processes
30
+ torch.manual_seed(1)
31
+ return local_rank, world_size
32
+
33
+ def download_pretrained_models(
34
+ ckpt_path: str,
35
+ tokenizer_path: str
36
+ ):
37
+ os.makedirs(llama_weight_path)
38
+ os.makedirs(tokenizer_weight_path)
39
+
40
+ storage_client = storage.Client.create_anonymous_client()
41
+ bucket = storage_client.bucket(bucket_name)
42
+
43
+ blobs = bucket.list_blobs(prefix=f"{ckpt_path}/")
44
+ for blob in blobs:
45
+ filename = blob.name.split("/")[1]
46
+ blob.download_to_filename(f"{llama_weight_path}/{filename}")
47
+
48
+ blobs = bucket.list_blobs(prefix=f"{tokenizer_path}/")
49
+ for blob in blobs:
50
+ filename = blob.name.split("/")[1]
51
+ blob.download_to_filename(f"{tokenizer_weight_path}/{filename}")
52
+
53
+ def get_pretrained_models(
54
+ ckpt_path: str,
55
+ tokenizer_path: str,
56
+ local_rank: int,
57
+ world_size: int) -> LLaMA:
58
+
59
+ download_pretrained_models(ckpt_path, tokenizer_path)
60
+
61
+ start_time = time.time()
62
+ checkpoints = sorted(Path(llama_weight_path).glob("*.pth"))
63
+
64
+ llama_ckpt_path = checkpoints[local_rank]
65
+ print("Loading")
66
+ checkpoint = torch.load(llama_ckpt_path, map_location="cpu")
67
+ with open(Path(llama_weight_path) / "params.json", "r") as f:
68
+ params = json.loads(f.read())
69
+
70
+ model_args: ModelArgs = ModelArgs(max_seq_len=512, max_batch_size=1, **params)
71
+ tokenizer = Tokenizer(model_path=f"{tokenizer_weight_path}/tokenizer.model")
72
+ model_args.vocab_size = tokenizer.n_words
73
+ torch.set_default_tensor_type(torch.cuda.HalfTensor)
74
+ model = Transformer(model_args).cuda().half()
75
+ torch.set_default_tensor_type(torch.FloatTensor)
76
+ model.load_state_dict(checkpoint, strict=False)
77
+
78
+ generator = LLaMA(model, tokenizer)
79
+ print(f"Loaded in {time.time() - start_time:.2f} seconds")
80
+ return generator
81
+
82
+ def get_output(
83
+ generator: LLaMA,
84
+ prompt: str,
85
+ max_gen_len: int = 256,
86
+ temperature: float = 0.8,
87
+ top_p: float = 0.95):
88
+
89
+ prompts = [prompt]
90
+ results = generator.generate(
91
+ prompts,
92
+ max_gen_len=max_gen_len,
93
+ temperature=temperature,
94
+ top_p=top_p
95
+ )
96
+
97
+ return results
llama/generation.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3
+
4
+ from typing import List
5
+
6
+ import torch
7
+
8
+ from llama.tokenizer import Tokenizer
9
+ from llama.model import Transformer
10
+
11
+
12
+ class LLaMA:
13
+ def __init__(self, model: Transformer, tokenizer: Tokenizer):
14
+ self.model = model
15
+ self.tokenizer = tokenizer
16
+
17
+ def generate(
18
+ self,
19
+ prompts: List[str],
20
+ max_gen_len: int,
21
+ temperature: float = 0.8,
22
+ top_p: float = 0.95,
23
+ ) -> List[str]:
24
+ bsz = len(prompts)
25
+ params = self.model.params
26
+ assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
27
+
28
+ prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
29
+
30
+ min_prompt_size = min([len(t) for t in prompt_tokens])
31
+ max_prompt_size = max([len(t) for t in prompt_tokens])
32
+
33
+ total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
34
+
35
+ tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long()
36
+ for k, t in enumerate(prompt_tokens):
37
+ tokens[k, : len(t)] = torch.tensor(t).long()
38
+ input_text_mask = tokens != self.tokenizer.pad_id
39
+ start_pos = min_prompt_size
40
+ prev_pos = 0
41
+ for cur_pos in range(start_pos, total_len):
42
+ logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
43
+ if temperature > 0:
44
+ probs = torch.softmax(logits / temperature, dim=-1)
45
+ next_token = sample_top_p(probs, top_p)
46
+ else:
47
+ next_token = torch.argmax(logits, dim=-1)
48
+ next_token = next_token.reshape(-1)
49
+ # only replace token if prompt has already been generated
50
+ next_token = torch.where(
51
+ input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
52
+ )
53
+ tokens[:, cur_pos] = next_token
54
+ prev_pos = cur_pos
55
+
56
+ decoded = []
57
+ for i, t in enumerate(tokens.tolist()):
58
+ # cut to max gen len
59
+ t = t[: len(prompt_tokens[i]) + max_gen_len]
60
+ # cut to eos tok if any
61
+ try:
62
+ t = t[: t.index(self.tokenizer.eos_id)]
63
+ except ValueError:
64
+ pass
65
+ decoded.append(self.tokenizer.decode(t))
66
+ return decoded
67
+
68
+
69
+ def sample_top_p(probs, p):
70
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
71
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
72
+ mask = probs_sum - probs_sort > p
73
+ probs_sort[mask] = 0.0
74
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
75
+ next_token = torch.multinomial(probs_sort, num_samples=1)
76
+ next_token = torch.gather(probs_idx, -1, next_token)
77
+ return next_token
llama/model.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3
+
4
+ from typing import Optional, Tuple
5
+ from dataclasses import dataclass
6
+ import math
7
+
8
+ import torch
9
+ from torch import nn
10
+ import torch.nn.functional as F
11
+
12
+ import fairscale.nn.model_parallel.initialize as fs_init
13
+ from fairscale.nn.model_parallel.layers import (
14
+ ParallelEmbedding,
15
+ RowParallelLinear,
16
+ ColumnParallelLinear,
17
+ )
18
+
19
+
20
+ @dataclass
21
+ class ModelArgs:
22
+ dim: int = 512
23
+ n_layers: int = 8
24
+ n_heads: int = 8
25
+ vocab_size: int = -1 # defined later by tokenizer
26
+ multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
27
+ norm_eps: float = 1e-5
28
+
29
+ max_batch_size: int = 32
30
+ max_seq_len: int = 1024
31
+
32
+
33
+ class RMSNorm(torch.nn.Module):
34
+ def __init__(self, dim: int, eps: float = 1e-6):
35
+ super().__init__()
36
+ self.eps = eps
37
+ self.weight = nn.Parameter(torch.ones(dim))
38
+
39
+ def _norm(self, x):
40
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
41
+
42
+ def forward(self, x):
43
+ output = self._norm(x.float()).type_as(x)
44
+ return output * self.weight
45
+
46
+
47
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
48
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
49
+ t = torch.arange(end, device=freqs.device) # type: ignore
50
+ freqs = torch.outer(t, freqs).float() # type: ignore
51
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
52
+ return freqs_cis
53
+
54
+
55
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
56
+ ndim = x.ndim
57
+ assert 0 <= 1 < ndim
58
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
59
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
60
+ return freqs_cis.view(*shape)
61
+
62
+
63
+ def apply_rotary_emb(
64
+ xq: torch.Tensor,
65
+ xk: torch.Tensor,
66
+ freqs_cis: torch.Tensor,
67
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
68
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
69
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
70
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
71
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
72
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
73
+ return xq_out.type_as(xq), xk_out.type_as(xk)
74
+
75
+
76
+ class Attention(nn.Module):
77
+ def __init__(self, args: ModelArgs):
78
+ super().__init__()
79
+
80
+ self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()
81
+ self.head_dim = args.dim // args.n_heads
82
+
83
+ self.wq = ColumnParallelLinear(
84
+ args.dim,
85
+ args.n_heads * self.head_dim,
86
+ bias=False,
87
+ gather_output=False,
88
+ init_method=lambda x: x,
89
+ )
90
+ self.wk = ColumnParallelLinear(
91
+ args.dim,
92
+ args.n_heads * self.head_dim,
93
+ bias=False,
94
+ gather_output=False,
95
+ init_method=lambda x: x,
96
+ )
97
+ self.wv = ColumnParallelLinear(
98
+ args.dim,
99
+ args.n_heads * self.head_dim,
100
+ bias=False,
101
+ gather_output=False,
102
+ init_method=lambda x: x,
103
+ )
104
+ self.wo = RowParallelLinear(
105
+ args.n_heads * self.head_dim,
106
+ args.dim,
107
+ bias=False,
108
+ input_is_parallel=True,
109
+ init_method=lambda x: x,
110
+ )
111
+
112
+ self.cache_k = torch.zeros(
113
+ (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
114
+ ).cuda()
115
+ self.cache_v = torch.zeros(
116
+ (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
117
+ ).cuda()
118
+
119
+ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
120
+ bsz, seqlen, _ = x.shape
121
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
122
+
123
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
124
+ xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
125
+ xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
126
+
127
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
128
+
129
+ self.cache_k = self.cache_k.to(xq)
130
+ self.cache_v = self.cache_v.to(xq)
131
+
132
+ self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
133
+ self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
134
+
135
+ keys = self.cache_k[:bsz, : start_pos + seqlen]
136
+ values = self.cache_v[:bsz, : start_pos + seqlen]
137
+
138
+ xq = xq.transpose(1, 2)
139
+ keys = keys.transpose(1, 2)
140
+ values = values.transpose(1, 2)
141
+ scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
142
+ if mask is not None:
143
+ scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen)
144
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq)
145
+ output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
146
+ output = output.transpose(
147
+ 1, 2
148
+ ).contiguous().view(bsz, seqlen, -1)
149
+
150
+ return self.wo(output)
151
+
152
+
153
+ class FeedForward(nn.Module):
154
+ def __init__(
155
+ self,
156
+ dim: int,
157
+ hidden_dim: int,
158
+ multiple_of: int,
159
+ ):
160
+ super().__init__()
161
+ hidden_dim = int(2 * hidden_dim / 3)
162
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
163
+
164
+ self.w1 = ColumnParallelLinear(
165
+ dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
166
+ )
167
+ self.w2 = RowParallelLinear(
168
+ hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
169
+ )
170
+ self.w3 = ColumnParallelLinear(
171
+ dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
172
+ )
173
+
174
+ def forward(self, x):
175
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
176
+
177
+
178
+ class TransformerBlock(nn.Module):
179
+ def __init__(self, layer_id: int, args: ModelArgs):
180
+ super().__init__()
181
+ self.n_heads = args.n_heads
182
+ self.dim = args.dim
183
+ self.head_dim = args.dim // args.n_heads
184
+ self.attention = Attention(args)
185
+ self.feed_forward = FeedForward(
186
+ dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of
187
+ )
188
+ self.layer_id = layer_id
189
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
190
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
191
+
192
+ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
193
+ h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask)
194
+ out = h + self.feed_forward.forward(self.ffn_norm(h))
195
+ return out
196
+
197
+
198
+ class Transformer(nn.Module):
199
+ def __init__(self, params: ModelArgs):
200
+ super().__init__()
201
+ self.params = params
202
+ self.vocab_size = params.vocab_size
203
+ self.n_layers = params.n_layers
204
+
205
+ self.tok_embeddings = ParallelEmbedding(
206
+ params.vocab_size, params.dim, init_method=lambda x: x
207
+ )
208
+
209
+ self.layers = torch.nn.ModuleList()
210
+ for layer_id in range(params.n_layers):
211
+ self.layers.append(TransformerBlock(layer_id, params))
212
+
213
+ self.norm = RMSNorm(params.dim, eps=params.norm_eps)
214
+ self.output = ColumnParallelLinear(
215
+ params.dim, params.vocab_size, bias=False, init_method=lambda x: x
216
+ )
217
+
218
+ self.freqs_cis = precompute_freqs_cis(
219
+ self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
220
+ )
221
+
222
+ @torch.inference_mode()
223
+ def forward(self, tokens: torch.Tensor, start_pos: int):
224
+ _bsz, seqlen = tokens.shape
225
+ h = self.tok_embeddings(tokens)
226
+ self.freqs_cis = self.freqs_cis.to(h.device)
227
+ freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
228
+
229
+ mask = None
230
+ if seqlen > 1:
231
+ mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
232
+ mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
233
+
234
+ for layer in self.layers:
235
+ h = layer(h, start_pos, freqs_cis, mask)
236
+ h = self.norm(h)
237
+ output = self.output(h[:, -1, :]) # only compute last logits
238
+ return output.float()
llama/tokenizer.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3
+
4
+ from sentencepiece import SentencePieceProcessor
5
+ from logging import getLogger
6
+ from typing import List
7
+ import os
8
+
9
+
10
+ logger = getLogger()
11
+
12
+
13
+ class Tokenizer:
14
+ def __init__(self, model_path: str):
15
+ # reload tokenizer
16
+ assert os.path.isfile(model_path), model_path
17
+ self.sp_model = SentencePieceProcessor(model_file=model_path)
18
+ logger.info(f"Reloaded SentencePiece model from {model_path}")
19
+
20
+ # BOS / EOS token IDs
21
+ self.n_words: int = self.sp_model.vocab_size()
22
+ self.bos_id: int = self.sp_model.bos_id()
23
+ self.eos_id: int = self.sp_model.eos_id()
24
+ self.pad_id: int = self.sp_model.pad_id()
25
+ logger.info(
26
+ f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
27
+ )
28
+ assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
29
+
30
+ def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
31
+ assert type(s) is str
32
+ t = self.sp_model.encode(s)
33
+ if bos:
34
+ t = [self.bos_id] + t
35
+ if eos:
36
+ t = t + [self.eos_id]
37
+ return t
38
+
39
+ def decode(self, t: List[int]) -> str:
40
+ return self.sp_model.decode(t)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ fairscale
3
+ sentencepiece
4
+ google-cloud-storage
strings.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ TITLE = "LLaMA 7B Model Playground"
2
+
3
+ ABSTRACT = """
4
+ This Space allows you to play with the one of the variant(7B) as part of the [LLaMA](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/)(Large Language Model Meta AI) released by Meta AI.
5
+
6
+ LLaMA is a general purpose language model, so it behaves differently comparing to [ChatGPT](https://openai.com/blog/chatgpt/). Even though the UI or this Space application is in Chat-like form, the generated output will be the completion of the given prompt. Because of this, your prompts should appropriately guide what to be generated.
7
+ """