Timo commited on
Commit
8ef1a8f
·
1 Parent(s): 4205025
Files changed (2) hide show
  1. src/draft_model.py +4 -3
  2. src/streamlit_app.py +141 -114
src/draft_model.py CHANGED
@@ -74,10 +74,11 @@ class DraftModel:
74
  deck_t = torch.stack([self._embed(c) for c in deck]).unsqueeze(0).to(self.device)
75
 
76
  vals = self.net(deck = deck_t, cards = card_t)
77
- #scores = torch.softmax(vals, dim=1).squeeze(0).cpu().numpy()
78
- scores = vals.squeeze(0).cpu().numpy()
79
  return {
80
  "pick": pack[scores.argmax()],
 
81
  "scores": scores.tolist(),
82
  }
83
  @torch.no_grad()
@@ -85,7 +86,7 @@ class DraftModel:
85
  keys = list(self.cards[set_code].keys())
86
  cards = torch.stack([self._embed(c) for c in keys]).unsqueeze(0).to(self.device)
87
 
88
- vals = self.predict(pack=keys, deck=None)["scores"]
89
  return keys, vals
90
 
91
 
 
74
  deck_t = torch.stack([self._embed(c) for c in deck]).unsqueeze(0).to(self.device)
75
 
76
  vals = self.net(deck = deck_t, cards = card_t)
77
+ scores = torch.softmax(vals, dim=1).squeeze(0).cpu().numpy()
78
+ logits = vals.squeeze(0).cpu().numpy()
79
  return {
80
  "pick": pack[scores.argmax()],
81
+ "logits": logits.tolist(),
82
  "scores": scores.tolist(),
83
  }
84
  @torch.no_grad()
 
86
  keys = list(self.cards[set_code].keys())
87
  cards = torch.stack([self._embed(c) for c in keys]).unsqueeze(0).to(self.device)
88
 
89
+ vals = self.predict(pack=keys, deck=None)["logits"]
90
  return keys, vals
91
 
92
 
src/streamlit_app.py CHANGED
@@ -47,6 +47,39 @@ SUPPORTED_SETS_PATH = Path("src/helper_files/supported_sets.txt")
47
  def load_model():
48
  return DraftModel()
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  if "model" not in st.session_state:
52
  st.session_state.model = load_model() # your class
@@ -58,7 +91,7 @@ if "undo_stack" not in st.session_state:
58
  st.session_state.undo_stack: List[str] = []
59
  if "set_code" not in st.session_state:
60
  # choose a default set code that exists in model.cards, e.g., "eoe"
61
- st.session_state.set_code = "eoe"
62
 
63
  model = st.session_state.model
64
 
@@ -88,8 +121,9 @@ def rank_cards(deck: List[str], pack: List[str]) -> List[Dict]:
88
  else:
89
  out = model.predict(pack, deck = deck)
90
  pick = out["pick"]
 
91
  scores = {pack[i]: score for i, score in enumerate(out["scores"])}
92
- return pick, scores
93
 
94
 
95
  def fetch_card_image(card_name: str) -> str:
@@ -118,29 +152,13 @@ with st.sidebar:
118
  st.header("Draft setup")
119
 
120
  supported_sets = get_supported_sets()
121
-
122
- # Hide control in an expander (collapsed by default)
123
- if supported_sets:
124
- # UPDATED: dropdown instead of radio
125
- set_code = st.selectbox(
126
- "Choose a set",
127
- supported_sets,
128
- index=0,
129
- key="set_code",
130
- )
131
- else:
132
- st.warning(
133
- "*supported_sets.txt* not found or empty. Using free-text input instead.",
134
- icon="⚠️",
135
- )
136
- set_code = st.text_input("Set code", value="EOE", key="set_code")
137
-
138
-
139
- # -------- Session state ------------------------------------------------------
140
-
141
- st.session_state.setdefault("pack", [])
142
- st.session_state.setdefault("picks", [])
143
- st.session_state.setdefault("undo_stack", [])
144
 
145
  # -------- Main content organised in tabs ------------------------------------
146
 
@@ -148,7 +166,15 @@ tabs = st.tabs(["Draft", "P1P1 Rankings"])
148
 
149
  def add_card(target: str, card: str):
150
  """target is 'pack' or 'picks'."""
151
- st.session_state[target].append(card)
 
 
 
 
 
 
 
 
152
 
