from pathlib import Path import json import torch from typing import List, Dict from collections import defaultdict from huggingface_hub import hf_hub_download from helpers import get_embedding_dict, get_card_embeddings, MLP_CrossAttention MODEL_REPO = "TimoBertram/MTG_Model_FIN" CFG_FILE = "config.json" MODEL_FILE = "network.pt" DATA_REPO = "TimoBertram/MTG_Drafting_Dataset" CARD_FILE = "cards_eoe.json" ENCODING_FILE = "card_encodings.pt" class DraftModel: def __init__(self): #self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device("cpu") # Force CPU for compatibility weight_path = hf_hub_download( repo_id=MODEL_REPO, filename=MODEL_FILE, repo_type="model" ) cfg_path = hf_hub_download( repo_id=MODEL_REPO, filename=CFG_FILE, repo_type="model" ) with open(cfg_path, "r") as f: cfg = json.load(f) # ---- load network --------------------------------------------------- self.net = MLP_CrossAttention(**cfg) self.net.load_state_dict(torch.load(weight_path, map_location=self.device)) self.net.eval() self.net.to(self.device) # ---- embeddings – one-time load ------------------------------------ self.embed_dict = get_embedding_dict( hf_hub_download(repo_id=DATA_REPO, filename=ENCODING_FILE, repo_type="dataset"), add_nontransformed=True ) self.emb_size = next(iter(self.embed_dict.values())).shape[0] raw_card_file = json.load(open(hf_hub_download( repo_id=DATA_REPO, filename=CARD_FILE, repo_type="dataset" ))) self.cards = defaultdict(dict) for card in raw_card_file: self.cards[card["set"]][card["name"]] = card def _embed(self, name): # helper return get_card_embeddings((name,), embedding_dict=self.embed_dict)[0] # --------------------------------------------------------------------- # # Public API expected by streamlit_app.py # # --------------------------------------------------------------------- # @torch.no_grad() def predict(self, pack: List[str], deck: List[str]) -> Dict: card_t = torch.stack([self._embed(c) for c in pack]).unsqueeze(0).to(self.device) if deck is None: deck_t = torch.zeros((1, 45, self.emb_size), device=self.device) else: deck_t = torch.stack([self._embed(c) for c in deck]).unsqueeze(0).to(self.device) vals = self.net(deck = deck_t, cards = card_t) scores = torch.softmax(vals, dim=1).squeeze(0).cpu().numpy() logits = vals.squeeze(0).cpu().numpy() return { "pick": pack[scores.argmax()], "logits": logits.tolist(), "scores": scores.tolist(), } @torch.no_grad() def get_p1p1(self, set_code:str): keys = list(self.cards[set_code].keys()) cards = torch.stack([self._embed(c) for c in keys]).unsqueeze(0).to(self.device) vals = self.predict(pack=keys, deck=None)["logits"] return keys, vals