Xmaster6y commited on
Commit
315b9b8
1 Parent(s): 3422e7f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +8 -6
README.md CHANGED
@@ -40,22 +40,24 @@ Two possible simple extensions:
40
  import chess
41
  from transformers import AutoModelForCausalLM, AutoTokenizer
42
 
 
43
  def next_move(model, tokenizer, fen):
44
  input_ids = tokenizer(f"FEN: {fen}\nMOVE:", return_tensors="pt")
45
- input_ids = {k:v.to(model.device) for k,v in input_ids.items()}
46
  out = model.generate(
47
- **input_ids,
48
  max_new_tokens=10,
49
  pad_token_id=tokenizer.eos_token_id,
50
  do_sample=True,
51
  temperature=0.1,
52
- )
53
  out_str = tokenizer.batch_decode(out)[0]
54
- return out_str.split('MOVE:')[-1].replace("<|endoftext|>", "").strip()
 
55
 
56
  board = chess.Board()
57
- model = AutoModelForCausalLM.from_pretrained('Xmaster6y/gpt2-stockfish-debug')
58
- tokenizer = AutoTokenizer.from_pretrained('gpt2')
59
  tokenizer.pad_token = tokenizer.eos_token
60
  for i in range(100):
61
  fen = board.fen()
 
40
  import chess
41
  from transformers import AutoModelForCausalLM, AutoTokenizer
42
 
43
+
44
  def next_move(model, tokenizer, fen):
45
  input_ids = tokenizer(f"FEN: {fen}\nMOVE:", return_tensors="pt")
46
+ input_ids = {k: v.to(model.device) for k, v in input_ids.items()}
47
  out = model.generate(
48
+ **input_ids,
49
  max_new_tokens=10,
50
  pad_token_id=tokenizer.eos_token_id,
51
  do_sample=True,
52
  temperature=0.1,
53
+ )
54
  out_str = tokenizer.batch_decode(out)[0]
55
+ return out_str.split("MOVE:")[-1].replace("<|endoftext|>", "").strip()
56
+
57
 
58
  board = chess.Board()
59
+ model = AutoModelForCausalLM.from_pretrained("Xmaster6y/gpt2-stockfish-debug")
60
+ tokenizer = AutoTokenizer.from_pretrained("Xmaster6y/gpt2-stockfish-debug") # or "gpt2"
61
  tokenizer.pad_token = tokenizer.eos_token
62
  for i in range(100):
63
  fen = board.fen()