HaileyStorm commited on
Commit
1230db0
1 Parent(s): 1401d32

Fixed early-stopping in get_mamba_response based on space/dot tokens (now decodes the strings instead of using hardcoded token ids).

Browse files
Files changed (1) hide show
  1. chess-gpt-eval/mamba_module.py +5 -2
chess-gpt-eval/mamba_module.py CHANGED
@@ -81,6 +81,8 @@ class MambaPlayer:
81
  self.vocab_size = vocab_size
82
  self.encode = encode
83
  self.decode = decode
 
 
84
  self.model = model
85
  self.ctx = ctx
86
  self.device = device
@@ -107,8 +109,9 @@ class MambaPlayer:
107
 
108
  probs = torch.nn.functional.softmax(logits, dim=-1)
109
  next_token_id = torch.multinomial(probs, num_samples=1)
110
- if have_non_space and (next_token_id == 0 or next_token_id==4):
111
- break
 
112
  else:
113
  have_non_space = True
114
  input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)
 
81
  self.vocab_size = vocab_size
82
  self.encode = encode
83
  self.decode = decode
84
+ self.space_tok = encode(' ')[0]
85
+ self.dot_tok = encode('.')[0]
86
  self.model = model
87
  self.ctx = ctx
88
  self.device = device
 
109
 
110
  probs = torch.nn.functional.softmax(logits, dim=-1)
111
  next_token_id = torch.multinomial(probs, num_samples=1)
112
+ if next_token_id == self.space_tok or next_token_id==self.dot_tok:
113
+ if have_non_space:
114
+ break
115
  else:
116
  have_non_space = True
117
  input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)