switch model deberta
Browse files
app.py
CHANGED
@@ -5,12 +5,12 @@ from datetime import datetime
|
|
5 |
import gradio as gr
|
6 |
import chess
|
7 |
import chess.svg
|
8 |
-
from transformers import
|
9 |
|
10 |
token = os.environ['auth_token']
|
11 |
|
12 |
tokenizer = AutoTokenizer.from_pretrained('jrahn/chessv3', use_auth_token=token)
|
13 |
-
model =
|
14 |
pipe = pipeline(task="text-classification", model=model, tokenizer=tokenizer)
|
15 |
|
16 |
empty_field = '0'
|
@@ -58,8 +58,8 @@ def btn_play(inp_fen, inp_move, inp_notation, inp_k):
|
|
58 |
board = chess.Board(inp_fen)
|
59 |
|
60 |
if inp_move:
|
61 |
-
if inp_notation == 'UCI': mv = chess.Move.from_uci(inp_move)
|
62 |
-
elif inp_notation == 'SAN': mv = board.parse_san(inp_move)
|
63 |
else:
|
64 |
mv = chess.Move.from_uci(predict_move(board.fen(), top_k=inp_k))
|
65 |
|
@@ -76,8 +76,10 @@ def btn_play(inp_fen, inp_move, inp_notation, inp_k):
|
|
76 |
with gr.Blocks() as block:
|
77 |
gr.Markdown(
|
78 |
'''
|
79 |
-
# Play YoloChess - Policy Network v0.
|
80 |
-
110M Parameter Transformer (
|
|
|
|
|
81 |
'''
|
82 |
)
|
83 |
with gr.Row() as row:
|
|
|
5 |
import gradio as gr
|
6 |
import chess
|
7 |
import chess.svg
|
8 |
+
from transformers import DebertaV2ForSequenceClassification, AutoTokenizer, pipeline
|
9 |
|
10 |
token = os.environ['auth_token']
|
11 |
|
12 |
tokenizer = AutoTokenizer.from_pretrained('jrahn/chessv3', use_auth_token=token)
|
13 |
+
model = DebertaV2ForSequenceClassification.from_pretrained('jrahn/chessv4', use_auth_token=token)
|
14 |
pipe = pipeline(task="text-classification", model=model, tokenizer=tokenizer)
|
15 |
|
16 |
empty_field = '0'
|
|
|
58 |
board = chess.Board(inp_fen)
|
59 |
|
60 |
if inp_move:
|
61 |
+
if inp_notation == 'UCI': mv = chess.Move.from_uci(inp_move)
|
62 |
+
elif inp_notation == 'SAN': mv = board.parse_san(inp_move)
|
63 |
else:
|
64 |
mv = chess.Move.from_uci(predict_move(board.fen(), top_k=inp_k))
|
65 |
|
|
|
76 |
with gr.Blocks() as block:
|
77 |
gr.Markdown(
|
78 |
'''
|
79 |
+
# Play YoloChess - Policy Network v0.4
|
80 |
+
110M Parameter Transformer (DeBERTaV2-base architecture)
|
81 |
+
- pre-trained (MLM) from scratch on FENs
|
82 |
+
- fine-tuned for text classification expert games in modified FEN notation.
|
83 |
'''
|
84 |
)
|
85 |
with gr.Row() as row:
|