153
  def remove_card(target: str, key: str):
154
  lst = st.session_state[target]
@@ -160,7 +186,7 @@ def push_undo():
160
  """Save a snapshot of pack + picks so we can undo one step."""
161
  st.session_state["undo_stack"].append({
162
  "pack": copy.deepcopy(st.session_state["pack"]),
163
- "picks": copy.deepcopy(st.session_state["picks"]),
164
  })
165
  # (optional) cap history
166
  if len(st.session_state["undo_stack"]) > 20:
@@ -170,29 +196,13 @@ def undo_last():
170
  if st.session_state.get("undo_stack"):
171
  snap = st.session_state["undo_stack"].pop()
172
  st.session_state["pack"] = snap["pack"]
173
- st.session_state["picks"] = snap["picks"]
174
 
175
- # --- callbacks ---
176
- def _add_selected_to_deck():
177
- val = st.session_state.get("deck_selectbox")
178
- if val:
179
- add_card("picks", val)
180
- st.session_state["deck_selectbox"] = None # clear selection
181
- st.toast(f"Added to deck: {val}")
182
-
183
- def _add_selected_to_pack():
184
- val = st.session_state.get("pack_selectbox")
185
- if val:
186
- add_card("pack", val)
187
- st.session_state["pack_selectbox"] = None # clear selection
188
- st.toast(f"Added to pack: {val}")
189
- #st.rerun()
190
 
191
 
192
- # --- Tab 1: Draft ------------------------------------------------------------
193
-
194
  with tabs[0]:
195
-
196
  if st.session_state["undo_stack"]:
197
  st.button("↩️ Undo last action", on_click=undo_last)
198
 
@@ -201,78 +211,95 @@ with tabs[0]:
201
 
202
  if st.session_state["pack"]:
203
  pack = st.session_state["pack"]
204
- deck = st.session_state["picks"]
205
 
206
  try:
207
- pick, scores = rank_cards(deck, pack)
 
 
208
  except Exception as e:
209
  st.error(f"Error calculating card rankings: {e}")
