Cal Mitchell
commited on
Commit
•
5f26252
1
Parent(s):
b9bdc56
init
Browse files- .gitignore +2 -0
- LICENSE +9 -0
- README.md +33 -5
- age_tokens.csv +33 -0
- example.ipynb +174 -0
- model.py +173 -0
- player_tokens.csv +0 -0
- porzingis-swapped-for-pritchard.png +0 -0
- prediction.png +0 -0
- requirements.txt +4 -0
- weights.pt +3 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
.python-version
|
LICENSE
CHANGED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# License
|
2 |
+
|
3 |
+
You may use the code and weights in this repository for personal, non-commercial purposes.
|
4 |
+
|
5 |
+
You may publicly share any results output by the model.
|
6 |
+
|
7 |
+
You may not build or release any products (whether open source or proprietary), or provide any services, using the code, weights, or derivatives of either, in this repository.
|
8 |
+
|
9 |
+
All code written and all weights trained by Cal Mitchell. All rights are reserved.
|
README.md
CHANGED
@@ -1,5 +1,33 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# NBA Predictions
|
2 |
+
|
3 |
+
This repo contains AI model code and weights which predicts the outcome of NBA games. Its output represents the chance that a given point spread will occur.
|
4 |
+
|
5 |
+
The model requires 8 players on the home and away teams, plus their ages, as input. It will then output probabilities for each point spread between -20 and +20 points, from the home team's point of view.
|
6 |
+
|
7 |
+
For example, the following text and chart shows the model predicting the home team with a 77% chance to win and a 14% chance of winning by 20 or more points. This kind of chart is indicative of a dominant team playing at home. Most games will have more of a bell curve shape to them.
|
8 |
+
|
9 |
+
![NBA prediction graph](prediction.png)
|
10 |
+
|
11 |
+
## Installation
|
12 |
+
|
13 |
+
I recommend installing Python 3.11.8, as that is what the repo was written / tested in. The code will likely work with most recent versions of Python, though.
|
14 |
+
|
15 |
+
Once you have Python installed, run `pip install -r requirements.txt`. It will take a while to install dependencies if you don't already have PyTorch cached.
|
16 |
+
|
17 |
+
## Usage
|
18 |
+
|
19 |
+
The `example.ipynb` notebook shows how to use the model to predict the final game of the 2023-24 NBA season - a game between the Dallas Mavericks and Boston Celtics. It will output the chart above.
|
20 |
+
|
21 |
+
To change the players and their ages, you must reference the `player_tokens.csv` and `age_tokens.csv` files.
|
22 |
+
|
23 |
+
For example, if you wanted to subtract Kristaps Porzingis from Boston's team and swap who was home / away, you would take the token representing Porzingis `4416` out of the `home_team_tokens` list, and replace him with, say, Payton Pritchard `4999`. You would then have to look up Pritchard's age (26), find the corresponding age token in `age_tokens.csv`, which is `11`, and replace Porzingis' age token (which is the second to last token).
|
24 |
+
|
25 |
+
To swap home and away, you could replace the variables containing all of the player and age tokens, or just set the `swap_home_away` variable to `True`. The results are as follows:
|
26 |
+
|
27 |
+
![NBA Finals prediction without Porzingis](porzingis-swapped-for-pritchard.png)
|
28 |
+
|
29 |
+
As you can see, Dallas' win probability improved from 23% to 35%, and their chance of being blown out by 20+ points decreased from 14% to 10%. Clearly, the model thinks Porzingis is important to the Celtics' chances, but still considers Boston to be the superior team without him.
|
30 |
+
|
31 |
+
## Training Process
|
32 |
+
|
33 |
+
I downloaded data from stats.nba.com using the [https://github.com/swar/nba_api](swar/nba_api) package to get information on minutes played, game outcomes, and a few other dimensional elements to make everything fit together. Then, I ran a custom PyTorch training loop to train the model(s) on their chosen loss objective (spread, money line, or spread probability).
|
age_tokens.csv
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
age,token
|
2 |
+
16,1
|
3 |
+
17,2
|
4 |
+
18,3
|
5 |
+
19,4
|
6 |
+
20,5
|
7 |
+
21,6
|
8 |
+
22,7
|
9 |
+
23,8
|
10 |
+
24,9
|
11 |
+
25,10
|
12 |
+
26,11
|
13 |
+
27,12
|
14 |
+
28,13
|
15 |
+
29,14
|
16 |
+
30,15
|
17 |
+
31,16
|
18 |
+
32,17
|
19 |
+
33,18
|
20 |
+
34,19
|
21 |
+
35,20
|
22 |
+
36,21
|
23 |
+
37,22
|
24 |
+
38,23
|
25 |
+
39,24
|
26 |
+
40,25
|
27 |
+
41,26
|
28 |
+
42,27
|
29 |
+
43,28
|
30 |
+
44,29
|
31 |
+
45,30
|
32 |
+
46,31
|
33 |
+
47,32
|
example.ipynb
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"from model import NBAModel, NBAConfig\n",
|
10 |
+
"from torch import device as torch_device, load as torch_load, int32, Tensor, bfloat16\n",
|
11 |
+
"import matplotlib.pyplot as plt\n",
|
12 |
+
"\n",
|
13 |
+
"device = torch_device(\"cpu\") \n",
|
14 |
+
"num_age_tokens=32\n",
|
15 |
+
"num_player_tokens=5141\n",
|
16 |
+
"num_net_score_tokens=41\n",
|
17 |
+
"players_per_team=8\n",
|
18 |
+
"\n",
|
19 |
+
"model_config = NBAConfig(\n",
|
20 |
+
" players_per_team=players_per_team,\n",
|
21 |
+
" player_tokens=num_player_tokens+2,\n",
|
22 |
+
" age_tokens=num_age_tokens+2,\n",
|
23 |
+
" num_labels=num_net_score_tokens+2,\n",
|
24 |
+
" n_layer=4,\n",
|
25 |
+
" n_head=4,\n",
|
26 |
+
" n_embd=1024,\n",
|
27 |
+
" dropout=0.0,\n",
|
28 |
+
" bias=False,\n",
|
29 |
+
" dtype=bfloat16,\n",
|
30 |
+
" seed=29,\n",
|
31 |
+
")\n",
|
32 |
+
"\n",
|
33 |
+
"model = NBAModel(model_config).to(device)\n",
|
34 |
+
"state_dict = torch_load('weights.pt', map_location='cpu')\n",
|
35 |
+
"model.load_state_dict(state_dict)\n",
|
36 |
+
"model = model.eval()"
|
37 |
+
]
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"cell_type": "code",
|
41 |
+
"execution_count": 2,
|
42 |
+
"metadata": {},
|
43 |
+
"outputs": [],
|
44 |
+
"source": [
|
45 |
+
"# Change player and age tokens here!\n",
|
46 |
+
"# You can find these values in player_tokens.csv and age_tokens.csv\n",
|
47 |
+
"# You must provide exactly 8 player tokens and 8 age tokens for each team.\n",
|
48 |
+
"\n",
|
49 |
+
"# Boston Celtics final game of 2023-24 season roster\n",
|
50 |
+
"home_player_tokens = [1994, 5039, 5027, 4981, 4972, 5004, 4416, 4983]\n",
|
51 |
+
"home_age_tokens = [11, 12, 19, 14, 23, 11, 13, 13]\n",
|
52 |
+
"\n",
|
53 |
+
"# Dallas Mavericks final game of 2023-24 season roster\n",
|
54 |
+
"away_player_tokens = [5117, 5097, 4956, 5109, 55, 149, 5121, 5112]\n",
|
55 |
+
"away_age_tokens = [10, 17, 10, 12, 10, 5, 8, 17]\n",
|
56 |
+
"\n",
|
57 |
+
"# The model usually gives the home team a bump in win probability.\n",
|
58 |
+
"# Change this to \"True\" to swap home and away teams.\n",
|
59 |
+
"swap_home_away = False\n",
|
60 |
+
"if swap_home_away:\n",
|
61 |
+
" home_player_tokens, away_player_tokens = away_player_tokens, home_player_tokens\n",
|
62 |
+
" home_age_tokens, away_age_tokens = away_age_tokens, home_age_tokens"
|
63 |
+
]
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"cell_type": "code",
|
67 |
+
"execution_count": 3,
|
68 |
+
"metadata": {},
|
69 |
+
"outputs": [
|
70 |
+
{
|
71 |
+
"name": "stdout",
|
72 |
+
"output_type": "stream",
|
73 |
+
"text": [
|
74 |
+
"Home team win probability: 0.77\n"
|
75 |
+
]
|
76 |
+
},
|
77 |
+
{
|
78 |
+
"data": {
|
79 |
+
"text/plain": [
|
80 |
+
"<BarContainer object of 40 artists>"
|
81 |
+
]
|
82 |
+
},
|
83 |
+
"execution_count": 3,
|
84 |
+
"metadata": {},
|
85 |
+
"output_type": "execute_result"
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"data": {
|
89 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAGdCAYAAAAxCSikAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAqWklEQVR4nO3df1DU953H8ReggL9YjSjrDxS9WI1VwaBQTBqTy46Y4S4h8Sw6mUgYx05SMVpyNuAptE1zcPFHSJWG2BnN9XoWz7nTWvVo6VbMtWKsIJdqEmsyMRDJgrYnKCZg2O/9kcvaPVdhV3Q/rM/HzHfifnl/v/v+zFfcVz7fHxtmWZYlAAAAg4UHuwEAAIDuEFgAAIDxCCwAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMbrF+wGeoPb7VZTU5OGDBmisLCwYLcDAAB6wLIsXbx4UaNHj1Z4+I3nUEIisDQ1NSk+Pj7YbQAAgAA0NjZq7NixN6wJicAyZMgQSV8MOCYmJsjdAACAnmhra1N8fLznc/xGQiKwfHkaKCYmhsACAEAf05PLObjoFgAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4/YLdAAAACK6E/P3d1pwpybgNnVwfMywAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMYLKLCUlZUpISFB0dHRSk1N1dGjR69be/LkSS1YsEAJCQkKCwtTaWnpDfddUlKisLAwrVq1KpDWAABACPI7sOzcuVN5eXkqKipSXV2dEhMTlZ6erpaWFp/1ly9f1sSJE1VSUiK73X7Dff/+97/X66+/rhkzZvjbFgAACGF+B5ZNmzZp2bJlysnJ0dSpU1VeXq6BAwdq27ZtPutnz56t9evXa9GiRYqKirrufi9duqQnn3xSP/7xjzVs2DB/2wIAACHMr8DS2dmp2tpaORyOqzsID5fD4VBNTc1NNbJ8+XJlZGR47ft6Ojo61NbW5rUAAIDQ5VdgOX/+vLq6uhQXF+e1Pi4uTi6XK+AmKioqVFdXp+Li4h7VFxcXy2azeZb4+PiA3xsAAJgv6HcJNTY2auXKlfrXf/1XRUdH92ibgoICtba2epbGxsZb3CUAAAgmv778MDY2VhEREWpubvZa39zc3O0FtddTW1urlpYW3XvvvZ51XV1devPNN7VlyxZ1dHQoIiLCa5uoqKgbXg8DAABCi18zLJGRkUpOTpbT6fSsc7vdcjqdSktLC6iBhx9+WH/4wx9UX1/vWWbNmqUnn3xS9fX114QVAABw5/FrhkWS8vLylJ2drVmzZiklJUWlpaVqb29XTk6OJGnJkiUaM2aM53qUzs5OvfPOO54/nz17VvX19Ro8eLDuvvtuDRkyRNOmTfN6j0GDBmn48OHXrAcAAHcmvwNLVlaWzp07p8LCQrlcLiUlJamystJzIW5DQ4PCw69O3DQ1NWnmzJme1xs2bNCGDRs0d+5cVVdX3/wIAABAyAuzLMsKdhM3q62tTTabTa2trYqJiQl2OwAA9CkJ+fu7rTlTktHr7+vP53fQ7xICAADoDoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPECCixlZWVKSEhQdHS0UlNTdfTo0evWnjx5UgsWLFBCQoLCwsJUWlp6TU1xcbFmz56tIUOGaOTIkcrMzNSpU6cCaQ0AAIQgvwPLzp07lZeXp6KiItXV1SkxMVHp6elqaWnxWX/58mVNnDhRJSUlstvtPmsOHTqk5cuX68iRI6qqqtKVK1c0b948tbe3+9seAAAIQWGWZVn+bJCamqrZs2dry5YtkiS32634+HitWLFC+fn5N9w2ISFBq1at0qpVq25Yd+7cOY0cOVKHDh3SAw880G1PbW1tstlsam1tVUxMTI/HAgAApIT8/d3WnCnJ6PX39efz268Zls7OTtXW1srhcFzdQXi4HA6HampqAuvWh9bWVknSXXfd5fPnHR0damtr81oAAEDo8iuwnD9/Xl1dXYqLi/NaHxcXJ5fL1SsNud1urVq1Svfdd5+mTZvms6a4uFg2m82zxMfH98p7AwAAMxl3l9Dy5ct14sQJVVRUXLemoKBAra2tnqWxsfE2dggAAG63fv4Ux8bGKiIiQs3NzV7rm5ubr3tBrT9yc3O1b98+vfnmmxo7dux166KiohQVFXXT7wcAAPoGv2ZYIiMjlZycLKfT6VnndrvldDqVlpYWcBOWZSk3N1e7d+/Wb37zG02YMCHgfQEAgNDj1wyLJOXl5Sk7O1uzZs1SSkqKSktL1d7erpycHEnSkiVLNGbMGBUXF0v64kLdd955x/Pns2fPqr6+XoMHD9bdd98t6YvTQDt27NDPf/5zDRkyxHM9jM1m04ABA3ploAAAoO/yO7BkZWXp3LlzKiwslMvlUlJSkiorKz0X4jY0NCg8/OrETVNTk2bOnOl5vWHDBm3YsEFz585VdXW1JOm1116TJD344INe77V9+3Y9/fTT/rYIAABCjN/PYTERz2EBACBwIfccFgAAgGAgsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwXkCBpaysTAkJCYqOjlZqaqqOHj163dqTJ09qwYIFSkhIUFhYmEpLS296nwAA4M7id2DZuXOn8vLyVFRUpLq6OiUmJio9PV0tLS0+6y9fvqyJEyeqpKREdru9V/YJAADuLH4Hlk2bNmnZsmXKycnR1KlTVV5eroEDB2rbtm0+62fPnq3169dr0aJFioqK6pV9AgCAO4tfgaWzs1O1tbVyOBxXdxAeLofDoZqamoAaCGSfHR0damtr81oAAEDo8iuwnD9/Xl1dXYqLi/NaHxcXJ5fLFVADgeyzuLhYNpvNs8THxwf03gAAoG/ok3cJFRQUqLW11bM0NjYGuyUAAHAL9fOnODY2VhEREWpubvZa39zcfN0Lam/FPqOioq57PQwAAAg9fs2wREZGKjk5WU6n07PO7XbL6XQqLS0toAZuxT4BAEBo8WuGRZLy8vKUnZ2tWbNmKSUlRaWlpWpvb1dOTo4kacmSJRozZoyKi4slfXFR7TvvvOP589mzZ1VfX6/Bgwfr7rvv7tE+AQDAnc3vwJKVlaVz586psLBQLpdLSUlJqqys9Fw029DQoPDwqxM3TU1Nmjlzpuf1hg0btGHDBs2dO1fV1dU92icAALizhVmWZQW7iZvV1tYmm82m1tZWxcTEBLsdAAD6lIT8/d3WnCnJ6PX39efzu0/eJQQAAO4sBBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjBdQYCkrK1NCQoKio6OVmpqqo0eP3rB+165dmjJliqKjozV9+nQdOHDA6+eXLl1Sbm6uxo4dqwEDBmjq1KkqLy8PpDUAABCC/A4sO3fuVF5enoqKilRXV6fExESlp6erpaXFZ/3hw4e1ePFiLV26VMePH1dmZqYyMzN14sQJT01eXp4qKyv105/+VO+++65WrVql3Nxc7d27N/CRAQCAkBFmWZblzwapqamaPXu2tmzZIklyu92Kj4/XihUrlJ+ff019VlaW2tvbtW/fPs+6r33ta0pKSvLMokybNk1ZWVlat26dpyY5OVmPPPKIfvCDH3TbU1tbm2w2m1pbWxUTE+PPcAAAuOMl5O/vtuZMSUavv68/n99+zbB0dnaqtrZWDofj6g7Cw+VwOFRTU+Nzm5qaGq96SUpPT/eqnzNnjvbu3auzZ8/KsiwdPHhQf/zjHzVv3jyf++zo6FBbW5vXAgAAQpdfgeX8+fPq6upSXFyc1/q4uDi5XC6f27hcrm7rN2/erKlTp2rs2LGKjIzU/PnzVVZWpgceeMDnPouLi2Wz2TxLfHy8P8MAAAB9jBF3CW3evFlHjhzR3r17VVtbq40bN2r58uX69a9/7bO+oKBAra2tnqWxsfE2dwwAAG6nfv4Ux8bGKiIiQs3NzV7rm5ubZbfbfW5jt9tvWP/pp59qzZo12r17tzIyvjg/NmPGDNXX12vDhg3XnE6SpKioKEVFRfnTOgAA6MP8mmGJjIxUcnKynE6nZ53b7ZbT6VRaWprPbdLS0rzqJamqqspTf+XKFV25ckXh4d6tREREyO12+9MeAAAIUX7NsEhf3IKcnZ2tWbNmKSUlRaWlpWpvb1dOTo4kacmSJRozZoyKi4slSStXrtTcuXO1ceNGZWRkqKKiQseOHdPWrVslSTExMZo7d65Wr16tAQMGaPz48Tp06JB+8pOfaNOmTb04VAAA0Ff5HViysrJ07tw5FRYWyuVyKSkpSZWVlZ4LaxsaGrxmS+bMmaMdO3Zo7dq1WrNmjSZNmqQ9e/Zo2rRpnpqKigoVFBToySef1J///GeNHz9eL730kp555pleGCIAAOjr/H4Oi4l4DgsAAIELueewAAAABAOBBQAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYLx+wW4AAHDzEvL3d1tzpiTjNnQC3BrMsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwXkCBpaysTAkJCYqOjlZqaqqOHj16w/pdu3ZpypQpio6O1vTp03XgwIFrat599109+uijstlsGjRokGbPnq2GhoZA2gMAACHG78Cyc+dO5eXlqaioSHV1dUpMTFR6erpaWlp81h8+fFiLFy/W0qVLdfz4cWVmZiozM1MnTpzw1HzwwQe6//77NWXKFFVXV+vtt9/WunXrFB0dHfjIAABAyAizLMvyZ4PU1FTNnj1bW7ZskSS53W7Fx8drxYoVys/Pv6Y+KytL7e3t2rdvn2fd1772NSUlJam8vFyStGjRIvXv31//8i//EtAg2traZLPZ1NraqpiYmID2AQB9WUL+/m5rzpRk3IZO0BcF6++PP5/ffs2wdHZ2qra2Vg6H4+oOwsPlcDhUU1Pjc5uamhqveklKT0/31Lvdbu3fv19f+cpXlJ6erpEjRyo1NVV79uy5bh8dHR1qa2vzWgAAQOjyK7CcP39eXV1diouL81ofFxcnl8vlcxuXy3XD+paWFl26dEklJSWaP3++fvWrX+nxxx/XE088oUOHDvncZ3FxsWw2m2eJj4/3ZxgAAKCPCfpdQm63W5L02GOP6dvf/raSkpKUn5+vv/mbv/GcMvr/CgoK1Nra6lkaGxtvZ8sAAOA26+dPcWxsrCIiItTc3Oy1vrm5WXa73ec2drv9hvWxsbHq16+fpk6d6lVzzz336Le//a3PfUZFRSkqKsqf1gEAQB/m1wxLZGSkkpOT5XQ6PevcbrecTqfS0tJ8bpOWluZVL0lVVVWe+sjISM2ePVunTp3yqvnjH/+o8ePH+9MeAAAIUX7NsEhSXl6esrOzNWvWLKWkpKi0tFTt7e3KycmRJC1ZskRjxoxRcXGxJGnlypWaO3euNm7cqIyMDFVUVOjYsWPaunWrZ5+rV69WVlaWHnjgAT300EOqrKzUL37xC1VXV/fOKAEAQJ/md2DJysrSuXPnVFhYKJfLpaSkJFVWVnourG1oaFB4+NWJmzlz5mjHjh1au3at1qxZo0mTJmnPnj2aNm2ap+bxxx9XeXm5iouL9dxzz2ny5Mn693//d91///29MEQA6Ju4VRm4yu/nsJiI57AACEX+BBbCDW5GyD2HBQAAIBgILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwHoEFAAAYj8ACAACM1y/YDQAAzJaQv/+GPz9TknGbOsGdjBkWAABgPAILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjNcv2A0AgIkS8vd3W3OmJOM2dAJAYoYFAAD0AQQWAABgvIACS1lZmRISEhQdHa3U1FQdPXr0hvW7du3SlClTFB0drenTp+vAgQPXrX3mmWcUFham0tLSQFoDAAAhyO/AsnPnTuXl5amoqEh1dXVKTExUenq6WlpafNYfPnxYixcv1tKlS3X8+HFlZmYqMzNTJ06cuKZ29+7dOnLkiEaPHu3/SAAAQMjy+6LbTZs2admyZcrJyZEklZeXa//+/dq2bZvy8/OvqX/11Vc1f/58rV69WpL04osvqqqqSlu2bFF5ebmn7uzZs1qxYoV++ctfKiODC9kA9B1coAvcen7NsHR2dqq2tlYOh+PqDsLD5XA4VFNT43Obmpoar3pJSk9P96p3u9166qmntHr1an31q1/tto+Ojg61tbV5LQAAIHT5FVjOnz+vrq4uxcXFea2Pi4uTy+XyuY3L5eq2/p/+6Z/Ur18/Pffccz3qo7i4WDabzbPEx8f7MwwAANDHBP0uodraWr366qt64403FBYW1qNtCgoK1Nra6lkaGxtvcZcAACCY/AossbGxioiIUHNzs9f65uZm2e12n9vY7fYb1v/Xf/2XWlpaNG7cOPXr10/9+vXTRx99pOeff14JCQk+9xkVFaWYmBivBQAAhC6/AktkZKSSk5PldDo969xut5xOp9LS0nxuk5aW5lUvSVVVVZ76p556Sm+//bbq6+s9y+jRo7V69Wr98pe/9Hc8AAAgBPl9l1BeXp6ys7M1a9YspaSkqLS0VO3t7Z67hpYsWaIxY8aouLhYkrRy5UrNnTtXGzduVEZGhioqKnTs2DFt3bpVkjR8+HANHz7c6z369+8vu92uyZMn3+z4AABACPA7sGRlZencuXMqLCyUy+VSUlKSKisrPRfWNjQ0KDz86sTNnDlztGPHDq1du1Zr1qzRpEmTtGfPHk2bNq33RgEAfQS3QAOBCejLD3Nzc5Wbm+vzZ9XV1desW7hwoRYuXNjj/Z85cyaQtgDghggLQN8V9LuEAAAAukNgAQAAxiOwAAAA4wV0DQsAAL50d50Q1wghUMywAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGA8AgsAADAe39YMoE/r7tuBJb4hGAgFzLAAAADjEVgAAIDxOCUEAAiK7k7nBXoqj9OEoYkZFgAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeDw4DgDuMH3xwWp9sWf0LmZYAACA8QgsAADAeJwSAgDcsTjV1HcQWIBbjH8QAeDmcUoIAAAYjxkWAACCiFnYnmGGBQAAGC+gwFJWVqaEhARFR0crNTVVR48evWH9rl27NGXKFEVHR2v69Ok6cOCA52dXrlzRCy+8oOnTp2vQoEEaPXq0lixZoqampkBaAwAAIcjvwLJz507l5eWpqKhIdXV1SkxMVHp6ulpaWnzWHz58WIsXL9bSpUt1/PhxZWZmKjMzUydOnJAkXb58WXV1dVq3bp3q6ur0H//xHzp16pQeffTRmxsZAAAIGX5fw7Jp0yYtW7ZMOTk5kqTy8nLt379f27ZtU35+/jX1r776qubPn6/Vq1dLkl588UVVVVVpy5YtKi8vl81mU1VVldc2W7ZsUUpKihoaGjRu3LhAxgXgL3COHEBf59cMS2dnp2pra+VwOK7uIDxcDodDNTU1Prepqanxqpek9PT069ZLUmtrq8LCwjR06FB/2gMAACHKrxmW8+fPq6urS3FxcV7r4+Li9N577/ncxuVy+ax3uVw+6z/77DO98MILWrx4sWJiYnzWdHR0qKOjw/O6ra3Nn2EAuAFmYwCYyKi7hK5cuaJvfOMbsixLr7322nXriouLZbPZPEt8fPxt7BIAANxufgWW2NhYRUREqLm52Wt9c3Oz7Ha7z23sdnuP6r8MKx999JGqqqquO7siSQUFBWptbfUsjY2N/gwDAAD0MX6dEoqMjFRycrKcTqcyMzMlSW63W06nU7m5uT63SUtLk9Pp1KpVqzzrqqqqlJaW5nn9ZVg5ffq0Dh48qOHDh9+wj6ioKEVFRfnTOhByOHUD4E7i911CeXl5ys7O1qxZs5SSkqLS0lK1t7d77hpasmSJxowZo+LiYknSypUrNXfuXG3cuFEZGRmqqKjQsWPHtHXrVklfhJW/+7u/U11dnfbt26euri7P9S133XWXIiMje2usAAD0aXfy/6j4HViysrJ07tw5FRYWyuVyKSkpSZWVlZ4LaxsaGhQefvVM05w5c7Rjxw6tXbtWa9as0aRJk7Rnzx5NmzZNknT27Fnt3btXkpSUlOT1XgcPHtSDDz4Y4NAAAECoCOi7hHJzc697Cqi6uvqadQsXLtTChQt91ickJMiyrEDaAAAA1xFqszF8+SEAAD3gTwAItbBgAqNuawYAAPCFwAIAAIxHYAEAAMbjGhYAAeM8PYDbhcACBKC7D2o+pAGgdxFYANwWzMYAuBlcwwIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB7PYekBnh+B24W/awDgGzMsAADAeMywAP+Hx+0DgLmYYQEAAMZjhgV9ir/XeDBrAgChgRkWAABgPAILAAAwHoEFAAAYj8ACAACMx0W3uCV4ABoAoDcxwwIAAIzHDAt6jFkTAECwMMMCAACMxwwLgo6ZGwBAd5hhAQAAxmOGJQQxYwEACDUElj7iVoUQwg0AoC/glBAAADAeMyxBxOwGAAA9wwwLAAAwHoEFAAAYj8ACAACMR2ABAADG46LbXsaFtAAA9L6AZljKysqUkJCg6Ohopaam6ujRozes37Vrl6ZMmaLo6GhNnz5dBw4c8Pq5ZVkqLCzUqFGjNGDAADkcDp0+fTqQ1gAAQAjyO7Ds3LlTeXl5KioqUl1dnRITE5Wenq6Wlhaf9YcPH9bixYu1dOlSHT9+XJmZmcrMzNSJEyc8NS+//LJ++MMfqry8XG+99ZYGDRqk9PR0ffbZZ4GPDAAAhAy/A8umTZu0bNky5eTkaOrUqSovL9fAgQO1bds2n/Wvvvqq5s+fr9WrV+uee+7Riy++qHvvvVdbtmyR9MXsSmlpqdauXavHHntMM2bM0E9+8hM1NTVpz549NzU4AAAQGvy6hqWzs1O1tbUqKCjwrAsPD5fD4VBNTY3PbWpqapSXl+e1Lj093RNGPvzwQ7lcLjkcDs/PbTabUlNTVVNTo0WLFl2zz46ODnV0dHhet7a2SpLa2tr8GU6PuTsud1vz5XtTa05tT+pNqP3Lemr9/x02od9Qru1JvQm1f1lPrTm/nz3dp2VZ3Rdbfjh79qwlyTp8+LDX+tWrV1spKSk+t+nfv7+1Y8cOr3VlZWXWyJEjLcuyrN/97neWJKupqcmrZuHChdY3vvENn/ssKiqyJLGwsLCwsLCEwNLY2NhtBumTdwkVFBR4zdq43W79+c9/1vDhwxUWFnZL37utrU3x8fFqbGxUTEzMLX2vYAjl8TG2vimUxyaF9vgYW991u8ZnWZYuXryo0aNHd1vrV2CJjY1VRESEmpubvdY3NzfLbrf73MZut9+w/sv/Njc3a9SoUV41SUlJPvcZFRWlqKgor3VDhw71Zyg3LSYmJiT/kn4plMfH2PqmUB6bFNrjY2x91+0Yn81m61GdXxfdRkZGKjk5WU6n07PO7XbL6XQqLS3N5zZpaWle9ZJUVVXlqZ8wYYLsdrtXTVtbm956663r7hMAANxZ/D4llJeXp+zsbM2aNUspKSkqLS1Ve3u7cnJyJElLlizRmDFjVFxcLElauXKl5s6dq40bNyojI0MVFRU6duyYtm7dKkkKCwvTqlWr9IMf/ECTJk3ShAkTtG7dOo0ePVqZmZm9N1IAANBn+R1YsrKydO7cORUWFsrlcikpKUmVlZWKi4uTJDU0NCg8/OrEzZw5c7Rjxw6tXbtWa9as0aRJk7Rnzx5NmzbNU/Od73xH7e3t+uY3v6kLFy7o/vvvV2VlpaKjo3thiL0rKipKRUVF15ySChWhPD7G1jeF8tik0B4fY+u7TBxfmGX15F4iAACA4OHLDwEAgPEILAAAwHgEFgAAYDwCCwAAMB6BpYfOnDmjpUuXasKECRowYID+6q/+SkVFRers7PSqe/vtt/X1r39d0dHRio+P18svvxykjv330ksvac6cORo4cOB1H8QXFhZ2zVJRUXF7Gw1AT8bW0NCgjIwMDRw4UCNHjtTq1av1+eef395Ge0FCQsI1x6ikpCTYbQWsrKxMCQkJio6OVmpqqo4ePRrslm7ad7/73WuO0ZQpU4LdVsDefPNN/e3f/q1Gjx6tsLCwa7641rIsFRYWatSoURowYIAcDodOnz4dnGb91N3Ynn766WuO5fz584PTrJ+Ki4s1e/ZsDRkyRCNHjlRmZqZOnTrlVfPZZ59p+fLlGj58uAYPHqwFCxZc8zDY24XA0kPvvfee3G63Xn/9dZ08eVKvvPKKysvLtWbNGk9NW1ub5s2bp/Hjx6u2tlbr16/Xd7/7Xc8zZ0zX2dmphQsX6tlnn71h3fbt2/XJJ594lr7wvJzuxtbV1aWMjAx1dnbq8OHD+ud//me98cYbKiwsvM2d9o7vf//7XsdoxYoVwW4pIDt37lReXp6KiopUV1enxMREpaenq6WlJdit3bSvfvWrXsfot7/9bbBbClh7e7sSExNVVlbm8+cvv/yyfvjDH6q8vFxvvfWWBg0apPT0dH322We3uVP/dTc2SZo/f77XsfzZz352GzsM3KFDh7R8+XIdOXJEVVVVunLliubNm6f29nZPzbe//W394he/0K5du3To0CE1NTXpiSeeCE7D3X7bEK7r5ZdftiZMmOB5/aMf/cgaNmyY1dHR4Vn3wgsvWJMnTw5GewHbvn27ZbPZfP5MkrV79+7b2k9vut7YDhw4YIWHh1sul8uz7rXXXrNiYmK8jmdfMH78eOuVV14Jdhu9IiUlxVq+fLnndVdXlzV69GiruLg4iF3dvKKiIisxMTHYbdwS///fCLfbbdntdmv9+vWedRcuXLCioqKsn/3sZ0HoMHC+/v3Lzs62HnvssaD009taWlosSdahQ4csy/riOPXv39/atWuXp+bdd9+1JFk1NTW3vT9mWG5Ca2ur7rrrLs/rmpoaPfDAA4qMjPSsS09P16lTp/Q///M/wWjxlli+fLliY2OVkpKibdu29exrwQ1XU1Oj6dOnex6AKH1x7Nra2nTy5MkgdhaYkpISDR8+XDNnztT69ev75Kmtzs5O1dbWyuFweNaFh4fL4XCopqYmiJ31jtOnT2v06NGaOHGinnzySTU0NAS7pVviww8/lMvl8jqONptNqampIXEcJam6ulojR47U5MmT9eyzz+pPf/pTsFsKSGtrqyR5Ptdqa2t15coVr2M3ZcoUjRs3LijHrk9+W7MJ3n//fW3evFkbNmzwrHO5XJowYYJX3ZcfgC6XS8OGDbutPd4K3//+9/XXf/3XGjhwoH71q1/pW9/6li5duqTnnnsu2K3dFJfL5RVWJO9j15c899xzuvfee3XXXXfp8OHDKigo0CeffKJNmzYFuzW/nD9/Xl1dXT6Py3vvvRekrnpHamqq3njjDU2ePFmffPKJvve97+nrX/+6Tpw4oSFDhgS7vV715e+Pr+PY1363fJk/f76eeOIJTZgwQR988IHWrFmjRx55RDU1NYqIiAh2ez3mdru1atUq3XfffZ4n0btcLkVGRl5z3V+wjt0dP8OSn5/v80LSv1z+/z+OZ8+e1fz587Vw4UItW7YsSJ33TCDju5F169bpvvvu08yZM/XCCy/oO9/5jtavX38LR3B9vT02k/kz1ry8PD344IOaMWOGnnnmGW3cuFGbN29WR0dHkEeBLz3yyCNauHChZsyYofT0dB04cEAXLlzQv/3bvwW7Nfhp0aJFevTRRzV9+nRlZmZq3759+v3vf6/q6upgt+aX5cuX68SJE0bfRHHHz7A8//zzevrpp29YM3HiRM+fm5qa9NBDD2nOnDnXXExrt9uvuXr6y9d2u713GvaTv+PzV2pqql588UV1dHTc9u+c6M2x2e32a+4+Cfax+0s3M9bU1FR9/vnnOnPmjCZPnnwLurs1YmNjFRER4fN3yoRj0puGDh2qr3zlK3r//feD3Uqv+/JYNTc3a9SoUZ71zc3NSkpKClJXt87EiRMVGxur999/Xw8//HCw2+mR3Nxc7du3T2+++abGjh3rWW+329XZ2akLFy54zbIE63fwjg8sI0aM0IgRI3pUe/bsWT300ENKTk7W9u3bvb7kUZLS0tL0D//wD7py5Yr69+8vSaqqqtLkyZODdjrIn/EFor6+XsOGDQvKF2T15tjS0tL00ksvqaWlRSNHjpT0xbGLiYnR1KlTe+U9bsbNjLW+vl7h4eGecfUVkZGRSk5OltPp9NyJ5na75XQ6lZubG9zmetmlS5f0wQcf6Kmnngp2K71uwoQJstvtcjqdnoDS1tamt956q9s7Evuijz/+WH/605+8wpmpLMvSihUrtHv3blVXV19zSUNycrL69+8vp9OpBQsWSJJOnTqlhoYGpaWlBaVh9MDHH39s3X333dbDDz9sffzxx9Ynn3ziWb504cIFKy4uznrqqaesEydOWBUVFdbAgQOt119/PYid99xHH31kHT9+3Pre975nDR482Dp+/Lh1/Phx6+LFi5ZlWdbevXutH//4x9Yf/vAH6/Tp09aPfvQja+DAgVZhYWGQO+9ed2P7/PPPrWnTplnz5s2z6uvrrcrKSmvEiBFWQUFBkDv3z+HDh61XXnnFqq+vtz744APrpz/9qTVixAhryZIlwW4tIBUVFVZUVJT1xhtvWO+88471zW9+0xo6dKjX3Vx90fPPP29VV1dbH374ofW73/3OcjgcVmxsrNXS0hLs1gJy8eJFz++UJGvTpk3W8ePHrY8++siyLMsqKSmxhg4dav385z+33n77beuxxx6zJkyYYH366adB7rx7NxrbxYsXrb//+7+3ampqrA8//ND69a9/bd17773WpEmTrM8++yzYrXfr2WeftWw2m1VdXe31mXb58mVPzTPPPGONGzfO+s1vfmMdO3bMSktLs9LS0oLSL4Glh7Zv325J8rn8pf/+7/+27r//fisqKsoaM2aMVVJSEqSO/Zedne1zfAcPHrQsy7L+8z//00pKSrIGDx5sDRo0yEpMTLTKy8utrq6u4DbeA92NzbIs68yZM9YjjzxiDRgwwIqNjbWef/5568qVK8FrOgC1tbVWamqqZbPZrOjoaOuee+6x/vEf/7FP/ON5PZs3b7bGjRtnRUZGWikpKdaRI0eC3dJNy8rKskaNGmVFRkZaY8aMsbKysqz3338/2G0F7ODBgz5/v7Kzsy3L+uLW5nXr1llxcXFWVFSU9fDDD1unTp0KbtM9dKOxXb582Zo3b541YsQIq3///tb48eOtZcuW9ZlAfb3PtO3bt3tqPv30U+tb3/qWNWzYMGvgwIHW448/7vU/6rdT2P81DQAAYKw7/i4hAABgPgILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIz3v5FUrqAH/4M1AAAAAElFTkSuQmCC",
|
90 |
+
"text/plain": [
|
91 |
+
"<Figure size 640x480 with 1 Axes>"
|
92 |
+
]
|
93 |
+
},
|
94 |
+
"metadata": {},
|
95 |
+
"output_type": "display_data"
|
96 |
+
}
|
97 |
+
],
|
98 |
+
"source": [
|
99 |
+
"# Run this cell to see the spread probabilities!\n",
|
100 |
+
"\n",
|
101 |
+
"assert len(home_player_tokens) == players_per_team\n",
|
102 |
+
"assert len(home_age_tokens) == players_per_team\n",
|
103 |
+
"assert len(away_player_tokens) == players_per_team\n",
|
104 |
+
"assert len(away_age_tokens) == players_per_team\n",
|
105 |
+
"\n",
|
106 |
+
"batch = {\n",
|
107 |
+
" 'home_player_tokens': Tensor([num_player_tokens+1] + home_player_tokens).to(dtype=int32).unsqueeze(0),\n",
|
108 |
+
" 'home_age_tokens': Tensor([num_age_tokens+1] + home_age_tokens).to(dtype=int32).unsqueeze(0),\n",
|
109 |
+
" 'away_player_tokens': Tensor(away_player_tokens).to(dtype=int32).unsqueeze(0),\n",
|
110 |
+
" 'away_age_tokens': Tensor(away_age_tokens).to(dtype=int32).unsqueeze(0),\n",
|
111 |
+
"}\n",
|
112 |
+
"\n",
|
113 |
+
"for key, value in batch.items():\n",
|
114 |
+
" if hasattr(value, 'to'):\n",
|
115 |
+
" batch[key] = value.to(device)\n",
|
116 |
+
"\n",
|
117 |
+
"output, _ = model(**batch)\n",
|
118 |
+
"output = output.squeeze().softmax(dim=0)\n",
|
119 |
+
"\n",
|
120 |
+
"probs = {}\n",
|
121 |
+
"loss_prob = 0\n",
|
122 |
+
"win_prob = 0\n",
|
123 |
+
"\n",
|
124 |
+
"first = True\n",
|
125 |
+
"for i, token in enumerate(output):\n",
|
126 |
+
" if first:\n",
|
127 |
+
" first = False\n",
|
128 |
+
" continue\n",
|
129 |
+
"\n",
|
130 |
+
" if i-21 < 0:\n",
|
131 |
+
" loss_prob += token.item()\n",
|
132 |
+
" elif i-21 > 0 and i-21 < 21:\n",
|
133 |
+
" win_prob += token.item()\n",
|
134 |
+
"\n",
|
135 |
+
" probs[i-21] = token.item()\n",
|
136 |
+
"\n",
|
137 |
+
"del probs[0]\n",
|
138 |
+
"del probs[21]\n",
|
139 |
+
"\n",
|
140 |
+
"print(f\"Home team win probability: {win_prob:.2f}\")\n",
|
141 |
+
"\n",
|
142 |
+
"plt.bar(probs.keys(), probs.values())"
|
143 |
+
]
|
144 |
+
},
|
145 |
+
{
|
146 |
+
"cell_type": "code",
|
147 |
+
"execution_count": null,
|
148 |
+
"metadata": {},
|
149 |
+
"outputs": [],
|
150 |
+
"source": []
|
151 |
+
}
|
152 |
+
],
|
153 |
+
"metadata": {
|
154 |
+
"kernelspec": {
|
155 |
+
"display_name": "nba",
|
156 |
+
"language": "python",
|
157 |
+
"name": "python3"
|
158 |
+
},
|
159 |
+
"language_info": {
|
160 |
+
"codemirror_mode": {
|
161 |
+
"name": "ipython",
|
162 |
+
"version": 3
|
163 |
+
},
|
164 |
+
"file_extension": ".py",
|
165 |
+
"mimetype": "text/x-python",
|
166 |
+
"name": "python",
|
167 |
+
"nbconvert_exporter": "python",
|
168 |
+
"pygments_lexer": "ipython3",
|
169 |
+
"version": "3.11.8"
|
170 |
+
}
|
171 |
+
},
|
172 |
+
"nbformat": 4,
|
173 |
+
"nbformat_minor": 2
|
174 |
+
}
|
model.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.nn import (
|
2 |
+
Module,
|
3 |
+
Embedding,
|
4 |
+
Dropout,
|
5 |
+
ModuleDict,
|
6 |
+
LayerNorm,
|
7 |
+
ModuleList,
|
8 |
+
Linear,
|
9 |
+
GELU,
|
10 |
+
functional,
|
11 |
+
)
|
12 |
+
from torch.nn.init import normal_, zeros_
|
13 |
+
from dataclasses import dataclass
|
14 |
+
from rotary_embedding_torch import RotaryEmbedding
|
15 |
+
from torch import ones, cat
|
16 |
+
from torch.nn.functional import scaled_dot_product_attention
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from math import sqrt
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class NBAConfig:
|
22 |
+
players_per_team: int = None
|
23 |
+
player_tokens: int = None
|
24 |
+
age_tokens: int = None
|
25 |
+
n_layer: int = None
|
26 |
+
n_head: int = None
|
27 |
+
n_embd: int = None
|
28 |
+
dropout: float = None
|
29 |
+
seed: int = None
|
30 |
+
bias: bool = None
|
31 |
+
dtype: type = None
|
32 |
+
num_labels: int = None
|
33 |
+
|
34 |
+
class SelfAttention(Module):
|
35 |
+
|
36 |
+
def __init__(self, config):
|
37 |
+
|
38 |
+
block_size = config.players_per_team * 2 + 1
|
39 |
+
|
40 |
+
super().__init__()
|
41 |
+
assert config.n_embd % config.n_head == 0
|
42 |
+
self.c_attn = Linear(config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=config.dtype)
|
43 |
+
self.c_proj = Linear(config.n_embd, config.n_embd, bias=config.bias, dtype=config.dtype)
|
44 |
+
self.attn_dropout = Dropout(config.dropout)
|
45 |
+
self.resid_dropout = Dropout(config.dropout)
|
46 |
+
self.n_head = config.n_head
|
47 |
+
self.n_embd = config.n_embd
|
48 |
+
self.dropout = config.dropout
|
49 |
+
self.rotary_emb = RotaryEmbedding(config.n_embd)
|
50 |
+
self.flash = hasattr(functional, 'scaled_dot_product_attention')
|
51 |
+
if not self.flash:
|
52 |
+
self.register_buffer("bias", ones(block_size, block_size)
|
53 |
+
).view(1, 1, block_size, block_size)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
B, T, C = x.size()
|
57 |
+
|
58 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
59 |
+
|
60 |
+
q = self.rotary_emb.rotate_queries_or_keys(q)
|
61 |
+
k = self.rotary_emb.rotate_queries_or_keys(k)
|
62 |
+
|
63 |
+
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
64 |
+
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
65 |
+
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
66 |
+
|
67 |
+
if self.flash:
|
68 |
+
y = scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=False)
|
69 |
+
else:
|
70 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / sqrt(k.size(-1)))
|
71 |
+
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
|
72 |
+
att = F.softmax(att, dim=-1)
|
73 |
+
att = self.attn_dropout(att)
|
74 |
+
y = att @ v
|
75 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
76 |
+
|
77 |
+
# output projection
|
78 |
+
y = self.resid_dropout(self.c_proj(y))
|
79 |
+
return y
|
80 |
+
|
81 |
+
class MLP(Module):
|
82 |
+
|
83 |
+
def __init__(self, config):
|
84 |
+
super().__init__()
|
85 |
+
self.c_fc = Linear(config.n_embd, 4 * config.n_embd, bias=config.bias, dtype=config.dtype)
|
86 |
+
self.gelu = GELU()
|
87 |
+
self.c_proj = Linear(4 * config.n_embd, config.n_embd, bias=config.bias, dtype=config.dtype)
|
88 |
+
self.dropout = Dropout(config.dropout)
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
x = self.c_fc(x)
|
92 |
+
x = self.gelu(x)
|
93 |
+
x = self.c_proj(x)
|
94 |
+
x = self.dropout(x)
|
95 |
+
return x
|
96 |
+
|
97 |
+
class Block(Module):
|
98 |
+
|
99 |
+
def __init__(self, config):
|
100 |
+
super().__init__()
|
101 |
+
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias, dtype=config.dtype)
|
102 |
+
self.attn = SelfAttention(config)
|
103 |
+
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias, dtype=config.dtype)
|
104 |
+
self.mlp = MLP(config)
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
x = x + self.attn(self.ln_1(x))
|
108 |
+
return x + self.mlp(self.ln_2(x))
|
109 |
+
|
110 |
+
class NBAModel(Module):
|
111 |
+
|
112 |
+
def __init__(self, config) -> None:
|
113 |
+
super().__init__()
|
114 |
+
|
115 |
+
self.config = config
|
116 |
+
|
117 |
+
self.transformer = ModuleDict(dict(
|
118 |
+
home_player_embeddings = Embedding(config.player_tokens, config.n_embd, dtype=config.dtype),
|
119 |
+
away_player_embeddings = Embedding(config.player_tokens, config.n_embd, dtype=config.dtype),
|
120 |
+
home_age_embeddings = Embedding(config.age_tokens, config.n_embd, dtype=config.dtype),
|
121 |
+
away_age_embeddings = Embedding(config.age_tokens, config.n_embd, dtype=config.dtype),
|
122 |
+
drop = Dropout(config.dropout),
|
123 |
+
h = ModuleList([Block(config) for _ in range(config.n_layer)]),
|
124 |
+
ln_f = LayerNorm(config.n_embd, bias=config.bias, dtype=config.dtype),
|
125 |
+
))
|
126 |
+
|
127 |
+
self.head = Linear(config.n_embd, config.num_labels, dtype=config.dtype)
|
128 |
+
|
129 |
+
self.apply(self._init_weights)
|
130 |
+
for pn, p in self.named_parameters():
|
131 |
+
if pn.endswith('c_proj.weight'):
|
132 |
+
normal_(p, mean=0.0, std=0.02/sqrt(2 * config.n_layer))
|
133 |
+
|
134 |
+
def _init_weights(self, module):
|
135 |
+
if isinstance(module, Linear):
|
136 |
+
normal_(module.weight, mean=0.0, std=0.02)
|
137 |
+
if module.bias is not None:
|
138 |
+
zeros_(module.bias)
|
139 |
+
elif isinstance(module, Embedding):
|
140 |
+
normal_(module.weight, mean=0.0, std=0.02)
|
141 |
+
|
142 |
+
def forward(self, **batch):
|
143 |
+
home_player_tokens = batch['home_player_tokens']
|
144 |
+
away_player_tokens = batch['away_player_tokens']
|
145 |
+
home_age_tokens = batch['home_age_tokens']
|
146 |
+
away_age_tokens = batch['away_age_tokens']
|
147 |
+
|
148 |
+
home_player_embeddings = self.transformer.home_player_embeddings(home_player_tokens)
|
149 |
+
away_player_embeddings = self.transformer.away_player_embeddings(away_player_tokens)
|
150 |
+
|
151 |
+
home_age_embeddings = self.transformer.home_age_embeddings(home_age_tokens)
|
152 |
+
away_age_embeddings = self.transformer.away_age_embeddings(away_age_tokens)
|
153 |
+
|
154 |
+
home_emb = home_player_embeddings + home_age_embeddings
|
155 |
+
away_emb = away_player_embeddings + away_age_embeddings
|
156 |
+
|
157 |
+
x = cat([home_emb, away_emb], dim=1)
|
158 |
+
|
159 |
+
x = self.transformer.drop(x)
|
160 |
+
|
161 |
+
for block in self.transformer.h:
|
162 |
+
x = block(x)
|
163 |
+
x = self.transformer.ln_f(x)
|
164 |
+
|
165 |
+
logits = self.head(x)
|
166 |
+
logits = logits[:, 0]
|
167 |
+
|
168 |
+
loss = None
|
169 |
+
if 'home_team_won' in batch:
|
170 |
+
loss = F.cross_entropy(logits, batch['home_net_score_token'])
|
171 |
+
loss = {'loss': loss}
|
172 |
+
|
173 |
+
return logits, loss
|
player_tokens.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
porzingis-swapped-for-pritchard.png
ADDED
prediction.png
ADDED
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
rotary_embedding_torch
|
2 |
+
torch
|
3 |
+
jupyter
|
4 |
+
matplotlib
|
weights.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:51ab232915c68ba50ac907b60139e7e45c08eb2ce92a95fa29488b19896ffa2e
|
3 |
+
size 121995384
|