File size: 2,838 Bytes
7932e5a
 
 
 
 
 
 
 
 
5a31f5d
 
 
7932e5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e66e083
7932e5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from typing import Tuple

import os
import time
import json
from pathlib import Path

import torch
from fairscale.nn.model_parallel.initialize import initialize_model_parallel
from llama.generation import LLaMA
from llama.model import ModelArgs, Transformer
from llama.tokenizer import Tokenizer

from google.cloud import storage

bucket_name = os.environ.get("GCS_BUCKET")

llama_weight_path = "weights/llama"
tokenizer_weight_path = "weights/tokenizer"

def setup_model_parallel() -> Tuple[int, int]:
    local_rank = int(os.environ.get("LOCAL_RANK", -1))
    world_size = int(os.environ.get("WORLD_SIZE", -1))

    torch.distributed.init_process_group("nccl")
    initialize_model_parallel(world_size)
    torch.cuda.set_device(local_rank)

    # seed must be the same in all processes
    torch.manual_seed(1)
    return local_rank, world_size

def download_pretrained_models(
    ckpt_path: str,
    tokenizer_path: str
):
    os.makedirs(llama_weight_path)
    os.makedirs(tokenizer_weight_path)

    storage_client = storage.Client()
    bucket = storage_client.bucket(bucket_name)

    blobs = bucket.list_blobs(prefix=f"{ckpt_path}/")
    for blob in blobs:
        filename = blob.name.split("/")[1]
        blob.download_to_filename(f"{llama_weight_path}/{filename}")

    blobs = bucket.list_blobs(prefix=f"{tokenizer_path}/")
    for blob in blobs:
        filename = blob.name.split("/")[1]
        blob.download_to_filename(f"{tokenizer_weight_path}/{filename}")    

def get_pretrained_models(
    ckpt_path: str, 
    tokenizer_path: str, 
    local_rank: int, 
    world_size: int) -> LLaMA:

    download_pretrained_models(ckpt_path, tokenizer_path)

    start_time = time.time()
    checkpoints = sorted(Path(llama_weight_path).glob("*.pth"))

    llama_ckpt_path = checkpoints[local_rank]
    print("Loading")
    checkpoint = torch.load(llama_ckpt_path, map_location=lambda storage, loc: storage.cuda(0))
    with open(Path(llama_weight_path) / "params.json", "r") as f:
        params = json.loads(f.read())

    model_args: ModelArgs = ModelArgs(max_seq_len=1024, max_batch_size=1, **params)
    tokenizer = Tokenizer(model_path=f"{tokenizer_weight_path}/tokenizer.model")
    model_args.vocab_size = tokenizer.n_words
    torch.set_default_tensor_type(torch.cuda.HalfTensor)
    model = Transformer(model_args).cuda().half()
    torch.set_default_tensor_type(torch.FloatTensor)
    model.load_state_dict(checkpoint, strict=False)

    generator = LLaMA(model, tokenizer)
    print(f"Loaded in {time.time() - start_time:.2f} seconds")
    return generator

def get_output(
    generator: LLaMA,
    prompt: str, 
    temperature: float = 0.8, 
    top_p: float = 0.95):
    
    prompts = [prompt]
    results = generator.generate(prompts, max_gen_len=256, temperature=temperature, top_p=top_p)

    return results