Timo commited on
Commit
2866e65
·
1 Parent(s): e6f2697

Functionality works, UI needs work

Browse files
Files changed (2) hide show
  1. src/draft_model.py +1 -2
  2. src/streamlit_app.py +158 -75
src/draft_model.py CHANGED
@@ -63,7 +63,7 @@ class DraftModel:
63
  # Public API expected by streamlit_app.py #
64
  # --------------------------------------------------------------------- #
65
  @torch.no_grad()
66
- def predict(self, pack: List[Dict], deck: List[Dict]) -> Dict:
67
 
68
  card_t = torch.stack([self._embed(c) for c in pack]).unsqueeze(0).to(self.device)
69
  if deck is None:
@@ -74,7 +74,6 @@ class DraftModel:
74
  vals = self.net(deck = deck_t, cards = card_t)
75
  #scores = torch.softmax(vals, dim=1).squeeze(0).cpu().numpy()
76
  scores = vals.squeeze(0).cpu().numpy()
77
- print(scores)
78
  return {
79
  "pick": pack[scores.argmax()],
80
  "scores": scores.tolist(),
 
63
  # Public API expected by streamlit_app.py #
64
  # --------------------------------------------------------------------- #
65
  @torch.no_grad()
66
+ def predict(self, pack: List[str], deck: List[str]) -> Dict:
67
 
68
  card_t = torch.stack([self._embed(c) for c in pack]).unsqueeze(0).to(self.device)
69
  if deck is None:
 
74
  vals = self.net(deck = deck_t, cards = card_t)
75
  #scores = torch.softmax(vals, dim=1).squeeze(0).cpu().numpy()
76
  scores = vals.squeeze(0).cpu().numpy()
 
77
  return {
78
  "pick": pack[scores.argmax()],
79
  "scores": scores.tolist(),
src/streamlit_app.py CHANGED
@@ -33,6 +33,8 @@ from pathlib import Path
33
 
34
  from typing import Dict, List
35
  import pandas as pd
 
 
36
 
37
  import requests
38
  import streamlit as st
@@ -41,20 +43,20 @@ from draft_model import DraftModel
41
 
42
  SUPPORTED_SETS_PATH = Path("src/helper_files/supported_sets.txt")
43
 
 
 
 
 
 
44
 
45
  @st.cache_data(show_spinner="Reading supported sets …")
46
  def get_supported_sets(path: Path = SUPPORTED_SETS_PATH) -> List[str]:
47
  """Return a list of legal set codes read from *supported_sets.txt*."""
48
-
49
  if path.is_file():
50
  return [ln.strip() for ln in path.read_text().splitlines() if ln.strip()]
51
  return []
52
 
53
 
54
- @st.cache_resource(show_spinner="Loading draft model …")
55
- def load_model():
56
- return DraftModel()
57
-
58
  @st.cache_data(show_spinner="Calculating P1P1 ...")
59
  def p1p1_ranking(set_code: str):
60
  names, scores = model.get_p1p1(set_code)
@@ -65,36 +67,11 @@ def p1p1_ranking(set_code: str):
65
 
66
 
67
  @st.cache_data(show_spinner="Calculating card rankings …")
68
- def rank_cards(set_code: str) -> List[Dict]:
69
- """Return a stubbed ranking list for *set_code*.
70
-
71
- Replace with your real evaluation logic. For now we just pull 30 random
72
- commons from the set and assign a dummy score.
73
- """
74
-
75
- url = f"https://api.scryfall.com/cards/search?q=set%3A{set_code}+unique%3Aprints+is%3Acommon"
76
- cards: List[Dict] = []
77
- while url and len(cards) < 60: # cap network use
78
- r = requests.get(url)
79
- r.raise_for_status()
80
- payload = r.json()
81
- cards += payload["data"]
82
- url = payload.get("next_page") if payload.get("has_more") else None
83
-
84
- sample = random.sample(cards, k=min(30, len(cards))) if cards else []
85
- ranked = [
86
- {"name": c["name"], "score": round(random.random(), 2)} for c in sample
87
- ]
88
- ranked.sort(key=lambda x: x["score"], reverse=True)
89
- return ranked
90
-
91
-
92
- model = load_model()
93
-
94
- def suggest_pick(pack: List[Dict], picks: List[Dict]) -> Dict:
95
- if model is None:
96
- return random.choice(pack)
97
- return model.predict(pack=pack, picks=picks) # type: ignore[attr-defined]
98
 
99
 
100
  def fetch_card_image(card_name: str) -> str:
@@ -125,61 +102,167 @@ with st.sidebar:
125
  supported_sets = get_supported_sets()
126
 
127
  # Hide control in an expander (collapsed by default)
128
- with st.expander("Set selection", expanded=False):
129
- if supported_sets:
130
- # UPDATED: dropdown instead of radio
131
- set_code = st.selectbox(
132
- "Choose a set",
133
- supported_sets,
134
- index=0,
135
- key="set_code",
136
- )
137
- else:
138
- st.warning(
139
- "*supported_sets.txt* not found or empty. Using free-text input instead.",
140
- icon="⚠️",
141
- )
142
- set_code = st.text_input("Set code", value="EOE", key="set_code")
143
 
144
- if st.button("Start new draft", type="primary"):
145
- st.session_state["pack"] = generate_booster(set_code)
146
- st.session_state["picks"] = []
147
 
148
  # -------- Session state ------------------------------------------------------
149
 
150
  st.session_state.setdefault("pack", [])
151
  st.session_state.setdefault("picks", [])
 
152
 
153
  # -------- Main content organised in tabs ------------------------------------
154
 
155
  tabs = st.tabs(["Draft", "P1P1 Rankings"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  # --- Tab 1: Draft ------------------------------------------------------------
158
 
159
  with tabs[0]:
160
- if not st.session_state["pack"]:
161
- st.info("Choose **Start new draft** in the sidebar to open pack 1.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  else:
163
- pack: List[Dict] = st.session_state["pack"]
164
- picks: List[Dict] = st.session_state["picks"]
165
-
166
- st.subheader(f"Pack {len(picks) // 15 + 1} — Pick {len(picks) % 15 + 1}")
167
- suggested = suggest_pick(pack, picks)
168
- st.success(f"**Model suggests:** {suggested['name']}")
169
-
170
- cols = st.columns(5)
171
- for idx, card in enumerate(pack):
172
- col = cols[idx % 5]
173
- col.image(fetch_card_image(card["name"]), use_column_width=True)
174
- if col.button(f"Pick {card['name']}", key=f"pick-{idx}"):
175
- picks.append(card)
176
- pack.remove(card)
177
- if not pack: # end of pack ⇒ open a fresh booster
178
- st.session_state["pack"] = generate_booster(set_code)
179
- st.experimental_rerun()
180
-
181
- with st.expander("Current picks", expanded=False):
182
- st.write("\n".join([c["name"] for c in picks]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  # --- Tab 2: Card rankings ----------------------------------------------------
185
 
 
33
 
34
  from typing import Dict, List
35
  import pandas as pd
36
+ import copy
37
+ import numpy as np
38
 
39
  import requests
40
  import streamlit as st
 
43
 
44
  SUPPORTED_SETS_PATH = Path("src/helper_files/supported_sets.txt")
45
 
46
+ @st.cache_resource(show_spinner="Loading draft model …")
47
+ def load_model():
48
+ return DraftModel()
49
+
50
+ model = load_model()
51
 
52
  @st.cache_data(show_spinner="Reading supported sets …")
53
  def get_supported_sets(path: Path = SUPPORTED_SETS_PATH) -> List[str]:
54
  """Return a list of legal set codes read from *supported_sets.txt*."""
 
55
  if path.is_file():
56
  return [ln.strip() for ln in path.read_text().splitlines() if ln.strip()]
57
  return []
58
 
59
 
 
 
 
 
60
  @st.cache_data(show_spinner="Calculating P1P1 ...")
61
  def p1p1_ranking(set_code: str):
62
  names, scores = model.get_p1p1(set_code)
 
67
 
68
 
69
  @st.cache_data(show_spinner="Calculating card rankings …")
70
+ def rank_cards(deck: List[str], pack: List[str]) -> List[Dict]:
71
+ out = model.predict(pack, deck)
72
+ pick = out["pick"]
73
+ scores = {pack[i]: score for i, score in enumerate(out["scores"])}
74
+ return pick, scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
 
77
  def fetch_card_image(card_name: str) -> str:
 
102
  supported_sets = get_supported_sets()
103
 
104
  # Hide control in an expander (collapsed by default)
105
+ if supported_sets:
106
+ # UPDATED: dropdown instead of radio
107
+ set_code = st.selectbox(
108
+ "Choose a set",
109
+ supported_sets,
110
+ index=0,
111
+ key="set_code",
112
+ )
113
+ else:
114
+ st.warning(
115
+ "*supported_sets.txt* not found or empty. Using free-text input instead.",
116
+ icon="⚠️",
117
+ )
118
+ set_code = st.text_input("Set code", value="EOE", key="set_code")
 
119
 
 
 
 
120
 
121
  # -------- Session state ------------------------------------------------------
122
 
123
  st.session_state.setdefault("pack", [])
124
  st.session_state.setdefault("picks", [])
125
+ st.session_state.setdefault("undo_stack", [])
126
 
127
  # -------- Main content organised in tabs ------------------------------------
128
 
129
  tabs = st.tabs(["Draft", "P1P1 Rankings"])
130
+ def ensure_state():
131
+ """Make sure session_state has the pack and picks lists."""
132
+ if "pack" not in st.session_state:
133
+ st.session_state["pack"] = []
134
+ if "picks" not in st.session_state:
135
+ st.session_state["picks"] = []
136
+ def add_card(target: str, card: str):
137
+ """target is 'pack' or 'picks'."""
138
+ ensure_state()
139
+ st.session_state[target].append(card)
140
+
141
+ def remove_card(target: str, key: str):
142
+ ensure_state()
143
+ lst = st.session_state[target]
144
+ idx = next((i for i, c in enumerate(lst) if c == key), None)
145
+ if idx is not None:
146
+ lst.pop(idx)
147
+
148
+ def push_undo():
149
+ """Save a snapshot of pack + picks so we can undo one step."""
150
+ ensure_state()
151
+ st.session_state["undo_stack"].append({
152
+ "pack": copy.deepcopy(st.session_state["pack"]),
153
+ "picks": copy.deepcopy(st.session_state["picks"]),
154
+ })
155
+ # (optional) cap history
156
+ if len(st.session_state["undo_stack"]) > 20:
157
+ st.session_state["undo_stack"].pop(0)
158
+
159
+ def undo_last():
160
+ if st.session_state.get("undo_stack"):
161
+ snap = st.session_state["undo_stack"].pop()
162
+ st.session_state["pack"] = snap["pack"]
163
+ st.session_state["picks"] = snap["picks"]
164
+ st.rerun()
165
+
166
+ # --- callbacks ---
167
+ def _add_selected_to_deck():
168
+ val = st.session_state.get("deck_selectbox")
169
+ if val:
170
+ add_card("picks", val)
171
+ st.session_state["deck_selectbox"] = None # clear selection
172
+ st.toast(f"Added to deck: {val}")
173
+ st.rerun()
174
+
175
+ def _add_selected_to_pack():
176
+ val = st.session_state.get("pack_selectbox")
177
+ if val:
178
+ add_card("pack", val)
179
+ st.session_state["pack_selectbox"] = None # clear selection
180
+ st.toast(f"Added to pack: {val}")
181
+ st.rerun()
182
+
183
 
184
  # --- Tab 1: Draft ------------------------------------------------------------
185
 
186
  with tabs[0]:
187
+ ensure_state()
188
+
189
+
190
+ if st.session_state["undo_stack"]:
191
+ st.button("↩️ Undo last action", on_click=undo_last)
192
+
193
+ scores = {}
194
+ pick = None
195
+
196
+ if st.session_state["pack"]:
197
+ pack = st.session_state["pack"]
198
+ deck = st.session_state["picks"]
199
+
200
+ try:
201
+ pick, scores = rank_cards(deck, pack)
202
+ except Exception as e:
203
+ st.error(f"Error calculating card rankings: {e}")
204
+ if not set_code:
205
+ st.info("Pick a set to draft")
206
  else:
207
+ options = list(model.cards[set_code.lower()].keys())
208
+ c1, c2 = st.columns(2)
209
+ with c1:
210
+ st.subheader("Add to Deck")
211
+ deck_sel = st.selectbox(
212
+ "Search card (deck)",
213
+ options,
214
+ index=None,
215
+ placeholder="Type to search…",
216
+ key="deck_selectbox",
217
+ on_change=_add_selected_to_deck, # <- auto-add
218
+ )
219
+
220
+ # Show current deck with remove buttons
221
+ if st.session_state["picks"]:
222
+ for i, card in enumerate(st.session_state["picks"]):
223
+ cols = st.columns([6, 2])
224
+ cols[0].write(card)
225
+ if cols[1].button("Remove", key=f"rm-deck-{i}"):
226
+ remove_card("picks", card)
227
+ st.rerun()
228
+ else:
229
+ st.caption("Deck is empty.")
230
+
231
+ with c2:
232
+ st.subheader("Add to pack")
233
+ pack_sel = st.selectbox(
234
+ "Search card (pack)",
235
+ options,
236
+ index=None,
237
+ placeholder="Type to search…",
238
+ key="pack_selectbox",
239
+ on_change=_add_selected_to_pack, # <- auto-add
240
+ )
241
+
242
+ if st.session_state["pack"]:
243
+ pack_list = st.session_state["pack"]
244
+ vals = [scores.get(c) if scores and c in scores else np.nan for c in pack_list]
245
+ df_scores = pd.DataFrame({"Card": pack_list, "Score": vals})
246
+ tbl_col, btn_col = st.columns([4, 1], gap="small")
247
+ with tbl_col:
248
+ st.dataframe(
249
+ df_scores,
250
+ use_container_width=True,
251
+ column_config={
252
+ "Score": st.column_config.NumberColumn(format="%.4f")
253
+ },
254
+ hide_index=True,
255
+ )
256
+ with btn_col:
257
+ for i, card in enumerate(pack_list):
258
+ cols = st.columns([8, 4])
259
+ if cols[1].button("Add", key=f"add_clear_{i}"):
260
+ push_undo()
261
+ st.session_state["picks"].append(card)
262
+ st.session_state["pack"] = [] # or generate_booster(set_code)
263
+ st.rerun()
264
+ else:
265
+ st.caption("Pack is empty.")
266
 
267
  # --- Tab 2: Card rankings ----------------------------------------------------
268