Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import time | |
import torch | |
from transformers import (GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, | |
OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, | |
XLNetLMHeadModel, XLNetTokenizer, | |
TransfoXLLMHeadModel, TransfoXLTokenizer, | |
CTRLLMHeadModel, CTRLTokenizer) | |
model_metadata = { | |
"gpt2/small": { | |
"tokenizer": GPT2Tokenizer, | |
"model": GPT2LMHeadModel, | |
"size": 550, | |
"checkpoint": "gpt2", | |
"identifier": "gpt2/small" | |
}, "gpt": { | |
"tokenizer": OpenAIGPTTokenizer, | |
"model": OpenAIGPTLMHeadModel, | |
"size": 550, | |
"checkpoint": "openai-community/openai-gpt", | |
"identifier": "gpt" | |
}, "xlnet": { | |
"tokenizer": XLNetTokenizer, | |
"model": XLNetLMHeadModel, | |
"size": 550, | |
"checkpoint": "xlnet-base-cased", | |
"identifier": "xlnet" | |
}, "gpt2/arxiv-nlp": { | |
"tokenizer": GPT2Tokenizer, | |
"model": GPT2LMHeadModel, | |
"size": 550, | |
"checkpoint": "arxiv-nlp-v1", | |
"identifier": "gpt2/arxiv-nlp" | |
}, "gpt2/medium": { | |
"tokenizer": GPT2Tokenizer, | |
"model": GPT2LMHeadModel, | |
"size": 1500, | |
"checkpoint": "openai-community/gpt2-medium", | |
"identifier": "gpt2/medium" | |
}, "gpt2/large": { | |
"tokenizer": GPT2Tokenizer, | |
"model": GPT2LMHeadModel, | |
"size": 3300, | |
"checkpoint": "openai-community/gpt2-large", | |
"identifier": "gpt2/large" | |
}, "distilgpt2/small": { | |
"tokenizer": GPT2Tokenizer, | |
"model": GPT2LMHeadModel, | |
"size": 350, | |
"checkpoint": "distilgpt2", | |
"identifier": "distilgpt2/small" | |
}, "ctrl": { | |
"tokenizer": CTRLTokenizer, | |
"model": CTRLLMHeadModel, | |
"size": 6300, | |
"checkpoint": "Salesforce/ctrl", | |
"identifier": "ctrl" | |
}, "pplm": { | |
"tokenizer": GPT2Tokenizer, | |
"model": GPT2LMHeadModel, | |
"size": 3000, | |
"checkpoint": "openai-community/gpt2-large", | |
"identifier": "pplm" | |
}, "gpt2/xl": { | |
"tokenizer": GPT2Tokenizer, | |
"model": GPT2LMHeadModel, | |
"size": 7000, | |
"checkpoint": "openai-community/gpt2-xl", | |
"identifier": "gpt2/xl" | |
}, "pplm": { | |
"tokenizer": GPT2Tokenizer, | |
"model": GPT2LMHeadModel, | |
"size": 4000, | |
"checkpoint": "openai-community/gpt2-medium", | |
"identifier": "pplm", | |
"configuration_options": { | |
"config": GPT2Config, | |
"options": { | |
"output_hidden_states": True | |
} | |
} | |
} | |
} | |
memory_overhead = 500 | |
class GPU: | |
def __init__(self, id): | |
self.id = id | |
self.models = [] | |
self.total_memory = torch.cuda.get_device_properties( | |
"cuda:{}".format(id)).total_memory / 1_000_000 - 1_000 | |
print("INIT GPU WITH DEVICE", "cuda:{}".format(id)) | |
def register_model(self, model, cached_path=None): | |
if self.total_memory_used() + model["size"] < self.total_memory: | |
model["device"] = "cuda:{}".format(self.id) | |
if cached_path: | |
model["cached_path"] = cached_path | |
self.models.append(model) | |
return True | |
else: | |
return False | |
def total_memory_used(self): | |
return sum([model["size"] for model in self.models]) + memory_overhead | |
def __repr__(self): | |
return str( | |
[(model["checkpoint"], model["size"]) for model in self.models] + | |
[str(round(100 * (self.total_memory_used() / self.total_memory))) + "%"] + | |
["cuda:{}".format(self.id)] | |
) | |
class GPUHandler: | |
def __init__(self, ids, model_list, gpu_ids, cached_models=None): | |
if cached_models is None: | |
cached_models = {} | |
self.gpus = [GPU(id) for id in gpu_ids] | |
print("GPU handler initiated with {} gpus.".format(len(self.gpus))) | |
self.sanity_check([model_metadata[model] for model in model_list]) | |
for model in model_list: | |
self.register_model(model_metadata[model], cached_models.get(model)) | |
def register_model(self, model, cached_path=None): | |
for index, gpu in enumerate(self.gpus): | |
if gpu.register_model(model, cached_path): | |
print("Registered model", model, "in GPU", gpu) | |
break | |
if index >= len(self.gpus): | |
raise ValueError("Could not load model", model["checkpoint"]) | |
def sanity_check(self, model_list): | |
temp_gpus = [GPU(id) for id in range(len(self.gpus))] | |
for model in model_list: | |
current_gpu_index = 0 | |
while current_gpu_index < len(temp_gpus): | |
if not temp_gpus[current_gpu_index].register_model(model): | |
current_gpu_index += 1 | |
else: | |
break | |
if current_gpu_index >= len(temp_gpus): | |
raise RuntimeError("SANITY CHECK FAILED") | |
print("Current layout", temp_gpus) | |
def __repr__(self): | |
return f"NO. GPUS: {len(self.gpus)}.\n{self.gpus}" | |