xlstm_wikipedia_110M_1M / modeling_xlstm.py
PatrickHaller's picture
Upload modeling_xlstm.py with huggingface_hub
a4a40bb verified
from typing import Optional, Sequence, Tuple, Union
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast
from xlstm.components.init import small_init_init_
from xlstm.utils import WeightDecayOptimGroupMixin
from xlstm.xlstm_block_stack import xLSTMBlockStack as _xLSTMBlockStack
from .configuration_xlstm import xLSTMConfig
class xLSTMPreTrainedModel(PreTrainedModel):
"""Base class for all models."""
config_class = xLSTMConfig
class xLSTMBlockStack(_xLSTMBlockStack):
"""Small wrapper to expose hidden states"""
def forward(
self, x: torch.Tensor, **kwargs
) -> Tuple[torch.Tensor, Sequence[torch.Tensor]]:
hidden_states = ()
for block in self.blocks:
x = block(x, **kwargs)
hidden_states += (x,)
x = self.post_blocks_norm(x)
return x, hidden_states
class xLSTMModel(xLSTMPreTrainedModel):
def __init__(self, config: xLSTMConfig):
super().__init__(config)
self.config = config
self.token_embedding = nn.Embedding(
num_embeddings=config.vocab_size, embedding_dim=config.embedding_dim
)
_config = config.to_xlstm_config()
self.emb_dropout = (
nn.Dropout(_config.dropout)
if _config.add_embedding_dropout
else nn.Identity()
)
self.xlstm_block_stack = xLSTMBlockStack(config=_config)
def forward(
self,
input_ids: torch.LongTensor,
output_hidden_states: Optional[bool] = None,
return_dict=Optional[bool],
) -> Union[Tuple, BaseModelOutput]:
token_embedding = self.token_embedding(input_ids)
x = self.emb_dropout(token_embedding)
x, hidden_states = self.xlstm_block_stack(x)
if output_hidden_states:
hidden_states = (token_embedding,) + hidden_states
if not return_dict:
return x, hidden_states
return BaseModelOutput(
last_hidden_state=x,
hidden_states=hidden_states if output_hidden_states else None,
)
class xLSTMForCausalLM(xLSTMPreTrainedModel, WeightDecayOptimGroupMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: xLSTMConfig, **kwargs):
super().__init__(config)
self.config = config
self.vocab_size = config.vocab_size
self.model = xLSTMModel(config)
self.lm_head = nn.Linear(
in_features=config.embedding_dim,
out_features=config.vocab_size,
bias=False,
)
self.post_init()
# TODO: Add option for up-projection
def get_input_embeddings(self):
return self.model.token_embedding
def set_input_embeddings(self, value: nn.Module):
self.model.token_embedding = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, value):
self.lm_head = value
def reset_parameters(self):
self.model.xlstm_block_stack.reset_parameters()
small_init_init_(
self.get_input_embeddings().weight, dim=self.config.embedding_dim
)
if not self.config.tie_word_embeddings:
small_init_init_(
self.get_output_embeddings().weight, dim=self.config.embedding_dim
)
def forward(
self,
input_ids: torch.Tensor,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output = self.model(
input_ids,
output_hidden_states=output_hidden_states,
)
hidden_state = output[0]
logits = self.lm_head(hidden_state)
logits = logits.float()
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + output[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
hidden_states=output.hidden_states,
)
def step(
self,
idx: torch.Tensor,
state: dict[str, dict[str, tuple[torch.Tensor, ...]]] = None,
**kwargs,
) -> tuple[torch.Tensor, dict[str, dict[str, tuple[torch.Tensor, ...]]]]:
x = self.token_embedding(idx)
x = self.emb_dropout(x)
x, state = self.xlstm_block_stack.step(x, state=state, **kwargs)
logits = self.lm_head(x)
return logits, state
def _create_weight_decay_optim_groups(
self, **kwargs
) -> tuple[Sequence[nn.Parameter], Sequence[nn.Parameter]]:
weight_decay, no_weight_decay = super()._create_weight_decay_optim_groups(
**kwargs
)
# remove token embedding and add it to the correct group, accrording to the config
weight_decay = list(weight_decay)
removed = 0
for idx in range(len(weight_decay)):
if weight_decay[idx - removed] is self.get_input_embeddings().weight:
weight_decay.pop(idx - removed)
removed += 1
weight_decay = tuple(weight_decay)
# TODO: Fix this
# if self.config.weight_decay_on_embedding:
if True:
weight_decay += (self.get_input_embeddings().weight,)
else:
no_weight_decay += (self.get_input_embeddings().weight,)
return weight_decay, no_weight_decay
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
new_embeddings = nn.Embedding(
new_num_tokens, self.token_embedding.embedding_dim
)
self.token_embedding = new_embeddings.to(self.device)
return new_embeddings
def tie_weights(self):
self.get_output_embeddings().weight = self.get_input_embeddings().weight
def prepare_inputs_for_generation(
self,
input_ids,
**kwargs,
):
model_inputs = {
"input_ids": input_ids.to(self.device),
}
return model_inputs