File size: 6,615 Bytes
f8519d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17dd4f5
f8519d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bb21ac
f8519d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bb21ac
 
 
 
 
 
 
f8519d1
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
Sample from a trained model
"""
import os
import pickle
from contextlib import nullcontext
import torch
import tiktoken
from nanogpt.model import GPTConfig, GPT

BASE_DIR = "nanogpt/"


class NanoGptPlayer:
    def __init__(self, model_name: str, move_num_in_gamestate: bool=False):
        self.model_name = model_name
        # -----------------------------------------------------------------------------

        init_from = "resume"  # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
        out_dir = "out"  # ignored if init_from is not 'resume'
        input_dir = "addition"
        test_name = "test.txt"
        start = "12+44="  # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
        num_samples = 1  # number of samples to draw
        max_new_tokens = 6  # number of tokens generated in each sample
        temperature = 0.01  # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
        top_k = 200  # retain only the top_k most likely tokens, clamp others to have 0 probability
        seed = 1337
        device = "cuda"  # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
        #device = "cpu"
        dtype = "float16"  # 'float32' or 'bfloat16' or 'float16'
        compile = False  # use PyTorch 2.0 to compile the model to be faster
        exec(
            open(f"{BASE_DIR}configurator.py").read()
        )  # overrides from command line or config file
        # -----------------------------------------------------------------------------

        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cuda.matmul.allow_tf32 = True  # allow tf32 on matmul
        torch.backends.cudnn.allow_tf32 = True  # allow tf32 on cudnn
        device_type = (
            "cuda" if "cuda" in device else "cpu"
        )  # for later use in torch.autocast
        ptdtype = {
            "float32": torch.float32,
            "bfloat16": torch.bfloat16,
            "float16": torch.float16,
        }[dtype]
        ctx = (
            nullcontext()
            if device_type == "cpu"
            else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
        )

        # model
        if init_from == "resume":
            # init from a model saved in a specific directory
            #ckpt_path = os.path.join(BASE_DIR, out_dir, self.model_name)
            ckpt_path = os.path.normpath(f"../chess-mamba-vs-xformer/out/Xformer/{self.model_name}")
            checkpoint = torch.load(ckpt_path, map_location=device)
            #gptconf = GPTConfig(**checkpoint["model_args"])
            #model = GPT(gptconf)
            model = GPT(checkpoint["model_args"])
            state_dict = checkpoint["model"]
            unwanted_prefix = "_orig_mod."
            for k, v in list(state_dict.items()):
                if k.startswith(unwanted_prefix):
                    state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
            model.load_state_dict(state_dict)
        elif init_from.startswith("gpt2"):
            # init from a given GPT-2 model
            model = GPT.from_pretrained(init_from, dict(dropout=0.0))

        model.eval()
        model.to(device)
        if compile:
            model = torch.compile(model)  # requires PyTorch 2.0 (optional)

        # look for the meta pickle in case it is available in the dataset folder
        meta_path = os.path.join(BASE_DIR, "out", "meta.pkl")
        load_meta = os.path.exists(meta_path)
        if move_num_in_gamestate and load_meta:
            with open(meta_path, "rb") as f:
                meta = pickle.load(f)
            stoi, itos = meta["stoi"], meta["itos"]
            vocab_size = meta['vocab_size']
            encode = lambda s: [stoi[c] for c in s]
            decode = lambda l: "".join([itos[i] for i in l])
        else:
            stoi = {' ': 0, '.': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, '1': 10, '2': 11, '3': 12, '4': 13, '5': 14, '6': 15, '7': 16, '8': 17, 'B': 18, 'N': 19, 'R': 20, 'Q': 21, 'K': 22, 'O': 23, 'x': 24, '+': 25, '#': 26, '=': 27}
            itos = {0: ' ', 1: '.', 2: 'a', 3: 'b', 4: 'c', 5: 'd', 6: 'e', 7: 'f', 8: 'g', 9: 'h', 10: '1', 11: '2', 12: '3', 13: '4', 14: '5', 15: '6', 16: '7', 17: '8', 18: 'B', 19: 'N', 20: 'R', 21: 'Q', 22: 'K', 23: 'O', 24: 'x', 25: '+', 26: '#', 27: '='}
            for s in stoi:
                assert itos[stoi[s]] == s
            vocab_size = len(stoi)
            print(f"Vocab size {vocab_size}")
            encode = lambda s: [stoi[c] for c in s.replace('-', '')]
            decode = lambda l: "".join([itos[i] for i in l if i < vocab_size]).replace("OOO", "O-O-O").replace("OO", "O-O")

        self.encode = encode
        self.decode = decode
        self.model = model
        self.ctx = ctx
        self.device = device

    def get_nanogpt_response(self, game_state: str, temperature: float) -> str:
        num_samples = 1  # number of samples to draw
        top_k = 200  # retain only the top_k most likely tokens, clamp others to have 0 probability
        max_new_tokens = 8

        # Remove ["stockfish elo xxx"]\n["stockfish elo xxx"]\n\n from game_state
        # nanogpt was trained only on pgn transcripts
        game_state = game_state.split("\n\n")[-1].strip()

        # print("game_state", game_state)

        #game_state = ";" + game_state

        start_ids = self.encode(game_state)

        x = torch.tensor(start_ids, dtype=torch.long, device=self.device)[None, ...]
        with torch.no_grad():
            with self.ctx:
                for k in range(num_samples):
                    y = self.model.generate(
                        x, max_new_tokens, temperature=temperature, top_k=top_k
                    )

                    model_response = self.decode(y[0].tolist())

        # print("model_response", model_response)
        # model_response includes the input string
        model_response = model_response[len(game_state):].split(";")[0]
        return model_response

    def get_move_from_response(self, response: str) -> str:
        try:
            # Parse the response to get only the first move
            moves = response.split()
            first_move = moves[0]
            return first_move
        except:
            return None


    def get_move(self, board: str, game_state: str, temperature: float) -> str:
        completion = self.get_nanogpt_response(game_state, temperature)
        return self.get_move_from_response(completion)

    def get_config(self) -> dict:
        return {"model": self.model_name}