jrahn commited on
Commit
4a02647
1 Parent(s): ded9643

update to chessv6 sl policy model

Browse files
Files changed (1) hide show
  1. app.py +6 -32
app.py CHANGED
@@ -9,40 +9,14 @@ from transformers import DebertaV2ForSequenceClassification, AutoTokenizer, pipe
9
 
10
  token = os.environ['auth_token']
11
 
12
- tokenizer = AutoTokenizer.from_pretrained('jrahn/chessv3', use_auth_token=token)
13
- model = DebertaV2ForSequenceClassification.from_pretrained('jrahn/chessv4', use_auth_token=token)
14
  pipe = pipeline(task="text-classification", model=model, tokenizer=tokenizer)
15
 
16
- empty_field = '0'
17
- board_split = ' | '
18
- nums = {str(n): empty_field * n for n in range(1, 9)}
19
- nums_rev = {v:k for k,v in reversed(nums.items())}
20
-
21
-
22
- def encode_fen(fen):
23
- # decompress fen representation
24
- # prepare for sub-word tokenization
25
- fen_board, fen_rest = fen.split(' ', 1)
26
- for n in nums:
27
- fen_board = fen_board.replace(n, nums[n])
28
- fen_board = '+' + fen_board
29
- fen_board = fen_board.replace('/', ' +')
30
- return board_split.join([fen_board, fen_rest])
31
-
32
- def decode_fen_repr(fen_repr):
33
- fen_board, fen_rest = fen_repr.split(board_split, 1)
34
- for n in nums_rev:
35
- fen_board = fen_board.replace(n, nums_rev[n])
36
- fen_board = fen_board.replace(' +', '/')
37
- fen_board = fen_board.replace('+', '')
38
- return ' '.join([fen_board, fen_rest])
39
-
40
  def predict_move(fen, top_k=3):
41
- fen_prep = encode_fen(fen)
42
- preds = pipe(fen_prep, top_k=top_k)
43
  weights = [p['score'] for p in preds]
44
  p = random.choices(preds, weights=weights)[0]
45
- # discard illegal moves (https://python-chess.readthedocs.io/en/latest/core.html#chess.Board.legal_moves), then select top_k
46
  return p['label']
47
 
48
  def btn_load(inp_fen):
@@ -76,9 +50,9 @@ def btn_play(inp_fen, inp_move, inp_notation, inp_k):
76
  with gr.Blocks() as block:
77
  gr.Markdown(
78
  '''
79
- # Play YoloChess - Policy Network v0.4
80
- 110M Parameter Transformer (DeBERTaV2-base architecture)
81
- - pre-trained (MLM) from scratch on chess positions in modified FEN notation
82
  - fine-tuned for text classification (moves) on expert games.
83
  '''
84
  )
 
9
 
10
  token = os.environ['auth_token']
11
 
12
+ tokenizer = AutoTokenizer.from_pretrained('jrahn/chessv6', use_auth_token=token)
13
+ model = DebertaV2ForSequenceClassification.from_pretrained('jrahn/chessv6', use_auth_token=token)
14
  pipe = pipeline(task="text-classification", model=model, tokenizer=tokenizer)
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def predict_move(fen, top_k=3):
17
+ preds = pipe(fen, top_k=top_k)
 
18
  weights = [p['score'] for p in preds]
19
  p = random.choices(preds, weights=weights)[0]
 
20
  return p['label']
21
 
22
  def btn_load(inp_fen):
 
50
  with gr.Blocks() as block:
51
  gr.Markdown(
52
  '''
53
+ # Play YoloChess - Policy Network v0.6
54
+ 87M Parameter Transformer (DeBERTaV2-base architecture)
55
+ - pre-trained (MLM) from scratch on chess positions in FEN notation
56
  - fine-tuned for text classification (moves) on expert games.
57
  '''
58
  )