HaileyStorm commited on
Commit
b32cef0
1 Parent(s): 27f8947

Upload chess-mamba-vs-xformer/mamba_lm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. chess-mamba-vs-xformer/mamba_lm.py +14 -7
chess-mamba-vs-xformer/mamba_lm.py CHANGED
@@ -5,7 +5,10 @@ import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
 
8
- from mamba import Mamba, MambaConfig, RMSNorm
 
 
 
9
 
10
  """
11
 
@@ -22,15 +25,18 @@ class MambaLMConfig(MambaConfig):
22
  pad_vocab_size_multiple: int = 8
23
 
24
  def __post_init__(self):
25
- super().__post_init__()
 
26
 
27
  #if self.vocab_size % self.pad_vocab_size_multiple != 0:
28
  # self.vocab_size += (self.pad_vocab_size_multiple - self.vocab_size % self.pad_vocab_size_multiple)
29
 
30
  def to_mamba_config(self) -> MambaConfig:
31
- mamba_config_fields = {field.name for field in fields(MambaConfig)}
32
- filtered_dict = {k: v for k, v in asdict(self).items() if k in mamba_config_fields}
33
- return MambaConfig(**filtered_dict)
 
 
34
 
35
  # adapted from https://github.com/johnma2006/mamba-minimal
36
  def from_pretrained(name: str):
@@ -65,7 +71,8 @@ def from_pretrained(name: str):
65
  config_data = load_config_hf(name)
66
  config = MambaLMConfig(d_model=config_data['d_model'], n_layers=config_data['n_layer'], vocab_size=config_data['vocab_size'])
67
 
68
- model = MambaLM(config)
 
69
 
70
  # copy weights
71
  state_dict = load_state_dict_hf(name)
@@ -90,7 +97,7 @@ class MambaLM(nn.Module):
90
  self.config = lm_config.to_mamba_config()
91
 
92
  self.embedding = nn.Embedding(self.lm_config.vocab_size, self.config.d_model)
93
- self.mamba = Mamba(self.config)
94
  self.norm_f = RMSNorm(self.config.d_model)
95
 
96
  self.lm_head = nn.Linear(self.config.d_model, self.lm_config.vocab_size, bias=False)
 
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
 
8
+ #from mamba import Mamba, MambaConfig, RMSNorm
9
+ from mamba_ssm import MambaLMHeadModel
10
+ from mamba_ssm.models.config_mamba import MambaConfig
11
+ from mamba_ssm.ops.triton.layernorm import RMSNorm
12
 
13
  """
14
 
 
25
  pad_vocab_size_multiple: int = 8
26
 
27
  def __post_init__(self):
28
+ pass
29
+ #super().__post_init__()
30
 
31
  #if self.vocab_size % self.pad_vocab_size_multiple != 0:
32
  # self.vocab_size += (self.pad_vocab_size_multiple - self.vocab_size % self.pad_vocab_size_multiple)
33
 
34
  def to_mamba_config(self) -> MambaConfig:
35
+ #mamba_config_fields = {field.name for field in fields(MambaConfig)}
36
+ #print(mamba_config_fields)
37
+ #filtered_dict = {k: v for k, v in asdict(self).items() if k in mamba_config_fields}
38
+ #return MambaConfig(**filtered_dict)
39
+ return MambaConfig(d_model=self.d_model, n_layer=self.n_layer, vocab_size=self.vocab_size, ssm_cfg=self.ssm_cfg)
40
 
41
  # adapted from https://github.com/johnma2006/mamba-minimal
42
  def from_pretrained(name: str):
 
71
  config_data = load_config_hf(name)
72
  config = MambaLMConfig(d_model=config_data['d_model'], n_layers=config_data['n_layer'], vocab_size=config_data['vocab_size'])
73
 
74
+ #model = MambaLM(config)
75
+ model = MambaLMHeadModel(config)
76
 
77
  # copy weights
78
  state_dict = load_state_dict_hf(name)
 
97
  self.config = lm_config.to_mamba_config()
98
 
99
  self.embedding = nn.Embedding(self.lm_config.vocab_size, self.config.d_model)
100
+ self.mamba = Mamba(**self.config.__dict__)
101
  self.norm_f = RMSNorm(self.config.d_model)
102
 
103
  self.lm_head = nn.Linear(self.config.d_model, self.lm_config.vocab_size, bias=False)