210
- if not set_code:
211
- st.info("Pick a set to draft")
212
- else:
213
- options = list(model.cards[set_code.lower()].keys())
214
- c1, c2 = st.columns(2)
215
- with c1:
216
- st.subheader("Add to Deck")
217
- deck_sel = st.selectbox(
218
- "Search card (deck)",
219
- options,
220
- index=None,
221
- placeholder="Type to search…",
222
- key="deck_selectbox",
223
- on_change=_add_selected_to_deck, # <- auto-add
224
- )
225
-
226
- # Show current deck with remove buttons
227
- if st.session_state["picks"]:
228
- for i, card in enumerate(st.session_state["picks"]):
229
- name_col, rm_col = st.columns([6, 3], gap="small")
230
- name_col.write(card)
231
- with rm_col:
232
- if st.button("Remove", key=f"rm-deck-{i}", use_container_width=True):
233
- remove_card("picks", card)
234
- st.rerun()
235
- else:
236
- st.caption("Deck is empty.")
237
-
238
- with c2:
239
- st.subheader("Add to pack")
240
- pack_sel = st.selectbox(
241
- "Search card (pack)",
242
- options,
243
- index=None,
244
- placeholder="Type to search…",
245
- key="pack_selectbox",
246
- on_change=_add_selected_to_pack, # <- auto-add
247
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
- if st.session_state["pack"]:
250
- pack_list = st.session_state["pack"]
251
- vals = [scores.get(c) if scores and c in scores else np.nan for c in pack_list]
252
- df_scores = pd.DataFrame({"Card": pack_list, "Score": vals})
253
-
254
- pack_list = st.session_state["pack"]
255
- vals = [scores.get(c) if scores and c in scores else np.nan for c in pack_list]
256
-
257
- # header row
258
- h1, h2, h3 = st.columns([6, 2, 3])
259
- h1.markdown("**Card**")
260
- h2.markdown("**Score**")
261
- h3.markdown("**Pick**")
262
-
263
-
264
- # rows
265
- for i, card in enumerate(pack_list):
266
- s = vals[i]
267
- c1, c2, c3 = st.columns([6, 2, 3], gap="small")
268
- c1.write(card)
269
- c2.markdown(f"<div class='score-cell'>{'' if np.isnan(s) else f'{s:.4f}'}</div>", unsafe_allow_html=True)
270
- with c3:
271
- if st.button("Pick", key=f"pick_btn_{i}", use_container_width=True, help="Add to deck & clear pack"):
272
- push_undo()
273
- st.session_state["picks"].append(card)
274
- st.session_state["pack"] = [] # or generate_booster(set_code)
275
- st.rerun()
276
 
277
  # --- Tab 2: Card rankings ----------------------------------------------------
278
 
@@ -289,4 +316,4 @@ with tabs[1]:
289
  except Exception as e:
290
  st.error(f"Could not calculate P1P1: {e}")
291
  else:
292
- st.info("Select a set in the sidebar to view P1P1.")
 
47
  def load_model():
48
  return DraftModel()
49
 
50
+ # --- callbacks ---
51
+ def _add_selected_to_deck():
52
+ val = st.session_state.get("deck_selectbox")
53
+ if val:
54
+ add_card("deck", val)
55
+ st.session_state["deck_selectbox"] = None # clear selection
56
+ st.toast(f"Added to deck: {val}")
57
+
58
+ def _add_selected_to_pack():
59
+ val = st.session_state.get("pack_selectbox")
60
+ if val:
61
+ success = add_card("pack", val)
62
+ st.session_state["pack_selectbox"] = None # clear selection
63
+ if success:
64
+ st.toast(f"Added to pack: {val}")
65
+ #st.rerun()
66
+
67
+ def _reset_draft_state():
68
+ st.session_state["pack"] = []
69
+ st.session_state["deck"] = []
70
+ st.session_state["undo_stack"] = []
71
+ st.session_state["deck_selectbox"] = None
72
+ st.session_state["pack_selectbox"] = None
73
+
74
+ def _on_set_changed():
75
+ curr = st.session_state.get("set_code")
76
+ prev = st.session_state.get("prev_set_code")
77
+ if prev != curr:
78
+ _reset_draft_state()
79
+ st.session_state["prev_set_code"] = curr
80
+ st.toast(f"Switched to set {curr}. Cleared current pack & deck.")
81
+
82
+
83
 
84
  if "model" not in st.session_state:
85
  st.session_state.model = load_model() # your class
 
91
  st.session_state.undo_stack: List[str] = []
92
  if "set_code" not in st.session_state:
93
  # choose a default set code that exists in model.cards, e.g., "eoe"
94
+ st.session_state.set_code = "EOE"
95
 
96
  model = st.session_state.model
97
 
 
121
  else:
122
  out = model.predict(pack, deck = deck)
123
  pick = out["pick"]
124
+ logits = {pack[i]: score for i, score in enumerate(out["logits"])}
125
  scores = {pack[i]: score for i, score in enumerate(out["scores"])}
126
+ return pick, logits, scores
127
 
128
 
129
  def fetch_card_image(card_name: str) -> str:
 
152
  st.header("Draft setup")
153
 
154
  supported_sets = get_supported_sets()
155
+ set_code = st.selectbox(
156
+ "Choose a set",
157
+ supported_sets,
158
+ index=0,
159
+ key="set_code",
160
+ on_change=_on_set_changed
161
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  # -------- Main content organised in tabs ------------------------------------
164
 
 
166
 
167
  def add_card(target: str, card: str):
168
  """target is 'pack' or 'picks'."""
169
+ if target == "pack":
170
+ if card not in st.session_state["pack"]:
171
+ st.session_state[target].append(card)
172
+ else:
173
+ st.warning(f"{card} is already in the pack.", icon="⚠️")
174
+ return False
175
+ elif target == "deck":
176
+ st.session_state[target].append(card)
177
+ return True
178
 
179
  def remove_card(target: str, key: str):
180
  lst = st.session_state[target]
 
186
  """Save a snapshot of pack + picks so we can undo one step."""
187
  st.session_state["undo_stack"].append({
188
  "pack": copy.deepcopy(st.session_state["pack"]),
189
+ "deck": copy.deepcopy(st.session_state["deck"]),
190
  })
191
  # (optional) cap history
192
  if len(st.session_state["undo_stack"]) > 20:
 
196
  if st.session_state.get("undo_stack"):
197
  snap = st.session_state["undo_stack"].pop()
198
  st.session_state["pack"] = snap["pack"]
199
+ st.session_state["deck"] = snap["deck"]
200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
 
203
+ # --- Tab 1: Draft -------------------------------------------------------
 
204
  with tabs[0]:
205
+
206
  if st.session_state["undo_stack"]:
207
  st.button("↩️ Undo last action", on_click=undo_last)
208
 
 
211
 
212
  if st.session_state["pack"]:
213
  pack = st.session_state["pack"]
214
+ deck = st.session_state["deck"]
215
 
216
  try:
217
+ pick, logits, scores = rank_cards(deck, pack)
218
+ if pick:
219
+ st.success(f"💡 Suggested pick: **{pick}**", icon="✨")
220
  except Exception as e:
221
  st.error(f"Error calculating card rankings: {e}")
222
+
223
+ options = list(model.cards[set_code.lower()].keys())
224
+ c1, c2 = st.columns(2)
225
+ with c1:
226
+ st.subheader("Add to Deck")
227
+ deck_sel = st.selectbox(
228
+ "Search card (deck)",
229
+ options,
230
+ index=None,
231
+ placeholder="Type to search…",
232
+ key="deck_selectbox",
233
+ on_change=_add_selected_to_deck, # <- auto-add
234
+ )
235
+ if st.session_state["deck"]:
236
+ # header row
237
+ st.button("🗑️ Clear deck", on_click=lambda: st.session_state.update(deck=[]), use_container_width=True)
238
+ h1, h2 = st.columns([6, 3])
239
+ h1.markdown("**Card**")
240
+ h2.markdown("**Remove?**")
241
+ for i, card in enumerate(st.session_state["deck"]):
242
+ name_col, rm_col = st.columns([6, 3], gap="small")
243
+ name_col.write(card)
244
+ with rm_col:
245
+ if st.button("Remove", key=f"rm-deck-{i}", use_container_width=True):
246
+ remove_card("deck", card)
247
+ st.rerun()
248
+ else:
249
+ st.caption("Deck is empty.")
250
+
251
+ with c2:
252
+ st.subheader("Add to pack")
253
+ pack_sel = st.selectbox(
254
+ "Search card (pack)",
255
+ options,
256
+ index=None,
257
+ placeholder="Type to search…",
258
+ key="pack_selectbox",
259
+ on_change=_add_selected_to_pack, # <- auto-add
260
+ )
261
+
262
+
263
+
264
+ if st.session_state["pack"]:
265
+ # header row
266
+ st.button("🗑️ Clear pack", on_click=lambda: st.session_state.update(pack=[]), use_container_width=True)
267
+ h1, h2, h3 = st.columns([6, 2, 3])
268
+ h1.markdown("**Card**")
269
+ h2.markdown("**Score**")
270
+ h3.markdown("**Pick?**")
271
+
272
+ pack_list = st.session_state["pack"]
273
+ vals = [scores.get(c) if scores and c in scores else np.nan for c in pack_list]
274
+ logits = [logits.get(c) if logits and c in logits else np.nan for c in pack_list]
275
+ df_scores = pd.DataFrame({"Card": pack_list, "Score": vals, "Logits": logits})
276
+ df_scores = df_scores.sort_values("Score", ascending=False, na_position="last").reset_index(drop=True)
277
+
278
+
279
+ # rows
280
+ for i, row in df_scores.iterrows():
281
+ card = row["Card"]
282
+ score = row["Score"]
283
+ logit = row["Logits"]
284
+
285
+
286
+ c1, c2, c3 = st.columns([6, 2, 3], gap="small")
287
+ c1.write(card)
288
 
289
+ tooltip_html = f"""
290
+ <div title="{logit:.4f}">
291
+ <progress Value="{score}" max="1" style="width: 100%; height: 20px;"></progress>
292
+ </div>
293
+ """
294
+ c2.markdown(tooltip_html, unsafe_allow_html=True)
295
+ with c3:
296
+ if st.button("Pick", key=f"pick_btn_{i}", use_container_width=True, help="Add to deck & clear pack"):
297
+ push_undo()
298
+ st.session_state["deck"].append(card)
299
+ st.session_state["pack"] = []
300
+ st.rerun()
301
+ else:
302
+ st.caption("Pack is empty.")
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
  # --- Tab 2: Card rankings ----------------------------------------------------
305
 
 
316
  except Exception as e:
317
  st.error(f"Could not calculate P1P1: {e}")
318
  else:
319
+ st.info("Select a set in the sidebar to view P1P1.")