File size: 17,688 Bytes
5f26252
 
 
 
d81c324
5f26252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d81c324
5f26252
 
 
 
 
 
d81c324
5f26252
 
 
 
 
 
 
 
d81c324
5f26252
 
 
 
 
d81c324
5f26252
 
 
 
 
 
 
 
 
7a4fc48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21e77ce
 
7a4fc48
 
 
 
 
 
 
5f26252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from model import NBAModel, NBAConfig\n",
    "from torch import device as torch_device, load as torch_load, int32, Tensor, bfloat16\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "device = torch_device(\"cpu\") \n",
    "num_age_tokens=32\n",
    "num_player_tokens=5141\n",
    "num_net_score_tokens=41\n",
    "players_per_team=8\n",
    "\n",
    "model_config = NBAConfig(\n",
    "    players_per_team=players_per_team,\n",
    "    player_tokens=num_player_tokens+2,\n",
    "    age_tokens=num_age_tokens+2,\n",
    "    num_labels=num_net_score_tokens+2,\n",
    "    n_layer=4,\n",
    "    n_head=4,\n",
    "    n_embd=1024,\n",
    "    dropout=0.0,\n",
    "    bias=False,\n",
    "    dtype=bfloat16,\n",
    "    seed=29,\n",
    ")\n",
    "\n",
    "model = NBAModel(model_config).to(device)\n",
    "state_dict = torch_load('weights.pt', map_location='cpu')\n",
    "model.load_state_dict(state_dict)\n",
    "model = model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Home team win probability: 0.65\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<BarContainer object of 40 artists>"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAGdCAYAAAAxCSikAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAjzklEQVR4nO3dfVCVdf7/8ReogJhgiXK8IbFd8yYVEoWwG2tjhIa2qNbIadQcx8byNlpTXIXKbXEtzVKK3Jls29ZwnZ1cM5fJKGyLU66o29qWaZNi2gGtBKUE43x+f/Tz1Pl6FA4i58Px+Zi5prh4Xxfvz1wdePW57kKMMUYAAAAWCw10AwAAAE0hsAAAAOsRWAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArNcx0A20BrfbrcOHD6tr164KCQkJdDsAAKAZjDE6fvy4evfurdDQc8+hBEVgOXz4sOLi4gLdBgAAaIGDBw+qb9++56wJisDStWtXST8OOCoqKsDdAACA5qitrVVcXJzn7/i5BEVgOX0aKCoqisACAEA705zLObjoFgAAWI/AAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWI7AAAADrEVgAAID1CCwAAMB6HQPdAAAACKz4+W80WbN/SWYbdHJ2zLAAAADrEVgAAID1CCwAAMB6BBYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOsRWAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWI7AAAADrEVgAAID1CCwAAMB6BBYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOsRWAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwXosCS2FhoeLj4xUREaGUlBRt27btnPXr16/XoEGDFBERoWHDhmnz5s1e3z9x4oRmzJihvn37qnPnzhoyZIiKiopa0hoAAAhCfgeWdevWKScnR/n5+dqxY4cSEhKUnp6u6upqn/Xl5eUaP368pkyZop07dyorK0tZWVnavXu3pyYnJ0clJSV65ZVX9Mknn2jOnDmaMWOGNm7c2PKRAQCAoBFijDH+bJCSkqJRo0Zp1apVkiS32624uDjNnDlT8+fPP6M+OztbdXV12rRpk2fdNddco8TERM8sytChQ5Wdna1FixZ5apKSknTLLbfo97//fZM91dbWKjo6WjU1NYqKivJnOAAAXPTi57/RZM3+JZmt/nP9+fvt1wxLQ0ODKioqlJaW9tMOQkOVlpYmp9Ppcxun0+lVL0np6ele9aNHj9bGjRt16NAhGWP0zjvv6LPPPtPYsWN97rO+vl61tbVeCwAACF5+BZajR4+qsbFRsbGxXutjY2Plcrl8buNyuZqsX7lypYYMGaK+ffsqLCxMGRkZKiws1A033OBznwUFBYqOjvYscXFx/gwDAAC0M1bcJbRy5Up98MEH2rhxoyoqKrRs2TJNnz5db731ls/63Nxc1dTUeJaDBw+2cccAAKAtdfSnOCYmRh06dFBVVZXX+qqqKjkcDp/bOByOc9Z///33WrBggV577TVlZv54fmz48OHatWuXnnrqqTNOJ0lSeHi4wsPD/WkdAAC0Y37NsISFhSkpKUmlpaWedW63W6WlpUpNTfW5TWpqqle9JG3ZssVTf+rUKZ06dUqhod6tdOjQQW6325/2AABAkPJrhkX68RbkSZMmaeTIkUpOTtaKFStUV1enyZMnS5ImTpyoPn36qKCgQJI0e/ZsjRkzRsuWLVNmZqaKi4u1fft2rV69WpIUFRWlMWPGaO7cuercubP69eunrVu36uWXX9by5ctbcagAAKC98juwZGdn68iRI8rLy5PL5VJiYqJKSko8F9ZWVlZ6zZaMHj1aa9eu1cKFC7VgwQINGDBAGzZs0NChQz01xcXFys3N1b333qtvvvlG/fr10xNPPKFp06a1whABAEB75/dzWGzEc1gAAGi5oHsOCwAAQCAQWAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWI7AAAADrEVgAAID1CCwAAMB6BBYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOsRWAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWI7AAAADrEVgAAID1CCwAAMB6BBYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOsRWAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWI7AAAADrtSiwFBYWKj4+XhEREUpJSdG2bdvOWb9+/XoNGjRIERERGjZsmDZv3nxGzSeffKLbbrtN0dHR6tKli0aNGqXKysqWtAcAAIKM34Fl3bp1ysnJUX5+vnbs2KGEhASlp6erurraZ315ebnGjx+vKVOmaOfOncrKylJWVpZ2797tqfn888913XXXadCgQSorK9NHH32kRYsWKSIiouUjAwAAQSPEGGP82SAlJUWjRo3SqlWrJElut1txcXGaOXOm5s+ff0Z9dna26urqtGnTJs+6a665RomJiSoqKpIk3XPPPerUqZP+8pe/tGgQtbW1io6OVk1NjaKiolq0DwAALlbx899osmb/ksxW/7n+/P32a4aloaFBFRUVSktL+2kHoaFKS0uT0+n0uY3T6fSql6T09HRPvdvt1htvvKErr7xS6enp6tmzp1JSUrRhw4az9lFfX6/a2lqvBQAABC+/AsvRo0fV2Nio2NhYr/WxsbFyuVw+t3G5XOesr66u1okTJ7RkyRJlZGTozTff1B133KE777xTW7du9bnPgoICRUdHe5a4uDh/hgEAANqZgN8l5Ha7JUm33367HnroISUmJmr+/Pm69dZbPaeM/q/c3FzV1NR4loMHD7ZlywAAoI119Kc4JiZGHTp0UFVVldf6qqoqORwOn9s4HI5z1sfExKhjx44aMmSIV83gwYP13nvv+dxneHi4wsPD/WkdAAC0Y37NsISFhSkpKUmlpaWedW63W6WlpUpNTfW5TWpqqle9JG3ZssVTHxYWplGjRmnPnj1eNZ999pn69evnT3sAACBI+TXDIkk5OTmaNGmSRo4cqeTkZK1YsUJ1dXWaPHmyJGnixInq06ePCgoKJEmzZ8/WmDFjtGzZMmVmZqq4uFjbt2/X6tWrPfucO3eusrOzdcMNN+imm25SSUmJXn/9dZWVlbXOKAEAQLvmd2DJzs7WkSNHlJeXJ5fLpcTERJWUlHgurK2srFRo6E8TN6NHj9batWu1cOFCLViwQAMGDNCGDRs0dOhQT80dd9yhoqIiFRQUaNasWRo4cKD+/ve/67rrrmuFIQIAgPbO7+ew2IjnsAAA0HJB9xwWAACAQCCwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWI7AAAADrEVgAAID1CCwAAMB6BBYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOsRWAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgvY6BbgAAcP7i57/RZM3+JZlt0AlwYTDDAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWI7AAAADrEVgAAID1CCwAAMB6BBYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANbrGOgGAOBiEj//jSZr9i/JbINOgPaFGRYAAGA9AgsAALAep4QAwAdO3QB2IbAAAM6pqfBGcENbILAAuGgwawK0X1zDAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWa1FgKSwsVHx8vCIiIpSSkqJt27ads379+vUaNGiQIiIiNGzYMG3evPmstdOmTVNISIhWrFjRktYAIGjEz3+jyQW4WPgdWNatW6ecnBzl5+drx44dSkhIUHp6uqqrq33Wl5eXa/z48ZoyZYp27typrKwsZWVlaffu3WfUvvbaa/rggw/Uu3dv/0cCAACClt+BZfny5Zo6daomT56sIUOGqKioSJGRkXrxxRd91j/zzDPKyMjQ3LlzNXjwYC1evFgjRozQqlWrvOoOHTqkmTNn6q9//as6derUstEAAICg5FdgaWhoUEVFhdLS0n7aQWio0tLS5HQ6fW7jdDq96iUpPT3dq97tdmvChAmaO3eurrrqqib7qK+vV21trdcCAACCl1+B5ejRo2psbFRsbKzX+tjYWLlcLp/buFyuJuv/+Mc/qmPHjpo1a1az+igoKFB0dLRniYuL82cYAACgnQn4XUIVFRV65pln9NJLLykkJKRZ2+Tm5qqmpsazHDx48AJ3CQAAAsmvwBITE6MOHTqoqqrKa31VVZUcDofPbRwOxznr//Wvf6m6ulqXX365OnbsqI4dO+rAgQN6+OGHFR8f73Of4eHhioqK8loAAEDw6uhPcVhYmJKSklRaWqqsrCxJP15/UlpaqhkzZvjcJjU1VaWlpZozZ45n3ZYtW5SamipJmjBhgs9rXCZMmKDJkyf70x4ABERzbi/evySzDToBgpdfgUWScnJyNGnSJI0cOVLJyclasWKF6urqPOFi4sSJ6tOnjwoKCiRJs2fP1pgxY7Rs2TJlZmaquLhY27dv1+rVqyVJ3bt3V/fu3b1+RqdOneRwODRw4MDzHR8AAAgCfgeW7OxsHTlyRHl5eXK5XEpMTFRJSYnnwtrKykqFhv50pmn06NFau3atFi5cqAULFmjAgAHasGGDhg4d2nqjAAAAQc3vwCJJM2bMOOspoLKysjPWjRs3TuPGjWv2/vfv39+StgAAQJAK+F1CAAAATSGwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWa9GTbgEA8KWpF0HyEki0FDMsAADAegQWAABgPU4JAWjXmjoFIXEaAggGzLAAAADrEVgAAID1CCwAAMB6BBYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOsRWAAAgPUILAAAwHoEFgAAYD1efghcYLycDwDOHzMsAADAegQWAABgPQILAACwHoEFAABYj4tugYsAF/4CaO+YYQEAANYjsAAAAOsRWAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA6/FofgBAQDT1ygheF4GfY4YFAABYj8ACAACsR2ABAADWI7AAAADrEVgAAID1CCwAAMB6BBYAAGA9nsMCwEtTz8aQeD4GgLZHYAGAi0x7DKXtsWe0Lk4JAQAA6xFYAACA9TglBLQA70D5EdP0ANoKgaUZ+KUMAEBgcUoIAABYj8ACAACsxykhoJ3iVCWAiwkzLAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA67UosBQWFio+Pl4RERFKSUnRtm3bzlm/fv16DRo0SBERERo2bJg2b97s+d6pU6c0b948DRs2TF26dFHv3r01ceJEHT58uCWtAQCAIOT3g+PWrVunnJwcFRUVKSUlRStWrFB6err27Nmjnj17nlFfXl6u8ePHq6CgQLfeeqvWrl2rrKws7dixQ0OHDtV3332nHTt2aNGiRUpISNC3336r2bNn67bbbtP27dtbZZAAAo8H3QE4H37PsCxfvlxTp07V5MmTNWTIEBUVFSkyMlIvvviiz/pnnnlGGRkZmjt3rgYPHqzFixdrxIgRWrVqlSQpOjpaW7Zs0d13362BAwfqmmuu0apVq1RRUaHKysrzGx0AAAgKfgWWhoYGVVRUKC0t7acdhIYqLS1NTqfT5zZOp9OrXpLS09PPWi9JNTU1CgkJUbdu3Xx+v76+XrW1tV4LAAAIXn6dEjp69KgaGxsVGxvrtT42Nlaffvqpz21cLpfPepfL5bP+5MmTmjdvnsaPH6+oqCifNQUFBXrsscf8aR1oFzhtAgC+WXWX0KlTp3T33XfLGKPnn3/+rHW5ubmqqanxLAcPHmzDLgEAQFvza4YlJiZGHTp0UFVVldf6qqoqORwOn9s4HI5m1Z8OKwcOHNDbb7991tkVSQoPD1d4eLg/rQMALhLMVAYnv2ZYwsLClJSUpNLSUs86t9ut0tJSpaam+twmNTXVq16StmzZ4lV/Oqzs3btXb731lrp37+5PWwAAIMj5fVtzTk6OJk2apJEjRyo5OVkrVqxQXV2dJk+eLEmaOHGi+vTpo4KCAknS7NmzNWbMGC1btkyZmZkqLi7W9u3btXr1akk/hpXf/OY32rFjhzZt2qTGxkbP9S2XXXaZwsLCWmusAACgnfI7sGRnZ+vIkSPKy8uTy+VSYmKiSkpKPBfWVlZWKjT0p4mb0aNHa+3atVq4cKEWLFigAQMGaMOGDRo6dKgk6dChQ9q4caMkKTEx0etnvfPOO7rxxhtbODQAABAs/A4skjRjxgzNmDHD5/fKysrOWDdu3DiNGzfOZ318fLyMMS1pA2hVTZ335pw3AASOVXcJAQAA+NKiGRYAAIIBdxS1H8ywAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsx3NY0K7wzAQAuDgxwwIAAKzHDAuCGu8HAoDgwAwLAACwHjMsAAA0w4W6ho5r85qHGRYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOtxWzOajVvvAACBwgwLAACwHjMsAAAEoWCbFSewIOCC7UMFAGh9BBZcEIQQAEBr4hoWAABgPQILAACwHqeELnKcugGA9uNi/p1NYAEAoJVdzMHiQuGUEAAAsB6BBQAAWI9TQkGIqUgAQLBhhgUAAFiPwAIAAKxHYAEAANYjsAAAAOtx0W0AcXEsAADNwwwLAACwHoEFAABYj8ACAACsR2ABAADW46LbdoILdAEAFzNmWAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHi8/bGW8pBAAgNbHDAsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOsRWAAAgPUILAAAwHoEFgAAYD0CCwAAsF6LAkthYaHi4+MVERGhlJQUbdu27Zz169ev16BBgxQREaFhw4Zp8+bNXt83xigvL0+9evVS586dlZaWpr1797akNQAAEIT8Dizr1q1TTk6O8vPztWPHDiUkJCg9PV3V1dU+68vLyzV+/HhNmTJFO3fuVFZWlrKysrR7925PzdKlS/Xss8+qqKhIH374obp06aL09HSdPHmy5SMDAABBw+/Asnz5ck2dOlWTJ0/WkCFDVFRUpMjISL344os+65955hllZGRo7ty5Gjx4sBYvXqwRI0Zo1apVkn6cXVmxYoUWLlyo22+/XcOHD9fLL7+sw4cPa8OGDec1OAAAEBz8evlhQ0ODKioqlJub61kXGhqqtLQ0OZ1On9s4nU7l5OR4rUtPT/eEkS+++EIul0tpaWme70dHRyslJUVOp1P33HPPGfusr69XfX295+uamhpJUm1trT/DaTZ3/XdN1pz+2dTaU9ucehtqf15Prf+fYRv6Deba5tTbUPvzemrt+Xw2d5/GmKaLjR8OHTpkJJny8nKv9XPnzjXJyck+t+nUqZNZu3at17rCwkLTs2dPY4wx77//vpFkDh8+7FUzbtw4c/fdd/vcZ35+vpHEwsLCwsLCEgTLwYMHm8wgfs2w2CI3N9dr1sbtduubb75R9+7dFRISckF/dm1treLi4nTw4EFFRUVd0J8VCME8PsbWPgXz2KTgHh9ja7/aanzGGB0/fly9e/dustavwBITE6MOHTqoqqrKa31VVZUcDofPbRwOxznrT/+zqqpKvXr18qpJTEz0uc/w8HCFh4d7revWrZs/QzlvUVFRQfkf6WnBPD7G1j4F89ik4B4fY2u/2mJ80dHRzarz66LbsLAwJSUlqbS01LPO7XartLRUqampPrdJTU31qpekLVu2eOr79+8vh8PhVVNbW6sPP/zwrPsEAAAXF79PCeXk5GjSpEkaOXKkkpOTtWLFCtXV1Wny5MmSpIkTJ6pPnz4qKCiQJM2ePVtjxozRsmXLlJmZqeLiYm3fvl2rV6+WJIWEhGjOnDn6/e9/rwEDBqh///5atGiRevfuraysrNYbKQAAaLf8DizZ2dk6cuSI8vLy5HK5lJiYqJKSEsXGxkqSKisrFRr608TN6NGjtXbtWi1cuFALFizQgAEDtGHDBg0dOtRT88gjj6iurk7333+/jh07puuuu04lJSWKiIhohSG2rvDwcOXn559xSipYBPP4GFv7FMxjk4J7fIyt/bJxfCHGNOdeIgAAgMDhXUIAAMB6BBYAAGA9AgsAALAegQUAAFiPwNJM+/fv15QpU9S/f3917txZv/jFL5Sfn6+Ghgavuo8++kjXX3+9IiIiFBcXp6VLlwaoY/898cQTGj16tCIjI8/6IL6QkJAzluLi4rZttAWaM7bKykplZmYqMjJSPXv21Ny5c/XDDz+0baOtID4+/oxjtGTJkkC31WKFhYWKj49XRESEUlJStG3btkC3dN4effTRM47RoEGDAt1Wi7377rv69a9/rd69eyskJOSMF9caY5SXl6devXqpc+fOSktL0969ewPTrJ+aGtt99913xrHMyMgITLN+Kigo0KhRo9S1a1f17NlTWVlZ2rNnj1fNyZMnNX36dHXv3l2XXHKJ7rrrrjMeBttWCCzN9Omnn8rtduuFF17Qxx9/rKefflpFRUVasGCBp6a2tlZjx45Vv379VFFRoSeffFKPPvqo55kztmtoaNC4ceP0wAMPnLNuzZo1+uqrrzxLe3heTlNja2xsVGZmphoaGlReXq4///nPeumll5SXl9fGnbaOxx9/3OsYzZw5M9Attci6deuUk5Oj/Px87dixQwkJCUpPT1d1dXWgWztvV111ldcxeu+99wLdUovV1dUpISFBhYWFPr+/dOlSPfvssyoqKtKHH36oLl26KD09XSdPnmzjTv3X1NgkKSMjw+tYvvrqq23YYctt3bpV06dP1wcffKAtW7bo1KlTGjt2rOrq6jw1Dz30kF5//XWtX79eW7du1eHDh3XnnXcGpuEm3zaEs1q6dKnp37+/5+vnnnvOXHrppaa+vt6zbt68eWbgwIGBaK/F1qxZY6Kjo31+T5J57bXX2rSf1nS2sW3evNmEhoYal8vlWff888+bqKgor+PZHvTr1888/fTTgW6jVSQnJ5vp06d7vm5sbDS9e/c2BQUFAezq/OXn55uEhIRAt3FB/N/fEW632zgcDvPkk0961h07dsyEh4ebV199NQAdtpyv33+TJk0yt99+e0D6aW3V1dVGktm6dasx5sfj1KlTJ7N+/XpPzSeffGIkGafT2eb9McNyHmpqanTZZZd5vnY6nbrhhhsUFhbmWZeenq49e/bo22+/DUSLF8T06dMVExOj5ORkvfjii817LbjlnE6nhg0b5nkAovTjsautrdXHH38cwM5aZsmSJerevbuuvvpqPfnkk+3y1FZDQ4MqKiqUlpbmWRcaGqq0tDQ5nc4AdtY69u7dq969e+uKK67Qvffeq8rKykC3dEF88cUXcrlcXscxOjpaKSkpQXEcJamsrEw9e/bUwIED9cADD+jrr78OdEstUlNTI0mev2sVFRU6deqU17EbNGiQLr/88oAcu3b5tmYb7Nu3TytXrtRTTz3lWedyudS/f3+vutN/AF0uly699NI27fFCePzxx/WrX/1KkZGRevPNN/Xggw/qxIkTmjVrVqBbOy8ul8srrEjex649mTVrlkaMGKHLLrtM5eXlys3N1VdffaXly5cHujW/HD16VI2NjT6Py6effhqgrlpHSkqKXnrpJQ0cOFBfffWVHnvsMV1//fXavXu3unbtGuj2WtXpz4+v49jePlu+ZGRk6M4771T//v31+eefa8GCBbrlllvkdDrVoUOHQLfXbG63W3PmzNG1117reRK9y+VSWFjYGdf9BerYXfQzLPPnz/d5IenPl//7y/HQoUPKyMjQuHHjNHXq1AB13jwtGd+5LFq0SNdee62uvvpqzZs3T4888oiefPLJCziCs2vtsdnMn7Hm5OToxhtv1PDhwzVt2jQtW7ZMK1euVH19fYBHgdNuueUWjRs3TsOHD1d6ero2b96sY8eO6W9/+1ugW4Of7rnnHt12220aNmyYsrKytGnTJv373/9WWVlZoFvzy/Tp07V7926rb6K46GdYHn74Yd13333nrLniiis8/3748GHddNNNGj169BkX0zocjjOunj79tcPhaJ2G/eTv+PyVkpKixYsXq76+vs3fOdGaY3M4HGfcfRLoY/dz5zPWlJQU/fDDD9q/f78GDhx4Abq7MGJiYtShQwefnykbjklr6tatm6688krt27cv0K20utPHqqqqSr169fKsr6qqUmJiYoC6unCuuOIKxcTEaN++fbr55psD3U6zzJgxQ5s2bdK7776rvn37etY7HA41NDTo2LFjXrMsgfoMXvSBpUePHurRo0ezag8dOqSbbrpJSUlJWrNmjddLHiUpNTVVv/vd73Tq1Cl16tRJkrRlyxYNHDgwYKeD/BlfS+zatUuXXnppQF6Q1ZpjS01N1RNPPKHq6mr17NlT0o/HLioqSkOGDGmVn3E+zmesu3btUmhoqGdc7UVYWJiSkpJUWlrquRPN7XartLRUM2bMCGxzrezEiRP6/PPPNWHChEC30ur69+8vh8Oh0tJST0Cpra3Vhx9+2OQdie3Rl19+qa+//tornNnKGKOZM2fqtddeU1lZ2RmXNCQlJalTp04qLS3VXXfdJUnas2ePKisrlZqaGpCG0Qxffvml+eUvf2luvvlm8+WXX5qvvvrKs5x27NgxExsbayZMmGB2795tiouLTWRkpHnhhRcC2HnzHThwwOzcudM89thj5pJLLjE7d+40O3fuNMePHzfGGLNx40bzpz/9yfz3v/81e/fuNc8995yJjIw0eXl5Ae68aU2N7YcffjBDhw41Y8eONbt27TIlJSWmR48eJjc3N8Cd+6e8vNw8/fTTZteuXebzzz83r7zyiunRo4eZOHFioFtrkeLiYhMeHm5eeukl87///c/cf//9plu3bl53c7VHDz/8sCkrKzNffPGFef/9901aWpqJiYkx1dXVgW6tRY4fP+75TEkyy5cvNzt37jQHDhwwxhizZMkS061bN/OPf/zDfPTRR+b22283/fv3N99//32AO2/aucZ2/Phx89vf/tY4nU7zxRdfmLfeesuMGDHCDBgwwJw8eTLQrTfpgQceMNHR0aasrMzrb9p3333nqZk2bZq5/PLLzdtvv222b99uUlNTTWpqakD6JbA005o1a4wkn8vP/ec//zHXXXedCQ8PN3369DFLliwJUMf+mzRpks/xvfPOO8YYY/75z3+axMREc8kll5guXbqYhIQEU1RUZBobGwPbeDM0NTZjjNm/f7+55ZZbTOfOnU1MTIx5+OGHzalTpwLXdAtUVFSYlJQUEx0dbSIiIszgwYPNH/7wh3bxy/NsVq5caS6//HITFhZmkpOTzQcffBDols5bdna26dWrlwkLCzN9+vQx2dnZZt++fYFuq8Xeeecdn5+vSZMmGWN+vLV50aJFJjY21oSHh5ubb77Z7NmzJ7BNN9O5xvbdd9+ZsWPHmh49ephOnTqZfv36malTp7abQH22v2lr1qzx1Hz//ffmwQcfNJdeeqmJjIw0d9xxh9f/qLelkP/fNAAAgLUu+ruEAACA/QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWI7AAAADrEVgAAID1CCwAAMB6BBYAAGA9AgsAALDe/wOjg8AYvXRe+QAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Change player and age tokens here!\n",
    "# You can find these values in player_tokens.csv and age_tokens.csv\n",
    "# You must provide exactly 8 player tokens and 8 age tokens for each team.\n",
    "\n",
    "# Denver Nuggets first game of 2023-24 season roster\n",
    "home_player_tokens = [5035, 4298, 4626, 4690, 4750, 5082, 4286, 4311]\n",
    "home_age_tokens = [14, 16, 13, 12, 10, 19, 8, 8]\n",
    "\n",
    "# Uncomment to take Jokic off team, replace with Peyton Watson\n",
    "# home_player_tokens = [4331, 4298, 4626, 4690, 4750, 5082, 4286, 4311]\n",
    "# home_age_tokens = [6, 16, 13, 12, 10, 19, 8, 8]\n",
    "\n",
    "# Boston Celtics final game of 2023-24 season roster\n",
    "away_player_tokens = [5042, 5039, 5027, 4981, 4972, 5004, 4416, 4983]\n",
    "away_age_tokens = [11, 12, 19, 14, 23, 11, 13, 13]\n",
    "\n",
    "# Uncomment to take Tatum off team, replace with Pritchard\n",
    "# away_player_tokens = [4999, 5039, 5027, 4981, 4972, 5004, 4416, 4983]\n",
    "# away_age_tokens = [11, 12, 19, 14, 23, 11, 13, 13]\n",
    "\n",
    "# The model usually gives the home team a bump in win probability.\n",
    "# Change this to \"True\" to swap home and away teams.\n",
    "swap_home_away = False\n",
    "if swap_home_away:\n",
    "    home_player_tokens, away_player_tokens = away_player_tokens, home_player_tokens\n",
    "    home_age_tokens, away_age_tokens = away_age_tokens, home_age_tokens\n",
    "\n",
    "assert len(home_player_tokens) == players_per_team\n",
    "assert len(home_age_tokens) == players_per_team\n",
    "assert len(away_player_tokens) == players_per_team\n",
    "assert len(away_age_tokens) == players_per_team\n",
    "\n",
    "batch = {\n",
    "    'home_player_tokens': Tensor([num_player_tokens+1] + home_player_tokens).to(dtype=int32).unsqueeze(0),\n",
    "    'home_age_tokens': Tensor([num_age_tokens+1] + home_age_tokens).to(dtype=int32).unsqueeze(0),\n",
    "    'away_player_tokens': Tensor(away_player_tokens).to(dtype=int32).unsqueeze(0),\n",
    "    'away_age_tokens': Tensor(away_age_tokens).to(dtype=int32).unsqueeze(0),\n",
    "}\n",
    "\n",
    "for key, value in batch.items():\n",
    "    if hasattr(value, 'to'):\n",
    "        batch[key] = value.to(device)\n",
    "\n",
    "output, _ = model(**batch)\n",
    "output = output.squeeze().softmax(dim=0)\n",
    "\n",
    "probs = {}\n",
    "loss_prob = 0\n",
    "win_prob = 0\n",
    "\n",
    "first = True\n",
    "for i, token in enumerate(output):\n",
    "    if first:\n",
    "        first = False\n",
    "        continue\n",
    "\n",
    "    if i-21 < 0:\n",
    "        loss_prob += token.item()\n",
    "    elif i-21 > 0 and i-21 < 21:\n",
    "        win_prob += token.item()\n",
    "\n",
    "    probs[i-21] = token.item()\n",
    "\n",
    "del probs[0]\n",
    "del probs[21]\n",
    "\n",
    "print(f\"Home team win probability: {win_prob:.2f}\")\n",
    "\n",
    "plt.bar(probs.keys(), probs.values())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "nba",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}