HaileyStorm commited on
Commit
07f1096
1 Parent(s): 5e634b7

Upload chess-gpt-eval/mamba_module.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. chess-gpt-eval/mamba_module.py +5 -4
chess-gpt-eval/mamba_module.py CHANGED
@@ -1,7 +1,8 @@
1
  import os
2
  import pickle
3
  import torch
4
- from mamba_lm import MambaLM, MambaLMConfig, from_pretrained
 
5
  from contextlib import nullcontext
6
 
7
  BASE_DIR = "mamba/"
@@ -41,10 +42,10 @@ class MambaPlayer:
41
  # Model initialization
42
  if init_from == "resume":
43
  #ckpt_path = os.path.join(BASE_DIR, out_dir, self.model_name)
44
- ckpt_path = os.path.normpath(f"../../mamba.py/out/{self.model_name}")
45
  checkpoint = torch.load(ckpt_path, map_location=device)
46
  model_config = checkpoint["model_args"]
47
- model = MambaLM(model_config)
48
  model.load_state_dict(checkpoint['model'])
49
  elif init_from.startswith('state-spaces'):
50
  model = from_pretrained(init_from).to(device)
@@ -96,7 +97,7 @@ class MambaPlayer:
96
  with torch.no_grad():
97
  have_non_space = False
98
  for _ in range(max_new_tokens):
99
- logits = self.model(input_ids)[0, -1, :] # Get logits for the last token
100
 
101
  # Apply temperature scaling and optionally sample from top k tokens
102
  logits = logits / temperature
 
1
  import os
2
  import pickle
3
  import torch
4
+ from mamba_lm import MambaLMConfig, from_pretrained
5
+ from mamba_ssm import MambaLMHeadModel
6
  from contextlib import nullcontext
7
 
8
  BASE_DIR = "mamba/"
 
42
  # Model initialization
43
  if init_from == "resume":
44
  #ckpt_path = os.path.join(BASE_DIR, out_dir, self.model_name)
45
+ ckpt_path = os.path.normpath(f"../chess-mamba-vs-xformer/out/Mamba/{self.model_name}")
46
  checkpoint = torch.load(ckpt_path, map_location=device)
47
  model_config = checkpoint["model_args"]
48
+ model = MambaLMHeadModel(model_config)
49
  model.load_state_dict(checkpoint['model'])
50
  elif init_from.startswith('state-spaces'):
51
  model = from_pretrained(init_from).to(device)
 
97
  with torch.no_grad():
98
  have_non_space = False
99
  for _ in range(max_new_tokens):
100
+ logits = self.model(input_ids).logits[0, -1, :] # Get logits for the last token
101
 
102
  # Apply temperature scaling and optionally sample from top k tokens
103
  logits = logits / temperature