XciD's picture
XciD HF staff
initial commit
8969f81
raw
history blame
No virus
5.07 kB
import time
import torch
from transformers import (GPT2LMHeadModel, GPT2Tokenizer, GPT2Config,
OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
XLNetLMHeadModel, XLNetTokenizer,
TransfoXLLMHeadModel, TransfoXLTokenizer,
CTRLLMHeadModel, CTRLTokenizer)
model_metadata = {
"gpt2/small": {
"tokenizer": GPT2Tokenizer,
"model": GPT2LMHeadModel,
"size": 550,
"checkpoint": "gpt2",
"identifier": "gpt2/small"
}, "gpt": {
"tokenizer": OpenAIGPTTokenizer,
"model": OpenAIGPTLMHeadModel,
"size": 550,
"checkpoint": "openai-gpt",
"identifier": "gpt"
}, "xlnet": {
"tokenizer": XLNetTokenizer,
"model": XLNetLMHeadModel,
"size": 550,
"checkpoint": "xlnet-base-cased",
"identifier": "xlnet"
}, "gpt2/arxiv-nlp": {
"tokenizer": GPT2Tokenizer,
"model": GPT2LMHeadModel,
"size": 550,
"checkpoint": "arxiv-nlp-v1",
"identifier": "gpt2/arxiv-nlp"
}, "gpt2/medium": {
"tokenizer": GPT2Tokenizer,
"model": GPT2LMHeadModel,
"size": 1500,
"checkpoint": "gpt2-medium",
"identifier": "gpt2/medium"
}, "gpt2/large": {
"tokenizer": GPT2Tokenizer,
"model": GPT2LMHeadModel,
"size": 3300,
"checkpoint": "gpt2-large",
"identifier": "gpt2/large"
}, "distilgpt2/small": {
"tokenizer": GPT2Tokenizer,
"model": GPT2LMHeadModel,
"size": 350,
"checkpoint": "distilgpt2",
"identifier": "distilgpt2/small"
}, "ctrl": {
"tokenizer": CTRLTokenizer,
"model": CTRLLMHeadModel,
"size": 6300,
"checkpoint": "ctrl",
"identifier": "ctrl"
}, "pplm": {
"tokenizer": GPT2Tokenizer,
"model": GPT2LMHeadModel,
"size": 3000,
"checkpoint": "gpt2-large",
"identifier": "pplm"
}, "gpt2/xl": {
"tokenizer": GPT2Tokenizer,
"model": GPT2LMHeadModel,
"size": 7000,
"checkpoint": "gpt2-xl",
"identifier": "gpt2/xl"
}, "pplm": {
"tokenizer": GPT2Tokenizer,
"model": GPT2LMHeadModel,
"size": 4000,
"checkpoint": "gpt2-medium",
"identifier": "pplm",
"configuration_options": {
"config": GPT2Config,
"options": {
"output_hidden_states": True
}
}
}
}
memory_overhead = 500
class GPU:
def __init__(self, id):
self.id = id
self.models = []
self.total_memory = torch.cuda.get_device_properties(
"cuda:{}".format(id)).total_memory / 1_000_000 - 1_000
print("INIT GPU WITH DEVICE", "cuda:{}".format(id))
def register_model(self, model, cached_path=None):
if self.total_memory_used() + model["size"] < self.total_memory:
model["device"] = "cuda:{}".format(self.id)
if cached_path:
model["cached_path"] = cached_path
self.models.append(model)
return True
else:
return False
def total_memory_used(self):
return sum([model["size"] for model in self.models]) + memory_overhead
def __repr__(self):
return str(
[(model["checkpoint"], model["size"]) for model in self.models] +
[str(round(100 * (self.total_memory_used() / self.total_memory))) + "%"] +
["cuda:{}".format(self.id)]
)
class GPUHandler:
def __init__(self, ids, model_list, gpu_ids, cached_models=None):
if cached_models is None:
cached_models = {}
self.gpus = [GPU(id) for id in gpu_ids]
print("GPU handler initiated with {} gpus.".format(len(self.gpus)))
self.sanity_check([model_metadata[model] for model in model_list])
for model in model_list:
self.register_model(model_metadata[model], cached_models.get(model))
def register_model(self, model, cached_path=None):
for index, gpu in enumerate(self.gpus):
if gpu.register_model(model, cached_path):
print("Registered model", model, "in GPU", gpu)
break
if index >= len(self.gpus):
raise ValueError("Could not load model", model["checkpoint"])
def sanity_check(self, model_list):
temp_gpus = [GPU(id) for id in range(len(self.gpus))]
for model in model_list:
current_gpu_index = 0
while current_gpu_index < len(temp_gpus):
if not temp_gpus[current_gpu_index].register_model(model):
current_gpu_index += 1
else:
break
if current_gpu_index >= len(temp_gpus):
raise RuntimeError("SANITY CHECK FAILED")
print("Current layout", temp_gpus)
def __repr__(self):
return f"NO. GPUS: {len(self.gpus)}.\n{self.gpus}"