from typing import Optional import torch from torch import nn from transformers import PreTrainedModel, GenerationMixin, AutoConfig, AutoModel, AutoModelForCausalLM from transformers.modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput from .configuration_lstm import LstmConfig class MLP(nn.Module): def __init__(self, config: LstmConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = nn.SiLU() def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj class LstmLayer(nn.Module): def __init__(self, config: LstmConfig): super().__init__() self.lstm = nn.LSTM(config.hidden_size, config.hidden_size, num_layers=1, batch_first=True, bias=False) self.mlp = MLP(config) self.input_ln = nn.RMSNorm((config.hidden_size,), eps=1e-6) self.post_ln = nn.RMSNorm((config.hidden_size,), eps=1e-6) def forward(self, hidden_states): lstm_part = self.input_ln(hidden_states) lstm_part, _ = self.lstm(lstm_part) hidden_states = hidden_states + lstm_part mlp_part = self.post_ln(hidden_states) mlp_part = self.mlp(mlp_part) return hidden_states + mlp_part class LstmPreTrainedModel(PreTrainedModel): config_class = LstmConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["LstmLayer"] def _init_weights(self, module): std = self.config.initializer_range gain = self.config.initializer_gain if isinstance(module, nn.Linear): #nn.init.normal_(module.weight.data, std=std) nn.init.kaiming_uniform_(module.weight.data) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight.data, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.RMSNorm): module.weight.data.fill_(0.4) elif isinstance(module, nn.LSTM): for name, param in module.named_parameters(): if "weight" in name: nn.init.xavier_uniform_(param, gain=gain) elif "bias" in name: with torch.no_grad(): param.zero_() class LstmModel(LstmPreTrainedModel): def __init__(self, config: LstmConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [LstmLayer(config) for layer_idx in range(config.num_hidden_layers)] ) self.norm = nn.RMSNorm((config.hidden_size,), eps=1e-6) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, **kwargs, ) -> BaseModelOutputWithNoAttention: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: hidden_states = self.embed_tokens(input_ids) for block in self.layers: if self.gradient_checkpointing and self.training: hidden_states = self._gradient_checkpointing_func( block.__call__, hidden_states, ) else: hidden_states = block(hidden_states) last_hidden_state = self.norm(hidden_states) return BaseModelOutputWithNoAttention( last_hidden_state=last_hidden_state ) class LstmForCausalLM(LstmPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: LstmConfig): super().__init__(config) self.model = LstmModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model def forward( self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, **kwargs, ): if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") hidden_states = self.model(input_ids, inputs_embeds).last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) return CausalLMOutput( loss=loss, logits=logits, ) AutoConfig.register("lstm", LstmConfig) AutoModel.register(LstmConfig, LstmModel) AutoModelForCausalLM.register(LstmConfig, LstmForCausalLM) LstmConfig.register_for_auto_class() LstmModel.register_for_auto_class("AutoModel") LstmForCausalLM.register_for_auto_class("AutoModelForCausalLM")