Xmaster6y commited on
Commit
82bd166
1 Parent(s): 27166aa

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +38 -0
README.md CHANGED
@@ -25,3 +25,41 @@ Simple completion tuning using an equivalent of:
25
  ```json
26
  {"prompt":"FEN: {fen}\nMOVE:", "completion": " {move}"}
27
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  ```json
26
  {"prompt":"FEN: {fen}\nMOVE:", "completion": " {move}"}
27
  ```
28
+
29
+ ## Use
30
+
31
+ ```python
32
+ import chess
33
+ from transformers import AutoModelForCausalLM, AutoTokenizer
34
+
35
+ def next_move(model, tokenizer, fen):
36
+ input_ids = tokenizer(f"FEN: {fen}\nMOVE:", return_tensors="pt")
37
+ input_ids = {k:v.to(model.device) for k,v in input_ids.items()}
38
+ out = model.generate(
39
+ **input_ids,
40
+ max_new_tokens=10,
41
+ pad_token_id=tokenizer.eos_token_id,
42
+ do_sample=True,
43
+ temperature=0.1,
44
+ )
45
+ out_str = tokenizer.batch_decode(out)[0]
46
+ return out_str.split('MOVE:')[-1].replace("<|endoftext|>", "").strip()
47
+
48
+ board = chess.Board()
49
+ model = AutoModelForCausalLM.from_pretrained('Xmaster6y/gpt2-stockfish-debug')
50
+ tokenizer = AutoTokenizer.from_pretrained('gpt2')
51
+ tokenizer.pad_token = tokenizer.eos_token
52
+ for i in range(100):
53
+ fen = board.fen()
54
+ move_uci = next_move(model, tokenizer, fen)
55
+ try:
56
+ print(move_uci)
57
+ move = chess.Move.from_uci(move_uci)
58
+ if move not in board.legal_moves:
59
+ raise chess.IllegalMoveError
60
+ board.push(move)
61
+ except chess.IllegalMoveError:
62
+ print(board)
63
+ print("Illegal move", i)
64
+ break
65
+ ```