from torch import nn from transformers import GPT2LMHeadModel as GPT2LMHeadModelBase from transformers.models.gpt2.modeling_gpt2 import GPT2Block as GPT2BlockBase class GPT2Block(GPT2BlockBase): def forward(self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, encoder_hidden_states=None, encoder_attention_mask=None, output_attentions=None): x = self.ln_1(x) output_attn = self.attn( x, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask, use_cache=use_cache) a = output_attn[0] x = x + a m = self.mlp(self.ln_2(x)) x = x + m outputs = (x,) + output_attn[1:] return outputs class GPT2LMHeadModel(GPT2LMHeadModelBase): def __init__(self, config): super().__init__(config) self.transformer.h = nn.ModuleList([GPT2Block(config, layer_idx) for layer_idx in range(config.n_layer)])