Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from torch.nn import CrossEntropyLoss | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2DoubleHeadsModel | |
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions | |
from types import MethodType | |
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') | |
def get_model(device='cpu', gpt2_name='gpt2', in_net=False, in_net_init_identity=True, out_net=False, out_net_init_identity=True, freeze_ln=False, freeze_pos=True, | |
freeze_wte=True, freeze_ff=True, freeze_attn=True, dup_lm_head=False, dup_lm_head_bias=False): | |
# ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'] | |
model = GPT2LMHeadModel.from_pretrained(gpt2_name).to(device) | |
# model = GPT2DoubleHeadsModel.from_pretrained('gpt2') | |
""" | |
Initialize linear input layer | |
""" | |
in_layer_sizes = [] | |
out_layer_sizes = [] | |
input_dim = model.config.n_embd | |
dropout = 0.1 | |
orth_gain = 1.41 | |
# orth_gain = None | |
in_net_init_identity = True | |
#Model - in_net | |
if in_net: | |
in_layers = [] | |
last_output_size = input_dim | |
for size in in_layer_sizes: | |
layer = nn.Linear(last_output_size, size) | |
if orth_gain is not None: | |
torch.nn.init.orthogonal_(layer.weight, gain=orth_gain) | |
layer.bias.data.zero_() | |
in_layers.append(layer) | |
in_layers.append(nn.ReLU()) | |
in_layers.append(nn.Dropout(dropout)) | |
last_output_size = size | |
in_final_linear = nn.Linear(last_output_size, model.config.n_embd) | |
# if orth_gain is not None: | |
# torch.nn.init.orthogonal_(in_final_linear.weight, gain=orth_gain) | |
# in_final_linear.bias.data.zero_() | |
# Initialize final_linear layer to identity transformation | |
if in_net_init_identity: | |
nn.init.eye_(in_final_linear.weight) | |
in_final_linear.bias.data.zero_() | |
in_layers.append(in_final_linear) | |
in_layers.append(nn.Dropout(dropout)) | |
model.in_net = nn.Sequential(*in_layers) | |
model.in_net.requires_grad = True | |
""" | |
Initialize linear output layer | |
""" | |
if out_net: | |
output_dim = model.config.n_embd | |
out_layers = [] | |
last_output_size = model.config.n_embd | |
for size in out_layer_sizes: | |
out_layers.append(nn.Linear(last_output_size, size)) | |
out_layers.append(nn.ReLU()) | |
out_layers.append(nn.Dropout(dropout)) | |
last_output_size = size | |
out_final_linear = nn.Linear(last_output_size, output_dim) | |
if out_net_init_identity: | |
nn.init.eye_(out_final_linear.weight) | |
out_final_linear.bias.data.zero_() | |
out_layers.append(out_final_linear) | |
model.out_net = nn.Sequential(*out_layers) | |
model.out_net.requires_grad = True | |
""" | |
out layer on top of lm_head | |
""" | |
# out_net_top = nn.Linear(model.config.vocab_size, model.config.vocab_size) | |
# nn.init.eye_(out_net_top.weight) | |
# model.out_net_top = out_net_top | |
# model.out_net_top.requires_grad = True | |
if dup_lm_head: | |
lm_head_new = nn.Linear(model.config.n_embd, | |
model.config.vocab_size, bias=dup_lm_head_bias) | |
lm_head_new.weight = torch.nn.Parameter( | |
model.lm_head.weight.data.detach().clone(), requires_grad=True) | |
# lm_head_new.bias.data.zero_() | |
model.lm_head_new = lm_head_new | |
model.lm_head_new.requires_grad = True | |
""" | |
Freeze transformer layers | |
""" | |
total_parameters = 0 | |
target_parameters = 0 | |
for name, p in model.transformer.named_parameters(): | |
name = name.lower() | |
size = p.size() | |
param_count = 1 | |
for dimension in size: | |
param_count *= dimension | |
total_parameters += param_count | |
if 'ln' in name or 'norm' in name: | |
p.requires_grad = not freeze_ln | |
elif 'wpe' in name or 'position_embeddings' in name or 'pos_drop' in name: | |
p.requires_grad = not freeze_pos | |
target_parameters += param_count | |
elif 'mlp' in name: | |
p.requires_grad = not freeze_ff | |
elif 'attn' in name: | |
p.requires_grad = not freeze_attn | |
elif 'wte' in name: | |
p.requires_grad = not freeze_wte | |
else: | |
p.requires_grad = False | |
# print(f'Total params: {total_parameters}') | |
# print( | |
# f'Target params: {target_parameters} ({target_parameters / total_parameters * 100:.2f}%)') | |
def forward( | |
self, | |
input_ids=None, | |
past_key_values=None, | |
attention_mask=None, | |
token_type_ids=None, | |
position_ids=None, | |
head_mask=None, | |
inputs_embeds=None, | |
encoder_hidden_states=None, | |
encoder_attention_mask=None, | |
labels=None, | |
use_cache=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
**kwargs | |
): | |
r""" | |
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set | |
``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to | |
``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]`` | |
""" | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
# Convert from input ids to word embeddings so that we can apply a linear layer | |
x = self.transformer.wte(input_ids) | |
try: | |
x = self.in_net(x) | |
except AttributeError: | |
pass | |
transformer_outputs = self.transformer( | |
inputs_embeds=x, | |
past_key_values=past_key_values, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=encoder_attention_mask, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
**kwargs | |
) | |
hidden_states = transformer_outputs[0] | |
# Set device for model parallelism | |
if self.model_parallel: | |
torch.cuda.set_device(self.transformer.first_device) | |
hidden_states = hidden_states.to(self.lm_head.weight.device) | |
try: | |
hidden_states = self.out_net(hidden_states) | |
except AttributeError: | |
pass | |
try: | |
lm_logits = self.lm_head_new(hidden_states) | |
except AttributeError: | |
lm_logits = self.lm_head(hidden_states) | |
# lm_logits = self.out_net_top(lm_logits) | |
loss = None | |
if labels is not None: | |
# Shift so that tokens < n predict n | |
shift_logits = lm_logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
# Flatten the tokens | |
loss_fct = CrossEntropyLoss() | |
loss = loss_fct( | |
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) | |
if not return_dict: | |
output = (lm_logits,) + transformer_outputs[1:] | |
return ((loss,) + output) if loss is not None else output | |
return CausalLMOutputWithCrossAttentions( | |
loss=loss, | |
logits=lm_logits, | |
past_key_values=transformer_outputs.past_key_values, | |
hidden_states=transformer_outputs.hidden_states, | |
attentions=transformer_outputs.attentions, | |
cross_attentions=transformer_outputs.cross_attentions, | |
) | |
model.forward = MethodType(forward, model) | |
return model | |
# model = get_model() | |
''' | |
only for testing purpose | |
''' | |
if __name__ == "__main__": | |
model = get_model(gpt2_name='gpt2', in_net=False, in_net_init_identity=True, out_net=False, out_net_init_identity=False, freeze_ln=True, freeze_pos=True, | |
freeze_wte=True, freeze_ff=True, freeze_attn=True) | |
for name, p in model.named_parameters(): | |
if p.requires_grad: | |
print(name, p.requires_grad) | |
for p in model.lm_head_new.parameters(): | |
print('lm_head_new', p) | |
# for p in model.out_net.parameters(): | |
# print('out_net',p) | |