davidmezzetti
commited on
Commit
•
cf62ef7
1
Parent(s):
58839f9
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Baseball statistics application with txtai and Streamlit.
|
3 |
+
|
4 |
+
Install txtai and streamlit to run:
|
5 |
+
pip install txtai streamlit
|
6 |
+
"""
|
7 |
+
|
8 |
+
import datetime
|
9 |
+
import os
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import pandas as pd
|
13 |
+
import streamlit as st
|
14 |
+
|
15 |
+
from txtai.embeddings import Embeddings
|
16 |
+
|
17 |
+
|
18 |
+
class Stats:
|
19 |
+
"""
|
20 |
+
Base stats class. Contains methods for loading, indexing and searching baseball stats.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self):
|
24 |
+
"""
|
25 |
+
Creates a new Stats instance.
|
26 |
+
"""
|
27 |
+
|
28 |
+
# Load columns
|
29 |
+
self.columns = self.loadcolumns()
|
30 |
+
|
31 |
+
# Load stats data
|
32 |
+
self.stats = self.load()
|
33 |
+
|
34 |
+
# Load names
|
35 |
+
self.names = self.loadnames()
|
36 |
+
|
37 |
+
# Build index
|
38 |
+
self.vectors, self.data, self.embeddings = self.index()
|
39 |
+
|
40 |
+
def loadcolumns(self):
|
41 |
+
"""
|
42 |
+
Returns a list of data columns.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
list of columns
|
46 |
+
"""
|
47 |
+
|
48 |
+
raise NotImplementedError
|
49 |
+
|
50 |
+
def load(self):
|
51 |
+
"""
|
52 |
+
Loads and returns raw stats.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
stats
|
56 |
+
"""
|
57 |
+
|
58 |
+
raise NotImplementedError
|
59 |
+
|
60 |
+
def sort(self, rows):
|
61 |
+
"""
|
62 |
+
Sorts rows stored as a DataFrame.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
rows: input DataFrame
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
sorted DataFrame
|
69 |
+
"""
|
70 |
+
|
71 |
+
raise NotImplementedError
|
72 |
+
|
73 |
+
def vector(self, row):
|
74 |
+
"""
|
75 |
+
Build a vector for input row.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
row: input row
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
row vector
|
82 |
+
"""
|
83 |
+
|
84 |
+
raise NotImplementedError
|
85 |
+
|
86 |
+
def loadnames(self):
|
87 |
+
"""
|
88 |
+
Loads a name - player id dictionary.
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
{player name: player id}
|
92 |
+
"""
|
93 |
+
|
94 |
+
# Get unique names
|
95 |
+
names = {}
|
96 |
+
rows = self.stats[["nameFirst", "nameLast", "playerID"]].drop_duplicates()
|
97 |
+
for _, row in rows.iterrows():
|
98 |
+
# Name key
|
99 |
+
key = f"{row['nameFirst']} {row['nameLast']}"
|
100 |
+
suffix = f" ({row['playerID']})" if key in names else ""
|
101 |
+
|
102 |
+
# Save name key - player id pair
|
103 |
+
names[f"{key}{suffix}"] = row["playerID"]
|
104 |
+
|
105 |
+
return names
|
106 |
+
|
107 |
+
def index(self):
|
108 |
+
"""
|
109 |
+
Builds an embeddings index to stats data. Returns vectors, input data and embeddings index.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
vectors, data, embeddings
|
113 |
+
"""
|
114 |
+
|
115 |
+
# Build data dictionary
|
116 |
+
vectors = {f'{row["yearID"]}{row["playerID"]}': self.transform(row) for _, row in self.stats.iterrows()}
|
117 |
+
data = {f'{row["yearID"]}{row["playerID"]}': dict(row) for _, row in self.stats.iterrows()}
|
118 |
+
|
119 |
+
embeddings = Embeddings({
|
120 |
+
"transform": self.transform,
|
121 |
+
})
|
122 |
+
|
123 |
+
embeddings.index((uid, vectors[uid], None) for uid in vectors)
|
124 |
+
|
125 |
+
return vectors, data, embeddings
|
126 |
+
|
127 |
+
def years(self, player):
|
128 |
+
"""
|
129 |
+
Looks up the years active for a player along with the player's best statistical year.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
player: player name
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
start, end, best
|
136 |
+
"""
|
137 |
+
|
138 |
+
if player in self.names:
|
139 |
+
df = self.sort(self.stats[self.stats["playerID"] == self.names[player]])
|
140 |
+
return int(df["yearID"].min()), int(df["yearID"].max()), int(df["yearID"].iloc[0])
|
141 |
+
|
142 |
+
return 1871, datetime.datetime.today().year, 1950
|
143 |
+
|
144 |
+
def search(self, player=None, year=None, row=None, limit=10):
|
145 |
+
"""
|
146 |
+
Runs an embeddings search. This method takes either a player-year or stats row as input.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
player: player name to search
|
150 |
+
year: year to search
|
151 |
+
row: row of stats to search
|
152 |
+
limit: max results to return
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
list of results
|
156 |
+
"""
|
157 |
+
|
158 |
+
if row:
|
159 |
+
query = self.vector(row)
|
160 |
+
else:
|
161 |
+
# Lookup player key and build vector id
|
162 |
+
query = f"{year}{self.names.get(player)}"
|
163 |
+
query = self.vectors.get(query)
|
164 |
+
|
165 |
+
results, ids = [], set()
|
166 |
+
if query is not None:
|
167 |
+
for uid, _ in self.embeddings.search(query, limit * 5):
|
168 |
+
# Only add unique players
|
169 |
+
if uid[4:] not in ids:
|
170 |
+
result = self.data[uid].copy()
|
171 |
+
result["link"] = f'https://www.baseball-reference.com/players/{result["nameLast"].lower()[0]}/{result["bbrefID"]}.shtml'
|
172 |
+
result["yearID"] = str(result["yearID"])
|
173 |
+
results.append(result)
|
174 |
+
ids.add(uid[4:])
|
175 |
+
|
176 |
+
if len(ids) >= limit:
|
177 |
+
break
|
178 |
+
|
179 |
+
return results
|
180 |
+
|
181 |
+
def transform(self, row):
|
182 |
+
"""
|
183 |
+
Transforms a stats row into a vector.
|
184 |
+
|
185 |
+
Args:
|
186 |
+
row: stats row
|
187 |
+
|
188 |
+
Returns:
|
189 |
+
vector
|
190 |
+
"""
|
191 |
+
|
192 |
+
if isinstance(row, np.ndarray):
|
193 |
+
return row
|
194 |
+
|
195 |
+
return np.array([0.0 if not row[x] or np.isnan(row[x]) else row[x] for x in self.columns])
|
196 |
+
|
197 |
+
|
198 |
+
class Batting(Stats):
|
199 |
+
def loadcolumns(self):
|
200 |
+
return [
|
201 |
+
"birthMonth", "age", "weight", "height", "yearID", "G", "AB", "R", "H", "1B", "2B", "3B", "HR", "RBI", "SB", "CS",
|
202 |
+
"BB", "SO", "IBB", "HBP", "SH", "SF", "GIDP", "POS", "AVG", "OBP", "TB", "SLG", "OPS", "OPS+"
|
203 |
+
]
|
204 |
+
|
205 |
+
def load(self):
|
206 |
+
# Retrieve raw data from GitHub
|
207 |
+
players = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/People.csv")
|
208 |
+
batting = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/Batting.csv")
|
209 |
+
fielding = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/Fielding.csv")
|
210 |
+
|
211 |
+
# Merge player data in
|
212 |
+
batting = pd.merge(players, batting, how="inner", on=["playerID"])
|
213 |
+
|
214 |
+
# Require player to have at least 350 plate appearances.
|
215 |
+
batting = batting[(batting["AB"] + batting["BB"]) >= 350]
|
216 |
+
|
217 |
+
# Derive primary player positions
|
218 |
+
positions = self.positions(fielding)
|
219 |
+
|
220 |
+
# Calculated columns
|
221 |
+
batting["age"] = batting["yearID"] - batting["birthYear"]
|
222 |
+
batting["POS"] = batting.apply(lambda row: self.position(positions, row), axis=1)
|
223 |
+
batting["AVG"] = batting["H"] / batting["AB"]
|
224 |
+
batting["OBP"] = (batting["H"] + batting["BB"]) / (batting["AB"] + batting["BB"])
|
225 |
+
batting["1B"] = batting["H"] - batting["2B"] - batting["3B"] - batting["HR"]
|
226 |
+
batting["TB"] = batting["1B"] + 2 * batting["2B"] + 3 * batting["3B"] + 4 * batting["HR"]
|
227 |
+
batting["SLG"] = batting["TB"] / batting["AB"]
|
228 |
+
batting["OPS"] = batting["OBP"] + batting["SLG"]
|
229 |
+
batting["OPS+"] = 100 + (batting["OPS"] - batting["OPS"].mean()) * 100
|
230 |
+
|
231 |
+
return batting
|
232 |
+
|
233 |
+
def sort(self, rows):
|
234 |
+
return rows.sort_values(by="OPS+", ascending=False)
|
235 |
+
|
236 |
+
def vector(self, row):
|
237 |
+
row["TB"] = row["1B"] + 2 * row["2B"] + 3 * row["3B"] + 4 * row["HR"]
|
238 |
+
row["AVG"] = row["H"] / row["AB"]
|
239 |
+
row["OBP"] = (row["H"] + row["BB"]) / (row["AB"] + row["BB"])
|
240 |
+
row["SLG"] = row["TB"] / row["AB"]
|
241 |
+
row["OPS"] = row["OBP"] + row["SLG"]
|
242 |
+
row["OPS+"] = 100 + (row["OPS"] - self.stats["OPS"].mean()) * 100
|
243 |
+
|
244 |
+
return self.transform(row)
|
245 |
+
|
246 |
+
def positions(self, fielding):
|
247 |
+
"""
|
248 |
+
Derives primary positions for players.
|
249 |
+
|
250 |
+
Args:
|
251 |
+
fielding: fielding data
|
252 |
+
|
253 |
+
Returns:
|
254 |
+
{player id: (position, number of games)}
|
255 |
+
"""
|
256 |
+
|
257 |
+
positions = {}
|
258 |
+
for x, row in fielding.iterrows():
|
259 |
+
uid = f'{row["yearID"]}{row["playerID"]}'
|
260 |
+
position = row["POS"] if row["POS"] else 0
|
261 |
+
if position == "P":
|
262 |
+
position = 1
|
263 |
+
elif position == "C":
|
264 |
+
position = 2
|
265 |
+
elif position == "1B":
|
266 |
+
position = 3
|
267 |
+
elif position == "2B":
|
268 |
+
position = 4
|
269 |
+
elif position == "3B":
|
270 |
+
position = 5
|
271 |
+
elif position == "SS":
|
272 |
+
position = 6
|
273 |
+
elif position == "OF":
|
274 |
+
position = 7
|
275 |
+
|
276 |
+
# Save position if not set or player played more at this position
|
277 |
+
if uid not in positions or positions[uid][1] < row["G"]:
|
278 |
+
positions[uid] = (position, row["G"])
|
279 |
+
|
280 |
+
return positions
|
281 |
+
|
282 |
+
def position(self, positions, row):
|
283 |
+
"""
|
284 |
+
Looks up primary position for player row.
|
285 |
+
|
286 |
+
Arg:
|
287 |
+
positions: all player positions
|
288 |
+
row: player row
|
289 |
+
|
290 |
+
Returns:
|
291 |
+
primary player positions
|
292 |
+
"""
|
293 |
+
|
294 |
+
uid = f'{row["yearID"]}{row["playerID"]}'
|
295 |
+
return positions[uid][0] if uid in positions else 0
|
296 |
+
|
297 |
+
class Pitching(Stats):
|
298 |
+
def loadcolumns(self):
|
299 |
+
return [
|
300 |
+
"birthMonth", "age", "weight", "height", "yearID", "W", "L", "G", "GS", "CG", "SHO", "SV", "IPouts",
|
301 |
+
"H", "ER", "HR", "BB", "SO", "BAOpp", "ERA", "IBB", "WP", "HBP", "BK", "BFP", "GF", "R", "SH", "SF",
|
302 |
+
"GIDP", "WHIP", "WADJ"
|
303 |
+
]
|
304 |
+
|
305 |
+
def load(self):
|
306 |
+
# Retrieve raw data from GitHub
|
307 |
+
players = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/People.csv")
|
308 |
+
pitching = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/Pitching.csv")
|
309 |
+
|
310 |
+
# Merge player data in
|
311 |
+
pitching = pd.merge(players, pitching, how="inner", on=["playerID"])
|
312 |
+
|
313 |
+
# Require player to have 20 appearances
|
314 |
+
pitching = pitching[pitching["G"] >= 20]
|
315 |
+
|
316 |
+
# Calculated columns
|
317 |
+
pitching["age"] = pitching["yearID"] - pitching["birthYear"]
|
318 |
+
pitching["WHIP"] = (pitching["BB"] + pitching["H"]) / (pitching["IPouts"] / 3)
|
319 |
+
pitching["WADJ"] =(pitching["W"] + pitching["SV"]) / (pitching["ERA"] + pitching["WHIP"])
|
320 |
+
|
321 |
+
return pitching
|
322 |
+
|
323 |
+
def sort(self, rows):
|
324 |
+
return rows.sort_values(by="WADJ", ascending=False)
|
325 |
+
|
326 |
+
def vector(self, row):
|
327 |
+
row["WHIP"] = (row["BB"] + row["H"]) / (row["IPouts"] / 3) if row["IPouts"] else None
|
328 |
+
row["WADJ"] =(row["W"] + row["SV"]) / (row["ERA"] + row["WHIP"]) if row["ERA"] and row["WHIP"] else None
|
329 |
+
|
330 |
+
return self.transform(row)
|
331 |
+
|
332 |
+
|
333 |
+
class Application:
|
334 |
+
"""
|
335 |
+
Main application.
|
336 |
+
"""
|
337 |
+
|
338 |
+
def __init__(self):
|
339 |
+
"""
|
340 |
+
Creates a new application.
|
341 |
+
"""
|
342 |
+
|
343 |
+
# Batting stats
|
344 |
+
self.batting = Batting()
|
345 |
+
|
346 |
+
# Pitching stats
|
347 |
+
self.pitching = Pitching()
|
348 |
+
|
349 |
+
def run(self):
|
350 |
+
"""
|
351 |
+
Runs a Streamlit application.
|
352 |
+
"""
|
353 |
+
|
354 |
+
st.title("⚾ Baseball Statistics")
|
355 |
+
st.markdown("""
|
356 |
+
This application finds the best matching historical players using vector search with [txtai](https://github.com/neuml/txtai).
|
357 |
+
Raw data is from the [Baseball Databank](https://github.com/chadwickbureau/baseballdatabank) GitHub project.
|
358 |
+
""")
|
359 |
+
|
360 |
+
self.player()
|
361 |
+
|
362 |
+
def player(self):
|
363 |
+
"""
|
364 |
+
Player tab.
|
365 |
+
"""
|
366 |
+
|
367 |
+
st.markdown("Match by player-season. Each player search defaults to the best season sorted by OPS or Wins Adjusted.")
|
368 |
+
|
369 |
+
category = st.radio("Stat", ["Batting", "Pitching"], horizontal=True, key="playerstat")
|
370 |
+
stats, default = (self.batting, "Babe Ruth") if category == "Batting" else (self.pitching, "Cy Young")
|
371 |
+
|
372 |
+
# Player name
|
373 |
+
names = sorted(stats.names)
|
374 |
+
player = st.selectbox("Player", names, names.index(default))
|
375 |
+
|
376 |
+
# Player year
|
377 |
+
start, end, best = stats.years(player)
|
378 |
+
year = st.slider("Year", start, end, best) if start != end else start
|
379 |
+
|
380 |
+
# Run search
|
381 |
+
results = stats.search(player, year)
|
382 |
+
|
383 |
+
# Display results
|
384 |
+
self.display(results, ["nameFirst", "nameLast", "teamID"] + stats.columns[1:] + ["link"])
|
385 |
+
|
386 |
+
def display(self, results, columns):
|
387 |
+
"""
|
388 |
+
Displays a list of results.
|
389 |
+
|
390 |
+
Args:
|
391 |
+
results: list of results
|
392 |
+
columns: column names
|
393 |
+
"""
|
394 |
+
|
395 |
+
if results:
|
396 |
+
st.dataframe(pd.DataFrame(results)[columns])
|
397 |
+
else:
|
398 |
+
st.write("Player-Year not found")
|
399 |
+
|
400 |
+
|
401 |
+
@st.cache_resource(show_spinner=False)
|
402 |
+
def create():
|
403 |
+
"""
|
404 |
+
Creates and caches a Streamlit application.
|
405 |
+
|
406 |
+
Returns:
|
407 |
+
Application
|
408 |
+
"""
|
409 |
+
|
410 |
+
return Application()
|
411 |
+
|
412 |
+
|
413 |
+
if __name__ == "__main__":
|
414 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
415 |
+
|
416 |
+
# Create and run application
|
417 |
+
app = create()
|
418 |
+
app.run()
|