davidmezzetti commited on
Commit
12c5ad7
1 Parent(s): 6100210

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -17
app.py CHANGED
@@ -7,6 +7,7 @@ Install txtai and streamlit (>= 1.23) to run:
7
 
8
  import datetime
9
  import os
 
10
 
11
  import altair as alt
12
  import numpy as np
@@ -124,20 +125,20 @@ class Stats:
124
 
125
  return vectors, data, embeddings
126
 
127
- def metrics(self, player):
128
  """
129
  Looks up a player's active years, best statistical year and key metrics.
130
 
131
  Args:
132
- player: player name
133
 
134
  Returns:
135
  active, best, metrics
136
  """
137
 
138
- if player in self.names:
139
  # Get player stats
140
- stats = self.stats[self.stats["playerID"] == self.names[player]]
141
 
142
  # Build key metrics
143
  metrics = stats[["yearID", self.metric()]]
@@ -150,12 +151,12 @@ class Stats:
150
 
151
  return range(1871, datetime.datetime.today().year), 1950, None
152
 
153
- def search(self, player=None, year=None, row=None, limit=10):
154
  """
155
  Runs an embeddings search. This method takes either a player-year or stats row as input.
156
 
157
  Args:
158
- player: player name to search
159
  year: year to search
160
  row: row of stats to search
161
  limit: max results to return
@@ -168,7 +169,7 @@ class Stats:
168
  query = self.vector(row)
169
  else:
170
  # Lookup player key and build vector id
171
- query = f"{year}{self.names.get(player)}"
172
  query = self.vectors.get(query)
173
 
174
  results, ids = [], set()
@@ -443,29 +444,115 @@ class Application:
443
 
444
  st.markdown("Match by player-season. Each player search defaults to the best season sorted by OPS or Wins Adjusted.")
445
 
446
- category = st.radio("Stat", ["Batting", "Pitching"], horizontal=True, key="playerstat")
447
- stats, default = (self.batting, "Babe Ruth") if category == "Batting" else (self.pitching, "Cy Young")
 
 
 
 
448
 
449
  # Player name
450
  names = sorted(stats.names)
451
- player = st.selectbox("Player", names, names.index(default))
452
 
453
  # Player metrics
454
- active, best, metrics = stats.metrics(player)
455
 
456
  # Player year
457
- year = int(st.select_slider("Year", active, best) if len(active) > 1 else active[0])
458
 
459
  # Display metrics chart
460
  if len(active) > 1:
461
  self.chart(category, metrics)
462
 
463
  # Run search
464
- results = stats.search(player, year)
465
 
466
  # Display results
467
  self.table(results, ["nameFirst", "nameLast", "teamID"] + stats.columns[1:] + ["link"])
468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  def chart(self, category, metrics):
470
  """
471
  Displays a metric chart.
@@ -485,10 +572,7 @@ class Application:
485
  chart = (
486
  alt.Chart(metrics)
487
  .mark_line(interpolate="monotone", point=True, strokeWidth=2.5, opacity=0.75)
488
- .encode(
489
- x=alt.X("yearID", title=""),
490
- y=alt.Y(metric, scale=alt.Scale(zero=False))
491
- )
492
  )
493
 
494
  # Create metric median rule line
 
7
 
8
  import datetime
9
  import os
10
+ import random
11
 
12
  import altair as alt
13
  import numpy as np
 
125
 
126
  return vectors, data, embeddings
127
 
128
+ def metrics(self, name):
129
  """
130
  Looks up a player's active years, best statistical year and key metrics.
131
 
132
  Args:
133
+ name: player name
134
 
135
  Returns:
136
  active, best, metrics
137
  """
138
 
139
+ if name in self.names:
140
  # Get player stats
141
+ stats = self.stats[self.stats["playerID"] == self.names[name]]
142
 
143
  # Build key metrics
144
  metrics = stats[["yearID", self.metric()]]
 
151
 
152
  return range(1871, datetime.datetime.today().year), 1950, None
153
 
