Spaces:
Sleeping
Sleeping
Timo
commited on
Commit
·
fa1e508
1
Parent(s):
1c44fce
added model
Browse files- requirements.txt +5 -4
- src/draft_model.py +60 -0
- src/streamlit_app.py +2 -18
requirements.txt
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
| 2 |
requests
|
| 3 |
-
huggingface_hub>=0.23
|
| 4 |
-
joblib # if you pickle weights
|
| 5 |
-
scrython
|
|
|
|
| 1 |
+
# requirements.txt
|
| 2 |
+
streamlit>=1.33
|
| 3 |
+
torch>=2.3 # adjust to whatever you trained with
|
| 4 |
+
scipy # for Spearman correlation in p1p1
|
| 5 |
requests
|
| 6 |
+
huggingface_hub>=0.23
|
|
|
|
|
|
src/draft_model.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
from typing import List, Dict
|
| 5 |
+
|
| 6 |
+
from huggingface_hub import hf_hub_download
|
| 7 |
+
|
| 8 |
+
from src.models.winrate_model import Winrate_Model
|
| 9 |
+
from src.training import train_mlp
|
| 10 |
+
from src.utils import utils
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
MODEL_REPO = "TimoBertram/MTG_Model_FIN/"
|
| 14 |
+
CFG_FILE = "config.json"
|
| 15 |
+
MODEL_FILE = "network.pt"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
DATA_REPO = "TimoBertram/MTG_Drafting_Dataset/"
|
| 19 |
+
CARD_FILE = "cards_eoe.json"
|
| 20 |
+
ENCODING_FILE = "card_encodings.pt"
|
| 21 |
+
|
| 22 |
+
class DraftModel:
|
| 23 |
+
def __init__(self, model_path: str):
|
| 24 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 25 |
+
|
| 26 |
+
cfg_path = hf_hub_download(
|
| 27 |
+
repo_id=MODEL_REPO, filename=CFG_FILE, repo_type="model"
|
| 28 |
+
)
|
| 29 |
+
weight_path = hf_hub_download(
|
| 30 |
+
repo_id=MODEL_REPO, filename=MODEL_FILE, repo_type="model"
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# ---- load network ---------------------------------------------------
|
| 34 |
+
cfg = open(cfg_path, "r")
|
| 35 |
+
cfg.pop("name", None)
|
| 36 |
+
|
| 37 |
+
self.net = train_mlp.MLP_CrossAttention(**cfg).to(self.device)
|
| 38 |
+
self.net.load_state_dict(torch.load(Path(model_path) / "network.pt", map_location=self.device))
|
| 39 |
+
self.net.eval()
|
| 40 |
+
|
| 41 |
+
# ---- embeddings – one-time load ------------------------------------
|
| 42 |
+
self.embed_dict = utils.get_embedding_dict(
|
| 43 |
+
hf_hub_download(repo_id=DATA_REPO, filename=ENCODING_FILE, repo_type="dataset"),
|
| 44 |
+
add_nontransformed=True
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# --------------------------------------------------------------------- #
|
| 48 |
+
# Public API expected by streamlit_app.py #
|
| 49 |
+
# --------------------------------------------------------------------- #
|
| 50 |
+
@torch.no_grad()
|
| 51 |
+
def predict(self, pack: List[Dict], picks: List[Dict], deck: List[Dict]) -> Dict:
|
| 52 |
+
names = [c["name"] for c in pack]
|
| 53 |
+
|
| 54 |
+
def embed(name): # helper
|
| 55 |
+
return utils.get_card_embeddings((name,), embedding_dict=self.embed_dict)[0]
|
| 56 |
+
|
| 57 |
+
card_t = torch.stack([embed(n) for n in names]).unsqueeze(0).to(self.device)
|
| 58 |
+
deck_t = torch.zeros((1, 45, self.emb_size), device=self.device)
|
| 59 |
+
|
| 60 |
+
return torch.softmax(self.net(card_t, deck_t), dim=1).squeeze(0).cpu().numpy().tolist()
|
src/streamlit_app.py
CHANGED
|
@@ -22,19 +22,8 @@ from typing import Dict, List
|
|
| 22 |
import requests
|
| 23 |
import streamlit as st
|
| 24 |
|
| 25 |
-
|
| 26 |
-
# Disable telemetry write‑outs (avoids PermissionError: /.streamlit on Spaces)
|
| 27 |
-
# -----------------------------------------------------------------------------
|
| 28 |
-
try:
|
| 29 |
-
from streamlit.runtime.metrics_util import disable_gather_usage_stats
|
| 30 |
|
| 31 |
-
disable_gather_usage_stats()
|
| 32 |
-
except Exception:
|
| 33 |
-
st.set_option("browser.gatherUsageStats", False)
|
| 34 |
-
|
| 35 |
-
# -----------------------------------------------------------------------------
|
| 36 |
-
# 0. Constants & helpers
|
| 37 |
-
# -----------------------------------------------------------------------------
|
| 38 |
|
| 39 |
SUPPORTED_SETS_PATH = Path("supported_sets.txt")
|
| 40 |
st.write("Running Streamlit", st.__version__)
|
|
@@ -49,14 +38,9 @@ def get_supported_sets(path: Path = SUPPORTED_SETS_PATH) -> List[str]:
|
|
| 49 |
return []
|
| 50 |
|
| 51 |
|
| 52 |
-
# -----------------------------------------------------------------------------
|
| 53 |
-
# 1. Model & ranking stubs
|
| 54 |
-
# -----------------------------------------------------------------------------
|
| 55 |
-
|
| 56 |
@st.cache_resource(show_spinner="Loading draft model …")
|
| 57 |
def load_model():
|
| 58 |
-
|
| 59 |
-
return None
|
| 60 |
|
| 61 |
|
| 62 |
@st.cache_data(show_spinner="Calculating card rankings …")
|
|
|
|
| 22 |
import requests
|
| 23 |
import streamlit as st
|
| 24 |
|
| 25 |
+
from draft_model import DraftModel
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
SUPPORTED_SETS_PATH = Path("supported_sets.txt")
|
| 29 |
st.write("Running Streamlit", st.__version__)
|
|
|
|
| 38 |
return []
|
| 39 |
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
@st.cache_resource(show_spinner="Loading draft model …")
|
| 42 |
def load_model():
|
| 43 |
+
return DraftModel()
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
@st.cache_data(show_spinner="Calculating card rankings …")
|