| import torch | |
| from transformers import AutoTokenizer, GPT2Model | |
| import torch.nn as nn | |
| class ChessMoveClassifier(nn.Module): | |
| def __init__(self, model_name, num_labels=4096): | |
| super().__init__() | |
| self.base_model = GPT2Model.from_pretrained(model_name) | |
| self.dropout = nn.Dropout(0.1) | |
| self.classifier = nn.Linear(self.base_model.config.n_embd, num_labels) | |
| def forward(self, input_ids, attention_mask=None, **kwargs): | |
| outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask) | |
| hidden_state = outputs.last_hidden_state[:, -1, :] | |
| logits = self.classifier(self.dropout(hidden_state)) | |
| return {"logits": logits} | |
| def model_fn(model_dir): | |
| model = ChessMoveClassifier(model_name="austindavis/ChessGPT_d12") | |
| model.load_state_dict(torch.load(f"{model_dir}/model.pt", map_location="cpu")) | |
| model.eval() | |
| return model | |