chansung commited on
Commit
7932e5a
1 Parent(s): de9bfc1

Create gen.py

Browse files
Files changed (1) hide show
  1. gen.py +89 -0
gen.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import os
4
+ import time
5
+ import json
6
+ from pathlib import Path
7
+
8
+ import torch
9
+ from fairscale.nn.model_parallel.initialize import initialize_model_parallel
10
+ from app.llama import ModelArgs, Transformer, Tokenizer, LLaMA
11
+
12
+ from google.cloud import storage
13
+
14
+ bucket_name = os.environ.get("GCS_BUCKET")
15
+
16
+ llama_weight_path = "weights/llama"
17
+ tokenizer_weight_path = "weights/tokenizer"
18
+
19
+ def setup_model_parallel() -> Tuple[int, int]:
20
+ local_rank = int(os.environ.get("LOCAL_RANK", -1))
21
+ world_size = int(os.environ.get("WORLD_SIZE", -1))
22
+
23
+ torch.distributed.init_process_group("nccl")
24
+ initialize_model_parallel(world_size)
25
+ torch.cuda.set_device(local_rank)
26
+
27
+ # seed must be the same in all processes
28
+ torch.manual_seed(1)
29
+ return local_rank, world_size
30
+
31
+ def download_pretrained_models(
32
+ ckpt_path: str,
33
+ tokenizer_path: str
34
+ ):
35
+ os.makedirs(llama_weight_path)
36
+ os.makedirs(tokenizer_weight_path)
37
+
38
+ storage_client = storage.Client()
39
+ bucket = storage_client.bucket(bucket_name)
40
+
41
+ blobs = bucket.list_blobs(prefix=f"{ckpt_path}/")
42
+ for blob in blobs:
43
+ filename = blob.name.split("/")[1]
44
+ blob.download_to_filename(f"{llama_weight_path}/{filename}")
45
+
46
+ blobs = bucket.list_blobs(prefix=f"{tokenizer_path}/")
47
+ for blob in blobs:
48
+ filename = blob.name.split("/")[1]
49
+ blob.download_to_filename(f"{tokenizer_weight_path}/{filename}")
50
+
51
+ def get_pretrained_models(
52
+ ckpt_path: str,
53
+ tokenizer_path: str,
54
+ local_rank: int,
55
+ world_size: int) -> LLaMA:
56
+
57
+ download_pretrained_models(ckpt_path, tokenizer_path)
58
+
59
+ start_time = time.time()
60
+ checkpoints = sorted(Path(llama_weight_path).glob("*.pth"))
61
+
62
+ llama_ckpt_path = checkpoints[local_rank]
63
+ print("Loading")
64
+ checkpoint = torch.load(llama_ckpt_path, map_location=lambda storage, loc: storage.cuda(0))
65
+ with open(Path(llama_weight_path) / "params.json", "r") as f:
66
+ params = json.loads(f.read())
67
+
68
+ model_args: ModelArgs = ModelArgs(max_seq_len=1024, max_batch_size=1, **params)
69
+ tokenizer = Tokenizer(model_path=f"{tokenizer_weight_path}/tokenizer.model")
70
+ model_args.vocab_size = tokenizer.n_words
71
+ torch.set_default_tensor_type(torch.cuda.HalfTensor)
72
+ model = Transformer(model_args)
73
+ torch.set_default_tensor_type(torch.FloatTensor)
74
+ model.load_state_dict(checkpoint, strict=False)
75
+
76
+ generator = LLaMA(model, tokenizer)
77
+ print(f"Loaded in {time.time() - start_time:.2f} seconds")
78
+ return generator
79
+
80
+ def get_output(
81
+ generator: LLaMA,
82
+ prompt: str,
83
+ temperature: float = 0.8,
84
+ top_p: float = 0.95):
85
+
86
+ prompts = [prompt]
87
+ results = generator.generate(prompts, max_gen_len=256, temperature=temperature, top_p=top_p)
88
+
89
+ return results