File size: 17,688 Bytes
5f26252 d81c324 5f26252 d81c324 5f26252 d81c324 5f26252 d81c324 5f26252 d81c324 5f26252 7a4fc48 21e77ce 7a4fc48 5f26252 |
|
{
"cells": [
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"from model import NBAModel, NBAConfig\n",
"from torch import device as torch_device, load as torch_load, int32, Tensor, bfloat16\n",
"import matplotlib.pyplot as plt\n",
"\n",
"device = torch_device(\"cpu\") \n",
"num_age_tokens=32\n",
"num_player_tokens=5141\n",
"num_net_score_tokens=41\n",
"players_per_team=8\n",
"\n",
"model_config = NBAConfig(\n",
" players_per_team=players_per_team,\n",
" player_tokens=num_player_tokens+2,\n",
" age_tokens=num_age_tokens+2,\n",
" num_labels=num_net_score_tokens+2,\n",
" n_layer=4,\n",
" n_head=4,\n",
" n_embd=1024,\n",
" dropout=0.0,\n",
" bias=False,\n",
" dtype=bfloat16,\n",
" seed=29,\n",
")\n",
"\n",
"model = NBAModel(model_config).to(device)\n",
"state_dict = torch_load('weights.pt', map_location='cpu')\n",
"model.load_state_dict(state_dict)\n",
"model = model.eval()"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Home team win probability: 0.65\n"
]
},
{
"data": {
"text/plain": [
"<BarContainer object of 40 artists>"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAGdCAYAAAAxCSikAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAjzklEQVR4nO3dfVCVdf7/8ReogJhgiXK8IbFd8yYVEoWwG2tjhIa2qNbIadQcx8byNlpTXIXKbXEtzVKK3Jls29ZwnZ1cM5fJKGyLU66o29qWaZNi2gGtBKUE43x+f/Tz1Pl6FA4i58Px+Zi5prh4Xxfvz1wdePW57kKMMUYAAAAWCw10AwAAAE0hsAAAAOsRWAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArNcx0A20BrfbrcOHD6tr164KCQkJdDsAAKAZjDE6fvy4evfurdDQc8+hBEVgOXz4sOLi4gLdBgAAaIGDBw+qb9++56wJisDStWtXST8OOCoqKsDdAACA5qitrVVcXJzn7/i5BEVgOX0aKCoqisACAEA705zLObjoFgAAWI/AAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWI7AAAADrEVgAAID1CCwAAMB6HQPdAAAACKz4+W80WbN/SWYbdHJ2zLAAAADrEVgAAID1CCwAAMB6BBYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOsRWAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWI7AAAADrEVgAAID1CCwAAMB6BBYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOsRWAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwXosCS2FhoeLj4xUREaGUlBRt27btnPXr16/XoEGDFBERoWHDhmnz5s1e3z9x4oRmzJihvn37qnPnzhoyZIiKiopa0hoAAAhCfgeWdevWKScnR/n5+dqxY4cSEhKUnp6u6upqn/Xl5eUaP368pkyZop07dyorK0tZWVnavXu3pyYnJ0clJSV65ZVX9Mknn2jOnDmaMWOGNm7c2PKRAQCAoBFijDH+bJCSkqJRo0Zp1apVkiS32624uDjNnDlT8+fPP6M+OztbdXV12rRpk2fdNddco8TERM8sytChQ5Wdna1FixZ5apKSknTLLbfo97//fZM91dbWKjo6WjU1NYqKivJnOAAAXPTi57/RZM3+JZmt/nP9+fvt1wxLQ0ODKioqlJaW9tMOQkOVlpYmp9Ppcxun0+lVL0np6ele9aNHj9bGjRt16NAhGWP0zjvv6LPPPtPYsWN97rO+vl61tbVeCwAACF5+BZajR4+qsbFRsbGxXutjY2Plcrl8buNyuZqsX7lypYYMGaK+ffsqLCxMGRkZKiws1A033OBznwUFBYqOjvYscXFx/gwDAAC0M1bcJbRy5Up98MEH2rhxoyoqKrRs2TJNnz5db731ls/63Nxc1dTUeJaDBw+2cccAAKAtdfSnOCYmRh06dFBVVZXX+qqqKjkcDp/bOByOc9Z///33WrBggV577TVlZv54fmz48OHatWuXnnrqqTNOJ0lSeHi4wsPD/WkdAAC0Y37NsISFhSkpKUmlpaWedW63W6WlpUpNTfW5TWpqqle9JG3ZssVTf+rUKZ06dUqhod6tdOjQQW6325/2AABAkPJrhkX68RbkSZMmaeTIkUpOTtaKFStUV1enyZMnS5ImTpyoPn36qKCgQJI0e/ZsjRkzRsuWLVNmZqaKi4u1fft2rV69WpIUFRWlMWPGaO7cuercubP69eunrVu36uWXX9by5ctbcagAAKC98juwZGdn68iRI8rLy5PL5VJiYqJKSko8F9ZWVlZ6zZaMHj1aa9eu1cKFC7VgwQINGDBAGzZs0NChQz01xcXFys3N1b333qtvvvlG/fr10xNPPKFp06a1whABAEB75/dzWGzEc1gAAGi5oHsOCwAAQCAQWAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWI7AAAADrEVgAAID1CCwAAMB6BBYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOsRWAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWI7AAAADrEVgAAID1CCwAAMB6BBYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOsRWAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWI7AAAADrtSiwFBYWKj4+XhEREUpJSdG2bdvOWb9+/XoNGjRIERERGjZsmDZv3nxGzSeffKLbbrtN0dHR6tKli0aNGqXKysqWtAcAAIKM34Fl3bp1ysnJUX5+vnbs2KGEhASlp6erurraZ315ebnGjx+vKVOmaOfOncrKylJWVpZ2797tqfn888913XXXadCgQSorK9NHH32kRYsWKSIiouUjAwAAQSPEGGP82SAlJUWjRo3SqlWrJElut1txcXGaOXOm5s+ff0Z9dna26urqtGnTJs+6a665RomJiSoqKpIk3XPPPerUqZP+8pe/tGgQtbW1io6OVk1NjaKiolq0DwAALlbx899osmb/ksxW/7n+/P32a4aloaFBFRUVSktL+2kHoaFKS0uT0+n0uY3T6fSql6T09HRPvdvt1htvvKErr7xS6enp6tmzp1JSUrRhw4az9lFfX6/a2lqvBQAABC+/AsvRo0fV2Nio2NhYr/WxsbFyuVw+t3G5XOesr66u1okTJ7RkyRJlZGTozTff1B133KE777xTW7du9bnPgoICRUdHe5a4uDh/hgEAANqZgN8l5Ha7JUm33367HnroISUmJmr+/Pm69dZbPaeM/q/c3FzV1NR4loMHD7ZlywAAoI119Kc4JiZGHTp0UFVVldf6qqoqORwOn9s4HI5z1sfExKhjx44aMmSIV83gwYP13nvv+dxneHi4wsPD/WkdAAC0Y37NsISFhSkpKUmlpaWedW63W6WlpUpNTfW5TWpqqle9JG3ZssVTHxYWplGjRmnPnj1eNZ999pn69evnT3sAACBI+TXDIkk5OTmaNGmSRo4cqeTkZK1YsUJ1dXWaPHmyJGnixInq06ePCgoKJEmzZ8/WmDFjtGzZMmVmZqq4uFjbt2/X6tWrPfucO3eusrOzdcMNN+imm25SSUmJXn/9dZWVlbXOKAEAQLvmd2DJzs7WkSNHlJeXJ5fLpcTERJWUlHgurK2srFRo6E8TN6NHj9batWu1cOFCLViwQAMGDNCGDRs0dOhQT80dd9yhoqIiFRQUaNasWRo4cKD+/ve/67rrrmuFIQIAgPbO7+ew2IjnsAAA0HJB9xwWAACAQCCwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWI7AAAADrEVgAAID1CCwAAMB6BBYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOsRWAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgvY6BbgAAcP7i57/RZM3+JZlt0AlwYTDDAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWI7AAAADrEVgAAID1CCwAAMB6BBYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANbrGOgGAOBiEj//jSZr9i/JbINOgPaFGRYAAGA9AgsAALAep4QAwAdO3QB2IbAAAM6pqfBGcENbILAAuGgwawK0X1zDAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWa1FgKSwsVHx8vCIiIpSSkqJt27ads379+vUaNGiQIiIiNGzYMG3evPmstdOmTVNISIhWrFjRktYAIGjEz3+jyQW4WPgdWNatW6ecnBzl5+drx44dSkhIUHp6uqqrq33Wl5eXa/z48ZoyZYp27typrKwsZWVlaffu3WfUvvbaa/rggw/Uu3dv/0cCAACClt+BZfny5Zo6daomT56sIUOGqKioSJGRkXrxxRd91j/zzDPKyMjQ3LlzNXjwYC1evFgjRozQqlWrvOoOHTqkmTNn6q9//as6derUstEAAICg5FdgaWhoUEVFhdLS0n7aQWio0tLS5HQ6fW7jdDq96iUpPT3dq97tdmvChAmaO3eurrrqqib7qK+vV21trdcCAACCl1+B5ejRo2psbFRsbKzX+tjYWLlcLp/buFyuJuv/+Mc/qmPHjpo1a1az+igoKFB0dLRniYuL82cYAACgnQn4XUIVFRV65pln9NJLLykkJKRZ2+Tm5qqmpsazHDx48AJ3CQAAAsmvwBITE6MOHTqoqqrKa31VVZUcDofPbRwOxznr//Wvf6m6ulqXX365OnbsqI4dO+rAgQN6+OGHFR8f73Of4eHhioqK8loAAEDw6uhPcVhYmJKSklRaWqqsrCxJP15/UlpaqhkzZvjcJjU1VaWlpZozZ45n3ZYtW5SamipJmjBhgs9rXCZMmKDJkyf70x4ABERzbi/evySzDToBgpdfgUWScnJyNGnSJI0cOVLJyclasWKF6urqPOFi4sSJ6tOnjwoKCiRJs2fP1pgxY7Rs2TJlZmaquLhY27dv1+rVqyVJ3bt3V/fu3b1+RqdOneRwODRw4MDzHR8AAAgCfgeW7OxsHTlyRHl5eXK5XEpMTFRJSYnnwtrKykqFhv50pmn06NFau3atFi5cqAULFmjAgAHasGGDhg4d2nqjAAAAQc3vwCJJM2bMOOspoLKysjPWjRs3TuPGjWv2/vfv39+StgAAQJAK+F1CAAAATSGwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWa9GTbgEA8KWpF0HyEki0FDMsAADAegQWAABgPU4JAWjXmjoFIXEaAggGzLAAAADrEVgAAID1CCwAAMB6BBYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOsRWAAAgPUILAAAwHoEFgAAYD1efghcYLycDwDOHzMsAADAegQWAABgPQILAACwHoEFAABYj4tugYsAF/4CaO+YYQEAANYjsAAAAOsRWAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA6/FofgBAQDT1ygheF4GfY4YFAABYj8ACAACsR2ABAADWI7AAAADrEVgAAID1CCwAAMB6BBYAAGA9nsMCwEtTz8aQeD4GgLZHYAGAi0x7DKXtsWe0Lk4JAQAA6xFYAACA9TglBLQA70D5EdP0ANoKgaUZ+KUMAEBgcUoIAABYj8ACAACsxykhoJ3iVCWAiwkzLAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA67UosBQWFio+Pl4RERFKSUnRtm3bzlm/fv16DRo0SBERERo2bJg2b97s+d6pU6c0b948DRs2TF26dFHv3r01ceJEHT58uCWtAQCAIOT3g+PWrVunnJwcFRUVKSUlRStWrFB6err27Nmjnj17nlFfXl6u8ePHq6CgQLfeeqvWrl2rrKws7dixQ0OHDtV3332nHTt2aNGiRUpISNC3336r2bNn67bbbtP27dtbZZAAAo8H3QE4H37PsCxfvlxTp07V5MmTNWTIEBUVFSkyMlIvvviiz/pnnnlGGRkZmjt3rgYPHqzFixdrxIgRWrVqlSQpOjpaW7Zs0d13362BAwfqmmuu0apVq1RRUaHKysrzGx0AAAgKfgWWhoYGVVRUKC0t7acdhIYqLS1NTqfT5zZOp9OrXpLS09PPWi9JNTU1CgkJUbdu3Xx+v76+XrW1tV4LAAAIXn6dEjp69KgaGxsVGxvrtT42Nlaffvqpz21cLpfPepfL5bP+5MmTmjdvnsaPH6+oqCifNQUFBXrsscf8aR1oFzhtAgC+WXWX0KlTp3T33XfLGKPnn3/+rHW5ubmqqanxLAcPHmzDLgEAQFvza4YlJiZGHTp0UFVVldf6qqoqORwOn9s4HI5m1Z8OKwcOHNDbb7991tkVSQoPD1d4eLg/rQMALhLMVAYnv2ZYwsLClJSUpNLSUs86t9ut0tJSpaam+twmNTXVq16StmzZ4lV/Oqzs3btXb731lrp37+5PWwAAIMj5fVtzTk6OJk2apJEjRyo5OVkrVqxQXV2dJk+eLEmaOHGi+vTpo4KCAknS7NmzNWbMGC1btkyZmZkqLi7W9u3btXr1akk/hpXf/OY32rFjhzZt2qTGxkbP9S2XXXaZwsLCWmusAACgnfI7sGRnZ+vIkSPKy8uTy+VSYmKiSkpKPBfWVlZWKjT0p4mb0aNHa+3atVq4cKEWLFigAQMGaMOGDRo6dKgk6dChQ9q4caMkKTEx0etnvfPOO7rxxhtbODQAABAs/A4skjRjxgzNmDHD5/fKysrOWDdu3DiNGzfOZ318fLyMMS1pA2hVTZ335pw3AASOVXcJAQAA+NKiGRYAAIIBdxS1H8ywAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsx3NY0K7wzAQAuDgxwwIAAKzHDAuCGu8HAoDgwAwLAACwHjMsAAA0w4W6ho5r85qHGRYAAGA9AgsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOtxWzOajVvvAACBwgwLAACwHjMsAAAEoWCbFSewIOCC7UMFAGh9BBZcEIQQAEBr4hoWAABgPQILAACwHqeELnKcugGA9uNi/p1NYAEAoJVdzMHiQuGUEAAAsB6BBQAAWI9TQkGIqUgAQLBhhgUAAFiPwAIAAKxHYAEAANYjsAAAAOtx0W0AcXEsAADNwwwLAACwHoEFAABYj8ACAACsR2ABAADW46LbdoILdAEAFzNmWAAAgPUILAAAwHoEFgAAYD0CCwAAsB6BBQAAWI/AAgAArEdgAQAA1iOwAAAA6xFYAACA9QgsAADAegQWAABgPQILAACwHi8/bGW8pBAAgNbHDAsAALAegQUAAFiPwAIAAKxHYAEAANYjsAAAAOsRWAAAgPUILAAAwHoEFgAAYD0CCwAAsF6LAkthYaHi4+MVERGhlJQUbdu27Zz169ev16BBgxQREaFhw4Zp8+bNXt83xigvL0+9evVS586dlZaWpr1797akNQAAEIT8Dizr1q1TTk6O8vPztWPHDiUkJCg9PV3V1dU+68vLyzV+/HhNmTJFO3fuVFZWlrKysrR7925PzdKlS/Xss8+qqKhIH374obp06aL09HSdPHmy5SMDAABBw+/Asnz5ck2dOlWTJ0/WkCFDVFRUpMjISL344os+65955hllZGRo7ty5Gjx4sBYvXqwRI0Zo1apVkn6cXVmxYoUWLlyo22+/XcOHD9fLL7+sw4cPa8OGDec1OAAAEBz8evlhQ0ODKioqlJub61kXGhqqtLQ0OZ1On9s4nU7l5OR4rUtPT/eEkS+++EIul0tpaWme70dHRyslJUVOp1P33HPPGfusr69XfX295+uamhpJUm1trT/DaTZ3/XdN1pz+2dTaU9ucehtqf15Prf+fYRv6Deba5tTbUPvzemrt+Xw2d5/GmKaLjR8OHTpkJJny8nKv9XPnzjXJyck+t+nUqZNZu3at17rCwkLTs2dPY4wx77//vpFkDh8+7FUzbtw4c/fdd/vcZ35+vpHEwsLCwsLCEgTLwYMHm8wgfs2w2CI3N9dr1sbtduubb75R9+7dFRISckF/dm1treLi4nTw4EFFRUVd0J8VCME8PsbWPgXz2KTgHh9ja7/aanzGGB0/fly9e/dustavwBITE6MOHTqoqqrKa31VVZUcDofPbRwOxznrT/+zqqpKvXr18qpJTEz0uc/w8HCFh4d7revWrZs/QzlvUVFRQfkf6WnBPD7G1j4F89ik4B4fY2u/2mJ80dHRzarz66LbsLAwJSUlqbS01LPO7XartLRUqampPrdJTU31qpekLVu2eOr79+8vh8PhVVNbW6sPP/zwrPsEAAAXF79PCeXk5GjSpEkaOXKkkpOTtWLFCtXV1Wny5MmSpIkTJ6pPnz4qKCiQJM2ePVtjxozRsmXLlJmZqeLiYm3fvl2rV6+WJIWEhGjOnDn6/e9/rwEDBqh///5atGiRevfuraysrNYbKQAAaLf8DizZ2dk6cuSI8vLy5HK5lJiYqJKSEsXGxkqSKisrFRr608TN6NGjtXbtWi1cuFALFizQgAEDtGHDBg0dOtRT88gjj6iurk7333+/jh07puuuu04lJSWKiIhohSG2rvDwcOXn559xSipYBPP4GFv7FMxjk4J7fIyt/bJxfCHGNOdeIgAAgMDhXUIAAMB6BBYAAGA9AgsAALAegQUAAFiPwNJM+/fv15QpU9S/f3917txZv/jFL5Sfn6+Ghgavuo8++kjXX3+9IiIiFBcXp6VLlwaoY/898cQTGj16tCIjI8/6IL6QkJAzluLi4rZttAWaM7bKykplZmYqMjJSPXv21Ny5c/XDDz+0baOtID4+/oxjtGTJkkC31WKFhYWKj49XRESEUlJStG3btkC3dN4effTRM47RoEGDAt1Wi7377rv69a9/rd69eyskJOSMF9caY5SXl6devXqpc+fOSktL0969ewPTrJ+aGtt99913xrHMyMgITLN+Kigo0KhRo9S1a1f17NlTWVlZ2rNnj1fNyZMnNX36dHXv3l2XXHKJ7rrrrjMeBttWCCzN9Omnn8rtduuFF17Qxx9/rKefflpFRUVasGCBp6a2tlZjx45Vv379VFFRoSeffFKPPvqo55kztmtoaNC4ceP0wAMPnLNuzZo1+uqrrzxLe3heTlNja2xsVGZmphoaGlReXq4///nPeumll5SXl9fGnbaOxx9/3OsYzZw5M9Attci6deuUk5Oj/Px87dixQwkJCUpPT1d1dXWgWztvV111ldcxeu+99wLdUovV1dUpISFBhYWFPr+/dOlSPfvssyoqKtKHH36oLl26KD09XSdPnmzjTv3X1NgkKSMjw+tYvvrqq23YYctt3bpV06dP1wcffKAtW7bo1KlTGjt2rOrq6jw1Dz30kF5//XWtX79eW7du1eHDh3XnnXcGpuEm3zaEs1q6dKnp37+/5+vnnnvOXHrppaa+vt6zbt68eWbgwIGBaK/F1qxZY6Kjo31+T5J57bXX2rSf1nS2sW3evNmEhoYal8vlWff888+bqKgor+PZHvTr1888/fTTgW6jVSQnJ5vp06d7vm5sbDS9e/c2BQUFAezq/OXn55uEhIRAt3FB/N/fEW632zgcDvPkk0961h07dsyEh4ebV199NQAdtpyv33+TJk0yt99+e0D6aW3V1dVGktm6dasx5sfj1KlTJ7N+/XpPzSeffGIkGafT2eb9McNyHmpqanTZZZd5vnY6nbrhhhsUFhbmWZeenq49e/bo22+/DUSLF8T06dMVExOj5ORkvfjii817LbjlnE6nhg0b5nkAovTjsautrdXHH38cwM5aZsmSJerevbuuvvpqPfnkk+3y1FZDQ4MqKiqUlpbmWRcaGqq0tDQ5nc4AdtY69u7dq969e+uKK67Qvffeq8rKykC3dEF88cUXcrlcXscxOjpaKSkpQXEcJamsrEw9e/bUwIED9cADD+jrr78OdEstUlNTI0mev2sVFRU6deqU17EbNGiQLr/88oAcu3b5tmYb7Nu3TytXrtRTTz3lWedyudS/f3+vutN/AF0uly699NI27fFCePzxx/WrX/1KkZGRevPNN/Xggw/qxIkTmjVrVqBbOy8ul8srrEjex649mTVrlkaMGKHLLrtM5eXlys3N1VdffaXly5cHujW/HD16VI2NjT6Py6effhqgrlpHSkqKXnrpJQ0cOFBfffWVHnvsMV1//fXavXu3unbtGuj2WtXpz4+v49jePlu+ZGRk6M4771T//v31+eefa8GCBbrlllvkdDrVoUOHQLfXbG63W3PmzNG1117reRK9y+VSWFjYGdf9BerYXfQzLPPnz/d5IenPl//7y/HQoUPKyMjQuHHjNHXq1AB13jwtGd+5LFq0SNdee62uvvpqzZs3T4888oiefPLJCziCs2vtsdnMn7Hm5OToxhtv1PDhwzVt2jQtW7ZMK1euVH19fYBHgdNuueUWjRs3TsOHD1d6ero2b96sY8eO6W9/+1ugW4Of7rnnHt12220aNmyYsrKytGnTJv373/9WWVlZoFvzy/Tp07V7926rb6K46GdYHn74Yd13333nrLniiis8/3748GHddNNNGj169BkX0zocjjOunj79tcPhaJ2G/eTv+PyVkpKixYsXq76+vs3fOdGaY3M4HGfcfRLoY/dz5zPWlJQU/fDDD9q/f78GDhx4Abq7MGJiYtShQwefnykbjklr6tatm6688krt27cv0K20utPHqqqqSr169fKsr6qqUmJiYoC6unCuuOIKxcTEaN++fbr55psD3U6zzJgxQ5s2bdK7776rvn37etY7HA41NDTo2LFjXrMsgfoMXvSBpUePHurRo0ezag8dOqSbbrpJSUlJWrNmjddLHiUpNTVVv/vd73Tq1Cl16tRJkrRlyxYNHDgwYKeD/BlfS+zatUuXXnppQF6Q1ZpjS01N1RNPPKHq6mr17NlT0o/HLioqSkOGDGmVn3E+zmesu3btUmhoqGdc7UVYWJiSkpJUWlrquRPN7XartLRUM2bMCGxzrezEiRP6/PPPNWHChEC30ur69+8vh8Oh0tJST0Cpra3Vhx9+2OQdie3Rl19+qa+//tornNnKGKOZM2fqtddeU1lZ2RmXNCQlJalTp04qLS3VXXfdJUnas2ePKisrlZqaGpCG0Qxffvml+eUvf2luvvlm8+WXX5qvvvrKs5x27NgxExsbayZMmGB2795tiouLTWRkpHnhhRcC2HnzHThwwOzcudM89thj5pJLLjE7d+40O3fuNMePHzfGGLNx40bzpz/9yfz3v/81e/fuNc8995yJjIw0eXl5Ae68aU2N7YcffjBDhw41Y8eONbt27TIlJSWmR48eJjc3N8Cd+6e8vNw8/fTTZteuXebzzz83r7zyiunRo4eZOHFioFtrkeLiYhMeHm5eeukl87///c/cf//9plu3bl53c7VHDz/8sCkrKzNffPGFef/9901aWpqJiYkx1dXVgW6tRY4fP+75TEkyy5cvNzt37jQHDhwwxhizZMkS061bN/OPf/zDfPTRR+b22283/fv3N99//32AO2/aucZ2/Phx89vf/tY4nU7zxRdfmLfeesuMGDHCDBgwwJw8eTLQrTfpgQceMNHR0aasrMzrb9p3333nqZk2bZq5/PLLzdtvv222b99uUlNTTWpqakD6JbA005o1a4wkn8vP/ec//zHXXXedCQ8PN3369DFLliwJUMf+mzRpks/xvfPOO8YYY/75z3+axMREc8kll5guXbqYhIQEU1RUZBobGwPbeDM0NTZjjNm/f7+55ZZbTOfOnU1MTIx5+OGHzalTpwLXdAtUVFSYlJQUEx0dbSIiIszgwYPNH/7wh3bxy/NsVq5caS6//HITFhZmkpOTzQcffBDols5bdna26dWrlwkLCzN9+vQx2dnZZt++fYFuq8Xeeecdn5+vSZMmGWN+vLV50aJFJjY21oSHh5ubb77Z7NmzJ7BNN9O5xvbdd9+ZsWPHmh49ephOnTqZfv36malTp7abQH22v2lr1qzx1Hz//ffmwQcfNJdeeqmJjIw0d9xxh9f/qLelkP/fNAAAgLUu+ruEAACA/QgsAADAegQWAABgPQILAACwHoEFAABYj8ACAACsR2ABAADWI7AAAADrEVgAAID1CCwAAMB6BBYAAGA9AgsAALDe/wOjg8AYvXRe+QAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Change player and age tokens here!\n",
"# You can find these values in player_tokens.csv and age_tokens.csv\n",
"# You must provide exactly 8 player tokens and 8 age tokens for each team.\n",
"\n",
"# Denver Nuggets first game of 2023-24 season roster\n",
"home_player_tokens = [5035, 4298, 4626, 4690, 4750, 5082, 4286, 4311]\n",
"home_age_tokens = [14, 16, 13, 12, 10, 19, 8, 8]\n",
"\n",
"# Uncomment to take Jokic off team, replace with Peyton Watson\n",
"# home_player_tokens = [4331, 4298, 4626, 4690, 4750, 5082, 4286, 4311]\n",
"# home_age_tokens = [6, 16, 13, 12, 10, 19, 8, 8]\n",
"\n",
"# Boston Celtics final game of 2023-24 season roster\n",
"away_player_tokens = [5042, 5039, 5027, 4981, 4972, 5004, 4416, 4983]\n",
"away_age_tokens = [11, 12, 19, 14, 23, 11, 13, 13]\n",
"\n",
"# Uncomment to take Tatum off team, replace with Pritchard\n",
"# away_player_tokens = [4999, 5039, 5027, 4981, 4972, 5004, 4416, 4983]\n",
"# away_age_tokens = [11, 12, 19, 14, 23, 11, 13, 13]\n",
"\n",
"# The model usually gives the home team a bump in win probability.\n",
"# Change this to \"True\" to swap home and away teams.\n",
"swap_home_away = False\n",
"if swap_home_away:\n",
" home_player_tokens, away_player_tokens = away_player_tokens, home_player_tokens\n",
" home_age_tokens, away_age_tokens = away_age_tokens, home_age_tokens\n",
"\n",
"assert len(home_player_tokens) == players_per_team\n",
"assert len(home_age_tokens) == players_per_team\n",
"assert len(away_player_tokens) == players_per_team\n",
"assert len(away_age_tokens) == players_per_team\n",
"\n",
"batch = {\n",
" 'home_player_tokens': Tensor([num_player_tokens+1] + home_player_tokens).to(dtype=int32).unsqueeze(0),\n",
" 'home_age_tokens': Tensor([num_age_tokens+1] + home_age_tokens).to(dtype=int32).unsqueeze(0),\n",
" 'away_player_tokens': Tensor(away_player_tokens).to(dtype=int32).unsqueeze(0),\n",
" 'away_age_tokens': Tensor(away_age_tokens).to(dtype=int32).unsqueeze(0),\n",
"}\n",
"\n",
"for key, value in batch.items():\n",
" if hasattr(value, 'to'):\n",
" batch[key] = value.to(device)\n",
"\n",
"output, _ = model(**batch)\n",
"output = output.squeeze().softmax(dim=0)\n",
"\n",
"probs = {}\n",
"loss_prob = 0\n",
"win_prob = 0\n",
"\n",
"first = True\n",
"for i, token in enumerate(output):\n",
" if first:\n",
" first = False\n",
" continue\n",
"\n",
" if i-21 < 0:\n",
" loss_prob += token.item()\n",
" elif i-21 > 0 and i-21 < 21:\n",
" win_prob += token.item()\n",
"\n",
" probs[i-21] = token.item()\n",
"\n",
"del probs[0]\n",
"del probs[21]\n",
"\n",
"print(f\"Home team win probability: {win_prob:.2f}\")\n",
"\n",
"plt.bar(probs.keys(), probs.values())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "nba",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|