File size: 6,307 Bytes
816f85a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
import pickle
import torch
from mamba_lm import MambaLM, MambaLMConfig, from_pretrained
from contextlib import nullcontext

BASE_DIR = "mamba/"

class MambaPlayer:
    def __init__(self, model_name: str):
        self.model_name = model_name
        # -----------------------------------------------------------------------------

        init_from = "resume"  # either 'resume' or a Mamba variant (e.g. 'state-spaces/mamba-1.4b')
        move_num_in_gamestate = True
        out_dir = "out"  # ignored if init_from is not 'resume'
        device = "cuda" if torch.cuda.is_available() else "cpu"
        #device = "cpu"
        dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float32'
        seed = 1337
        compile = False  # set to True if using PyTorch 2.0 and Mamba supports it
        # -----------------------------------------------------------------------------

        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        
        device_type = (
            "cuda" if "cuda" in device else "cpu"
        )  # for later use in torch.autocast
        ptdtype = {
            "float32": torch.float32,
            "bfloat16": torch.bfloat16,
            "float16": torch.float16,
        }[dtype]
        ctx = (
            nullcontext()
            if device_type == "cpu"
            else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
        )

        # Model initialization
        if init_from == "resume":
            #ckpt_path = os.path.join(BASE_DIR, out_dir, self.model_name)
            ckpt_path = os.path.normpath(f"../../mamba.py/out/{self.model_name}")
            checkpoint = torch.load(ckpt_path, map_location=device)
            model_config = checkpoint["model_args"]
            model = MambaLM(model_config)
            model.load_state_dict(checkpoint['model'])
        elif init_from.startswith('state-spaces'):
            model = from_pretrained(init_from).to(device)
        else:
            raise ValueError("Invalid init_from value")

        model.eval()
        model.to(device)

        if compile and hasattr(torch, 'compile'):
            model = torch.compile(model)

        # look for the meta pickle in case it is available in the dataset folder
        meta_path = os.path.join(BASE_DIR, "out", "meta.pkl")
        load_meta = os.path.exists(meta_path)
        if move_num_in_gamestate and load_meta:
            with open(meta_path, "rb") as f:
                meta = pickle.load(f)
            stoi, itos = meta["stoi"], meta["itos"]
            vocab_size = meta['vocab_size']
            encode = lambda s: [stoi[c] for c in s]
            decode = lambda l: "".join([itos[i] for i in l])
        else:
            stoi = {' ': 0, '.': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, '1': 10, '2': 11, '3': 12, '4': 13, '5': 14, '6': 15, '7': 16, '8': 17, 'B': 18, 'N': 19, 'R': 20, 'Q': 21, 'K': 22, 'O': 23, 'x': 24, '+': 25, '#': 26, '=': 27}
            itos = {0: ' ', 1: '.', 2: 'a', 3: 'b', 4: 'c', 5: 'd', 6: 'e', 7: 'f', 8: 'g', 9: 'h', 10: '1', 11: '2', 12: '3', 13: '4', 14: '5', 15: '6', 16: '7', 17: '8', 18: 'B', 19: 'N', 20: 'R', 21: 'Q', 22: 'K', 23: 'O', 24: 'x', 25: '+', 26: '#', 27: '='}
            for s in stoi:
                assert itos[stoi[s]] == s
            vocab_size = len(stoi)
            print(f"Vocab size {vocab_size}")
            encode = lambda s: [stoi[c] for c in s.replace('-', '')]
            decode = lambda l: "".join([itos[i] for i in l]).replace("OOO", "O-O-O").replace("OO", "O-O")

        self.encode = encode
        self.decode = decode
        self.model = model
        self.ctx = ctx
        self.device = device

    def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
        game_state = game_state.split("\n\n")[-1].strip()
        #game_state = ";" + game_state

        # Tokenize the game state
        encoded_prompt = self.encode(game_state)
        input_ids = torch.tensor([encoded_prompt], dtype=torch.long, device=self.device)

        self.model.eval()  # Set the model to evaluation mode
        with torch.no_grad():
            have_non_space = False
            for _ in range(max_new_tokens):
                logits = self.model(input_ids)[0, -1, :]  # Get logits for the last token

                # Apply temperature scaling and optionally sample from top k tokens
                logits = logits / temperature
                if top_k > 0:
                    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
                    logits[indices_to_remove] = -float('Inf')

                probs = torch.nn.functional.softmax(logits, dim=-1)
                next_token_id = torch.multinomial(probs, num_samples=1)
                if have_non_space and (next_token_id == 0 or next_token_id==4):
                    break
                else:
                    have_non_space = True
                input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)

        model_response = self.decode(input_ids[0].tolist())
        model_response = model_response[len(game_state):].split(";")[0]
        return model_response

    #def encode(self, text: str):
        # Implement the appropriate tokenization for MambaLM
        # This could be a simple mapping or a more complex tokenizer
    #    return [stoi[char] for char in text]  # Example

    #def decode(self, token_ids: list):
        # Implement the appropriate decoding for MambaLM
     #   return ''.join([itos[id] for id in token_ids])  # Example
        
    def get_move_from_response(self, response: str) -> str:
        if not response:
            return None
        # Parse the response to get only the first move
        moves = response.split()
        first_move = moves[0]
        first_move = first_move.lstrip('.') # A patch for a weird phase during training ... doesn't seem to be an issue anymore, but don't see the harm.

        return first_move

    def get_move(self, board: str, game_state: str, temperature: float) -> str:
        completion = self.get_mamba_response(game_state, temperature, 8, 32)
        return self.get_move_from_response(completion)

    def get_config(self) -> dict:
        return {"model": self.model_name}