mjschock commited on
Commit
3546d63
1 Parent(s): 32f2d5a

Upload model

Browse files
Files changed (3) hide show
  1. config.json +6 -1
  2. model.safetensors +3 -0
  3. modeling_mamba.py +82 -0
config.json CHANGED
@@ -1,6 +1,10 @@
1
  {
 
 
 
2
  "auto_map": {
3
- "AutoConfig": "configuration_mamba.MambaConfig"
 
4
  },
5
  "d_model": 1024,
6
  "fused_add_norm": true,
@@ -10,6 +14,7 @@
10
  "residual_in_fp32": true,
11
  "rms_norm": true,
12
  "ssm_cfg": {},
 
13
  "transformers_version": "4.37.2",
14
  "vocab_size": 50277
15
  }
 
1
  {
2
+ "architectures": [
3
+ "MambaModel"
4
+ ],
5
  "auto_map": {
6
+ "AutoConfig": "configuration_mamba.MambaConfig",
7
+ "AutoModel": "modeling_mamba.MambaModel"
8
  },
9
  "d_model": 1024,
10
  "fused_add_norm": true,
 
14
  "residual_in_fp32": true,
15
  "rms_norm": true,
16
  "ssm_cfg": {},
17
+ "torch_dtype": "float16",
18
  "transformers_version": "4.37.2",
19
  "vocab_size": 50277
20
  }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:390cb2bce8ff33d561bdef7f428f79e13ff08a2255e2c5c1bfeb6c7412cf927a
3
+ size 746429904
modeling_mamba.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
4
+ import torch
5
+ from transformers import GenerationMixin, PreTrainedModel
6
+ from transformers.generation import TextStreamer
7
+
8
+ from .configuration_mamba import MambaConfig
9
+
10
+ class MambaModel(PreTrainedModel):
11
+ config_class = MambaConfig
12
+
13
+ def __init__(
14
+ self,
15
+ config,
16
+ initializer_cfg=None,
17
+ device=None,
18
+ dtype=None,
19
+ **kwargs,
20
+ ):
21
+ super().__init__(
22
+ config,
23
+ **kwargs,
24
+ )
25
+
26
+ self.model = MambaLMHeadModel(
27
+ config,
28
+ initializer_cfg=initializer_cfg,
29
+ device=device,
30
+ dtype=dtype,
31
+ )
32
+
33
+ def forward(
34
+ self,
35
+ input_ids,
36
+ position_ids=None,
37
+ inference_params=None,
38
+ num_last_tokens=0,
39
+ **kwargs,
40
+ ):
41
+ return self.model.forward(
42
+ input_ids,
43
+ position_ids,
44
+ inference_params,
45
+ num_last_tokens
46
+ )
47
+
48
+ class MambaModelForCausalLM(MambaModel, GenerationMixin):
49
+ def generate(
50
+ self,
51
+ input_ids,
52
+ max_length: int = 2048,
53
+ top_k: int = 1,
54
+ top_p: float = 0.0,
55
+ temperature: float = 1.0,
56
+ return_dict_in_generate: bool = False,
57
+ output_scores: bool = False,
58
+ repetition_penalty: float = 1.0,
59
+ eos_token_id: Optional[int] = None,
60
+ teacher_outputs: Optional[torch.Tensor] = None,
61
+ vocab_size: Optional[int] = None,
62
+ cg: bool = False,
63
+ enable_timing: bool = False,
64
+ streamer: Optional[TextStreamer] = None,
65
+ **kwargs,
66
+ ):
67
+ return self.model.generate(
68
+ input_ids=input_ids,
69
+ max_length=max_length,
70
+ top_k=top_k,
71
+ top_p=top_p,
72
+ temperature=temperature,
73
+ return_dict_in_generate=return_dict_in_generate,
74
+ output_scores=output_scores,
75
+ repetition_penalty=repetition_penalty,
76
+ eos_token_id=eos_token_id,
77
+ teacher_outputs=teacher_outputs,
78
+ vocab_size=vocab_size,
79
+ cg=cg,
80
+ enable_timing=enable_timing,
81
+ streamer=streamer,
82
+ )