154
+ def search(self, name=None, year=None, row=None, limit=10):
155
  """
156
  Runs an embeddings search. This method takes either a player-year or stats row as input.
157
 
158
  Args:
159
+ name: player name to search
160
  year: year to search
161
  row: row of stats to search
162
  limit: max results to return
 
169
  query = self.vector(row)
170
  else:
171
  # Lookup player key and build vector id
172
+ query = f"{year}{self.names.get(name)}"
173
  query = self.vectors.get(query)
174
 
175
  results, ids = [], set()
 
444
 
445
  st.markdown("Match by player-season. Each player search defaults to the best season sorted by OPS or Wins Adjusted.")
446
 
447
+ # Get parameters
448
+ params = self.params()
449
+
450
+ # Category and stats
451
+ category = self.category(params.get("category"))
452
+ stats = self.batting if category == "Batting" else self.pitching
453
 
454
  # Player name
455
  names = sorted(stats.names)
456
+ name = self.name(names, params.get("name"))
457
 
458
  # Player metrics
459
+ active, best, metrics = stats.metrics(name)
460
 
461
  # Player year
462
+ year = self.year(active, params.get("year"), best)
463
 
464
  # Display metrics chart
465
  if len(active) > 1:
466
  self.chart(category, metrics)
467
 
468
  # Run search
469
+ results = stats.search(name, year)
470
 
471
  # Display results
472
  self.table(results, ["nameFirst", "nameLast", "teamID"] + stats.columns[1:] + ["link"])
473
 
474
+ # Save parameters
475
+ st.experimental_set_query_params(category=category, name=name, year=year)
476
+
477
+ def params(self):
478
+ """
479
+ Get application parameters. This method combines URL parameters with session parameters.
480
+
481
+ Returns:
482
+ parameters
483
+ """
484
+
485
+ # Get parameters
486
+ params = st.experimental_get_query_params()
487
+ params = {x: params[x][0] for x in params}
488
+
489
+ # Sync parameters with session state
490
+ if all(x in st.session_state for x in ["category", "name", "year"]):
491
+ # Only use session year if name is unchanged
492
+ params["year"] = str(st.session_state["year"]) if params["name"] == st.session_state["name"] else None
493
+
494
+ # Copy category and name from session state
495
+ params["category"] = st.session_state["category"]
496
+ params["name"] = st.session_state["name"]
497
+
498
+ return params
499
+
500
+ def category(self, category):
501
+ """
502
+ Builds category input widget.
503
+
504
+ Args:
505
+ category: category parameter
506
+
507
+ Returns:
508
+ category component
509
+ """
510
+
511
+ # List of stat categories
512
+ categories = ["Batting", "Pitching"]
513
+
514
+ # Get category parameter, default if not available or valid
515
+ default = categories.index(category) if category and category in categories else 0
516
+
517
+ # Radio box component
518
+ return st.radio("Stat", categories, index=default, horizontal=True, key="category")
519
+
520
+ def name(self, names, name):
521
+ """
522
+ Builds name input widget.
523
+
524
+ Args:
525
+ names: list of all allowable names
526
+
527
+ Returns:
528
+ name component
529
+ """
530
+
531
+ # Get name parameter, default to random value if not valid
532
+ name = name if name and name in names else random.choice(names)
533
+
534
+ # Select box component
535
+ return st.selectbox("Name", names, names.index(name), key="name")
536
+
537
+ def year(self, years, year, best):
538
+ """
539
+ Builds year input widget.
540
+
541
+ Args:
542
+ years: active years for a player
543
+ year: year parameter
544
+ best: default to best year if year is invalid
545
+
546
+ Returns:
547
+ year component
548
+ """
549
+
550
+ # Get year parameter, default if not available or valid
551
+ year = int(year) if year and year.isdigit() and int(year) in years else best
552
+
553
+ # Slider component
554
+ return int(st.select_slider("Year", years, year, key="year") if len(years) > 1 else years[0])
555
+
556
  def chart(self, category, metrics):
557
  """
558
  Displays a metric chart.
 
572
  chart = (
573
  alt.Chart(metrics)
574
  .mark_line(interpolate="monotone", point=True, strokeWidth=2.5, opacity=0.75)
575
+ .encode(x=alt.X("yearID", title=""), y=alt.Y(metric, scale=alt.Scale(zero=False)))
 
 
 
576
  )
577
 
578
  # Create metric median rule line