jrahn commited on
Commit
4a50f91
1 Parent(s): 7763e11

switch model deberta

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -5,12 +5,12 @@ from datetime import datetime
5
  import gradio as gr
6
  import chess
7
  import chess.svg
8
- from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
9
 
10
  token = os.environ['auth_token']
11
 
12
  tokenizer = AutoTokenizer.from_pretrained('jrahn/chessv3', use_auth_token=token)
13
- model = AutoModelForSequenceClassification.from_pretrained('jrahn/chessv3', use_auth_token=token)
14
  pipe = pipeline(task="text-classification", model=model, tokenizer=tokenizer)
15
 
16
  empty_field = '0'
@@ -58,8 +58,8 @@ def btn_play(inp_fen, inp_move, inp_notation, inp_k):
58
  board = chess.Board(inp_fen)
59
 
60
  if inp_move:
61
- if inp_notation == 'UCI': mv = chess.Move.from_uci(inp_move) #board.push_uci(inp_move)
62
- elif inp_notation == 'SAN': mv = board.parse_san(inp_move) #chess.Move.from_san(inp_move) #board.push_san(inp_move)
63
  else:
64
  mv = chess.Move.from_uci(predict_move(board.fen(), top_k=inp_k))
65
 
@@ -76,8 +76,10 @@ def btn_play(inp_fen, inp_move, inp_notation, inp_k):
76
  with gr.Blocks() as block:
77
  gr.Markdown(
78
  '''
79
- # Play YoloChess - Policy Network v0.3
80
- 110M Parameter Transformer (BERT-base architecture) trained for text classification from scratch on expert games in modified FEN notation.
 
 
81
  '''
82
  )
83
  with gr.Row() as row:
 
5
  import gradio as gr
6
  import chess
7
  import chess.svg
8
+ from transformers import DebertaV2ForSequenceClassification, AutoTokenizer, pipeline
9
 
10
  token = os.environ['auth_token']
11
 
12
  tokenizer = AutoTokenizer.from_pretrained('jrahn/chessv3', use_auth_token=token)
13
+ model = DebertaV2ForSequenceClassification.from_pretrained('jrahn/chessv4', use_auth_token=token)
14
  pipe = pipeline(task="text-classification", model=model, tokenizer=tokenizer)
15
 
16
  empty_field = '0'
 
58
  board = chess.Board(inp_fen)
59
 
60
  if inp_move:
61
+ if inp_notation == 'UCI': mv = chess.Move.from_uci(inp_move)
62
+ elif inp_notation == 'SAN': mv = board.parse_san(inp_move)
63
  else:
64
  mv = chess.Move.from_uci(predict_move(board.fen(), top_k=inp_k))
65
 
 
76
  with gr.Blocks() as block:
77
  gr.Markdown(
78
  '''
79
+ # Play YoloChess - Policy Network v0.4
80
+ 110M Parameter Transformer (DeBERTaV2-base architecture)
81
+ - pre-trained (MLM) from scratch on FENs
82
+ - fine-tuned for text classification expert games in modified FEN notation.
83
  '''
84
  )
85
  with gr.Row() as row: