chansung commited on
Commit
0538992
1 Parent(s): 9f9f705

Update gen.py

Browse files
Files changed (1) hide show
  1. gen.py +77 -32
gen.py CHANGED
@@ -1,15 +1,13 @@
1
  from typing import Tuple
2
 
3
  import os
 
4
  import time
5
  import json
6
  from pathlib import Path
7
 
8
- import torch
9
- from fairscale.nn.model_parallel.initialize import initialize_model_parallel
10
- from llama.generation import LLaMA
11
- from llama.model import ModelArgs, Transformer
12
- from llama.tokenizer import Tokenizer
13
 
14
  from google.cloud import storage
15
 
@@ -18,18 +16,6 @@ bucket_name = os.environ.get("GCS_BUCKET")
18
  llama_weight_path = "weights/llama"
19
  tokenizer_weight_path = "weights/tokenizer"
20
 
21
- def setup_model_parallel() -> Tuple[int, int]:
22
- local_rank = int(os.environ.get("LOCAL_RANK", -1))
23
- world_size = int(os.environ.get("WORLD_SIZE", -1))
24
-
25
- torch.distributed.init_process_group("nccl")
26
- initialize_model_parallel(world_size)
27
- torch.cuda.set_device(local_rank)
28
-
29
- # seed must be the same in all processes
30
- torch.manual_seed(1)
31
- return local_rank, world_size
32
-
33
  def download_pretrained_models(
34
  ckpt_path: str,
35
  tokenizer_path: str
@@ -52,33 +38,92 @@ def download_pretrained_models(
52
 
53
  def get_pretrained_models(
54
  ckpt_path: str,
55
- tokenizer_path: str,
56
- local_rank: int,
57
- world_size: int) -> LLaMA:
58
 
59
  download_pretrained_models(ckpt_path, tokenizer_path)
 
 
 
 
 
 
 
60
 
 
 
 
 
 
 
 
 
61
  start_time = time.time()
62
- checkpoints = sorted(Path(llama_weight_path).glob("*.pth"))
 
63
 
64
- llama_ckpt_path = checkpoints[local_rank]
65
- print("Loading")
66
- checkpoint = torch.load(llama_ckpt_path, map_location="cpu")
67
- with open(Path(llama_weight_path) / "params.json", "r") as f:
68
  params = json.loads(f.read())
69
 
70
- model_args: ModelArgs = ModelArgs(max_seq_len=512, max_batch_size=1, **params)
71
- tokenizer = Tokenizer(model_path=f"{tokenizer_weight_path}/tokenizer.model")
 
 
72
  model_args.vocab_size = tokenizer.n_words
73
- torch.set_default_tensor_type(torch.cuda.HalfTensor)
74
- model = Transformer(model_args).cuda().half()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  torch.set_default_tensor_type(torch.FloatTensor)
76
- model.load_state_dict(checkpoint, strict=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  generator = LLaMA(model, tokenizer)
79
- print(f"Loaded in {time.time() - start_time:.2f} seconds")
 
 
80
  return generator
81
 
 
82
  def get_output(
83
  generator: LLaMA,
84
  prompt: str,
@@ -94,4 +139,4 @@ def get_output(
94
  top_p=top_p
95
  )
96
 
97
- return results
 
1
  from typing import Tuple
2
 
3
  import os
4
+ import torch
5
  import time
6
  import json
7
  from pathlib import Path
8
 
9
+ os.environ["BITSANDBYTES_NOWELCOME"] = "1"
10
+ from llama import ModelArgs, Transformer, Tokenizer, LLaMA, default_quantize
 
 
 
11
 
12
  from google.cloud import storage
13
 
 
16
  llama_weight_path = "weights/llama"
17
  tokenizer_weight_path = "weights/tokenizer"
18
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def download_pretrained_models(
20
  ckpt_path: str,
21
  tokenizer_path: str
 
38
 
39
  def get_pretrained_models(
40
  ckpt_path: str,
41
+ tokenizer_path: str) -> LLaMA:
 
 
42
 
43
  download_pretrained_models(ckpt_path, tokenizer_path)
44
+ # max_seq_len: int = 512, max_batch_size: int = 32
45
+ generator = load(
46
+ ckpt_dir=llama_weight_path,
47
+ tokenizer_path=tokenizer_weight_path,
48
+ max_seq_len=512,
49
+ max_batch_size=1
50
+ )
51
 
52
+ return generator
53
+
54
+ def load(
55
+ ckpt_dir: str,
56
+ tokenizer_path: str,
57
+ max_seq_len: int,
58
+ max_batch_size: int,
59
+ ) -> LLaMA:
60
  start_time = time.time()
61
+ checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
62
+ print(checkpoints)
63
 
64
+ with open(Path(ckpt_dir) / "params.json", "r") as f:
 
 
 
65
  params = json.loads(f.read())
66
 
67
+ model_args: ModelArgs = ModelArgs(
68
+ max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
69
+ )
70
+ tokenizer = Tokenizer(model_path=f"{tokenizer_path}/tokenizer.model")
71
  model_args.vocab_size = tokenizer.n_words
72
+
73
+ torch.set_default_tensor_type(torch.HalfTensor)
74
+ print("Allocating transformer on host")
75
+ ctx_tok = default_quantize.set(True)
76
+ model = Transformer(model_args)
77
+ default_quantize.set(ctx_tok)
78
+ key_to_dim = {
79
+ "w1": 0,
80
+ "w2": -1,
81
+ "w3": 0,
82
+ "wo": -1,
83
+ "wq": 0,
84
+ "wk": 0,
85
+ "wv": 0,
86
+ "output": 0,
87
+ "tok_embeddings": -1,
88
+ "ffn_norm": None,
89
+ "attention_norm": None,
90
+ "norm": None,
91
+ "rope": None,
92
+ }
93
+
94
+ # ?
95
  torch.set_default_tensor_type(torch.FloatTensor)
96
+
97
+ # load the state dict incrementally, to avoid memory problems
98
+ for i, ckpt in enumerate(checkpoints):
99
+ print(f"Loading checkpoint {i}")
100
+ checkpoint = torch.load(ckpt, map_location="cpu")
101
+ for parameter_name, parameter in model.named_parameters():
102
+ short_name = parameter_name.split(".")[-2]
103
+ if key_to_dim[short_name] is None and i == 0:
104
+ parameter.data = checkpoint[parameter_name]
105
+ elif key_to_dim[short_name] == 0:
106
+ size = checkpoint[parameter_name].size(0)
107
+ parameter.data[size * i : size * (i + 1), :] = checkpoint[
108
+ parameter_name
109
+ ]
110
+ elif key_to_dim[short_name] == -1:
111
+ size = checkpoint[parameter_name].size(-1)
112
+ parameter.data[:, size * i : size * (i + 1)] = checkpoint[
113
+ parameter_name
114
+ ]
115
+ del checkpoint[parameter_name]
116
+ del checkpoint
117
+
118
+ model.cuda()
119
 
120
  generator = LLaMA(model, tokenizer)
121
+ print(
122
+ f"Loaded in {time.time() - start_time:.2f} seconds with {torch.cuda.max_memory_allocated() / 1024 ** 3:.2f} GiB"
123
+ )
124
  return generator
125
 
126
+
127
  def get_output(
128
  generator: LLaMA,
129
  prompt: str,
 
139
  top_p=top_p
140
  )
141
 
142
+ return results