File size: 6,710 Bytes
f8519d1
 
 
07f1096
 
f8519d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07f1096
f8519d1
 
07f1096
f8519d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
584a1a1
f8519d1
 
 
 
1230db0
 
f8519d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07f1096
f8519d1
 
 
 
 
 
 
 
 
1230db0
 
 
f8519d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7dbf659
f8519d1
 
7dbf659
 
 
 
f8519d1
7dbf659
 
 
f8519d1
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os
import pickle
import torch
from mamba_lm import MambaLMConfig, from_pretrained
from mamba_ssm import MambaLMHeadModel
from contextlib import nullcontext

BASE_DIR = "mamba/"

class MambaPlayer:
    def __init__(self, model_name: str, move_num_in_gamestate: bool=False):
        self.model_name = model_name
        self.move_num_in_gamestate = move_num_in_gamestate
        # -----------------------------------------------------------------------------

        init_from = "resume"  # either 'resume' or a Mamba variant (e.g. 'state-spaces/mamba-1.4b')
        out_dir = "out"  # ignored if init_from is not 'resume'
        device = "cuda" if torch.cuda.is_available() else "cpu"
        #device = "cpu"
        dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float32'
        seed = 1337
        compile = False  # set to True if using PyTorch 2.0 and Mamba supports it
        # -----------------------------------------------------------------------------

        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        
        device_type = (
            "cuda" if "cuda" in device else "cpu"
        )  # for later use in torch.autocast
        ptdtype = {
            "float32": torch.float32,
            "bfloat16": torch.bfloat16,
            "float16": torch.float16,
        }[dtype]
        ctx = (
            nullcontext()
            if device_type == "cpu"
            else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
        )

        # Model initialization
        if init_from == "resume":
            #ckpt_path = os.path.join(BASE_DIR, out_dir, self.model_name)
            ckpt_path = os.path.normpath(f"../chess-mamba-vs-xformer/out/Mamba/{self.model_name}")
            checkpoint = torch.load(ckpt_path, map_location=device)
            model_config = checkpoint["model_args"]
            model = MambaLMHeadModel(model_config)
            model.load_state_dict(checkpoint['model'])
        elif init_from.startswith('state-spaces'):
            model = from_pretrained(init_from).to(device)
        else:
            raise ValueError("Invalid init_from value")

        model.eval()
        model.to(device)

        if compile and hasattr(torch, 'compile'):
            model = torch.compile(model)

        # look for the meta pickle in case it is available in the dataset folder
        meta_path = os.path.join(BASE_DIR, "out", "meta.pkl")
        load_meta = os.path.exists(meta_path)
        if move_num_in_gamestate and load_meta:
            with open(meta_path, "rb") as f:
                meta = pickle.load(f)
            stoi, itos = meta["stoi"], meta["itos"]
            vocab_size = meta['vocab_size']
            encode = lambda s: [stoi[c] for c in s]
            decode = lambda l: "".join([itos[i] for i in l])
        else:
            stoi = {' ': 0, '.': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, '1': 10, '2': 11, '3': 12, '4': 13, '5': 14, '6': 15, '7': 16, '8': 17, 'B': 18, 'N': 19, 'R': 20, 'Q': 21, 'K': 22, 'O': 23, 'x': 24, '+': 25, '#': 26, '=': 27}
            itos = {0: ' ', 1: '.', 2: 'a', 3: 'b', 4: 'c', 5: 'd', 6: 'e', 7: 'f', 8: 'g', 9: 'h', 10: '1', 11: '2', 12: '3', 13: '4', 14: '5', 15: '6', 16: '7', 17: '8', 18: 'B', 19: 'N', 20: 'R', 21: 'Q', 22: 'K', 23: 'O', 24: 'x', 25: '+', 26: '#', 27: '='}
            for s in stoi:
                assert itos[stoi[s]] == s
            vocab_size = len(stoi)
            print(f"Vocab size {vocab_size}")
            encode = lambda s: [stoi[c] for c in s.replace('-', '')]
            decode = lambda l: "".join([itos[i] for i in l if i < vocab_size]).replace("OOO", "O-O-O").replace("OO", "O-O")

        self.vocab_size = vocab_size
        self.encode = encode
        self.decode = decode
        self.space_tok = encode(' ')[0]
        self.dot_tok = encode('.')[0]
        self.model = model
        self.ctx = ctx
        self.device = device

    def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
        game_state = game_state.split("\n\n")[-1].strip()
        #game_state = ";" + game_state

        # Tokenize the game state
        encoded_prompt = self.encode(game_state)
        input_ids = torch.tensor([encoded_prompt], dtype=torch.long, device=self.device)

        self.model.eval()  # Set the model to evaluation mode
        with torch.no_grad():
            have_non_space = False
            for _ in range(max_new_tokens):
                logits = self.model(input_ids).logits[0, -1, :]  # Get logits for the last token

                # Apply temperature scaling and optionally sample from top k tokens
                logits = logits / temperature
                if top_k > 0:
                    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
                    logits[indices_to_remove] = -float('Inf')

                probs = torch.nn.functional.softmax(logits, dim=-1)
                next_token_id = torch.multinomial(probs, num_samples=1)
                if next_token_id == self.space_tok or next_token_id==self.dot_tok:
                    if have_non_space:
                        break
                else:
                    have_non_space = True
                input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)

        model_response = self.decode(input_ids[0].tolist())
        model_response = model_response[len(game_state):].split(";")[0]
        return model_response

    #def encode(self, text: str):
        # Implement the appropriate tokenization for MambaLM
        # This could be a simple mapping or a more complex tokenizer
    #    return [stoi[char] for char in text]  # Example

    #def decode(self, token_ids: list):
        # Implement the appropriate decoding for MambaLM
     #   return ''.join([itos[id] for id in token_ids])  # Example
        
    def get_move_from_response(self, response: str) -> str:
        if not response or len(response) == 0:
            return None
        # Parse the response to get only the first move
        try:
            moves = response.split()
            first_move = moves[0]
            first_move = first_move.lstrip('.') # A patch for a weird phase during training ... doesn't seem to be an issue anymore, but don't see the harm.

            return first_move
        except:
            return None

    def get_move(self, board: str, game_state: str, temperature: float) -> str:
        completion = self.get_mamba_response(game_state, temperature, 8, self.vocab_size)
        return self.get_move_from_response(completion)

    def get_config(self) -> dict:
        return {"model": self.model_name}