File size: 5,180 Bytes
8969f81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f7c716
8969f81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f7c716
8969f81
 
 
 
 
1f7c716
8969f81
 
 
 
 
 
 
 
 
 
 
1f7c716
8969f81
 
 
 
 
1f7c716
8969f81
 
 
 
 
1f7c716
8969f81
 
 
 
 
1f7c716
8969f81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import time
import torch
from transformers import (GPT2LMHeadModel, GPT2Tokenizer, GPT2Config,
                          OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
                          XLNetLMHeadModel, XLNetTokenizer,
                          TransfoXLLMHeadModel, TransfoXLTokenizer,
                          CTRLLMHeadModel, CTRLTokenizer)

model_metadata = {
    "gpt2/small": {
        "tokenizer": GPT2Tokenizer,
        "model": GPT2LMHeadModel,
        "size": 550,
        "checkpoint": "gpt2",
        "identifier": "gpt2/small"
    }, "gpt": {
        "tokenizer": OpenAIGPTTokenizer,
        "model": OpenAIGPTLMHeadModel,
        "size": 550,
        "checkpoint": "openai-community/openai-gpt",
        "identifier": "gpt"
    }, "xlnet": {
        "tokenizer": XLNetTokenizer,
        "model": XLNetLMHeadModel,
        "size": 550,
        "checkpoint": "xlnet-base-cased",
        "identifier": "xlnet"
    }, "gpt2/arxiv-nlp": {
        "tokenizer": GPT2Tokenizer,
        "model": GPT2LMHeadModel,
        "size": 550,
        "checkpoint": "arxiv-nlp-v1",
        "identifier": "gpt2/arxiv-nlp"
    }, "gpt2/medium": {
        "tokenizer": GPT2Tokenizer,
        "model": GPT2LMHeadModel,
        "size": 1500,
        "checkpoint": "openai-community/gpt2-medium",
        "identifier": "gpt2/medium"
    }, "gpt2/large": {
        "tokenizer": GPT2Tokenizer,
        "model": GPT2LMHeadModel,
        "size": 3300,
        "checkpoint": "openai-community/gpt2-large",
        "identifier": "gpt2/large"
    }, "distilgpt2/small": {
        "tokenizer": GPT2Tokenizer,
        "model": GPT2LMHeadModel,
        "size": 350,
        "checkpoint": "distilgpt2",
        "identifier": "distilgpt2/small"
    }, "ctrl": {
        "tokenizer": CTRLTokenizer,
        "model": CTRLLMHeadModel,
        "size": 6300,
        "checkpoint": "Salesforce/ctrl",
        "identifier": "ctrl"
    }, "pplm": {
        "tokenizer": GPT2Tokenizer,
        "model": GPT2LMHeadModel,
        "size": 3000,
        "checkpoint": "openai-community/gpt2-large",
        "identifier": "pplm"
    }, "gpt2/xl": {
        "tokenizer": GPT2Tokenizer,
        "model": GPT2LMHeadModel,
        "size": 7000,
        "checkpoint": "openai-community/gpt2-xl",
        "identifier": "gpt2/xl"
    }, "pplm": {
        "tokenizer": GPT2Tokenizer,
        "model": GPT2LMHeadModel,
        "size": 4000,
        "checkpoint": "openai-community/gpt2-medium",
        "identifier": "pplm",
        "configuration_options": {
            "config": GPT2Config,
            "options": {
                "output_hidden_states": True
            }
        }
    }
}

memory_overhead = 500

class GPU:
    def __init__(self, id):
        self.id = id
        self.models = []
        self.total_memory = torch.cuda.get_device_properties(
            "cuda:{}".format(id)).total_memory / 1_000_000 - 1_000

        print("INIT GPU WITH DEVICE", "cuda:{}".format(id))

    def register_model(self, model, cached_path=None):
        if self.total_memory_used() + model["size"] < self.total_memory:
            model["device"] = "cuda:{}".format(self.id)

            if cached_path:
                model["cached_path"] = cached_path
    
            self.models.append(model)
            return True
        else:
            return False

    def total_memory_used(self):
        return sum([model["size"] for model in self.models]) + memory_overhead

    def __repr__(self):
        return str(
            [(model["checkpoint"], model["size"]) for model in self.models] +
            [str(round(100 * (self.total_memory_used() / self.total_memory))) + "%"] +
            ["cuda:{}".format(self.id)]
        )


class GPUHandler:
    def __init__(self, ids, model_list, gpu_ids, cached_models=None):
        if cached_models is None:
            cached_models = {}

        self.gpus = [GPU(id) for id in gpu_ids]
        print("GPU handler initiated with {} gpus.".format(len(self.gpus)))

        self.sanity_check([model_metadata[model] for model in model_list])
        
        for model in model_list:
            self.register_model(model_metadata[model], cached_models.get(model))

    def register_model(self, model, cached_path=None):
        for index, gpu in enumerate(self.gpus):
            if gpu.register_model(model, cached_path):
                print("Registered model", model, "in GPU", gpu)
                break

            if index >= len(self.gpus):
                raise ValueError("Could not load model", model["checkpoint"])

    def sanity_check(self, model_list):
        temp_gpus = [GPU(id) for id in range(len(self.gpus))]

        for model in model_list:

            current_gpu_index = 0
            while current_gpu_index < len(temp_gpus):
                if not temp_gpus[current_gpu_index].register_model(model):
                    current_gpu_index += 1
                else:
                    break

                if current_gpu_index >= len(temp_gpus):
                    raise RuntimeError("SANITY CHECK FAILED")

        print("Current layout", temp_gpus)

    def __repr__(self):
        return f"NO. GPUS: {len(self.gpus)}.\n{self.gpus}"