|
--- |
|
license: mit |
|
datasets: |
|
- Xmaster6y/stockfish-debug |
|
name: Xmaster6y/gpt2-stockfish-debug |
|
results: |
|
- task: train |
|
metrics: |
|
- name: train-loss |
|
type: loss |
|
value: 0.151 |
|
verified: false |
|
- name: eval-loss |
|
type: loss |
|
value: 0.138 |
|
verified: false |
|
widget: |
|
- text: "FEN: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1\nMOVE:" |
|
example_title: "Init board" |
|
- text: "FEN: r2qr1k1/1p3ppp/2n1bb2/p2p4/P2N4/1B1P2B1/1PP2PPP/R2QR1K1 b - - 0 16\nMOVE:" |
|
example_title: "Board with legal completion" |
|
- text: "FEN: 6k1/1p3ppp/2b2b2/p2p4/P2Q4/1B1P2B1/1PP2PPP/4q1K1 w - - 2 24\nMOVE:" |
|
example_title: "Board with illegal completion" |
|
--- |
|
# Model Card for gpt2-stockfish-debug |
|
|
|
## Training Details |
|
|
|
The model was trained during 1 epoch on the `Xmaster6y/stockfish-debug` dataset (no hyperparameter tuning done). The samples are: |
|
|
|
```json |
|
{"prompt":"FEN: {fen}\nMOVE:", "completion": " {move}"} |
|
``` |
|
|
|
Two possible simple extensions: |
|
|
|
- Expand the FEN string: `r2qk3/...` -> `r11qk111/...` or equivalent |
|
- Condition with the result (ELO not available in the dataset): |
|
```json |
|
{"prompt":"RES: {res}\nFEN: {fen}\nMOVE:", "completion": " {move}"} |
|
``` |
|
|
|
## Use the Model |
|
|
|
The following code requires `python-chess` (in addition to `transformers`) which you can install using `pip install python-chess`. |
|
|
|
```python |
|
import chess |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
def next_move(model, tokenizer, fen): |
|
input_ids = tokenizer(f"FEN: {fen}\nMOVE:", return_tensors="pt") |
|
input_ids = {k: v.to(model.device) for k, v in input_ids.items()} |
|
out = model.generate( |
|
**input_ids, |
|
max_new_tokens=10, |
|
pad_token_id=tokenizer.eos_token_id, |
|
do_sample=True, |
|
temperature=0.1, |
|
) |
|
out_str = tokenizer.batch_decode(out)[0] |
|
return out_str.split("MOVE:")[-1].replace("<|endoftext|>", "").strip() |
|
|
|
|
|
board = chess.Board() |
|
model = AutoModelForCausalLM.from_pretrained("Xmaster6y/gpt2-stockfish-debug") |
|
tokenizer = AutoTokenizer.from_pretrained("Xmaster6y/gpt2-stockfish-debug") # or "gpt2" |
|
tokenizer.pad_token = tokenizer.eos_token |
|
for i in range(100): |
|
fen = board.fen() |
|
move_uci = next_move(model, tokenizer, fen) |
|
try: |
|
print(move_uci) |
|
move = chess.Move.from_uci(move_uci) |
|
if move not in board.legal_moves: |
|
raise chess.IllegalMoveError |
|
board.push(move) |
|
except chess.IllegalMoveError: |
|
print(board) |
|
print("Illegal move", i) |
|
break |
|
``` |
|
|