julien-c's picture
julien-c HF staff
move canonical checkpoints to their new location to keep this Space in "linked Spaces" (#1)
1f7c716 verified
raw
history blame
5.18 kB
import time
import torch
from transformers import (GPT2LMHeadModel, GPT2Tokenizer, GPT2Config,
OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
XLNetLMHeadModel, XLNetTokenizer,
TransfoXLLMHeadModel, TransfoXLTokenizer,
CTRLLMHeadModel, CTRLTokenizer)
model_metadata = {
"gpt2/small": {
"tokenizer": GPT2Tokenizer,
"model": GPT2LMHeadModel,
"size": 550,
"checkpoint": "gpt2",
"identifier": "gpt2/small"
}, "gpt": {
"tokenizer": OpenAIGPTTokenizer,
"model": OpenAIGPTLMHeadModel,
"size": 550,
"checkpoint": "openai-community/openai-gpt",
"identifier": "gpt"
}, "xlnet": {
"tokenizer": XLNetTokenizer,
"model": XLNetLMHeadModel,
"size": 550,
"checkpoint": "xlnet-base-cased",
"identifier": "xlnet"
}, "gpt2/arxiv-nlp": {
"tokenizer": GPT2Tokenizer,
"model": GPT2LMHeadModel,
"size": 550,
"checkpoint": "arxiv-nlp-v1",
"identifier": "gpt2/arxiv-nlp"
}, "gpt2/medium": {
"tokenizer": GPT2Tokenizer,
"model": GPT2LMHeadModel,
"size": 1500,
"checkpoint": "openai-community/gpt2-medium",
"identifier": "gpt2/medium"
}, "gpt2/large": {
"tokenizer": GPT2Tokenizer,
"model": GPT2LMHeadModel,
"size": 3300,
"checkpoint": "openai-community/gpt2-large",
"identifier": "gpt2/large"
}, "distilgpt2/small": {
"tokenizer": GPT2Tokenizer,
"model": GPT2LMHeadModel,
"size": 350,
"checkpoint": "distilgpt2",
"identifier": "distilgpt2/small"
}, "ctrl": {
"tokenizer": CTRLTokenizer,
"model": CTRLLMHeadModel,
"size": 6300,
"checkpoint": "Salesforce/ctrl",
"identifier": "ctrl"
}, "pplm": {
"tokenizer": GPT2Tokenizer,
"model": GPT2LMHeadModel,
"size": 3000,
"checkpoint": "openai-community/gpt2-large",
"identifier": "pplm"
}, "gpt2/xl": {
"tokenizer": GPT2Tokenizer,
"model": GPT2LMHeadModel,
"size": 7000,
"checkpoint": "openai-community/gpt2-xl",
"identifier": "gpt2/xl"
}, "pplm": {
"tokenizer": GPT2Tokenizer,
"model": GPT2LMHeadModel,
"size": 4000,
"checkpoint": "openai-community/gpt2-medium",
"identifier": "pplm",
"configuration_options": {
"config": GPT2Config,
"options": {
"output_hidden_states": True
}
}
}
}
memory_overhead = 500
class GPU:
def __init__(self, id):
self.id = id
self.models = []
self.total_memory = torch.cuda.get_device_properties(
"cuda:{}".format(id)).total_memory / 1_000_000 - 1_000
print("INIT GPU WITH DEVICE", "cuda:{}".format(id))
def register_model(self, model, cached_path=None):
if self.total_memory_used() + model["size"] < self.total_memory:
model["device"] = "cuda:{}".format(self.id)
if cached_path:
model["cached_path"] = cached_path
self.models.append(model)
return True
else:
return False
def total_memory_used(self):
return sum([model["size"] for model in self.models]) + memory_overhead
def __repr__(self):
return str(
[(model["checkpoint"], model["size"]) for model in self.models] +
[str(round(100 * (self.total_memory_used() / self.total_memory))) + "%"] +
["cuda:{}".format(self.id)]
)
class GPUHandler:
def __init__(self, ids, model_list, gpu_ids, cached_models=None):
if cached_models is None:
cached_models = {}
self.gpus = [GPU(id) for id in gpu_ids]
print("GPU handler initiated with {} gpus.".format(len(self.gpus)))
self.sanity_check([model_metadata[model] for model in model_list])
for model in model_list:
self.register_model(model_metadata[model], cached_models.get(model))
def register_model(self, model, cached_path=None):
for index, gpu in enumerate(self.gpus):
if gpu.register_model(model, cached_path):
print("Registered model", model, "in GPU", gpu)
break
if index >= len(self.gpus):
raise ValueError("Could not load model", model["checkpoint"])
def sanity_check(self, model_list):
temp_gpus = [GPU(id) for id in range(len(self.gpus))]
for model in model_list:
current_gpu_index = 0
while current_gpu_index < len(temp_gpus):
if not temp_gpus[current_gpu_index].register_model(model):
current_gpu_index += 1
else:
break
if current_gpu_index >= len(temp_gpus):
raise RuntimeError("SANITY CHECK FAILED")
print("Current layout", temp_gpus)
def __repr__(self):
return f"NO. GPUS: {len(self.gpus)}.\n{self.gpus}"