from typing import Optional from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel from transformers import GenerationMixin, PreTrainedModel from transformers.generation import TextStreamer from .configuration_mamba import MambaConfig class MambaModel(PreTrainedModel): config_class = MambaConfig def __init__( self, config, initializer_cfg=None, device=None, dtype=None, **kwargs, ): super().__init__( config, **kwargs, ) self.model = MambaLMHeadModel( config, initializer_cfg=initializer_cfg, device=device, dtype=dtype, ) def forward( self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **kwargs, ): return self.model.forward( input_ids, position_ids, inference_params, num_last_tokens ) class MambaModelForCausalLM(MambaModel, GenerationMixin): def generate( self, input_ids, max_length, top_k=1, top_p=0.0, temperature=1.0, return_dict_in_generate=False, output_scores=False, repetition_penalty=1.0, eos_token_id=None, teacher_outputs=None, vocab_size=None, cg=False, enable_timing=False, streamer: Optional[TextStreamer] = None, **kwargs, ): return self.model.generate( input_ids=input_ids, max_length=max_length, top_k=top_k, top_p=top_p, temperature=temperature, return_dict_in_generate=return_dict_in_generate, output_scores=output_scores, repetition_penalty=repetition_penalty, eos_token_id=eos_token_id, teacher_outputs=teacher_outputs, vocab_size=vocab_size, cg=cg, enable_timing=enable_timing, streamer = streamer, )