|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""T5 model.""" |
|
|
|
import torch |
|
|
|
from megatron import ( |
|
get_args, |
|
mpu |
|
) |
|
from megatron.model.enums import AttnMaskType |
|
from megatron.model.language_model import parallel_lm_logits, get_language_model |
|
from megatron.model.transformer import LayerNorm |
|
from megatron.model.utils import ( |
|
openai_gelu, |
|
get_linear_layer, |
|
init_method_normal, |
|
scaled_init_method_normal |
|
) |
|
from .module import MegatronModule |
|
|
|
|
|
def t5_extended_attention_mask(attention_mask_list): |
|
|
|
def attn_mask_postprocess(attn_mask): |
|
|
|
extended_attention_mask = attn_mask.unsqueeze(1) |
|
return extended_attention_mask |
|
|
|
return [attn_mask_postprocess(attn_mask) for attn_mask in attention_mask_list] |
|
|
|
|
|
def t5_position_ids(token_ids): |
|
|
|
seq_length = token_ids.size(1) |
|
position_ids = torch.arange(seq_length, dtype=torch.long, |
|
device=token_ids.device) |
|
position_ids = position_ids.unsqueeze(0).expand_as(token_ids) |
|
|
|
return position_ids |
|
|
|
|
|
class T5LMHead(MegatronModule): |
|
"""Masked LM head for T5 |
|
|
|
Arguments: |
|
mpu_vocab_size: model parallel size of vocabulary. |
|
hidden_size: hidden size |
|
init_method: init method for weight initialization |
|
layernorm_epsilon: tolerance for layer norm divisions |
|
parallel_output: wether output logits being distributed or not. |
|
""" |
|
|
|
def __init__(self, mpu_vocab_size, parallel_output): |
|
super(T5LMHead, self).__init__() |
|
|
|
args = get_args() |
|
|
|
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) |
|
self.bias.model_parallel = True |
|
self.bias.partition_dim = 0 |
|
self.bias.stride = 1 |
|
self.parallel_output = parallel_output |
|
|
|
def forward(self, hidden_states, word_embeddings_weight): |
|
output = parallel_lm_logits(hidden_states, |
|
word_embeddings_weight, |
|
self.parallel_output, |
|
bias=self.bias) |
|
return output |
|
|
|
|
|
class T5Model(MegatronModule): |
|
"""T5 Language model.""" |
|
|
|
def __init__(self, |
|
num_tokentypes=0, |
|
parallel_output=True, |
|
pre_process=True, |
|
post_process=True, |
|
add_encoder=True, |
|
add_decoder=True): |
|
super(T5Model, self).__init__() |
|
args = get_args() |
|
|
|
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy |
|
self.parallel_output = parallel_output |
|
init_method = init_method_normal(args.init_method_std) |
|
scaled_init_method = scaled_init_method_normal(args.init_method_std, |
|
args.num_layers) |
|
self.pre_process = pre_process |
|
self.post_process = post_process |
|
self.add_encoder = add_encoder |
|
self.add_decoder = add_decoder |
|
|
|
self.language_model, self._language_model_key = get_language_model( |
|
num_tokentypes=num_tokentypes, |
|
add_pooler=False, |
|
add_encoder=add_encoder, |
|
add_decoder=add_decoder, |
|
encoder_attn_mask_type=AttnMaskType.padding, |
|
init_method=init_method, |
|
scaled_init_method=scaled_init_method, |
|
pre_process=self.pre_process, |
|
post_process=self.post_process) |
|
|
|
self.initialize_word_embeddings(init_method_normal) |
|
|
|
if self.post_process and self.add_decoder: |
|
self.lm_head = T5LMHead( |
|
self.word_embeddings_weight().size(0), |
|
parallel_output) |
|
self._lm_head_key = 'lm_head' |
|
|
|
def set_input_tensor(self, input_tensor): |
|
"""See megatron.model.transformer.set_input_tensor()""" |
|
self.language_model.set_input_tensor(input_tensor) |
|
|
|
def forward(self, |
|
encoder_input_ids, |
|
decoder_input_ids, |
|
encoder_attn_mask, |
|
decoder_attn_mask, |
|
encoder_decoder_attn_mask, |
|
tokentype_ids=None, |
|
lm_labels=None, |
|
enc_hidden_states=None): |
|
|
|
|
|
encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask = t5_extended_attention_mask( |
|
[encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask]) |
|
|
|
encoder_position_ids = t5_position_ids(encoder_input_ids) |
|
decoder_position_ids = t5_position_ids(decoder_input_ids) |
|
|
|
lm_output = self.language_model(encoder_input_ids, |
|
encoder_position_ids, |
|
encoder_attn_mask, |
|
decoder_input_ids, |
|
decoder_position_ids, |
|
decoder_attn_mask, |
|
encoder_decoder_attn_mask, |
|
tokentype_ids=tokentype_ids, |
|
enc_hidden_states=enc_hidden_states) |
|
import pdb;pdb.set_trace() |
|
if self.post_process and self.add_decoder: |
|
decoder_output, encoder_output = lm_output |
|
|
|
lm_logits = self.lm_head(decoder_output, |
|
self.word_embeddings_weight()) |
|
|
|
if lm_labels is None: |
|
|
|
return lm_logits.transpose(0,1).contiguous() |
|
else: |
|
|
|
lm_labels = lm_labels.transpose(0,1).contiguous() |
|
if self.fp16_lm_cross_entropy: |
|
assert lm_logits.dtype == torch.half |
|
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) |
|
else: |
|
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), |
|
lm_labels) |
|
|
|
lm_loss = lm_loss.transpose(0,1).contiguous() |
|
return lm_loss |
|
elif self.add_decoder and not self.add_encoder: |
|
decoder_output, encoder_output = lm_output |
|
return decoder_output |
|
else: |
|
encoder_output = lm_output |
|
return encoder_output |
|
|
|
def state_dict_for_save_checkpoint(self, destination=None, prefix='', |
|
keep_vars=False): |
|
"""For easy load when model is combined with other heads, |
|
add an extra key.""" |
|
|
|
state_dict_ = {} |
|
state_dict_[self._language_model_key] \ |
|
= self.language_model.state_dict_for_save_checkpoint( |
|
destination, prefix, keep_vars) |
|
if self.post_process and self.add_decoder: |
|
state_dict_[self._lm_head_key] \ |
|
= self.lm_head.state_dict_for_save_checkpoint( |
|
destination, prefix, keep_vars) |
|
|
|
if self.post_process and not self.pre_process and self.add_decoder: |
|
state_dict_[self._word_embeddings_for_head_key] \ |
|
= self.word_embeddings.state_dict(destination, prefix, keep_vars) |
|
return state_dict_ |
|
|
|
def load_state_dict(self, state_dict, strict=True): |
|
"""Customized load.""" |
|
|
|
self.language_model.load_state_dict( |
|
state_dict[self._language_model_key], strict=strict) |
|
if self.post_process and self.add_decoder: |
|
self.lm_head.load_state_dict(state_dict[self._lm_head_key], |
|
strict=strict) |
|
|
|
if self.post_process and not self.pre_process and self.add_decoder: |
|
self.word_embeddings.load_state_dict( |
|
state_dict[self._word_embeddings_for_head_key], strict=strict) |
|
|