debiasing-lms / model.py
Michael Gira
Initialize demo
6f82d3b
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2DoubleHeadsModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from types import MethodType
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
def get_model(device='cpu', gpt2_name='gpt2', in_net=False, in_net_init_identity=True, out_net=False, out_net_init_identity=True, freeze_ln=False, freeze_pos=True,
freeze_wte=True, freeze_ff=True, freeze_attn=True, dup_lm_head=False, dup_lm_head_bias=False):
# ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']
model = GPT2LMHeadModel.from_pretrained(gpt2_name).to(device)
# model = GPT2DoubleHeadsModel.from_pretrained('gpt2')
"""
Initialize linear input layer
"""
in_layer_sizes = []
out_layer_sizes = []
input_dim = model.config.n_embd
dropout = 0.1
orth_gain = 1.41
# orth_gain = None
in_net_init_identity = True
#Model - in_net
if in_net:
in_layers = []
last_output_size = input_dim
for size in in_layer_sizes:
layer = nn.Linear(last_output_size, size)
if orth_gain is not None:
torch.nn.init.orthogonal_(layer.weight, gain=orth_gain)
layer.bias.data.zero_()
in_layers.append(layer)
in_layers.append(nn.ReLU())
in_layers.append(nn.Dropout(dropout))
last_output_size = size
in_final_linear = nn.Linear(last_output_size, model.config.n_embd)
# if orth_gain is not None:
# torch.nn.init.orthogonal_(in_final_linear.weight, gain=orth_gain)
# in_final_linear.bias.data.zero_()
# Initialize final_linear layer to identity transformation
if in_net_init_identity:
nn.init.eye_(in_final_linear.weight)
in_final_linear.bias.data.zero_()
in_layers.append(in_final_linear)
in_layers.append(nn.Dropout(dropout))
model.in_net = nn.Sequential(*in_layers)
model.in_net.requires_grad = True
"""
Initialize linear output layer
"""
if out_net:
output_dim = model.config.n_embd
out_layers = []
last_output_size = model.config.n_embd
for size in out_layer_sizes:
out_layers.append(nn.Linear(last_output_size, size))
out_layers.append(nn.ReLU())
out_layers.append(nn.Dropout(dropout))
last_output_size = size
out_final_linear = nn.Linear(last_output_size, output_dim)
if out_net_init_identity:
nn.init.eye_(out_final_linear.weight)
out_final_linear.bias.data.zero_()
out_layers.append(out_final_linear)
model.out_net = nn.Sequential(*out_layers)
model.out_net.requires_grad = True
"""
out layer on top of lm_head
"""
# out_net_top = nn.Linear(model.config.vocab_size, model.config.vocab_size)
# nn.init.eye_(out_net_top.weight)
# model.out_net_top = out_net_top
# model.out_net_top.requires_grad = True
if dup_lm_head:
lm_head_new = nn.Linear(model.config.n_embd,
model.config.vocab_size, bias=dup_lm_head_bias)
lm_head_new.weight = torch.nn.Parameter(
model.lm_head.weight.data.detach().clone(), requires_grad=True)
# lm_head_new.bias.data.zero_()
model.lm_head_new = lm_head_new
model.lm_head_new.requires_grad = True
"""
Freeze transformer layers
"""
total_parameters = 0
target_parameters = 0
for name, p in model.transformer.named_parameters():
name = name.lower()
size = p.size()
param_count = 1
for dimension in size:
param_count *= dimension
total_parameters += param_count
if 'ln' in name or 'norm' in name:
p.requires_grad = not freeze_ln
elif 'wpe' in name or 'position_embeddings' in name or 'pos_drop' in name:
p.requires_grad = not freeze_pos
target_parameters += param_count
elif 'mlp' in name:
p.requires_grad = not freeze_ff
elif 'attn' in name:
p.requires_grad = not freeze_attn
elif 'wte' in name:
p.requires_grad = not freeze_wte
else:
p.requires_grad = False
# print(f'Total params: {total_parameters}')
# print(
# f'Target params: {target_parameters} ({target_parameters / total_parameters * 100:.2f}%)')
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Convert from input ids to word embeddings so that we can apply a linear layer
x = self.transformer.wte(input_ids)
try:
x = self.in_net(x)
except AttributeError:
pass
transformer_outputs = self.transformer(
inputs_embeds=x,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs
)
hidden_states = transformer_outputs[0]
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.transformer.first_device)
hidden_states = hidden_states.to(self.lm_head.weight.device)
try:
hidden_states = self.out_net(hidden_states)
except AttributeError:
pass
try:
lm_logits = self.lm_head_new(hidden_states)
except AttributeError:
lm_logits = self.lm_head(hidden_states)
# lm_logits = self.out_net_top(lm_logits)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)
model.forward = MethodType(forward, model)
return model
# model = get_model()
'''
only for testing purpose
'''
if __name__ == "__main__":
model = get_model(gpt2_name='gpt2', in_net=False, in_net_init_identity=True, out_net=False, out_net_init_identity=False, freeze_ln=True, freeze_pos=True,
freeze_wte=True, freeze_ff=True, freeze_attn=True)
for name, p in model.named_parameters():
if p.requires_grad:
print(name, p.requires_grad)
for p in model.lm_head_new.parameters():
print('lm_head_new', p)
# for p in model.out_net.parameters():
# print('out_net',p)