|
import torch |
|
from transformers import AutoModelForCausalLM |
|
from accelerate import dispatch_model |
|
|
|
|
|
def _device_map(num_gpus, num_layers): |
|
per_gpu_layers = (num_layers + 2) / num_gpus |
|
|
|
device_map = { |
|
'transformer.wte': 0, |
|
'transformer.ln_f': 0, |
|
'lm_head': num_gpus-1 |
|
} |
|
|
|
used = 1 |
|
gpu_target = 0 |
|
for i in range(num_layers): |
|
if used >= per_gpu_layers: |
|
gpu_target += 1 |
|
used = 0 if gpu_target < num_gpus-1 else 1 |
|
assert gpu_target < num_gpus |
|
device_map[f'transformer.h.{i}'] = gpu_target |
|
used += 1 |
|
|
|
return device_map |
|
|
|
|
|
def load_model_on_gpus(model_name_or_path, num_gpus: int = 2): |
|
num_devices = torch.cuda.device_count() |
|
|
|
if num_gpus == 1: |
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map='auto', |
|
trust_remote_code=True).eval() |
|
elif 1 < num_gpus <= num_devices: |
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map='cpu', |
|
trust_remote_code=True).eval() |
|
num_layers = model.config.num_hidden_layers |
|
device_map = _device_map(num_gpus, num_layers) |
|
print(device_map) |
|
model = dispatch_model(model, device_map=device_map) |
|
else: |
|
raise KeyError |
|
|
|
return model |
|
|