Timo commited on
Commit
fa1e508
·
1 Parent(s): 1c44fce

added model

Browse files
Files changed (3) hide show
  1. requirements.txt +5 -4
  2. src/draft_model.py +60 -0
  3. src/streamlit_app.py +2 -18
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
- streamlit>=1.28
 
 
 
2
  requests
3
- huggingface_hub>=0.23 # if you pull the model from the Hub
4
- joblib # if you pickle weights
5
- scrython
 
1
+ # requirements.txt
2
+ streamlit>=1.33
3
+ torch>=2.3 # adjust to whatever you trained with
4
+ scipy # for Spearman correlation in p1p1
5
  requests
6
+ huggingface_hub>=0.23
 
 
src/draft_model.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import json
3
+ import torch
4
+ from typing import List, Dict
5
+
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ from src.models.winrate_model import Winrate_Model
9
+ from src.training import train_mlp
10
+ from src.utils import utils
11
+
12
+
13
+ MODEL_REPO = "TimoBertram/MTG_Model_FIN/"
14
+ CFG_FILE = "config.json"
15
+ MODEL_FILE = "network.pt"
16
+
17
+
18
+ DATA_REPO = "TimoBertram/MTG_Drafting_Dataset/"
19
+ CARD_FILE = "cards_eoe.json"
20
+ ENCODING_FILE = "card_encodings.pt"
21
+
22
+ class DraftModel:
23
+ def __init__(self, model_path: str):
24
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+
26
+ cfg_path = hf_hub_download(
27
+ repo_id=MODEL_REPO, filename=CFG_FILE, repo_type="model"
28
+ )
29
+ weight_path = hf_hub_download(
30
+ repo_id=MODEL_REPO, filename=MODEL_FILE, repo_type="model"
31
+ )
32
+
33
+ # ---- load network ---------------------------------------------------
34
+ cfg = open(cfg_path, "r")
35
+ cfg.pop("name", None)
36
+
37
+ self.net = train_mlp.MLP_CrossAttention(**cfg).to(self.device)
38
+ self.net.load_state_dict(torch.load(Path(model_path) / "network.pt", map_location=self.device))
39
+ self.net.eval()
40
+
41
+ # ---- embeddings – one-time load ------------------------------------
42
+ self.embed_dict = utils.get_embedding_dict(
43
+ hf_hub_download(repo_id=DATA_REPO, filename=ENCODING_FILE, repo_type="dataset"),
44
+ add_nontransformed=True
45
+ )
46
+
47
+ # --------------------------------------------------------------------- #
48
+ # Public API expected by streamlit_app.py #
49
+ # --------------------------------------------------------------------- #
50
+ @torch.no_grad()
51
+ def predict(self, pack: List[Dict], picks: List[Dict], deck: List[Dict]) -> Dict:
52
+ names = [c["name"] for c in pack]
53
+
54
+ def embed(name): # helper
55
+ return utils.get_card_embeddings((name,), embedding_dict=self.embed_dict)[0]
56
+
57
+ card_t = torch.stack([embed(n) for n in names]).unsqueeze(0).to(self.device)
58
+ deck_t = torch.zeros((1, 45, self.emb_size), device=self.device)
59
+
60
+ return torch.softmax(self.net(card_t, deck_t), dim=1).squeeze(0).cpu().numpy().tolist()
src/streamlit_app.py CHANGED
@@ -22,19 +22,8 @@ from typing import Dict, List
22
  import requests
23
  import streamlit as st
24
 
25
- # -----------------------------------------------------------------------------
26
- # Disable telemetry write‑outs (avoids PermissionError: /.streamlit on Spaces)
27
- # -----------------------------------------------------------------------------
28
- try:
29
- from streamlit.runtime.metrics_util import disable_gather_usage_stats
30
 
31
- disable_gather_usage_stats()
32
- except Exception:
33
- st.set_option("browser.gatherUsageStats", False)
34
-
35
- # -----------------------------------------------------------------------------
36
- # 0. Constants & helpers
37
- # -----------------------------------------------------------------------------
38
 
39
  SUPPORTED_SETS_PATH = Path("supported_sets.txt")
40
  st.write("Running Streamlit", st.__version__)
@@ -49,14 +38,9 @@ def get_supported_sets(path: Path = SUPPORTED_SETS_PATH) -> List[str]:
49
  return []
50
 
51
 
52
- # -----------------------------------------------------------------------------
53
- # 1. Model & ranking stubs
54
- # -----------------------------------------------------------------------------
55
-
56
  @st.cache_resource(show_spinner="Loading draft model …")
57
  def load_model():
58
- # 🔄 Replace with your own pipeline or `hf_hub_download` call
59
- return None
60
 
61
 
62
  @st.cache_data(show_spinner="Calculating card rankings …")
 
22
  import requests
23
  import streamlit as st
24
 
25
+ from draft_model import DraftModel
 
 
 
 
26
 
 
 
 
 
 
 
 
27
 
28
  SUPPORTED_SETS_PATH = Path("supported_sets.txt")
29
  st.write("Running Streamlit", st.__version__)
 
38
  return []
39
 
40
 
 
 
 
 
41
  @st.cache_resource(show_spinner="Loading draft model …")
42
  def load_model():
43
+ return DraftModel()
 
44
 
45
 
46
  @st.cache_data(show_spinner="Calculating card rankings …")