from typing import Optional, Tuple from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel import torch from transformers import GenerationMixin, PreTrainedModel from transformers.generation import TextStreamer from .configuration_mamba import MambaConfig class MambaModel(PreTrainedModel): config_class = MambaConfig def __init__( self, config, initializer_cfg=None, device=None, dtype=None, **kwargs, ): super().__init__( config, **kwargs, ) self.model = MambaLMHeadModel( config, initializer_cfg=initializer_cfg, device=device, dtype=dtype, ) def forward( self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **kwargs, ): return self.model.forward( input_ids, position_ids, inference_params, num_last_tokens ) class MambaModelForCausalLM(MambaModel, GenerationMixin): def generate( self, input_ids, max_length: int = 2048, top_k: int = 1, top_p: float = 0.0, temperature: float = 1.0, return_dict_in_generate: bool = False, output_scores: bool = False, repetition_penalty: float = 1.0, eos_token_id: Optional[int] = None, teacher_outputs: Optional[torch.Tensor] = None, vocab_size: Optional[int] = None, cg: bool = False, enable_timing: bool = False, streamer: Optional[TextStreamer] = None, **kwargs, ): return self.model.generate( input_ids=input_ids, max_length=max_length, top_k=top_k, top_p=top_p, temperature=temperature, return_dict_in_generate=return_dict_in_generate, output_scores=output_scores, repetition_penalty=repetition_penalty, eos_token_id=eos_token_id, teacher_outputs=teacher_outputs, vocab_size=vocab_size, cg=cg, enable_timing=enable_timing, streamer=streamer, )