davidmezzetti commited on
Commit
cf62ef7
1 Parent(s): 58839f9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +418 -0
app.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Baseball statistics application with txtai and Streamlit.
3
+
4
+ Install txtai and streamlit to run:
5
+ pip install txtai streamlit
6
+ """
7
+
8
+ import datetime
9
+ import os
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import streamlit as st
14
+
15
+ from txtai.embeddings import Embeddings
16
+
17
+
18
+ class Stats:
19
+ """
20
+ Base stats class. Contains methods for loading, indexing and searching baseball stats.
21
+ """
22
+
23
+ def __init__(self):
24
+ """
25
+ Creates a new Stats instance.
26
+ """
27
+
28
+ # Load columns
29
+ self.columns = self.loadcolumns()
30
+
31
+ # Load stats data
32
+ self.stats = self.load()
33
+
34
+ # Load names
35
+ self.names = self.loadnames()
36
+
37
+ # Build index
38
+ self.vectors, self.data, self.embeddings = self.index()
39
+
40
+ def loadcolumns(self):
41
+ """
42
+ Returns a list of data columns.
43
+
44
+ Returns:
45
+ list of columns
46
+ """
47
+
48
+ raise NotImplementedError
49
+
50
+ def load(self):
51
+ """
52
+ Loads and returns raw stats.
53
+
54
+ Returns:
55
+ stats
56
+ """
57
+
58
+ raise NotImplementedError
59
+
60
+ def sort(self, rows):
61
+ """
62
+ Sorts rows stored as a DataFrame.
63
+
64
+ Args:
65
+ rows: input DataFrame
66
+
67
+ Returns:
68
+ sorted DataFrame
69
+ """
70
+
71
+ raise NotImplementedError
72
+
73
+ def vector(self, row):
74
+ """
75
+ Build a vector for input row.
76
+
77
+ Args:
78
+ row: input row
79
+
80
+ Returns:
81
+ row vector
82
+ """
83
+
84
+ raise NotImplementedError
85
+
86
+ def loadnames(self):
87
+ """
88
+ Loads a name - player id dictionary.
89
+
90
+ Returns:
91
+ {player name: player id}
92
+ """
93
+
94
+ # Get unique names
95
+ names = {}
96
+ rows = self.stats[["nameFirst", "nameLast", "playerID"]].drop_duplicates()
97
+ for _, row in rows.iterrows():
98
+ # Name key
99
+ key = f"{row['nameFirst']} {row['nameLast']}"
100
+ suffix = f" ({row['playerID']})" if key in names else ""
101
+
102
+ # Save name key - player id pair
103
+ names[f"{key}{suffix}"] = row["playerID"]
104
+
105
+ return names
106
+
107
+ def index(self):
108
+ """
109
+ Builds an embeddings index to stats data. Returns vectors, input data and embeddings index.
110
+
111
+ Returns:
112
+ vectors, data, embeddings
113
+ """
114
+
115
+ # Build data dictionary
116
+ vectors = {f'{row["yearID"]}{row["playerID"]}': self.transform(row) for _, row in self.stats.iterrows()}
117
+ data = {f'{row["yearID"]}{row["playerID"]}': dict(row) for _, row in self.stats.iterrows()}
118
+
119
+ embeddings = Embeddings({
120
+ "transform": self.transform,
121
+ })
122
+
123
+ embeddings.index((uid, vectors[uid], None) for uid in vectors)
124
+
125
+ return vectors, data, embeddings
126
+
127
+ def years(self, player):
128
+ """
129
+ Looks up the years active for a player along with the player's best statistical year.
130
+
131
+ Args:
132
+ player: player name
133
+
134
+ Returns:
135
+ start, end, best
136
+ """
137
+
138
+ if player in self.names:
139
+ df = self.sort(self.stats[self.stats["playerID"] == self.names[player]])
140
+ return int(df["yearID"].min()), int(df["yearID"].max()), int(df["yearID"].iloc[0])
141
+
142
+ return 1871, datetime.datetime.today().year, 1950
143
+
144
+ def search(self, player=None, year=None, row=None, limit=10):
145
+ """
146
+ Runs an embeddings search. This method takes either a player-year or stats row as input.
147
+
148
+ Args:
149
+ player: player name to search
150
+ year: year to search
151
+ row: row of stats to search
152
+ limit: max results to return
153
+
154
+ Returns:
155
+ list of results
156
+ """
157
+
158
+ if row:
159
+ query = self.vector(row)
160
+ else:
161
+ # Lookup player key and build vector id
162
+ query = f"{year}{self.names.get(player)}"
163
+ query = self.vectors.get(query)
164
+
165
+ results, ids = [], set()
166
+ if query is not None:
167
+ for uid, _ in self.embeddings.search(query, limit * 5):
168
+ # Only add unique players
169
+ if uid[4:] not in ids:
170
+ result = self.data[uid].copy()
171
+ result["link"] = f'https://www.baseball-reference.com/players/{result["nameLast"].lower()[0]}/{result["bbrefID"]}.shtml'
172
+ result["yearID"] = str(result["yearID"])
173
+ results.append(result)
174
+ ids.add(uid[4:])
175
+
176
+ if len(ids) >= limit:
177
+ break
178
+
179
+ return results
180
+
181
+ def transform(self, row):
182
+ """
183
+ Transforms a stats row into a vector.
184
+
185
+ Args:
186
+ row: stats row
187
+
188
+ Returns:
189
+ vector
190
+ """
191
+
192
+ if isinstance(row, np.ndarray):
193
+ return row
194
+
195
+ return np.array([0.0 if not row[x] or np.isnan(row[x]) else row[x] for x in self.columns])
196
+
197
+
198
+ class Batting(Stats):
199
+ def loadcolumns(self):
200
+ return [
201
+ "birthMonth", "age", "weight", "height", "yearID", "G", "AB", "R", "H", "1B", "2B", "3B", "HR", "RBI", "SB", "CS",
202
+ "BB", "SO", "IBB", "HBP", "SH", "SF", "GIDP", "POS", "AVG", "OBP", "TB", "SLG", "OPS", "OPS+"
203
+ ]
204
+
205
+ def load(self):
206
+ # Retrieve raw data from GitHub
207
+ players = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/People.csv")
208
+ batting = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/Batting.csv")
209
+ fielding = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/Fielding.csv")
210
+
211
+ # Merge player data in
212
+ batting = pd.merge(players, batting, how="inner", on=["playerID"])
213
+
214
+ # Require player to have at least 350 plate appearances.
215
+ batting = batting[(batting["AB"] + batting["BB"]) >= 350]
216
+
217
+ # Derive primary player positions
218
+ positions = self.positions(fielding)
219
+
220
+ # Calculated columns
221
+ batting["age"] = batting["yearID"] - batting["birthYear"]
222
+ batting["POS"] = batting.apply(lambda row: self.position(positions, row), axis=1)
223
+ batting["AVG"] = batting["H"] / batting["AB"]
224
+ batting["OBP"] = (batting["H"] + batting["BB"]) / (batting["AB"] + batting["BB"])
225
+ batting["1B"] = batting["H"] - batting["2B"] - batting["3B"] - batting["HR"]
226
+ batting["TB"] = batting["1B"] + 2 * batting["2B"] + 3 * batting["3B"] + 4 * batting["HR"]
227
+ batting["SLG"] = batting["TB"] / batting["AB"]
228
+ batting["OPS"] = batting["OBP"] + batting["SLG"]
229
+ batting["OPS+"] = 100 + (batting["OPS"] - batting["OPS"].mean()) * 100
230
+
231
+ return batting
232
+
233
+ def sort(self, rows):
234
+ return rows.sort_values(by="OPS+", ascending=False)
235
+
236
+ def vector(self, row):
237
+ row["TB"] = row["1B"] + 2 * row["2B"] + 3 * row["3B"] + 4 * row["HR"]
238
+ row["AVG"] = row["H"] / row["AB"]
239
+ row["OBP"] = (row["H"] + row["BB"]) / (row["AB"] + row["BB"])
240
+ row["SLG"] = row["TB"] / row["AB"]
241
+ row["OPS"] = row["OBP"] + row["SLG"]
242
+ row["OPS+"] = 100 + (row["OPS"] - self.stats["OPS"].mean()) * 100
243
+
244
+ return self.transform(row)
245
+
246
+ def positions(self, fielding):
247
+ """
248
+ Derives primary positions for players.
249
+
250
+ Args:
251
+ fielding: fielding data
252
+
253
+ Returns:
254
+ {player id: (position, number of games)}
255
+ """
256
+
257
+ positions = {}
258
+ for x, row in fielding.iterrows():
259
+ uid = f'{row["yearID"]}{row["playerID"]}'
260
+ position = row["POS"] if row["POS"] else 0
261
+ if position == "P":
262
+ position = 1
263
+ elif position == "C":
264
+ position = 2
265
+ elif position == "1B":
266
+ position = 3
267
+ elif position == "2B":
268
+ position = 4
269
+ elif position == "3B":
270
+ position = 5
271
+ elif position == "SS":
272
+ position = 6
273
+ elif position == "OF":
274
+ position = 7
275
+
276
+ # Save position if not set or player played more at this position
277
+ if uid not in positions or positions[uid][1] < row["G"]:
278
+ positions[uid] = (position, row["G"])
279
+
280
+ return positions
281
+
282
+ def position(self, positions, row):
283
+ """
284
+ Looks up primary position for player row.
285
+
286
+ Arg:
287
+ positions: all player positions
288
+ row: player row
289
+
290
+ Returns:
291
+ primary player positions
292
+ """
293
+
294
+ uid = f'{row["yearID"]}{row["playerID"]}'
295
+ return positions[uid][0] if uid in positions else 0
296
+
297
+ class Pitching(Stats):
298
+ def loadcolumns(self):
299
+ return [
300
+ "birthMonth", "age", "weight", "height", "yearID", "W", "L", "G", "GS", "CG", "SHO", "SV", "IPouts",
301
+ "H", "ER", "HR", "BB", "SO", "BAOpp", "ERA", "IBB", "WP", "HBP", "BK", "BFP", "GF", "R", "SH", "SF",
302
+ "GIDP", "WHIP", "WADJ"
303
+ ]
304
+
305
+ def load(self):
306
+ # Retrieve raw data from GitHub
307
+ players = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/People.csv")
308
+ pitching = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/Pitching.csv")
309
+
310
+ # Merge player data in
311
+ pitching = pd.merge(players, pitching, how="inner", on=["playerID"])
312
+
313
+ # Require player to have 20 appearances
314
+ pitching = pitching[pitching["G"] >= 20]
315
+
316
+ # Calculated columns
317
+ pitching["age"] = pitching["yearID"] - pitching["birthYear"]
318
+ pitching["WHIP"] = (pitching["BB"] + pitching["H"]) / (pitching["IPouts"] / 3)
319
+ pitching["WADJ"] =(pitching["W"] + pitching["SV"]) / (pitching["ERA"] + pitching["WHIP"])
320
+
321
+ return pitching
322
+
323
+ def sort(self, rows):
324
+ return rows.sort_values(by="WADJ", ascending=False)
325
+
326
+ def vector(self, row):
327
+ row["WHIP"] = (row["BB"] + row["H"]) / (row["IPouts"] / 3) if row["IPouts"] else None
328
+ row["WADJ"] =(row["W"] + row["SV"]) / (row["ERA"] + row["WHIP"]) if row["ERA"] and row["WHIP"] else None
329
+
330
+ return self.transform(row)
331
+
332
+
333
+ class Application:
334
+ """
335
+ Main application.
336
+ """
337
+
338
+ def __init__(self):
339
+ """
340
+ Creates a new application.
341
+ """
342
+
343
+ # Batting stats
344
+ self.batting = Batting()
345
+
346
+ # Pitching stats
347
+ self.pitching = Pitching()
348
+
349
+ def run(self):
350
+ """
351
+ Runs a Streamlit application.
352
+ """
353
+
354
+ st.title("⚾ Baseball Statistics")
355
+ st.markdown("""
356
+ This application finds the best matching historical players using vector search with [txtai](https://github.com/neuml/txtai).
357
+ Raw data is from the [Baseball Databank](https://github.com/chadwickbureau/baseballdatabank) GitHub project.
358
+ """)
359
+
360
+ self.player()
361
+
362
+ def player(self):
363
+ """
364
+ Player tab.
365
+ """
366
+
367
+ st.markdown("Match by player-season. Each player search defaults to the best season sorted by OPS or Wins Adjusted.")
368
+
369
+ category = st.radio("Stat", ["Batting", "Pitching"], horizontal=True, key="playerstat")
370
+ stats, default = (self.batting, "Babe Ruth") if category == "Batting" else (self.pitching, "Cy Young")
371
+
372
+ # Player name
373
+ names = sorted(stats.names)
374
+ player = st.selectbox("Player", names, names.index(default))
375
+
376
+ # Player year
377
+ start, end, best = stats.years(player)
378
+ year = st.slider("Year", start, end, best) if start != end else start
379
+
380
+ # Run search
381
+ results = stats.search(player, year)
382
+
383
+ # Display results
384
+ self.display(results, ["nameFirst", "nameLast", "teamID"] + stats.columns[1:] + ["link"])
385
+
386
+ def display(self, results, columns):
387
+ """
388
+ Displays a list of results.
389
+
390
+ Args:
391
+ results: list of results
392
+ columns: column names
393
+ """
394
+
395
+ if results:
396
+ st.dataframe(pd.DataFrame(results)[columns])
397
+ else:
398
+ st.write("Player-Year not found")
399
+
400
+
401
+ @st.cache_resource(show_spinner=False)
402
+ def create():
403
+ """
404
+ Creates and caches a Streamlit application.
405
+
406
+ Returns:
407
+ Application
408
+ """
409
+
410
+ return Application()
411
+
412
+
413
+ if __name__ == "__main__":
414
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
415
+
416
+ # Create and run application
417
+ app = create()
418
+ app.run()