Cal Mitchell commited on
Commit
5f26252
1 Parent(s): b9bdc56
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__/
2
+ .python-version
LICENSE CHANGED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # License
2
+
3
+ You may use the code and weights in this repository for personal, non-commercial purposes.
4
+
5
+ You may publicly share any results output by the model.
6
+
7
+ You may not build or release any products (whether open source or proprietary), or provide any services, using the code, weights, or derivatives of either, in this repository.
8
+
9
+ All code written and all weights trained by Cal Mitchell. All rights are reserved.
README.md CHANGED
@@ -1,5 +1,33 @@
1
- ---
2
- license: other
3
- license_name: license
4
- license_link: LICENSE
5
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NBA Predictions
2
+
3
+ This repo contains AI model code and weights which predicts the outcome of NBA games. Its output represents the chance that a given point spread will occur.
4
+
5
+ The model requires 8 players on the home and away teams, plus their ages, as input. It will then output probabilities for each point spread between -20 and +20 points, from the home team's point of view.
6
+
7
+ For example, the following text and chart shows the model predicting the home team with a 77% chance to win and a 14% chance of winning by 20 or more points. This kind of chart is indicative of a dominant team playing at home. Most games will have more of a bell curve shape to them.
8
+
9
+ ![NBA prediction graph](prediction.png)
10
+
11
+ ## Installation
12
+
13
+ I recommend installing Python 3.11.8, as that is what the repo was written / tested in. The code will likely work with most recent versions of Python, though.
14
+
15
+ Once you have Python installed, run `pip install -r requirements.txt`. It will take a while to install dependencies if you don't already have PyTorch cached.
16
+
17
+ ## Usage
18
+
19
+ The `example.ipynb` notebook shows how to use the model to predict the final game of the 2023-24 NBA season - a game between the Dallas Mavericks and Boston Celtics. It will output the chart above.
20
+
21
+ To change the players and their ages, you must reference the `player_tokens.csv` and `age_tokens.csv` files.
22
+
23
+ For example, if you wanted to subtract Kristaps Porzingis from Boston's team and swap who was home / away, you would take the token representing Porzingis `4416` out of the `home_team_tokens` list, and replace him with, say, Payton Pritchard `4999`. You would then have to look up Pritchard's age (26), find the corresponding age token in `age_tokens.csv`, which is `11`, and replace Porzingis' age token (which is the second to last token).
24
+
25
+ To swap home and away, you could replace the variables containing all of the player and age tokens, or just set the `swap_home_away` variable to `True`. The results are as follows:
26
+
27
+ ![NBA Finals prediction without Porzingis](porzingis-swapped-for-pritchard.png)
28
+
29
+ As you can see, Dallas' win probability improved from 23% to 35%, and their chance of being blown out by 20+ points decreased from 14% to 10%. Clearly, the model thinks Porzingis is important to the Celtics' chances, but still considers Boston to be the superior team without him.
30
+
31
+ ## Training Process
32
+
33
+ I downloaded data from stats.nba.com using the [https://github.com/swar/nba_api](swar/nba_api) package to get information on minutes played, game outcomes, and a few other dimensional elements to make everything fit together. Then, I ran a custom PyTorch training loop to train the model(s) on their chosen loss objective (spread, money line, or spread probability).
age_tokens.csv ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ age,token
2
+ 16,1
3
+ 17,2
4
+ 18,3
5
+ 19,4
6
+ 20,5
7
+ 21,6
8
+ 22,7
9
+ 23,8
10
+ 24,9
11
+ 25,10
12
+ 26,11
13
+ 27,12
14
+ 28,13
15
+ 29,14
16
+ 30,15
17
+ 31,16
18
+ 32,17
19
+ 33,18
20
+ 34,19
21
+ 35,20
22
+ 36,21
23
+ 37,22
24
+ 38,23
25
+ 39,24
26
+ 40,25
27
+ 41,26
28
+ 42,27
29
+ 43,28
30
+ 44,29
31
+ 45,30
32
+ 46,31
33
+ 47,32
example.ipynb ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from model import NBAModel, NBAConfig\n",
10
+ "from torch import device as torch_device, load as torch_load, int32, Tensor, bfloat16\n",
11
+ "import matplotlib.pyplot as plt\n",
12
+ "\n",
13
+ "device = torch_device(\"cpu\") \n",
14
+ "num_age_tokens=32\n",
15
+ "num_player_tokens=5141\n",
16
+ "num_net_score_tokens=41\n",
17
+ "players_per_team=8\n",
18
+ "\n",
19
+ "model_config = NBAConfig(\n",
20
+ " players_per_team=players_per_team,\n",
21
+ " player_tokens=num_player_tokens+2,\n",
22
+ " age_tokens=num_age_tokens+2,\n",
23
+ " num_labels=num_net_score_tokens+2,\n",
24
+ " n_layer=4,\n",
25
+ " n_head=4,\n",
26
+ " n_embd=1024,\n",
27
+ " dropout=0.0,\n",
28
+ " bias=False,\n",
29
+ " dtype=bfloat16,\n",
30
+ " seed=29,\n",
31
+ ")\n",
32
+ "\n",
33
+ "model = NBAModel(model_config).to(device)\n",
34
+ "state_dict = torch_load('weights.pt', map_location='cpu')\n",
35
+ "model.load_state_dict(state_dict)\n",
36
+ "model = model.eval()"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": 2,
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "# Change player and age tokens here!\n",
46
+ "# You can find these values in player_tokens.csv and age_tokens.csv\n",
47
+ "# You must provide exactly 8 player tokens and 8 age tokens for each team.\n",
48
+ "\n",
49
+ "# Boston Celtics final game of 2023-24 season roster\n",
50
+ "home_player_tokens = [1994, 5039, 5027, 4981, 4972, 5004, 4416, 4983]\n",
51
+ "home_age_tokens = [11, 12, 19, 14, 23, 11, 13, 13]\n",
52
+ "\n",
53
+ "# Dallas Mavericks final game of 2023-24 season roster\n",
54
+ "away_player_tokens = [5117, 5097, 4956, 5109, 55, 149, 5121, 5112]\n",
55
+ "away_age_tokens = [10, 17, 10, 12, 10, 5, 8, 17]\n",
56
+ "\n",
57
+ "# The model usually gives the home team a bump in win probability.\n",
58
+ "# Change this to \"True\" to swap home and away teams.\n",
59
+ "swap_home_away = False\n",
60
+ "if swap_home_away:\n",
61
+ " home_player_tokens, away_player_tokens = away_player_tokens, home_player_tokens\n",
62
+ " home_age_tokens, away_age_tokens = away_age_tokens, home_age_tokens"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": 3,
68
+ "metadata": {},
69
+ "outputs": [
70
+ {
71
+ "name": "stdout",
72
+ "output_type": "stream",
73
+ "text": [
74
+ "Home team win probability: 0.77\n"
75
+ ]
76
+ },
77
+ {
78
+ "data": {
79
+ "text/plain": [
80
+ "<BarContainer object of 40 artists>"
81
+ ]
82
+ },
83
+ "execution_count": 3,
84
+ "metadata": {},
85
+ "output_type": "execute_result"
86
+ },
87
+ {
88
+ "data": {
89
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAGdCAYAAAAxCSikAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAqWklEQVR4nO3df1DU953H8ReggL9YjSjrDxS9WI1VwaBQTBqTy46Y4S4h8Sw6mUgYx05SMVpyNuAptE1zcPFHSJWG2BnN9XoWz7nTWvVo6VbMtWKsIJdqEmsyMRDJgrYnKCZg2O/9kcvaPVdhV3Q/rM/HzHfifnl/v/v+zFfcVz7fHxtmWZYlAAAAg4UHuwEAAIDuEFgAAIDxCCwAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMbrF+wGeoPb7VZTU5OGDBmisLCwYLcDAAB6wLIsXbx4UaNHj1Z4+I3nUEIisDQ1NSk+Pj7YbQAAgAA0NjZq7NixN6wJicAyZMgQSV8MOCYmJsjdAACAnmhra1N8fLznc/xGQiKwfHkaKCYmhsACAEAf05PLObjoFgAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4/YLdAAAACK6E/P3d1pwpybgNnVwfMywAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMYLKLCUlZUpISFB0dHRSk1N1dGjR69be/LkSS1YsEAJCQkKCwtTaWnpDfddUlKisLAwrVq1KpDWAABACPI7sOzcuVN5eXkqKipSXV2dEhMTlZ6erpaWFp/1ly9f1sSJE1VSUiK73X7Dff/+97/X66+/rhkzZvjbFgAACGF+B5ZNmzZp2bJlysnJ0dSpU1VeXq6BAwdq27ZtPutnz56t9evXa9GiRYqKirrufi9duqQnn3xSP/7xjzVs2DB/2wIAACHMr8DS2dmp2tpaORyOqzsID5fD4VBNTc1NNbJ8+XJlZGR47ft6Ojo61NbW5rUAAIDQ5VdgOX/+vLq6uhQXF+e1Pi4uTi6XK+AmKioqVFdXp+Li4h7VFxcXy2azeZb4+PiA3xsAAJgv6HcJNTY2auXKlfrXf/1XRUdH92ibgoICtba2epbGxsZb3CUAAAgmv778MDY2VhEREWpubvZa39zc3O0FtddTW1urlpYW3XvvvZ51XV1devPNN7VlyxZ1dHQoIiLCa5uoqKgbXg8DAABCi18zLJGRkUpOTpbT6fSsc7vdcjqdSktLC6iBhx9+WH/4wx9UX1/vWWbNmqUnn3xS9fX114QVAABw5/FrhkWS8vLylJ2drVmzZiklJUWlpaVqb29XTk6OJGnJkiUaM2aM53qUzs5OvfPOO54/nz17VvX19Ro8eLDuvvtuDRkyRNOmTfN6j0GDBmn48OHXrAcAAHcmvwNLVlaWzp07p8LCQrlcLiUlJamystJzIW5DQ4PCw69O3DQ1NWnmzJme1xs2bNCGDRs0d+5cVVdX3/wIAABAyAuzLMsKdhM3q62tTTabTa2trYqJiQl2OwAA9CkJ+fu7rTlTktHr7+vP53fQ7xICAADoDoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPECCixlZWVKSEhQdHS0UlNTdfTo0evWnjx5UgsWLFBCQoLCwsJUWlp6TU1xcbFmz56tIUOGaOTIkcrMzNSpU6cCaQ0AAIQgvwPLzp07lZeXp6KiItXV1SkxMVHp6elqaWnxWX/58mVNnDhRJSUlstvtPmsOHTqk5cuX68iRI6qqqtKVK1c0b948tbe3+9seAAAIQWGWZVn+bJCamqrZs2dry5YtkiS32634+HitWLFC+fn5N9w2ISFBq1at0qpVq25Yd+7cOY0cOVKHDh3SAw880G1PbW1tstlsam1tVUxMTI/HAgAApIT8/d3WnCnJ6PX39efz268Zls7OTtXW1srhcFzdQXi4HA6HampqAuvWh9bWVknSXXfd5fPnHR0damtr81oAAEDo8iuwnD9/Xl1dXYqLi/NaHxcXJ5fL1SsNud1urVq1Svfdd5+mTZvms6a4uFg2m82zxMfH98p7AwAAMxl3l9Dy5ct14sQJVVRUXLemoKBAra2tnqWxsfE2dggAAG63fv4Ux8bGKiIiQs3NzV7rm5ubr3tBrT9yc3O1b98+vfnmmxo7dux166KiohQVFXXT7wcAAPoGv2ZYIiMjlZycLKfT6VnndrvldDqVlpYWcBOWZSk3N1e7d+/Wb37zG02YMCHgfQEAgNDj1wyLJOXl5Sk7O1uzZs1SSkqKSktL1d7erpycHEnSkiVLNGbMGBUXF0v64kLdd955x/Pns2fPqr6+XoMHD9bdd98t6YvTQDt27NDPf/5zDRkyxHM9jM1m04ABA3ploAAAoO/yO7BkZWXp3LlzKiwslMvlUlJSkiorKz0X4jY0NCg8/OrETVNTk2bOnOl5vWHDBm3YsEFz585VdXW1JOm1116TJD344INe77V9+3Y9/fTT/rYIAABCjN/PYTERz2EBACBwIfccFgAAgGAgsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwXkCBpaysTAkJCYqOjlZqaqqOHj163dqTJ09qwYIFSkhIUFhYmEpLS296nwAA4M7id2DZuXOn8vLyVFRUpLq6OiUmJio9PV0tLS0+6y9fvqyJEyeqpKREdru9V/YJAADuLH4Hlk2bNmnZsmXKycnR1KlTVV5eroEDB2rbtm0+62fPnq3169dr0aJFioqK6pV9AgCAO4tfgaWzs1O1tbVyOBxXdxAeLofDoZqamoAaCGSfHR0damtr81oAAEDo8iuwnD9/Xl1dXYqLi/NaHxcXJ5fLFVADgeyzuLhYNpvNs8THxwf03gAAoG/ok3cJFRQUqLW11bM0NjYGuyUAAHAL9fOnODY2VhEREWpubvZa39zcfN0Lam/FPqOioq57PQwAAAg9fs2wREZGKjk5WU6n07PO7XbL6XQqLS0toAZuxT4BAEBo8WuGRZLy8vKUnZ2tWbNmKSUlRaWlpWpvb1dOTo4kacmSJRozZoyKi4slfXFR7TvvvOP589mzZ1VfX6/Bgwfr7rvv7tE+AQDAnc3vwJKVlaVz586psLBQLpdLSUlJqqys9Fw029DQoPDwqxM3TU1Nmjlzpuf1hg0btGHDBs2dO1fV1dU92icAALizhVmWZQW7iZvV1tYmm82m1tZWxcTEBLsdAAD6lIT8/d3WnCnJ6PX39efzu0/eJQQAAO4sBBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjBdQYCkrK1NCQoKio6OVmpqqo0eP3rB+165dmjJliqKjozV9+nQdOHDA6+eXLl1Sbm6uxo4dqwEDBmjq1KkqLy8PpDUAABCC/A4sO3fuVF5enoqKilRXV6fExESlp6erpaXFZ/3hw4e1ePFiLV26VMePH1dmZqYyMzN14sQJT01eXp4qKyv105/+VO+++65WrVql3Nxc7d27N/CRAQCAkBFmWZblzwapqamaPXu2tmzZIklyu92Kj4/XihUrlJ+ff019VlaW2tvbtW/fPs+6r33ta0pKSvLMokybNk1ZWVlat26dpyY5OVmPPPKIfvCDH3TbU1tbm2w2m1pbWxUTE+PPcAAAuOMl5O/vtuZMSUavv68/n99+zbB0dnaqtrZWDofj6g7Cw+VwOFRTU+Nzm5qaGq96SUpPT/eqnzNnjvbu3auzZ8/KsiwdPHhQf/zjHzVv3jyf++zo6FBbW5vXAgAAQpdfgeX8+fPq6upSXFyc1/q4uDi5XC6f27hcrm7rN2/erKlTp2rs2LGKjIzU/PnzVVZWpgceeMDnPouLi2Wz2TxLfHy8P8MAAAB9jBF3CW3evFlHjhzR3r17VVtbq40bN2r58uX69a9/7bO+oKBAra2tnqWxsfE2dwwAAG6nfv4Ux8bGKiIiQs3NzV7rm5ubZbfbfW5jt9tvWP/pp59qzZo12r17tzIyvjg/NmPGDNXX12vDhg3XnE6SpKioKEVFRfnTOgAA6MP8mmGJjIxUcnKynE6nZ53b7ZbT6VRaWprPbdLS0rzqJamqqspTf+XKFV25ckXh4d6tREREyO12+9MeAAAIUX7NsEhf3IKcnZ2tWbNmKSUlRaWlpWpvb1dOTo4kacmSJRozZoyKi4slSStXrtTcuXO1ceNGZWRkqKKiQseOHdPWrVslSTExMZo7d65Wr16tAQMGaPz48Tp06JB+8pOfaNOmTb04VAAA0Ff5HViysrJ07tw5FRYWyuVyKSkpSZWVlZ4LaxsaGrxmS+bMmaMdO3Zo7dq1WrNmjSZNmqQ9e/Zo2rRpnpqKigoVFBToySef1J///GeNHz9eL730kp555pleGCIAAOjr/H4Oi4l4DgsAAIELueewAAAABAOBBQAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYLx+wW4AAHDzEvL3d1tzpiTjNnQC3BrMsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwXkCBpaysTAkJCYqOjlZqaqqOHj16w/pdu3ZpypQpio6O1vTp03XgwIFrat599109+uijstlsGjRokGbPnq2GhoZA2gMAACHG78Cyc+dO5eXlqaioSHV1dUpMTFR6erpaWlp81h8+fFiLFy/W0qVLdfz4cWVmZiozM1MnTpzw1HzwwQe6//77NWXKFFVXV+vtt9/WunXrFB0dHfjIAABAyAizLMvyZ4PU1FTNnj1bW7ZskSS53W7Fx8drxYoVys/Pv6Y+KytL7e3t2rdvn2fd1772NSUlJam8vFyStGjRIvXv31//8i//EtAg2traZLPZ1NraqpiYmID2AQB9WUL+/m5rzpRk3IZO0BcF6++PP5/ffs2wdHZ2qra2Vg6H4+oOwsPlcDhUU1Pjc5uamhqveklKT0/31Lvdbu3fv19f+cpXlJ6erpEjRyo1NVV79uy5bh8dHR1qa2vzWgAAQOjyK7CcP39eXV1diouL81ofFxcnl8vlcxuXy3XD+paWFl26dEklJSWaP3++fvWrX+nxxx/XE088oUOHDvncZ3FxsWw2m2eJj4/3ZxgAAKCPCfpdQm63W5L02GOP6dvf/raSkpKUn5+vv/mbv/GcMvr/CgoK1Nra6lkaGxtvZ8sAAOA26+dPcWxsrCIiItTc3Oy1vrm5WXa73ec2drv9hvWxsbHq16+fpk6d6lVzzz336Le//a3PfUZFRSkqKsqf1gEAQB/m1wxLZGSkkpOT5XQ6PevcbrecTqfS0tJ8bpOWluZVL0lVVVWe+sjISM2ePVunTp3yqvnjH/+o8ePH+9MeAAAIUX7NsEhSXl6esrOzNWvWLKWkpKi0tFTt7e3KycmRJC1ZskRjxoxRcXGxJGnlypWaO3euNm7cqIyMDFVUVOjYsWPaunWrZ5+rV69WVlaWHnjgAT300EOqrKzUL37xC1VXV/fOKAEAQJ/md2DJysrSuXPnVFhYKJfLpaSkJFVWVnourG1oaFB4+NWJmzlz5mjHjh1au3at1qxZo0mTJmnPnj2aNm2ap+bxxx9XeXm5iouL9dxzz2ny5Mn693//d91///29MEQA6Ju4VRm4yu/nsJiI57AACEX+BBbCDW5GyD2HBQAAIBgILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwHoEFAAAYj8ACAACM1y/YDQAAzJaQv/+GPz9TknGbOsGdjBkWAABgPAILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjNcv2A0AgIkS8vd3W3OmJOM2dAJAYoYFAAD0AQQWAABgvIACS1lZmRISEhQdHa3U1FQdPXr0hvW7du3SlClTFB0drenTp+vAgQPXrX3mmWcUFham0tLSQFoDAAAhyO/AsnPnTuXl5amoqEh1dXVKTExUenq6WlpafNYfPnxYixcv1tKlS3X8+HFlZmYqMzNTJ06cuKZ29+7dOnLkiEaPHu3/SAAAQMjy+6LbTZs2admyZcrJyZEklZeXa//+/dq2bZvy8/OvqX/11Vc1f/58rV69WpL04osvqqqqSlu2bFF5ebmn7uzZs1qxYoV++ctfKiODC9kA9B1coAvcen7NsHR2dqq2tlYOh+PqDsLD5XA4VFNT43Obmpoar3pJSk9P96p3u9166qmntHr1an31q1/tto+Ojg61tbV5LQAAIHT5FVjOnz+vrq4uxcXFea2Pi4uTy+XyuY3L5eq2/p/+6Z/Ur18/Pffccz3qo7i4WDabzbPEx8f7MwwAANDHBP0uodraWr366qt64403FBYW1qNtCgoK1Nra6lkaGxtvcZcAACCY/AossbGxioiIUHNzs9f65uZm2e12n9vY7fYb1v/Xf/2XWlpaNG7cOPXr10/9+vXTRx99pOeff14JCQk+9xkVFaWYmBivBQAAhC6/AktkZKSSk5PldDo969xut5xOp9LS0nxuk5aW5lUvSVVVVZ76p556Sm+//bbq6+s9y+jRo7V69Wr98pe/9Hc8AAAgBPl9l1BeXp6ys7M1a9YspaSkqLS0VO3t7Z67hpYsWaIxY8aouLhYkrRy5UrNnTtXGzduVEZGhioqKnTs2DFt3bpVkjR8+HANHz7c6z369+8vu92uyZMn3+z4AABACPA7sGRlZencuXMqLCyUy+VSUlKSKisrPRfWNjQ0KDz86sTNnDlztGPHDq1du1Zr1qzRpEmTtGfPHk2bNq33RgEAfQS3QAOBCejLD3Nzc5Wbm+vzZ9XV1desW7hwoRYuXNjj/Z85cyaQtgDghggLQN8V9LuEAAAAukNgAQAAxiOwAAAA4wV0DQsAAL50d50Q1wghUMywAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGA8AgsAADAe39YMoE/r7tuBJb4hGAgFzLAAAADjEVgAAIDxOCUEAAiK7k7nBXoqj9OEoYkZFgAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeDw4DgDuMH3xwWp9sWf0LmZYAACA8QgsAADAeJwSAgDcsTjV1HcQWIBbjH8QAeDmcUoIAAAYjxkWAACCiFnYnmGGBQAAGC+gwFJWVqaEhARFR0crNTVVR48evWH9rl27NGXKFEVHR2v69Ok6cOCA52dXrlzRCy+8oOnTp2vQoEEaPXq0lixZoqampkBaAwAAIcjvwLJz507l5eWpqKhIdXV1SkxMVHp6ulpaWnzWHz58WIsXL9bSpUt1/PhxZWZmKjMzUydOnJAkXb58WXV1dVq3bp3q6ur0H//xHzp16pQeffTRmxsZAAAIGX5fw7Jp0yYtW7ZMOTk5kqTy8nLt379f27ZtU35+/jX1r776qubPn6/Vq1dLkl588UVVVVVpy5YtKi8vl81mU1VVldc2W7ZsUUpKihoaGjRu3LhAxgXgL3COHEBf59cMS2dnp2pra+VwOK7uIDxcDodDNTU1Prepqanxqpek9PT069ZLUmtrq8LCwjR06FB/2gMAACHKrxmW8+fPq6urS3FxcV7r4+Li9N577/ncxuVy+ax3uVw+6z/77DO98MILWrx4sWJiYnzWdHR0qKOjw/O6ra3Nn2EAuAFmYwCYyKi7hK5cuaJvfOMbsixLr7322nXriouLZbPZPEt8fPxt7BIAANxufgWW2NhYRUREqLm52Wt9c3Oz7Ha7z23sdnuP6r8MKx999JGqqqquO7siSQUFBWptbfUsjY2N/gwDAAD0MX6dEoqMjFRycrKcTqcyMzMlSW63W06nU7m5uT63SUtLk9Pp1KpVqzzrqqqqlJaW5nn9ZVg5ffq0Dh48qOHDh9+wj6ioKEVFRfnTOhByOHUD4E7i911CeXl5ys7O1qxZs5SSkqLS0lK1t7d77hpasmSJxowZo+LiYknSypUrNXfuXG3cuFEZGRmqqKjQsWPHtHXrVklfhJW/+7u/U11dnfbt26euri7P9S133XWXIiMje2usAAD0aXfy/6j4HViysrJ07tw5FRYWyuVyKSkpSZWVlZ4LaxsaGhQefvVM05w5c7Rjxw6tXbtWa9as0aRJk7Rnzx5NmzZNknT27Fnt3btXkpSUlOT1XgcPHtSDDz4Y4NAAAECoCOi7hHJzc697Cqi6uvqadQsXLtTChQt91ickJMiyrEDaAAAA1xFqszF8+SEAAD3gTwAItbBgAqNuawYAAPCFwAIAAIxHYAEAAMbjGhYAAeM8PYDbhcACBKC7D2o+pAGgdxFYANwWzMYAuBlcwwIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB7PYekBnh+B24W/awDgGzMsAADAeMywAP+Hx+0DgLmYYQEAAMZjhgV9ir/XeDBrAgChgRkWAABgPAILAAAwHoEFAAAYj8ACAACMx0W3uCV4ABoAoDcxwwIAAIzHDAt6jFkTAECwMMMCAACMxwwLgo6ZGwBAd5hhAQAAxmOGJQQxYwEACDUElj7iVoUQwg0AoC/glBAAADAeMyxBxOwGAAA9wwwLAAAwHoEFAAAYj8ACAACMR2ABAADG46LbXsaFtAAA9L6AZljKysqUkJCg6Ohopaam6ujRozes37Vrl6ZMmaLo6GhNnz5dBw4c8Pq5ZVkqLCzUqFGjNGDAADkcDp0+fTqQ1gAAQAjyO7Ds3LlTeXl5KioqUl1dnRITE5Wenq6Wlhaf9YcPH9bixYu1dOlSHT9+XJmZmcrMzNSJEyc8NS+//LJ++MMfqry8XG+99ZYGDRqk9PR0ffbZZ4GPDAAAhAy/A8umTZu0bNky5eTkaOrUqSovL9fAgQO1bds2n/Wvvvqq5s+fr9WrV+uee+7Riy++qHvvvVdbtmyR9MXsSmlpqdauXavHHntMM2bM0E9+8hM1NTVpz549NzU4AAAQGvy6hqWzs1O1tbUqKCjwrAsPD5fD4VBNTY3PbWpqapSXl+e1Lj093RNGPvzwQ7lcLjkcDs/PbTabUlNTVVNTo0WLFl2zz46ODnV0dHhet7a2SpLa2tr8GU6PuTsud1vz5XtTa05tT+pNqP3Lemr9/x02od9Qru1JvQm1f1lPrTm/nz3dp2VZ3Rdbfjh79qwlyTp8+LDX+tWrV1spKSk+t+nfv7+1Y8cOr3VlZWXWyJEjLcuyrN/97neWJKupqcmrZuHChdY3vvENn/ssKiqyJLGwsLCwsLCEwNLY2NhtBumTdwkVFBR4zdq43W79+c9/1vDhwxUWFnZL37utrU3x8fFqbGxUTEzMLX2vYAjl8TG2vimUxyaF9vgYW991u8ZnWZYuXryo0aNHd1vrV2CJjY1VRESEmpubvdY3NzfLbrf73MZut9+w/sv/Njc3a9SoUV41SUlJPvcZFRWlqKgor3VDhw71Zyg3LSYmJiT/kn4plMfH2PqmUB6bFNrjY2x91+0Yn81m61GdXxfdRkZGKjk5WU6n07PO7XbL6XQqLS3N5zZpaWle9ZJUVVXlqZ8wYYLsdrtXTVtbm956663r7hMAANxZ/D4llJeXp+zsbM2aNUspKSkqLS1Ve3u7cnJyJElLlizRmDFjVFxcLElauXKl5s6dq40bNyojI0MVFRU6duyYtm7dKkkKCwvTqlWr9IMf/ECTJk3ShAkTtG7dOo0ePVqZmZm9N1IAANBn+R1YsrKydO7cORUWFsrlcikpKUmVlZWKi4uTJDU0NCg8/OrEzZw5c7Rjxw6tXbtWa9as0aRJk7Rnzx5NmzbNU/Od73xH7e3t+uY3v6kLFy7o/vvvV2VlpaKjo3thiL0rKipKRUVF15ySChWhPD7G1jeF8tik0B4fY+u7TBxfmGX15F4iAACA4OHLDwEAgPEILAAAwHgEFgAAYDwCCwAAMB6BpYfOnDmjpUuXasKECRowYID+6q/+SkVFRers7PSqe/vtt/X1r39d0dHRio+P18svvxykjv330ksvac6cORo4cOB1H8QXFhZ2zVJRUXF7Gw1AT8bW0NCgjIwMDRw4UCNHjtTq1av1+eef395Ge0FCQsI1x6ikpCTYbQWsrKxMCQkJio6OVmpqqo4ePRrslm7ad7/73WuO0ZQpU4LdVsDefPNN/e3f/q1Gjx6tsLCwa7641rIsFRYWatSoURowYIAcDodOnz4dnGb91N3Ynn766WuO5fz584PTrJ+Ki4s1e/ZsDRkyRCNHjlRmZqZOnTrlVfPZZ59p+fLlGj58uAYPHqwFCxZc8zDY24XA0kPvvfee3G63Xn/9dZ08eVKvvPKKysvLtWbNGk9NW1ub5s2bp/Hjx6u2tlbr16/Xd7/7Xc8zZ0zX2dmphQsX6tlnn71h3fbt2/XJJ594lr7wvJzuxtbV1aWMjAx1dnbq8OHD+ud//me98cYbKiwsvM2d9o7vf//7XsdoxYoVwW4pIDt37lReXp6KiopUV1enxMREpaenq6WlJdit3bSvfvWrXsfot7/9bbBbClh7e7sSExNVVlbm8+cvv/yyfvjDH6q8vFxvvfWWBg0apPT0dH322We3uVP/dTc2SZo/f77XsfzZz352GzsM3KFDh7R8+XIdOXJEVVVVunLliubNm6f29nZPzbe//W394he/0K5du3To0CE1NTXpiSeeCE7D3X7bEK7r5ZdftiZMmOB5/aMf/cgaNmyY1dHR4Vn3wgsvWJMnTw5GewHbvn27ZbPZfP5MkrV79+7b2k9vut7YDhw4YIWHh1sul8uz7rXXXrNiYmK8jmdfMH78eOuVV14Jdhu9IiUlxVq+fLnndVdXlzV69GiruLg4iF3dvKKiIisxMTHYbdwS///fCLfbbdntdmv9+vWedRcuXLCioqKsn/3sZ0HoMHC+/v3Lzs62HnvssaD009taWlosSdahQ4csy/riOPXv39/atWuXp+bdd9+1JFk1NTW3vT9mWG5Ca2ur7rrrLs/rmpoaPfDAA4qMjPSsS09P16lTp/Q///M/wWjxlli+fLliY2OVkpKibdu29exrwQ1XU1Oj6dOnex6AKH1x7Nra2nTy5MkgdhaYkpISDR8+XDNnztT69ev75Kmtzs5O1dbWyuFweNaFh4fL4XCopqYmiJ31jtOnT2v06NGaOHGinnzySTU0NAS7pVviww8/lMvl8jqONptNqampIXEcJam6ulojR47U5MmT9eyzz+pPf/pTsFsKSGtrqyR5Ptdqa2t15coVr2M3ZcoUjRs3LijHrk9+W7MJ3n//fW3evFkbNmzwrHO5XJowYYJX3ZcfgC6XS8OGDbutPd4K3//+9/XXf/3XGjhwoH71q1/pW9/6li5duqTnnnsu2K3dFJfL5RVWJO9j15c899xzuvfee3XXXXfp8OHDKigo0CeffKJNmzYFuzW/nD9/Xl1dXT6Py3vvvRekrnpHamqq3njjDU2ePFmffPKJvve97+nrX/+6Tpw4oSFDhgS7vV715e+Pr+PY1363fJk/f76eeOIJTZgwQR988IHWrFmjRx55RDU1NYqIiAh2ez3mdru1atUq3XfffZ4n0btcLkVGRl5z3V+wjt0dP8OSn5/v80LSv1z+/z+OZ8+e1fz587Vw4UItW7YsSJ33TCDju5F169bpvvvu08yZM/XCCy/oO9/5jtavX38LR3B9vT02k/kz1ry8PD344IOaMWOGnnnmGW3cuFGbN29WR0dHkEeBLz3yyCNauHChZsyYofT0dB04cEAXLlzQv/3bvwW7Nfhp0aJFevTRRzV9+nRlZmZq3759+v3vf6/q6upgt+aX5cuX68SJE0bfRHHHz7A8//zzevrpp29YM3HiRM+fm5qa9NBDD2nOnDnXXExrt9uvuXr6y9d2u713GvaTv+PzV2pqql588UV1dHTc9u+c6M2x2e32a+4+Cfax+0s3M9bU1FR9/vnnOnPmjCZPnnwLurs1YmNjFRER4fN3yoRj0puGDh2qr3zlK3r//feD3Uqv+/JYNTc3a9SoUZ71zc3NSkpKClJXt87EiRMVGxur999/Xw8//HCw2+mR3Nxc7du3T2+++abGjh3rWW+329XZ2akLFy54zbIE63fwjg8sI0aM0IgRI3pUe/bsWT300ENKTk7W9u3bvb7kUZLS0tL0D//wD7py5Yr69+8vSaqqqtLkyZODdjrIn/EFor6+XsOGDQvKF2T15tjS0tL00ksvqaWlRSNHjpT0xbGLiYnR1KlTe+U9bsbNjLW+vl7h4eGecfUVkZGRSk5OltPp9NyJ5na75XQ6lZubG9zmetmlS5f0wQcf6Kmnngp2K71uwoQJstvtcjqdnoDS1tamt956q9s7Evuijz/+WH/605+8wpmpLMvSihUrtHv3blVXV19zSUNycrL69+8vp9OpBQsWSJJOnTqlhoYGpaWlBaVh9MDHH39s3X333dbDDz9sffzxx9Ynn3ziWb504cIFKy4uznrqqaesEydOWBUVFdbAgQOt119/PYid99xHH31kHT9+3Pre975nDR482Dp+/Lh1/Phx6+LFi5ZlWdbevXutH//4x9Yf/vAH6/Tp09aPfvQja+DAgVZhYWGQO+9ed2P7/PPPrWnTplnz5s2z6uvrrcrKSmvEiBFWQUFBkDv3z+HDh61XXnnFqq+vtz744APrpz/9qTVixAhryZIlwW4tIBUVFVZUVJT1xhtvWO+88471zW9+0xo6dKjX3Vx90fPPP29VV1dbH374ofW73/3OcjgcVmxsrNXS0hLs1gJy8eJFz++UJGvTpk3W8ePHrY8++siyLMsqKSmxhg4dav385z+33n77beuxxx6zJkyYYH366adB7rx7NxrbxYsXrb//+7+3ampqrA8//ND69a9/bd17773WpEmTrM8++yzYrXfr2WeftWw2m1VdXe31mXb58mVPzTPPPGONGzfO+s1vfmMdO3bMSktLs9LS0oLSL4Glh7Zv325J8rn8pf/+7/+27r//fisqKsoaM2aMVVJSEqSO/Zedne1zfAcPHrQsy7L+8z//00pKSrIGDx5sDRo0yEpMTLTKy8utrq6u4DbeA92NzbIs68yZM9YjjzxiDRgwwIqNjbWef/5568qVK8FrOgC1tbVWamqqZbPZrOjoaOuee+6x/vEf/7FP/ON5PZs3b7bGjRtnRUZGWikpKdaRI0eC3dJNy8rKskaNGmVFRkZaY8aMsbKysqz3338/2G0F7ODBgz5/v7Kzsy3L+uLW5nXr1llxcXFWVFSU9fDDD1unTp0KbtM9dKOxXb582Zo3b541YsQIq3///tb48eOtZcuW9ZlAfb3PtO3bt3tqPv30U+tb3/qWNWzYMGvgwIHW448/7vU/6rdT2P81DQAAYKw7/i4hAABgPgILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIz3v5FUrqAH/4M1AAAAAElFTkSuQmCC",
90
+ "text/plain": [
91
+ "<Figure size 640x480 with 1 Axes>"
92
+ ]
93
+ },
94
+ "metadata": {},
95
+ "output_type": "display_data"
96
+ }
97
+ ],
98
+ "source": [
99
+ "# Run this cell to see the spread probabilities!\n",
100
+ "\n",
101
+ "assert len(home_player_tokens) == players_per_team\n",
102
+ "assert len(home_age_tokens) == players_per_team\n",
103
+ "assert len(away_player_tokens) == players_per_team\n",
104
+ "assert len(away_age_tokens) == players_per_team\n",
105
+ "\n",
106
+ "batch = {\n",
107
+ " 'home_player_tokens': Tensor([num_player_tokens+1] + home_player_tokens).to(dtype=int32).unsqueeze(0),\n",
108
+ " 'home_age_tokens': Tensor([num_age_tokens+1] + home_age_tokens).to(dtype=int32).unsqueeze(0),\n",
109
+ " 'away_player_tokens': Tensor(away_player_tokens).to(dtype=int32).unsqueeze(0),\n",
110
+ " 'away_age_tokens': Tensor(away_age_tokens).to(dtype=int32).unsqueeze(0),\n",
111
+ "}\n",
112
+ "\n",
113
+ "for key, value in batch.items():\n",
114
+ " if hasattr(value, 'to'):\n",
115
+ " batch[key] = value.to(device)\n",
116
+ "\n",
117
+ "output, _ = model(**batch)\n",
118
+ "output = output.squeeze().softmax(dim=0)\n",
119
+ "\n",
120
+ "probs = {}\n",
121
+ "loss_prob = 0\n",
122
+ "win_prob = 0\n",
123
+ "\n",
124
+ "first = True\n",
125
+ "for i, token in enumerate(output):\n",
126
+ " if first:\n",
127
+ " first = False\n",
128
+ " continue\n",
129
+ "\n",
130
+ " if i-21 < 0:\n",
131
+ " loss_prob += token.item()\n",
132
+ " elif i-21 > 0 and i-21 < 21:\n",
133
+ " win_prob += token.item()\n",
134
+ "\n",
135
+ " probs[i-21] = token.item()\n",
136
+ "\n",
137
+ "del probs[0]\n",
138
+ "del probs[21]\n",
139
+ "\n",
140
+ "print(f\"Home team win probability: {win_prob:.2f}\")\n",
141
+ "\n",
142
+ "plt.bar(probs.keys(), probs.values())"
143
+ ]
144
+ },
145
+ {
146
+ "cell_type": "code",
147
+ "execution_count": null,
148
+ "metadata": {},
149
+ "outputs": [],
150
+ "source": []
151
+ }
152
+ ],
153
+ "metadata": {
154
+ "kernelspec": {
155
+ "display_name": "nba",
156
+ "language": "python",
157
+ "name": "python3"
158
+ },
159
+ "language_info": {
160
+ "codemirror_mode": {
161
+ "name": "ipython",
162
+ "version": 3
163
+ },
164
+ "file_extension": ".py",
165
+ "mimetype": "text/x-python",
166
+ "name": "python",
167
+ "nbconvert_exporter": "python",
168
+ "pygments_lexer": "ipython3",
169
+ "version": "3.11.8"
170
+ }
171
+ },
172
+ "nbformat": 4,
173
+ "nbformat_minor": 2
174
+ }
model.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn import (
2
+ Module,
3
+ Embedding,
4
+ Dropout,
5
+ ModuleDict,
6
+ LayerNorm,
7
+ ModuleList,
8
+ Linear,
9
+ GELU,
10
+ functional,
11
+ )
12
+ from torch.nn.init import normal_, zeros_
13
+ from dataclasses import dataclass
14
+ from rotary_embedding_torch import RotaryEmbedding
15
+ from torch import ones, cat
16
+ from torch.nn.functional import scaled_dot_product_attention
17
+ import torch.nn.functional as F
18
+ from math import sqrt
19
+
20
+ @dataclass
21
+ class NBAConfig:
22
+ players_per_team: int = None
23
+ player_tokens: int = None
24
+ age_tokens: int = None
25
+ n_layer: int = None
26
+ n_head: int = None
27
+ n_embd: int = None
28
+ dropout: float = None
29
+ seed: int = None
30
+ bias: bool = None
31
+ dtype: type = None
32
+ num_labels: int = None
33
+
34
+ class SelfAttention(Module):
35
+
36
+ def __init__(self, config):
37
+
38
+ block_size = config.players_per_team * 2 + 1
39
+
40
+ super().__init__()
41
+ assert config.n_embd % config.n_head == 0
42
+ self.c_attn = Linear(config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=config.dtype)
43
+ self.c_proj = Linear(config.n_embd, config.n_embd, bias=config.bias, dtype=config.dtype)
44
+ self.attn_dropout = Dropout(config.dropout)
45
+ self.resid_dropout = Dropout(config.dropout)
46
+ self.n_head = config.n_head
47
+ self.n_embd = config.n_embd
48
+ self.dropout = config.dropout
49
+ self.rotary_emb = RotaryEmbedding(config.n_embd)
50
+ self.flash = hasattr(functional, 'scaled_dot_product_attention')
51
+ if not self.flash:
52
+ self.register_buffer("bias", ones(block_size, block_size)
53
+ ).view(1, 1, block_size, block_size)
54
+
55
+ def forward(self, x):
56
+ B, T, C = x.size()
57
+
58
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
59
+
60
+ q = self.rotary_emb.rotate_queries_or_keys(q)
61
+ k = self.rotary_emb.rotate_queries_or_keys(k)
62
+
63
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
64
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
65
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
66
+
67
+ if self.flash:
68
+ y = scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=False)
69
+ else:
70
+ att = (q @ k.transpose(-2, -1)) * (1.0 / sqrt(k.size(-1)))
71
+ att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
72
+ att = F.softmax(att, dim=-1)
73
+ att = self.attn_dropout(att)
74
+ y = att @ v
75
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
76
+
77
+ # output projection
78
+ y = self.resid_dropout(self.c_proj(y))
79
+ return y
80
+
81
+ class MLP(Module):
82
+
83
+ def __init__(self, config):
84
+ super().__init__()
85
+ self.c_fc = Linear(config.n_embd, 4 * config.n_embd, bias=config.bias, dtype=config.dtype)
86
+ self.gelu = GELU()
87
+ self.c_proj = Linear(4 * config.n_embd, config.n_embd, bias=config.bias, dtype=config.dtype)
88
+ self.dropout = Dropout(config.dropout)
89
+
90
+ def forward(self, x):
91
+ x = self.c_fc(x)
92
+ x = self.gelu(x)
93
+ x = self.c_proj(x)
94
+ x = self.dropout(x)
95
+ return x
96
+
97
+ class Block(Module):
98
+
99
+ def __init__(self, config):
100
+ super().__init__()
101
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias, dtype=config.dtype)
102
+ self.attn = SelfAttention(config)
103
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias, dtype=config.dtype)
104
+ self.mlp = MLP(config)
105
+
106
+ def forward(self, x):
107
+ x = x + self.attn(self.ln_1(x))
108
+ return x + self.mlp(self.ln_2(x))
109
+
110
+ class NBAModel(Module):
111
+
112
+ def __init__(self, config) -> None:
113
+ super().__init__()
114
+
115
+ self.config = config
116
+
117
+ self.transformer = ModuleDict(dict(
118
+ home_player_embeddings = Embedding(config.player_tokens, config.n_embd, dtype=config.dtype),
119
+ away_player_embeddings = Embedding(config.player_tokens, config.n_embd, dtype=config.dtype),
120
+ home_age_embeddings = Embedding(config.age_tokens, config.n_embd, dtype=config.dtype),
121
+ away_age_embeddings = Embedding(config.age_tokens, config.n_embd, dtype=config.dtype),
122
+ drop = Dropout(config.dropout),
123
+ h = ModuleList([Block(config) for _ in range(config.n_layer)]),
124
+ ln_f = LayerNorm(config.n_embd, bias=config.bias, dtype=config.dtype),
125
+ ))
126
+
127
+ self.head = Linear(config.n_embd, config.num_labels, dtype=config.dtype)
128
+
129
+ self.apply(self._init_weights)
130
+ for pn, p in self.named_parameters():
131
+ if pn.endswith('c_proj.weight'):
132
+ normal_(p, mean=0.0, std=0.02/sqrt(2 * config.n_layer))
133
+
134
+ def _init_weights(self, module):
135
+ if isinstance(module, Linear):
136
+ normal_(module.weight, mean=0.0, std=0.02)
137
+ if module.bias is not None:
138
+ zeros_(module.bias)
139
+ elif isinstance(module, Embedding):
140
+ normal_(module.weight, mean=0.0, std=0.02)
141
+
142
+ def forward(self, **batch):
143
+ home_player_tokens = batch['home_player_tokens']
144
+ away_player_tokens = batch['away_player_tokens']
145
+ home_age_tokens = batch['home_age_tokens']
146
+ away_age_tokens = batch['away_age_tokens']
147
+
148
+ home_player_embeddings = self.transformer.home_player_embeddings(home_player_tokens)
149
+ away_player_embeddings = self.transformer.away_player_embeddings(away_player_tokens)
150
+
151
+ home_age_embeddings = self.transformer.home_age_embeddings(home_age_tokens)
152
+ away_age_embeddings = self.transformer.away_age_embeddings(away_age_tokens)
153
+
154
+ home_emb = home_player_embeddings + home_age_embeddings
155
+ away_emb = away_player_embeddings + away_age_embeddings
156
+
157
+ x = cat([home_emb, away_emb], dim=1)
158
+
159
+ x = self.transformer.drop(x)
160
+
161
+ for block in self.transformer.h:
162
+ x = block(x)
163
+ x = self.transformer.ln_f(x)
164
+
165
+ logits = self.head(x)
166
+ logits = logits[:, 0]
167
+
168
+ loss = None
169
+ if 'home_team_won' in batch:
170
+ loss = F.cross_entropy(logits, batch['home_net_score_token'])
171
+ loss = {'loss': loss}
172
+
173
+ return logits, loss
player_tokens.csv ADDED
The diff for this file is too large to render. See raw diff
 
porzingis-swapped-for-pritchard.png ADDED
prediction.png ADDED
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ rotary_embedding_torch
2
+ torch
3
+ jupyter
4
+ matplotlib
weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51ab232915c68ba50ac907b60139e7e45c08eb2ce92a95fa29488b19896ffa2e
3
+ size 121995384