File size: 3,236 Bytes
313445a
 
72998f1
 
 
313445a
 
4a50f91
313445a
 
 
4a02647
 
313445a
 
 
4a02647
313445a
 
 
 
b02b814
72998f1
1e7240f
b02b814
313445a
 
3efeb97
313445a
0671ff1
7763e11
799a947
b02b814
 
4a50f91
 
b02b814
0671ff1
b02b814
0c203c2
 
 
 
b02b814
 
 
 
 
 
 
799a947
 
4a02647
 
 
b836b8f
799a947
 
b02b814
 
0671ff1
 
 
799a947
ded9643
b02b814
 
 
1e7240f
 
 
 
 
 
 
 
 
b02b814
 
 
3efeb97
0671ff1
b02b814
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import random
from datetime import datetime

import gradio as gr
import chess
import chess.svg
from transformers import DebertaV2ForSequenceClassification, AutoTokenizer, pipeline

token = os.environ['auth_token']

tokenizer = AutoTokenizer.from_pretrained('jrahn/chessv6', use_auth_token=token)
model = DebertaV2ForSequenceClassification.from_pretrained('jrahn/chessv6', use_auth_token=token)
pipe = pipeline(task="text-classification", model=model, tokenizer=tokenizer)

def predict_move(fen, top_k=3):
    preds = pipe(fen, top_k=top_k)
    weights = [p['score'] for p in preds]
    p = random.choices(preds, weights=weights)[0]
    return p['label']

def btn_load(inp_fen):
    print(f'** log - load - ts {datetime.now().isoformat()}, fen: {inp_fen}')
    board = chess.Board()
    
    with open('board.svg', 'w') as f:
        f.write(str(chess.svg.board(board)))
    return 'board.svg', board.fen(), ''

def btn_play(inp_fen, inp_move, inp_notation, inp_k):
    print(f'** log - play - ts {datetime.now().isoformat()}, fen: {inp_fen}, move: {inp_move}, notation: {inp_notation}, top_k: {inp_k}')
    board = chess.Board(inp_fen)
    
    if inp_move:
        if inp_notation == 'UCI': mv = chess.Move.from_uci(inp_move)
        elif inp_notation == 'SAN': mv = board.parse_san(inp_move)
    else:
        mv = chess.Move.from_uci(predict_move(board.fen(), top_k=inp_k))
    
    if mv in board.legal_moves:
        board.push(mv)
    else:
        raise ValueError(f'Illegal Move: {str(mv)} @ {board.fen()}')
        
    with open('board.svg', 'w') as f:
        f.write(str(chess.svg.board(board, lastmove=mv)))
    
    return 'board.svg', board.fen(), ''

with gr.Blocks() as block:
    gr.Markdown(
    '''
    # Play YoloChess - Policy Network v0.6
    87M Parameter Transformer (DeBERTaV2-base architecture)  
    - pre-trained (MLM) from scratch on chess positions in FEN notation
    - fine-tuned for text classification (moves) on expert games.  
    '''
    )
    with gr.Row() as row:
        with gr.Column():
            with gr.Row():
                move = gr.Textbox(label='human player move')
                notation = gr.Radio(["SAN", "UCI"], value="SAN", label='move notation')
            fen = gr.Textbox(value=chess.Board().fen(), label='FEN')
            top_k = gr.Number(value=3, label='sample from top_k moves', precision=0)
            with gr.Row():
                load_btn = gr.Button("Load")
                play_btn = gr.Button("Play")
            gr.Markdown(
                '''
                - Click "Load" button to start and reset board.  
                - Click "Play" button to get Engine move.  
                - Enter a "human player move" in UCI or SAN notation and click "Play" to move a piece.  
                - Output "ERROR" generally occurs on illegal moves (Human or Engine).
                - Enter "FEN" to start from a custom position.
                '''
            )
        with gr.Column():
            position_output = gr.Image(label='board')

    load_btn.click(fn=btn_load, inputs=fen, outputs=[position_output, fen, move])
    play_btn.click(fn=btn_play, inputs=[fen, move, notation, top_k], outputs=[position_output, fen, move])
    

block.launch()