File size: 24,809 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# Automatic Tensor Parallelism
import re
from torch import nn
from .replace_policy import replace_policies
from typing import Optional
import torch
from deepspeed import comm as dist
from .layers import *
from deepspeed.accelerator import get_accelerator
from .fusedqkv_utils import require_tp_fused_qkvw
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
from deepspeed.utils import groups
from deepspeed.module_inject.layers import is_autotp_training_mode
def move(tensor, device, copy=True):
if tensor.is_meta:
return torch.empty_like(tensor, device=device)
else:
# Using new tensors help in freeing memory (after split for example) was done before by calling clone().
# Using copy=True instead of clone() will help in case of cpu --> cpu.
# Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced.
return tensor.to(device, copy=copy)
class ReplaceWithTensorSlicing:
def __init__(self, mp_group=None, mp_size=1, out_dim=1, in_dim=0):
if mp_group is not None:
self.gpu_index = dist.get_rank(group=mp_group)
else:
self.gpu_index = 0
self.out_dim = out_dim
self.in_dim = in_dim
self.mp_size = mp_size
def merge_assert(self, dim1, dim2):
assert dim1 > dim2, \
'Merging tensors is not allowed here! Please use deepspeed load_checkpoint\
for merging your checkpoints before replacing the transformer layer with\
inference-kernels'
def strided_copy(self,
dst: Optional[torch.Tensor],
src: Optional[torch.Tensor],
num_splits: int,
int8: bool = False,
allocate_tensor: bool = False):
if src is None:
return src
src_shape = src.shape
dst_shape = dst.shape
outer_dim = 0 if int8 else -1
if allocate_tensor:
dst = torch.empty_like(dst)
src_split = torch.split(src.data, src.shape[outer_dim] // num_splits, dim=outer_dim)
if (len(src_shape) == 2 and len(dst_shape) == 2):
if src_shape[outer_dim] == dst_shape[self.out_dim]:
try:
dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape)
except:
print(dst.shape, src.shape)
exit()
dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
if hasattr(src, 'scale'):
dst.scale = src.scale
return dst
self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim])
qkv_size = dst_shape[self.out_dim] // num_splits
qkv_split = [torch.split(src_s, qkv_size, dim=outer_dim) for src_s in src_split]
weight_split = [
torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=outer_dim) for i in range(len(qkv_split[0]))
]
dst = dst.reshape(-1).data.copy_(weight_split[self.gpu_index].contiguous().reshape(-1)).reshape(
weight_split[self.gpu_index].shape)
else:
if src_shape[0] == dst_shape[0]:
return torch.nn.parameter.Parameter(src)
qkv_size = dst_shape[0] // num_splits
qkv_split = [torch.split(src_s, qkv_size, dim=0) for src_s in src_split]
bias_split = [torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=0) for i in range(len(qkv_split[0]))]
dst.data.copy_(bias_split[self.gpu_index].contiguous())
dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
if hasattr(src, 'scale'):
dst.scale = src.scale
return dst
def copy(self, dst, src, int8=False, allocate_tensor=False):
if src is None:
return src
assert not dst.data.is_meta # the torch.Tensor.copy_ method used below will silently fail on meta tensors
if allocate_tensor:
dst = torch.empty_like(dst)
outer_dim = 0 if int8 else 1
inner_dim = 1 if int8 else 0
src_shape = src.shape
dst_shape = dst.shape
if (len(src_shape) == 2 and len(dst_shape) == 2):
if src_shape[inner_dim] == dst_shape[self.in_dim] and src_shape[outer_dim] == dst_shape[self.out_dim]:
dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape)
else:
if src_shape[inner_dim] != dst_shape[self.in_dim]:
self.merge_assert(src_shape[inner_dim], dst_shape[self.in_dim])
dst.data.copy_(src[:, self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim]] if inner_dim == 1 else \
src[self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim], :])
else:
self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim])
dst.data.copy_(src[:, self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim]] if outer_dim == 1 else \
src[self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim], :])
else:
if src_shape[0] == dst_shape[0]:
dst = src if src.dtype == dst.dtype else dst.data.copy_(src)
else:
dst.data.copy_(src[self.gpu_index * dst_shape[-1]:(self.gpu_index + 1) * dst_shape[-1]])
dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
if hasattr(src, 'scale'):
dst.scale = src.scale
return dst
class Loading():
def is_load_module(module):
load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm]
load_layer_names = [
"LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear",
"MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm", "Phi3RotaryEmbedding", "Phi3SuScaledRotaryEmbedding",
"Phi3RMSNorm", "YuanRMSNorm", "YuanRotaryEmbedding", "Phi3LongRoPEScaledRotaryEmbedding", "Qwen2RMSNorm",
"Qwen3RMSNorm", "Qwen3MoeRMSNorm", "DeepseekV2RMSNorm", "DeepseekV3RMSNorm",
"DeepseekV2YarnRotaryEmbedding", "DeepseekV3YarnRotaryEmbedding", "MoEGate"
]
return module.__class__ in load_layers or module._get_name() in load_layer_names
def load_buffer(module, state_dict, prefix):
for name in module._buffers.keys():
if module._buffers[name].data.is_meta:
module._buffers[name] = torch.nn.parameter.Parameter(
data=torch.empty_like(module._buffers[name].data, device="cpu"),
requires_grad=module._buffers[name].data.requires_grad)
if prefix + name in state_dict.keys():
module._buffers[name].data.copy_(state_dict[prefix + name])
def load(module, state_dict, prefix, mp_group=None):
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
if hasattr(module, 'weight'):
if module.weight.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.weight = torch.nn.parameter.Parameter(data=torch.empty_like(module.weight.data, device="cpu"),
requires_grad=module.weight.data.requires_grad)
if 'query_key_value' in prefix:
module.weight = mp_replace.strided_copy(module.weight.data,
state_dict[prefix + 'weight'],
num_splits=3)
else:
module.weight = mp_replace.copy(module.weight.data, state_dict[prefix + 'weight'])
else:
if hasattr(module, 'norm') and hasattr(module.norm, 'weight'):
if module.norm.weight.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.norm.weight = torch.nn.parameter.Parameter(
data=torch.empty_like(module.norm.weight.data, device="cpu"),
requires_grad=module.norm.weight.data.requires_grad)
module.norm.weight = mp_replace.copy(module.norm.weight.data, state_dict[prefix + 'weight'])
if prefix + 'bias' in state_dict.keys():
if hasattr(module, 'bias'):
if module.bias.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.bias = torch.nn.parameter.Parameter(data=torch.empty_like(module.bias.data, device="cpu"),
requires_grad=module.bias.data.requires_grad)
module.bias = mp_replace.copy(module.bias, state_dict[prefix + 'bias'])
else:
if hasattr(module, 'norm') and hasattr(module.norm, 'bias'):
if module.norm.bias.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.norm.bias = torch.nn.parameter.Parameter(
data=torch.empty_like(module.norm.bias.data, device="cpu"),
requires_grad=module.norm.bias.data.requires_grad)
module.norm.bias = mp_replace.copy(module.norm.bias, state_dict[prefix + 'bias'])
class AutoTP():
def __init__(self,
module,
all_reduce_linears,
prefix,
state_dict,
linear_layer_setting,
orig_layer_impl,
keep_module_on_host=False):
self.module = module
self.all_reduce_linears = all_reduce_linears
self.prefix = prefix
self.state_dict = state_dict
self.mp_size = None
self.mp_group = None
self.linear_layer_setting = linear_layer_setting
self.orig_layer_impl = orig_layer_impl
self.linear_policies = None
self.conv_linear_layer = False
TensorParallel_Layer.set_keep_module_on_host(keep_module_on_host)
def in_module_list(module, module_list):
for item in module_list:
if type(item).__name__ == type(module).__name__:
return True
return False
def get_module_list(model):
mlist = []
for child in model.children():
if isinstance(child, nn.ModuleList):
for module in child.children():
if not mlist:
mlist = [module]
elif not AutoTP.in_module_list(module, mlist):
mlist = mlist + [module]
else:
mlist = mlist + AutoTP.get_module_list(child)
return mlist
def supported(model):
unsupported = ['deberta', 'flaubert', 'fsmt', 'gpt2', 'led', 'longformer', 'xlm', 'xlnet']
model = str(model)
key = re.search(r": (.*?)Model", model)
if key is None:
key = re.search(r": (.*?)Stack", model)
if key is None:
key = re.match(r"(.*?)Model", model)
assert key is not None, "Not able to determine model policy automatically. Please provide policy."
if key.group(1).lower() in unsupported:
return False
return True
def get_layers(parent, module):
layer_list = []
for key, submodule in module._modules.items():
if isinstance(submodule, nn.Linear):
layer_list = layer_list + [parent + "." + key]
elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
layer_list = layer_list + ["ln"]
else:
layer_list = layer_list + AutoTP.get_layers(key, submodule)
return layer_list
def update_policy_list(policy_list, new_module, new_gems):
if len(policy_list):
for i, policy in enumerate(policy_list):
# if module already exists in policy, combine gems and remove duplicates
if policy[0] == type(new_module):
new_gems = set(new_gems + policy[1])
policy_list[i] = tuple([type(new_module), new_gems])
return policy_list
policy_list.append(tuple([type(new_module), new_gems]))
return policy_list
def kernel_supported(module_list):
policy = []
for plcy in replace_policies:
# instantiate a throw-away policy in order to populate the _orig_layer_class
_ = plcy(None)
if isinstance(plcy._orig_layer_class, list):
for orig_layer_class in plcy._orig_layer_class:
policy.append(orig_layer_class)
elif plcy._orig_layer_class is not None:
policy.append(plcy._orig_layer_class)
for child in module_list:
if child.__class__ in policy:
return True
return False
def tp_parser(model):
policy_list = []
module_list = []
layer_list = []
gem_list = []
module_list = AutoTP.get_module_list(model)
assert AutoTP.supported(model), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \
if AutoTP.kernel_supported(module_list) else "AutoTP not supported for model. Please provide policy."
norm_layer_name_list = ['LayerNorm', 'layer_norm', 'ln_1', 'ln_2']
#ln_1 , ln_2 for Qwen
for module in module_list:
for key, submodule in module._modules.items():
if isinstance(submodule, nn.Linear):
layer_list = layer_list + ["." + key]
elif isinstance(submodule, nn.LayerNorm) or key in norm_layer_name_list:
layer_list = layer_list + ["ln"]
else:
layer_list = layer_list + AutoTP.get_layers(key, submodule)
for i, layer in enumerate(layer_list):
if layer == 'ln':
if layer_list[i - 1] != 'ln':
gem_list = gem_list + [layer_list[i - 1]]
elif 'out_proj' in layer:
gem_list = gem_list + [layer]
elif 'o_proj' in layer:
gem_list = gem_list + [layer]
elif 'down_proj' in layer:
gem_list = gem_list + [layer]
elif 'attention.dense' in layer and 'GPTNeoX' in str(model):
gem_list = gem_list + [layer]
elif 'self_attention.dense' in layer and 'falcon' in str(
type(module)): # this is a hack to get the right linear layer for this model!
gem_list = gem_list + [layer]
# Mixtral-7x8b used w2*act(w1*w3) linear. need to replace w2 to linearallreduce.
elif 'w2' in layer and 'Mixtral' in str(type(module)):
gem_list = gem_list + [layer]
elif 'self_attn.dense' in layer and 'Phi' in str(type(module)):
gem_list = gem_list + [layer]
elif 'self_attention.dense' in layer and 'ChatGLM' in str(model):
gem_list = gem_list + [layer]
elif 'dense_4h_to_h' in layer and 'ChatGLM' in str(model):
gem_list = gem_list + [layer]
layer_list = []
if gem_list != []:
gem_list = list(set(gem_list))
policy_list = AutoTP.update_policy_list(policy_list, module, gem_list)
gem_list = []
assert len(policy_list), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \
if AutoTP.kernel_supported(module_list) else "Not able to determine model policy automatically. Please provide policy."
return policy_list
def set_tensor_parallel_config(self, mp_size, mp_group):
if is_autotp_training_mode():
self.mp_group = groups.get_tensor_model_parallel_group()
self.mp_size = groups.get_tensor_model_parallel_world_size()
return
self.mp_size = mp_size
self.mp_group = mp_group
def _replace(self, child, name, conv_linear_layer):
# This function should clearly define the routing rules for specific layers
# and avoid any complex shard-related logic.
if getattr(child, "replaced", False) == True:
return
weight_shape = child.weight.shape
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
# For TP layer skip, e.g., MoE gate, deepseek low rank layer skip
if "mlp.gate" == name or "q_a_proj" in name or "kv_a_proj_with_mqa" in name or name == "block_sparse_moe.gate" or (
('mlp.shared_expert_gate' == name or 'mlp.gate' == name) and 'qwen2_moe' in str(type(self.module))):
return child
# For Yuan model
if 'Yuan' in str(self.module):
if 'v_proj' in name:
return Yuan_LinearLayer(child, self.mp_group)
elif 'o_proj' in name:
return Yuan_LinearAllreduce(child, self.mp_group)
# For MLP including chunk layer.
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):
return GateUpPack_LinearLayer(child, self.mp_group)
# For Arctic model, bypass to all_reduce replacement for w2 weights
arctic_w2_all_reduce_linear = False
if 'Arctic' in str(self.module) and 'w2' in name:
arctic_w2_all_reduce_linear = True
# For MoE MLP model, e.g., deepseek and jamba
down_proj = False
if 'down_proj' in name:
down_proj = True
if name in self.all_reduce_linears or arctic_w2_all_reduce_linear or down_proj:
setattr(child, "replaced", True)
if self.conv_linear_layer:
return Conv_LinearALlreduce(child, self.mp_group, name=name)
elif name == "lm_head" or name == 'embed_out':
return LmHeadLinearAllreduce(child, self.mp_group)
return LinearAllreduce(child, self.mp_group, name=name)
else:
setattr(child, "replaced", True)
if self.conv_linear_layer:
conv_LinearLayer(child, self.mp_group)
elif require_tp_fused_qkvw(name, self.mp_size):
#Check and handle fused qkv for TP
return fused_LinearLayer(child, self.mp_group, fused_module=self.module)
return LinearLayer(child, self.mp_group, name=name)
def _slice_embedding(self, child, name, conv_linear_layer):
if getattr(child, "replaced", False) == True:
return
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
if hasattr(child.weight, 'ds_tensor'):
data = child.weight.ds_tensor.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size), dim=1)
else:
data = child.weight.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size, name), dim=1)
data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
data = torch.nn.parameter.Parameter(data, requires_grad=False)
new_embedding = nn.Embedding(child.weight.shape[0], get_shard_size(child.weight.shape[1], self.mp_size, name))
new_embedding.weight.data.copy_(data)
setattr(child, "replaced", True)
return new_embedding
def update_mp_params(self, child):
if getattr(child, "replaced", False) == True:
return
param_list = [
"n_heads", "inner_dim", "num_heads", "num_kv", "num_attention_heads", "num_attn_heads", "all_head_size",
"embed_dim", "hidden_size", "num_key_value_heads", "num_kv_heads", "kv_n_heads", "d_model",
"num_attention_heads_per_partition", "num_multi_query_groups_per_partition", "hidden_size_per_partition"
]
for param in param_list:
if "Yuan" in str(child) and 'embed_dim' in param_list:
param_list.remove('embed_dim')
if hasattr(child, param):
param_val = getattr(child, param)
setattr(child, param, get_shard_size(param_val, self.mp_size))
setattr(child, "replaced", True)
def update_linear_policies(self):
self.conv_linear_layer = False
if self.linear_layer_setting is not None:
self.linear_policies = {self.linear_layer_setting[0]: self._replace}
if len(self.linear_layer_setting) == 2:
self.linear_policies.update({self.linear_layer_setting[1]: self._slice_embedding})
else:
import transformers
if self.orig_layer_impl is transformers.models.gpt2.modeling_gpt2.GPT2Block:
try:
self.conv_linear_layer = True
self.linear_policies = {transformers.pytorch_utils.Conv1D: self._replace}
except ImportError:
self.linear_policies = {nn.Linear: self._replace}
else:
self.linear_policies = {nn.Linear: self._replace, nn.Embedding: self._slice_embedding}
def _replace_module(self, r_module, prev_name='', prev_class_name=''):
for name, child in r_module.named_children():
if prev_class_name == "":
class_name = prev_name
elif prev_name == "":
class_name = prev_class_name
else:
class_name = prev_class_name + '.' + prev_name
checking_key = self.prefix + '.' + class_name + '.' + name + '.' if class_name != "" else self.prefix + '.' + name + '.'
if Loading.is_load_module(child) and self.state_dict is not None:
if any(checking_key in item for item in self.state_dict):
Loading.load(child, self.state_dict, checking_key, self.mp_group)
else:
continue
if len(child._buffers) != 0 and self.state_dict is not None:
Loading.load_buffer(child, self.state_dict, checking_key)
if child.__class__ in self.linear_policies:
setattr(r_module, name, self.linear_policies[child.__class__](child, prev_name + '.' + name,
self.conv_linear_layer))
elif any(isinstance(child, lp) for lp in self.linear_policies):
# Added for falcon model support
# Note: isinstance will account for class inheritance, child.__class__ does not
key = None
for lp in self.linear_policies:
if isinstance(child, lp):
key = lp
break
assert key is not None
setattr(r_module, name, self.linear_policies[key](child, prev_name + '.' + name,
self.conv_linear_layer))
else:
self.update_mp_params(child)
self._replace_module(child, name, class_name)
return r_module
def get_model_num_kv_heads(self, config):
num_kv_heads = None
# multi_query_group_num is for chatglm2 & chatglm3
kv_head_names = [
'multi_query_group_num', 'num_kv_heads', 'num_key_value_heads', 'num_attention_heads', 'n_heads',
'attention_heads'
]
for name in kv_head_names:
if hasattr(config, name):
num_kv_heads = getattr(config, name)
if num_kv_heads is not None:
break
return num_kv_heads
def _replace_last_linear_module(self, r_module):
if hasattr(r_module, "lm_head"):
name = "lm_head"
child = r_module.lm_head
elif hasattr(r_module, "embed_out"):
name = "embed_out"
child = r_module.embed_out
else:
return r_module
if child.__class__ in self.linear_policies:
setattr(r_module, name, self.linear_policies[child.__class__](child, name, self.conv_linear_layer))
return r_module
|