HaileyStorm
commited on
Commit
•
5e634b7
1
Parent(s):
6645dae
Upload chess-gpt-eval/mamba_lm.py with huggingface_hub
Browse files
chess-gpt-eval/mamba_lm.py
CHANGED
@@ -5,7 +5,8 @@ import torch
|
|
5 |
import torch.nn as nn
|
6 |
import torch.nn.functional as F
|
7 |
|
8 |
-
from mamba import Mamba, MambaConfig, RMSNorm
|
|
|
9 |
|
10 |
"""
|
11 |
|
@@ -65,7 +66,7 @@ def from_pretrained(name: str):
|
|
65 |
config_data = load_config_hf(name)
|
66 |
config = MambaLMConfig(d_model=config_data['d_model'], n_layers=config_data['n_layer'], vocab_size=config_data['vocab_size'])
|
67 |
|
68 |
-
model =
|
69 |
|
70 |
# copy weights
|
71 |
state_dict = load_state_dict_hf(name)
|
@@ -87,7 +88,7 @@ class MambaLM(nn.Module):
|
|
87 |
def __init__(self, lm_config: MambaLMConfig):
|
88 |
super().__init__()
|
89 |
self.lm_config = lm_config
|
90 |
-
self.config = lm_config
|
91 |
|
92 |
self.embedding = nn.Embedding(self.lm_config.vocab_size, self.config.d_model)
|
93 |
self.mamba = Mamba(self.config)
|
|
|
5 |
import torch.nn as nn
|
6 |
import torch.nn.functional as F
|
7 |
|
8 |
+
from mamba import MambaConfig #Mamba, MambaConfig, RMSNorm
|
9 |
+
from mamba_ssm import MambaLMHeadModel
|
10 |
|
11 |
"""
|
12 |
|
|
|
66 |
config_data = load_config_hf(name)
|
67 |
config = MambaLMConfig(d_model=config_data['d_model'], n_layers=config_data['n_layer'], vocab_size=config_data['vocab_size'])
|
68 |
|
69 |
+
model = MambaLMHeadModel(config)
|
70 |
|
71 |
# copy weights
|
72 |
state_dict = load_state_dict_hf(name)
|
|
|
88 |
def __init__(self, lm_config: MambaLMConfig):
|
89 |
super().__init__()
|
90 |
self.lm_config = lm_config
|
91 |
+
self.config = lm_config#.to_mamba_config()
|
92 |
|
93 |
self.embedding = nn.Embedding(self.lm_config.vocab_size, self.config.d_model)
|
94 |
self.mamba = Mamba(self.config)
|