Update gen.py
Browse files
gen.py
CHANGED
@@ -1,15 +1,13 @@
|
|
1 |
from typing import Tuple
|
2 |
|
3 |
import os
|
|
|
4 |
import time
|
5 |
import json
|
6 |
from pathlib import Path
|
7 |
|
8 |
-
|
9 |
-
from
|
10 |
-
from llama.generation import LLaMA
|
11 |
-
from llama.model import ModelArgs, Transformer
|
12 |
-
from llama.tokenizer import Tokenizer
|
13 |
|
14 |
from google.cloud import storage
|
15 |
|
@@ -18,18 +16,6 @@ bucket_name = os.environ.get("GCS_BUCKET")
|
|
18 |
llama_weight_path = "weights/llama"
|
19 |
tokenizer_weight_path = "weights/tokenizer"
|
20 |
|
21 |
-
def setup_model_parallel() -> Tuple[int, int]:
|
22 |
-
local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
23 |
-
world_size = int(os.environ.get("WORLD_SIZE", -1))
|
24 |
-
|
25 |
-
torch.distributed.init_process_group("nccl")
|
26 |
-
initialize_model_parallel(world_size)
|
27 |
-
torch.cuda.set_device(local_rank)
|
28 |
-
|
29 |
-
# seed must be the same in all processes
|
30 |
-
torch.manual_seed(1)
|
31 |
-
return local_rank, world_size
|
32 |
-
|
33 |
def download_pretrained_models(
|
34 |
ckpt_path: str,
|
35 |
tokenizer_path: str
|
@@ -52,33 +38,92 @@ def download_pretrained_models(
|
|
52 |
|
53 |
def get_pretrained_models(
|
54 |
ckpt_path: str,
|
55 |
-
tokenizer_path: str
|
56 |
-
local_rank: int,
|
57 |
-
world_size: int) -> LLaMA:
|
58 |
|
59 |
download_pretrained_models(ckpt_path, tokenizer_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
start_time = time.time()
|
62 |
-
checkpoints = sorted(Path(
|
|
|
63 |
|
64 |
-
|
65 |
-
print("Loading")
|
66 |
-
checkpoint = torch.load(llama_ckpt_path, map_location="cpu")
|
67 |
-
with open(Path(llama_weight_path) / "params.json", "r") as f:
|
68 |
params = json.loads(f.read())
|
69 |
|
70 |
-
model_args: ModelArgs = ModelArgs(
|
71 |
-
|
|
|
|
|
72 |
model_args.vocab_size = tokenizer.n_words
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
torch.set_default_tensor_type(torch.FloatTensor)
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
generator = LLaMA(model, tokenizer)
|
79 |
-
print(
|
|
|
|
|
80 |
return generator
|
81 |
|
|
|
82 |
def get_output(
|
83 |
generator: LLaMA,
|
84 |
prompt: str,
|
@@ -94,4 +139,4 @@ def get_output(
|
|
94 |
top_p=top_p
|
95 |
)
|
96 |
|
97 |
-
return results
|
|
|
1 |
from typing import Tuple
|
2 |
|
3 |
import os
|
4 |
+
import torch
|
5 |
import time
|
6 |
import json
|
7 |
from pathlib import Path
|
8 |
|
9 |
+
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
|
10 |
+
from llama import ModelArgs, Transformer, Tokenizer, LLaMA, default_quantize
|
|
|
|
|
|
|
11 |
|
12 |
from google.cloud import storage
|
13 |
|
|
|
16 |
llama_weight_path = "weights/llama"
|
17 |
tokenizer_weight_path = "weights/tokenizer"
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
def download_pretrained_models(
|
20 |
ckpt_path: str,
|
21 |
tokenizer_path: str
|
|
|
38 |
|
39 |
def get_pretrained_models(
|
40 |
ckpt_path: str,
|
41 |
+
tokenizer_path: str) -> LLaMA:
|
|
|
|
|
42 |
|
43 |
download_pretrained_models(ckpt_path, tokenizer_path)
|
44 |
+
# max_seq_len: int = 512, max_batch_size: int = 32
|
45 |
+
generator = load(
|
46 |
+
ckpt_dir=llama_weight_path,
|
47 |
+
tokenizer_path=tokenizer_weight_path,
|
48 |
+
max_seq_len=512,
|
49 |
+
max_batch_size=1
|
50 |
+
)
|
51 |
|
52 |
+
return generator
|
53 |
+
|
54 |
+
def load(
|
55 |
+
ckpt_dir: str,
|
56 |
+
tokenizer_path: str,
|
57 |
+
max_seq_len: int,
|
58 |
+
max_batch_size: int,
|
59 |
+
) -> LLaMA:
|
60 |
start_time = time.time()
|
61 |
+
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
62 |
+
print(checkpoints)
|
63 |
|
64 |
+
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
|
|
|
|
|
|
65 |
params = json.loads(f.read())
|
66 |
|
67 |
+
model_args: ModelArgs = ModelArgs(
|
68 |
+
max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
|
69 |
+
)
|
70 |
+
tokenizer = Tokenizer(model_path=f"{tokenizer_path}/tokenizer.model")
|
71 |
model_args.vocab_size = tokenizer.n_words
|
72 |
+
|
73 |
+
torch.set_default_tensor_type(torch.HalfTensor)
|
74 |
+
print("Allocating transformer on host")
|
75 |
+
ctx_tok = default_quantize.set(True)
|
76 |
+
model = Transformer(model_args)
|
77 |
+
default_quantize.set(ctx_tok)
|
78 |
+
key_to_dim = {
|
79 |
+
"w1": 0,
|
80 |
+
"w2": -1,
|
81 |
+
"w3": 0,
|
82 |
+
"wo": -1,
|
83 |
+
"wq": 0,
|
84 |
+
"wk": 0,
|
85 |
+
"wv": 0,
|
86 |
+
"output": 0,
|
87 |
+
"tok_embeddings": -1,
|
88 |
+
"ffn_norm": None,
|
89 |
+
"attention_norm": None,
|
90 |
+
"norm": None,
|
91 |
+
"rope": None,
|
92 |
+
}
|
93 |
+
|
94 |
+
# ?
|
95 |
torch.set_default_tensor_type(torch.FloatTensor)
|
96 |
+
|
97 |
+
# load the state dict incrementally, to avoid memory problems
|
98 |
+
for i, ckpt in enumerate(checkpoints):
|
99 |
+
print(f"Loading checkpoint {i}")
|
100 |
+
checkpoint = torch.load(ckpt, map_location="cpu")
|
101 |
+
for parameter_name, parameter in model.named_parameters():
|
102 |
+
short_name = parameter_name.split(".")[-2]
|
103 |
+
if key_to_dim[short_name] is None and i == 0:
|
104 |
+
parameter.data = checkpoint[parameter_name]
|
105 |
+
elif key_to_dim[short_name] == 0:
|
106 |
+
size = checkpoint[parameter_name].size(0)
|
107 |
+
parameter.data[size * i : size * (i + 1), :] = checkpoint[
|
108 |
+
parameter_name
|
109 |
+
]
|
110 |
+
elif key_to_dim[short_name] == -1:
|
111 |
+
size = checkpoint[parameter_name].size(-1)
|
112 |
+
parameter.data[:, size * i : size * (i + 1)] = checkpoint[
|
113 |
+
parameter_name
|
114 |
+
]
|
115 |
+
del checkpoint[parameter_name]
|
116 |
+
del checkpoint
|
117 |
+
|
118 |
+
model.cuda()
|
119 |
|
120 |
generator = LLaMA(model, tokenizer)
|
121 |
+
print(
|
122 |
+
f"Loaded in {time.time() - start_time:.2f} seconds with {torch.cuda.max_memory_allocated() / 1024 ** 3:.2f} GiB"
|
123 |
+
)
|
124 |
return generator
|
125 |
|
126 |
+
|
127 |
def get_output(
|
128 |
generator: LLaMA,
|
129 |
prompt: str,
|
|
|
139 |
top_p=top_p
|
140 |
)
|
141 |
|
142 |
+
return results
|