mamba-1.4b / modeling_mamba.py
mjschock's picture
Upload model
50d41cf verified
raw
history blame contribute delete
No virus
2.26 kB
from typing import Optional, Tuple
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
import torch
from transformers import GenerationMixin, PreTrainedModel
from transformers.generation import TextStreamer
from .configuration_mamba import MambaConfig
class MambaModel(PreTrainedModel):
config_class = MambaConfig
def __init__(
self,
config,
initializer_cfg=None,
device=None,
dtype=None,
**kwargs,
):
super().__init__(
config,
**kwargs,
)
self.model = MambaLMHeadModel(
config,
initializer_cfg=initializer_cfg,
device=device,
dtype=dtype,
)
def forward(
self,
input_ids,
position_ids=None,
inference_params=None,
num_last_tokens=0,
**kwargs,
):
return self.model.forward(
input_ids,
position_ids,
inference_params,
num_last_tokens
)
class MambaModelForCausalLM(MambaModel, GenerationMixin):
def generate(
self,
input_ids,
max_length: int = 2048,
top_k: int = 1,
top_p: float = 0.0,
temperature: float = 1.0,
return_dict_in_generate: bool = False,
output_scores: bool = False,
repetition_penalty: float = 1.0,
eos_token_id: Optional[int] = None,
teacher_outputs: Optional[torch.Tensor] = None,
vocab_size: Optional[int] = None,
cg: bool = False,
enable_timing: bool = False,
streamer: Optional[TextStreamer] = None,
**kwargs,
):
return self.model.generate(
input_ids=input_ids,
max_length=max_length,
top_k=top_k,
top_p=top_p,
temperature=temperature,
return_dict_in_generate=return_dict_in_generate,
output_scores=output_scores,
repetition_penalty=repetition_penalty,
eos_token_id=eos_token_id,
teacher_outputs=teacher_outputs,
vocab_size=vocab_size,
cg=cg,
enable_timing=enable_timing,
streamer=streamer,
)