import torch from transformers import AutoModelForCausalLM from accelerate import dispatch_model def _device_map(num_gpus, num_layers): per_gpu_layers = (num_layers + 2) / num_gpus device_map = { 'transformer.wte': 0, 'transformer.ln_f': 0, 'lm_head': num_gpus-1 } used = 1 gpu_target = 0 for i in range(num_layers): if used >= per_gpu_layers: gpu_target += 1 used = 0 if gpu_target < num_gpus-1 else 1 assert gpu_target < num_gpus device_map[f'transformer.h.{i}'] = gpu_target used += 1 return device_map def load_model_on_gpus(model_name_or_path, num_gpus: int = 2): num_devices = torch.cuda.device_count() if num_gpus == 1: model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map='auto', trust_remote_code=True).eval() elif 1 < num_gpus <= num_devices: model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map='cpu', trust_remote_code=True).eval() num_layers = model.config.num_hidden_layers device_map = _device_map(num_gpus, num_layers) print(device_map) model = dispatch_model(model, device_map=device_map) else: raise KeyError return model