File size: 40,694 Bytes
7b5e67a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 |
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import os\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"output_type = \"png\" # or \"pdf\"\n",
"timevis = \"noB_tnn\"\n",
"dvi = \"parametricUmap_step2_A\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"DATASET = \"mnist\"\n",
"CONTENT_PATH = \"/home/xianglin/projects/DVI_data/resnet18_{}\".format(DATASET)\n",
"content_path = CONTENT_PATH"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"train_corrs = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_3_corrs.npy\".format(timevis)))\n",
"train_ps = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_3_ps.npy\".format(timevis)))\n",
"train_tnn = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_3_5_tnn.npy\".format(timevis)))\n",
"test_corrs = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_test_3_corrs.npy\".format(timevis)))\n",
"test_ps = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_test_3_ps.npy\".format(timevis)))\n",
"test_tnn = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_test_3_5_tnn.npy\".format(timevis)))\n",
"\n",
"\n",
"dvi_train_corrs = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_train_{}_3_corrs.npy\".format(dvi)))\n",
"dvi_train_ps = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_train_{}_3_ps.npy\".format(dvi)))\n",
"dvi_train_tnn = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_train_{}_3_5_tnn.npy\".format(dvi)))\n",
"dvi_test_corrs = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_test_{}_3_corrs.npy\".format(dvi)))\n",
"dvi_test_ps = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_test_{}_3_ps.npy\".format(dvi)))\n",
"dvi_test_tnn = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_test_{}_3_5_tnn.npy\".format(dvi)))\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"selected_idxs = np.argsort(train_corrs[19])[-100:]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(<seaborn.axisgrid.FacetGrid at 0x7fb4dc4bfa50>,\n",
" <seaborn.axisgrid.FacetGrid at 0x7fb4dc6110d0>)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAARtklEQVR4nO3db4xlB1nH8e+PbitoK2xlW9d1NwWsCCFScEAsaICCLn1TMGBFhA1WtwQxIITYwAs1vkGjSPwT7AINq0EoQrFFsVBKoZJCYSGlbF2ggEDXbrpTQKmagFseX9zTOA6zu3e3c+5zZ+b7SW7uveeeO+fpZO63Z8/ccydVhSRp9h7QPYAkbVQGWJKaGGBJamKAJamJAZakJpu6B5jGzp0769prr+0eQ5JOVlZauCb2gO++++7uESRp1a2JAEvSemSAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCajBTjJA5N8IslnktyW5PeH5WcmuS7J7cP15rFmkKR5NuYe8LeBp1fVY4HzgJ1JngRcBlxfVecC1w/3JWnDGS3ANfGfw91Th0sBFwF7h+V7gWePNYMkzbNRjwEnOSXJLcBh4Lqquhk4u6oOAQzXZx3lubuT7Euyb3FxccwxJc2xbdt3kGQuLtu271jV/7ZRP5C9qu4FzkvyEOA9SR5zAs/dA+wBWFhYqHEmlDTv7jx4BxdfflP3GABceen5q/r1ZvIuiKr6d+DDwE7griRbAYbrw7OYQZLmzZjvgtgy7PmS5EHAM4DPAdcAu4bVdgFXjzWDJM2zMQ9BbAX2JjmFSejfWVX/kORjwDuTXAJ8DXjeiDNI0twaLcBVdSvwuBWWfx24YKztStJa4ZlwktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1KT0QKcZHuSG5IcSHJbkpcPy38vyb8luWW4XDjWDJI0zzaN+LWPAK+qqk8nOQP4VJLrhsf+tKr+eMRtS9LcGy3AVXUIODTcvifJAWDbWNuTpLVmJseAk5wDPA64eVj0siS3JrkiyeajPGd3kn1J9i0uLs5iTEmaqdEDnOR04N3AK6rqW8AbgUcA5zHZQ/6TlZ5XVXuqaqGqFrZs2TL2mJI0c6MGOMmpTOL7tqq6CqCq7qqqe6vqu8CbgCeOOYMkzasx3wUR4C3Agap6/ZLlW5es9hxg/1gzSNI8G/NdEE8GXgh8Nsktw7LXAM9Pch5QwFeAS0ecQZLm1pjvgvgokBUeet9Y25SktcQz4SSpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJqMFOMn2JDckOZDktiQvH5afmeS6JLcP15vHmkGS5tmYe8BHgFdV1aOAJwG/meTRwGXA9VV1LnD9cF+SNpzRAlxVh6rq08Pte4ADwDbgImDvsNpe4NljzSBJ82wmx4CTnAM8DrgZOLuqDsEk0sBZR3nO7iT7kuxbXFycxZiSNFOjBzjJ6cC7gVdU1bemfV5V7amqhapa2LJly3gDSlKTUQOc5FQm8X1bVV01LL4rydbh8a3A4TFnkKR5Nea7IAK8BThQVa9f8tA1wK7h9i7g6rFmkKR5tmnEr/1k4IXAZ5PcMix7DfA64J1JLgG+BjxvxBkkaW6NFuCq+iiQozx8wVjblaS1wjPhJKmJAZakJgZYkpoYYElqYoC14W3bvoMkc3HZtn1H97dDMzTm29CkNeHOg3dw8eU3dY8BwJWXnt89gmbIPWBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmUwU4yZOnWSZJmt60e8B/PuUySdKUjvlXkZP8DHA+sCXJK5c89IPAKWMOJknr3fH+LP1pwOnDemcsWf4t4LljDSVJG8ExA1xVHwE+kuStVfXVGc0kSRvC8faA7/N9SfYA5yx9TlU9fYyhJGkjmDbAfwf8FfBm4N7xxpGkjWPaAB+pqjeOOokkbTDTvg3tvUlemmRrkjPvu4w6mSStc9PuAe8arl+9ZFkBD1/dcSRp45gqwFX1sLEHkaSNZqoAJ3nRSsur6q9XdxxJ2jimPQTxhCW3HwhcAHwaMMCSdJKmPQTxW0vvJ3kw8DejTCRJG8TJfhzlfwPnruYgkrTRTHsM+L1M3vUAkw/heRTwzrGGkqSNYNpjwH+85PYR4KtVdXCEeSRpw5jqEMTwoTyfY/KJaJuB7xzvOUmuSHI4yf4ly34vyb8luWW4XHiyg0vSWjftX8T4JeATwPOAXwJuTnK8j6N8K7BzheV/WlXnDZf3nciwkrSeTHsI4rXAE6rqMECSLcAHgXcd7QlVdWOSc+73hJK0Tk37LogH3BffwddP4LnLvSzJrcMhis1HWynJ7iT7kuxbXFw8yU1Ja8wDNpFkLi7btu/o/m6se9PuAV+b5P3A24f7FwMnc/jgjcAfMHlHxR8AfwL82korVtUeYA/AwsJCrbSOtO589wgXX35T9xQAXHnp+d0jrHvH+5twPwacXVWvTvKLwFOAAB8D3naiG6uqu5Z87TcB/3CiX0OS1ovjHUZ4A3APQFVdVVWvrKrfZrL3+4YT3ViSrUvuPgfYf7R1JWm9O94hiHOq6tblC6tq3/F+wZbk7cBTgYcmOQj8LvDUJOcxOQTxFeDSEx9ZktaH4wX4gcd47EHHemJVPX+FxW857kSStEEc7xDEJ5P8xvKFSS4BPjXOSJK0MRxvD/gVwHuSvID/C+4CcBqTY7iSpJN0zAAP71o4P8nTgMcMi/+xqj40+mSStM5N+3nANwA3jDyLJG0o056IIWmjGc7K03gMsKSVzclZeev5jLyT/TwHSdL9ZIAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaeCKG2mzbvoM7D97RPYbUxgCrzZ0H7/BMK21oHoKQpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKajBbgJFckOZxk/5JlZya5Lsntw/XmsbYvSfNuzD3gtwI7ly27DLi+qs4Frh/uS9KGNFqAq+pG4BvLFl8E7B1u7wWePdb2JWnezfoY8NlVdQhguD7raCsm2Z1kX5J9i4uLMxtQkmZlbn8JV1V7qmqhqha2bNnSPY4krbpZB/iuJFsBhuvDM96+JM2NWQf4GmDXcHsXcPWMty9Jc2PMt6G9HfgY8MgkB5NcArwOeGaS24FnDvclaUPaNNYXrqrnH+WhC8bapiStJXP7SzhJWu8MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDXZ1LHRJF8B7gHuBY5U1ULHHJLUqSXAg6dV1d2N25ekVh6CkKQmXQEu4ANJPpVk90orJNmdZF+SfYuLiye1kW3bd5Ck/bJt+477872StE51HYJ4clXdmeQs4Lokn6uqG5euUFV7gD0ACwsLdTIbufPgHVx8+U33f9r76cpLz+8eQdIcatkDrqo7h+vDwHuAJ3bMIUmdZh7gJD+Q5Iz7bgM/D+yf9RyS1K3jEMTZwHuS3Lf9v62qaxvmkKRWMw9wVX0ZeOystytJ88a3oUlSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktRkU/cAG8IDNpGkewoATjn1+7j3f77dPYYkDPBsfPcIF19+U/cUAFx56flzNYu0kXkIQpKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJi0BTrIzyeeTfDHJZR0zSFK3mQc4ySnAXwLPAh4NPD/Jo2c9hyR169gDfiLwxar6clV9B3gHcFHDHJLUKlU12w0mzwV2VtWvD/dfCPx0Vb1s2Xq7gd3D3UcCn7+fm34ocPf9/Bod1uLca3FmcO5ZWoszw8nPfXdV7Vy+sOOPcq7054G/5/8CVbUH2LNqG032VdXCan29WVmLc6/FmcG5Z2ktzgyrP3fHIYiDwPYl938UuLNhDklq1RHgTwLnJnlYktOAXwauaZhDklrN/BBEVR1J8jLg/cApwBVVddsMNr1qhzNmbC3OvRZnBueepbU4M6zy3DP/JZwkacIz4SSpiQGWpCbrNsBJzkxyXZLbh+vNK6yzPckNSQ4kuS3Jy5tmPeap2Zn4s+HxW5M8vmPO5aaY+wXDvLcmuSnJYzvmXG7aU+GTPCHJvcN711tNM3OSpya5ZfhZ/sisZ1zJFD8jD07y3iSfGeZ+ccecy2a6IsnhJPuP8vjqvR6ral1egD8CLhtuXwb84QrrbAUeP9w+A/gC8OgZz3kK8CXg4cBpwGeWzwBcCPwTk/dQPwm4eQ6+v9PMfT6webj9rLUy95L1PgS8D3juvM8MPAT4F2DHcP+stfC9Bl5z32sT2AJ8Azitee6fAx4P7D/K46v2ely3e8BMTm/eO9zeCzx7+QpVdaiqPj3cvgc4AGyb1YCDaU7Nvgj465r4OPCQJFtnPOdyx527qm6qqm8Odz/O5D3f3aY9Ff63gHcDh2c53FFMM/OvAFdV1dcAqmqtzF3AGUkCnM4kwEdmO+aygapuHOY4mlV7Pa7nAJ9dVYdgElrgrGOtnOQc4HHAzeOP9v9sA+5Ycv8g3/s/gWnWmbUTnekSJnsN3Y47d5JtwHOAv5rhXMcyzff6x4HNST6c5FNJXjSz6Y5umrn/AngUk5OxPgu8vKq+O5vxTtqqvR47TkVeNUk+CPzwCg+99gS/zulM9nZeUVXfWo3ZTmTzKyxb/t7AqU7fnrGpZ0ryNCYBfsqoE01nmrnfAPxOVd072TFrN83Mm4CfAi4AHgR8LMnHq+oLYw93DNPM/QvALcDTgUcA1yX554bX4YlYtdfjmg5wVT3jaI8luSvJ1qo6NPzzYMV/kiU5lUl831ZVV4006rFMc2r2PJ6+PdVMSX4SeDPwrKr6+oxmO5Zp5l4A3jHE96HAhUmOVNXfz2TC7zXtz8jdVfVfwH8luRF4LJPfa3SZZu4XA6+rycHVLyb5V+AngE/MZsSTsmqvx/V8COIaYNdwexdw9fIVhuNObwEOVNXrZzjbUtOcmn0N8KLht69PAv7jvsMrjY47d5IdwFXAC5v3xJY67txV9bCqOqeqzgHeBby0Mb4w3c/I1cDPJtmU5PuBn2byO41O08z9NSZ77SQ5m8knH355plOeuNV7PXb+tnHk32T+EHA9cPtwfeaw/EeA9w23n8Lknw63Mvln0C3AhQ2zXshkT+VLwGuHZS8BXjLcDpMPsf8Sk+NkC93f3ynnfjPwzSXf233dM08z97J130rzuyCmnRl4NZN3Quxncjht7r/Xw+vxA8PP9X7gV+dg5rcDh4D/YbK3e8lYr0dPRZakJuv5EIQkzTUDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1+V8M8r9g8JjwfQAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 360x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAayElEQVR4nO3df7Bc5X3f8fcnMsa0NgaKIEKCgaRyG2BiHBSihrTjYE8taFrhTuzKTQzj4ighEIObpgHnj6TT0YzbOia1qUkU7EGkibGaOEVxwQSDHdcTflh2MUKAbSUQLNAgyYEapzMKkr/9Y4+GRVpdrdA9+9y99/2a2dmz3z3n7lfA/XD0nOc8m6pCkjR539e6AUlaqAxgSWrEAJakRgxgSWrEAJakRl7VuoG+rFq1qj772c+2bkOSADKqOG/PgHfv3t26BUma0bwNYEma6wxgSWrEAJakRgxgSWrEAJakRgxgSWrEAJakRgxgSWrEAJakRnoP4CSLkvyfJJ/pXp+U5O4k3+yeTxza9/ok25J8PcnbhurnJ9nSvfeRJCNv65OkaTKJM+BrgMeGXl8H3FNVy4F7utckORtYA5wDrAI+lmRRd8xNwFpgefdYNYG+JalXvQZwkmXAPwNuHiqvBjZ02xuAS4fqt1XVnqp6AtgGXJBkCXB8Vd1Xg+9PunXoGEmaWn2fAf8W8O+B7w3VTq2qHQDd8yldfSnwraH9tne1pd32gfWDJFmbZHOSzbt27ZqVP4Ak9aW3AE7yU8DOqvrKuIeMqNUM9YOLVeurakVVrVi8ePGYHytJbfS5HvCFwL9IcgnwGuD4JP8deDbJkqra0Q0v7Oz23w6cPnT8MuCZrr5sRF3SAvPeq9/P07uff1lt6ckncPONN7Rp6Cj1FsBVdT1wPUCSNwP/rqp+Nsl/AS4HPtg9394dsgn4gyQfBk5jcLHtwaral+SFJCuBB4DLgI/21bekuevp3c/z+ovWvrx27/pG3Ry9Ft+I8UFgY5IrgKeAdwBU1dYkG4FHgb3AVVW1rzvmSuAW4Djgzu4hSVNtIgFcVV8AvtBtfxt4yyH2WwesG1HfDJzbX4eSNHneCSdJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktSIASxJjRjAktTIq1o3IGn+ee/V7+fp3c8fVF968gncfOMNk29ojjKAJc26p3c/z+svWntw/d71DbqZuxyCkKRGDGBJasQAlqRGegvgJK9J8mCSryXZmuQ/dPXfSPJ0koe6xyVDx1yfZFuSryd521D9/CRbuvc+kiR99S1Jk9LnRbg9wEVV9d0kxwBfSnJn994NVfWh4Z2TnA2sAc4BTgM+l+QNVbUPuAlYC9wP3AGsAu5EkqZYb2fANfDd7uUx3aNmOGQ1cFtV7amqJ4BtwAVJlgDHV9V9VVXArcClffUtSZPS6xhwkkVJHgJ2AndX1QPdW1cneTjJJ5Kc2NWWAt8aOnx7V1vabR9YH/V5a5NsTrJ5165ds/lHkaRZ12sAV9W+qjoPWMbgbPZcBsMJPwicB+wAfrPbfdS4bs1QH/V566tqRVWtWLx48VF2L0n9msgsiKp6HvgCsKqqnu2C+XvA7wIXdLttB04fOmwZ8ExXXzaiLklTrc9ZEIuTnNBtHwe8FXi8G9Pd7+3AI932JmBNkmOTnAUsBx6sqh3AC0lWdrMfLgNu76tvSZqUPmdBLAE2JFnEIOg3VtVnkvxekvMYDCM8Cfw8QFVtTbIReBTYC1zVzYAAuBK4BTiOwewHZ0BImnq9BXBVPQy8aUT93TMcsw5YN6K+GTh3VhuUpMa8E06SGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJamRV/X1g5O8BvgicGz3OX9YVb+e5CTgU8CZwJPAO6vque6Y64ErgH3A+6rqrq5+PnALcBxwB3BNVVVfvUuaHlsf2cLFa95zUH3pySdw8403jPUz3nv1+3l69/Ov+PhXqrcABvYAF1XVd5McA3wpyZ3AvwTuqaoPJrkOuA741SRnA2uAc4DTgM8leUNV7QNuAtYC9zMI4FXAnT32LmlKvFiLeP1Faw+qP33v+rF/xtO7nz/oZxzJ8a9Ub0MQNfDd7uUx3aOA1cCGrr4BuLTbXg3cVlV7quoJYBtwQZIlwPFVdV931nvr0DGSNLV6HQNOsijJQ8BO4O6qegA4tap2AHTPp3S7LwW+NXT49q62tNs+sD7q89Ym2Zxk865du2b1zyJJs63XAK6qfVV1HrCMwdnsuTPsnlE/Yob6qM9bX1UrqmrF4sWLj7hfSZqkicyCqKrngS8wGLt9thtWoHve2e22HTh96LBlwDNdfdmIuiRNtd4COMniJCd028cBbwUeBzYBl3e7XQ7c3m1vAtYkOTbJWcBy4MFumOKFJCuTBLhs6BhJmlp9zoJYAmxIsohB0G+sqs8kuQ/YmOQK4CngHQBVtTXJRuBRYC9wVTcDAuBKXpqGdifOgJA0D/QWwFX1MPCmEfVvA285xDHrgHUj6puBmcaPJWnqeCecJDViAEtSIwawJDViAEtSIwawJDViAEtSIwawJDViAEtSIwawJDViAEtSIwawJDViAEtSIwawJDViAEtSIwawJDViAEtSIwawJDViAEtSIwawJDViAEtSIwawJDXS59fSS/Pee69+P0/vfv5ltaUnn8DNN97QpiFNFQNYOgpP736e11+09uW1e9c36kbTxiEISWrEAJakRgxgSWrEMWBpHvLi4HQwgKV5yIuD08EhCElqpLcATnJ6ks8neSzJ1iTXdPXfSPJ0koe6xyVDx1yfZFuSryd521D9/CRbuvc+kiR99S1Jk9LnEMRe4Jer6qtJXgd8Jcnd3Xs3VNWHhndOcjawBjgHOA34XJI3VNU+4CZgLXA/cAewCrizx94lqXe9BXBV7QB2dNsvJHkMWDrDIauB26pqD/BEkm3ABUmeBI6vqvsAktwKXIoBrHlq1AU0mOxFNC/iTcZELsIlORN4E/AAcCFwdZLLgM0MzpKfYxDO9w8dtr2rvdhtH1gf9TlrGZwpc8YZZ8zuH0KakFEX0GCyF9G8iDcZvV+ES/Ja4I+Aa6vqOwyGE34QOI/BGfJv7t91xOE1Q/3gYtX6qlpRVSsWL158tK1LUq96DeAkxzAI39+vqk8DVNWzVbWvqr4H/C5wQbf7duD0ocOXAc909WUj6pI01fqcBRHg48BjVfXhofqSod3eDjzSbW8C1iQ5NslZwHLgwW4s+YUkK7ufeRlwe199S9Kk9DkGfCHwbmBLkoe62geAdyU5j8EwwpPAzwNU1dYkG4FHGcyguKqbAQFwJXALcByDi29egJM09fqcBfElRo/f3jHDMeuAdSPqm4FzZ687SWrPO+EkqRHXgpA05xxqLvRjj3+DlRdNvp++GMCSxrL1kS1cvOY9B9X7uEHjUHOh92y5dlY/pzUDWNJYXqxFzW8QmW8cA5akRgxgSWrEAJakRgxgSWrEAJakRgxgSWrEaWiSJmbUXOKFvNC7ASxpYkbNJV7I84gNYEkLxly7xdkAlrRgzLVbnMe6CJfkwnFqkqTxjTsL4qNj1iRJY5pxCCLJPwJ+HFic5N8OvXU8sKjPxiRpvjvcGPCrgdd2+71uqP4d4Kf7akrS5Iy6MDXf1t2dq2YM4Kr6M+DPktxSVX81oZ4kTdCoC1Pzbd3duWrcWRDHJlkPnDl8TFX5/0hJeoXGDeD/Afw2cDOw7zD7SpLGMG4A762qm3rtRJIWmHGnof1Jkl9MsiTJSfsfvXYmSfPcuGfAl3fPvzJUK+AHZrcdSVo4xgrgqjqr70akPh1qDYCFvBKX2hsrgJNcNqpeVbfObjtSPw61BsBCXolL7Y07BPGjQ9uvAd4CfBUwgCXpFRp3COKXhl8neT3we710JEkLxCv9SqL/ByyfzUYkaaEZdwz4TxjMeoDBIjw/BGzsqylpPhp1IdCLgAvbuGPAHxra3gv8VVVtn+mAJKczGCP+fuB7wPqq+q/d/OFPMbit+UngnVX1XHfM9cAVDO62e19V3dXVzwduAY4D7gCuqapCmiKjLgR6EXBhG2sIoluU53EGK6KdCPztGIftBX65qn4IWAlcleRs4DrgnqpaDtzTvaZ7bw1wDrAK+FiS/Ute3gSsZTDssbx7X5Km2rjfiPFO4EHgHcA7gQeSzLgcZVXtqKqvdtsvAI8BS4HVwIZutw3Apd32auC2qtpTVU8A24ALkiwBjq+q+7qz3luHjpGkqTXuEMSvAT9aVTsBkiwGPgf84TgHJzkTeBPwAHBqVe2AQUgnOaXbbSlw/9Bh27vai932gfVRn7OWwZkyZ5xxxjitSVIz486C+L794dv59rjHJnkt8EfAtVX1nZl2HVGrGeoHF6vWV9WKqlqxePHicdqTpGbGPQP+bJK7gE92r/8Vg4thM0pyDIPw/f2q+nRXfjbJku7sdwmwP9i3A6cPHb4MeKarLxtRl6SpNuNZbJK/n+TCqvoV4HeAHwbeCNwHzHj5NkmAjwOPVdWHh97axEuL+1wO3D5UX5Pk2CRnMbjY9mA3XPFCkpXdz7xs6BhJmlqHOwP+LeADAN0Z7KcBkqzo3vvnMxx7IfBuYEuSh7raB4APAhuTXAE8xeDCHlW1NclG4FEGMyiuqqr9i79fyUvT0O7sHpI01Q4XwGdW1cMHFqtqc3dh7ZCq6kuMHr+FwVoSo45ZB6wb9XnAuYfpVZKmyuEupL1mhveOm81GJGmhOVwAfznJzx1Y7IYPvtJPS5K0MBxuCOJa4I+T/AwvBe4K4NXA23vsS5LmvRkDuKqeBX48yU/y0hjs/6qqe3vvTJLmuXHXA/488Pmee5GkBeWVrgcsSTpKBrAkNWIAS1Ij464FIWkOGvUtGwCPPf4NVl40+X50ZAxgaYqN+pYNgD1brp18MzpiDkFIUiMGsCQ1YgBLUiMGsCQ1YgBLUiMGsCQ1YgBLUiMGsCQ1YgBLUiMGsCQ14q3I0hhcc0F9MIClMbjmgvrgEIQkNWIAS1IjBrAkNWIAS1IjBrAkNWIAS1IjBrAkNdLbPOAknwB+CthZVed2td8Afg7Y1e32gaq6o3vveuAKYB/wvqq6q6ufD9wCHAfcAVxTVdVX31pYtj6yhYvXvOdltaUnn8DNN97QqCMtJH3eiHELcCNw6wH1G6rqQ8OFJGcDa4BzgNOAzyV5Q1XtA24C1gL3MwjgVcCdPfatBeTFWnTQDRZP37u+UTdaaHobgqiqLwJ/Pebuq4HbqmpPVT0BbAMuSLIEOL6q7uvOem8FLu2lYUmasBZjwFcneTjJJ5Kc2NWWAt8a2md7V1vabR9YHynJ2iSbk2zetWvXoXaTpDlh0mtB3AT8R6C6598E/g2QEfvWDPWRqmo9sB5gxYoVjhMvUKMWznHRHM1FEw3gqnp2/3aS3wU+073cDpw+tOsy4JmuvmxEXTqkUQvnuGiO5qKJDkF0Y7r7vR14pNveBKxJcmySs4DlwINVtQN4IcnKJAEuA26fZM+S1Jc+p6F9EngzcHKS7cCvA29Och6DYYQngZ8HqKqtSTYCjwJ7gau6GRAAV/LSNLQ7cQaEpHmitwCuqneNKH98hv3XAetG1DcD585ia5I0J3gnnCQ1YgBLUiMGsCQ1YgBLUiMGsCQ14rcia0451Ne/u0KZ5iMDWHPKob7+3RXKNB8ZwFJDo9Yjdt2KhcMAlhoatR6x61YsHF6Ek6RGDGBJasQAlqRGDGBJasQAlqRGDGBJasQAlqRGDGBJasQAlqRGvBNOs8rFdKTxGcCaVS6mI43PIQhJasQAlqRGDGBJasQAlqRGvAg3z4yaheAMBGluMoDnmVGzEJyBIM1NDkFIUiMGsCQ1YgBLUiOOAesVG3XBz2/0lcbXWwAn+QTwU8DOqjq3q50EfAo4E3gSeGdVPde9dz1wBbAPeF9V3dXVzwduAY4D7gCuqarqq2+Nb9QFP7/RVxpfn0MQtwCrDqhdB9xTVcuBe7rXJDkbWAOc0x3zsSSLumNuAtYCy7vHgT9TkqZSbwFcVV8E/vqA8mpgQ7e9Abh0qH5bVe2pqieAbcAFSZYAx1fVfd1Z761Dx0jSVJv0RbhTq2oHQPd8SldfCnxraL/tXW1pt31gfaQka5NsTrJ5165ds9q4JM22uTILIiNqNUN9pKpaX1UrqmrF4sWLZ605SerDpAP42W5Yge55Z1ffDpw+tN8y4JmuvmxEXZKm3qSnoW0CLgc+2D3fPlT/gyQfBk5jcLHtwaral+SFJCuBB4DLgI9OuGdJC9DWR7Zw8Zr3vKw22+uq9DkN7ZPAm4GTk2wHfp1B8G5McgXwFPAOgKrammQj8CiwF7iqqvZ1P+pKXpqGdmf3kKRevViLel9XpbcArqp3HeKttxxi/3XAuhH1zcC5s9iaGpjE2YQ0bbwTThMxibMJadrMlVkQkrTgeAa8gI1ay+HJb36DM5e/4WU1hwqkfhjAC9iotRye33KtQwXShDgEIUmNeAasqeAsCs1HBrCmgrMoNB85BCFJjRjAktSIQxCNjJoC5pimtLAYwI2MmgLmmKa0sDgEIUmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1Ii3Ik+pUWtJADz2+DdYedHk+5F05AzgKTVqLQmAPVuunXwzkl4RhyAkqREDWJIaMYAlqREDWJIaMYAlqRFnQcwho756HfyqImm+ahLASZ4EXgD2AXurakWSk4BPAWcCTwLvrKrnuv2vB67o9n9fVd3VoO3ejfrqdfCriqT5quUQxE9W1XlVtaJ7fR1wT1UtB+7pXpPkbGANcA6wCvhYkkUtGpak2TSXhiBWA2/utjcAXwB+tavfVlV7gCeSbAMuAO5r0OMrMuquNe9Yk9QqgAv40yQF/E5VrQdOraodAFW1I8kp3b5LgfuHjt3e1Q6SZC2wFuCMM87oq/cjNuquNe9Yk9QqgC+sqme6kL07yeMz7JsRtRq1Yxfk6wFWrFgxch9JmiuajAFX1TPd807gjxkMKTybZAlA97yz2307cPrQ4cuAZybXrST1Y+IBnOTvJnnd/m3gnwKPAJuAy7vdLgdu77Y3AWuSHJvkLGA58OBku5ak2ddiCOJU4I+T7P/8P6iqzyb5MrAxyRXAU8A7AKpqa5KNwKPAXuCqqtrXoG9JmlUTD+Cq+kvgjSPq3wbecohj1gHrem5NkibKW5ElqREDWJIaMYAlqREDWJIaMYAlqREDWJIamUuL8UyVUQvsuG6vpCNhAL9CoxbYcd1eSUfCIQhJasQz4MMYNdQArucr6egZwIcxaqgBXM9X0tFzCEKSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRA1iSGjGAJakRF2QfMurbL/zmC0l9MYCHjPr2C7/5QlJfHIKQpEamJoCTrEry9STbklzXuh9JOlpTEcBJFgH/DbgYOBt4V5Kz23YlSUdnKgIYuADYVlV/WVV/C9wGrG7ckyQdlVRV6x4OK8lPA6uq6r3d63cDP1ZVVx+w31pg/1W0fwB8vce2TgZ29/jzZ4t9zr5p6XVa+oTp6fWV9rm7qlYdWJyWWRAZUTvo/xxVtR5Y3387kGRzVa2YxGcdDfucfdPS67T0CdPT62z3OS1DENuB04deLwOeadSLJM2KaQngLwPLk5yV5NXAGmBT454k6ahMxRBEVe1NcjVwF7AI+ERVbW3c1kSGOmaBfc6+ael1WvqE6el1VvuciotwkjQfTcsQhCTNOwawJDViAI8pyUlJ7k7yze75xBH7nJ7k80keS7I1yTUT7G/GW7Uz8JHu/YeT/MikejvCPn+m6+/hJH+e5I0t+ux6Gev29yQ/mmRfN1994sbpM8mbkzzU/Xf5Z5PusevhcP/uX5/kT5J8revzPY36/ESSnUkeOcT7s/e7VFU+xngA/xm4rtu+DvhPI/ZZAvxIt/064BvA2RPobRHwF8APAK8Gvnbg5wKXAHcymFO9EnigwT/Dcfr8ceDEbvviFn2O2+vQfvcCdwA/PRf7BE4AHgXO6F6fMkf7/MD+3ytgMfDXwKsb9PpPgB8BHjnE+7P2u+QZ8PhWAxu67Q3ApQfuUFU7quqr3fYLwGPA0gn0Ns6t2quBW2vgfuCEJEsm0NsR9VlVf15Vz3Uv72cw57uFcW9//yXgj4Cdk2xuyDh9/mvg01X1FEBVteh1nD4LeF2SAK9lEMB7J9smVNUXu88+lFn7XTKAx3dqVe2AQdACp8y0c5IzgTcBD/TfGkuBbw293s7BwT/OPn070h6uYHCm0cJhe02yFHg78NsT7OtA4/wzfQNwYpIvJPlKkssm1t1LxunzRuCHGNxktQW4pqq+N5n2jsis/S5NxTzgSUnyOeD7R7z1a0f4c17L4Kzo2qr6zmz0driPHFE7cH7hWLdz92zsHpL8JIMA/oleOzq0cXr9LeBXq2rf4KStiXH6fBVwPvAW4DjgviT3V9U3+m5uyDh9vg14CLgI+EHg7iT/e0K/Q0di1n6XDOAhVfXWQ72X5NkkS6pqR/fXjZF/jUtyDIPw/f2q+nRPrR5onFu158Lt3GP1kOSHgZuBi6vq2xPq7UDj9LoCuK0L35OBS5Lsrar/OZEOB8b9d7+7qv4G+JskXwTeyOAaxaSM0+d7gA/WYKB1W5IngH8IPDiZFsc2a79LDkGMbxNwebd9OXD7gTt0Y1cfBx6rqg9PsLdxbtXeBFzWXcFdCfzf/UMqc6nPJGcAnwbePeEztAMdtteqOquqzqyqM4E/BH5xwuE7Vp8M/lv9x0leleTvAD/G4PrEXOvzKQZn6SQ5lcGKhn850S7HM3u/S5O+wjitD+DvAfcA3+yeT+rqpwF3dNs/weCvIg8z+KvUQ8AlE+rvEgZnNH8B/FpX+wXgF7rtMFjU/i8YjK+taPTP8XB93gw8N/TPb3PDf+cz9nrAvrfQYBbEuH0Cv8JgJsQjDIbG5lyf3e/Sn3b/fT4C/GyjPj8J7ABeZHC2e0Vfv0veiixJjTgEIUmNGMCS1IgBLEmNGMCS1IgBLEmNGMCS1IgBLEmN/H8Z8DWgBMVBpwAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 360x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"sns.displot(train_corrs[0][selected_idxs]),sns.displot(train_corrs[0])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"epoch_num = train_corrs.shape[0]\n",
"train_num = train_corrs.shape[1]\n",
"test_num = test_corrs.shape[1]\n",
"\n",
"train_data = np.zeros((epoch_num*train_num, 2))\n",
"for i in range(len(train_corrs)):\n",
" train_data[i*train_num:(i+1)*train_num][:,0] = train_corrs[i]\n",
" train_data[i*train_num:(i+1)*train_num][:,1] = i+1\n",
"test_data = np.zeros((epoch_num*test_num, 2))\n",
"for i in range(len(test_corrs)):\n",
" test_data[i*test_num:(i+1)*test_num][:,0] = test_corrs[i]\n",
" test_data[i*test_num:(i+1)*test_num][:,1] = i+1\n",
"data = np.concatenate((train_data, test_data), axis=0)\n",
"type = [\"Train\" for _ in range(len(train_data))] + [\"Test\" for _ in range(len(test_data))]\n",
"method = [\"TimeVis\" for _ in range(len(data))]\n",
"\n",
"dvi_train_data = np.zeros((epoch_num*train_num, 2))\n",
"for i in range(len(dvi_train_corrs)):\n",
" dvi_train_data[i*train_num:(i+1)*train_num][:,0] = dvi_train_corrs[i]\n",
" dvi_train_data[i*train_num:(i+1)*train_num][:,1] = i+1\n",
"dvi_test_data = np.zeros((epoch_num*test_num, 2))\n",
"for i in range(len(dvi_test_corrs)):\n",
" dvi_test_data[i*test_num:(i+1)*test_num][:,0] = dvi_test_corrs[i]\n",
" dvi_test_data[i*test_num:(i+1)*test_num][:,1] = i+1\n",
"dvi_data = np.concatenate((dvi_train_data, dvi_test_data), axis=0)\n",
"dvi_type = [\"Train\" for _ in range(len(dvi_train_data))]+[\"Test\" for _ in range(len(dvi_test_data))]\n",
"dvi_method = [\"DVI\" for _ in range(len(dvi_data))]\n",
"\n",
"data = np.concatenate((data, dvi_data), axis=0)\n",
"type = type + dvi_type\n",
"method = method + dvi_method\n",
"\n",
"df = pd.DataFrame(data,columns=[\"corr\", \"epoch\"])\n",
"df2 = df.assign(type = type)\n",
"df3 = df2.assign(method = method)\n",
"df3[[\"epoch\"]] = df[[\"epoch\"]].astype(int)\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 600x400 with 0 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.rcParams['figure.dpi'] = 100 # 图形分辨率\n",
"sns.set_theme(style='darkgrid')\n",
"plt.style.use('ggplot')\n",
"plt.title(\"MNIST\")\n",
"fg = sns.lineplot(x=\"epoch\", y=\"corr\", hue=\"method\", style=\"type\", markers=False, ci=95, data=df3)\n",
"plt.savefig(\n",
" \"./plot_results/corr_3_{}.{}\".format(\"mnist\", output_type),\n",
" dpi=300,\n",
" bbox_inches=\"tight\",\n",
" pad_inches=0.0,\n",
")\n",
"plt.clf()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 600x400 with 0 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"epoch_num = train_tnn.shape[0]\n",
"train_num = train_tnn.shape[1]\n",
"test_num = test_tnn.shape[1]\n",
"\n",
"train_data = np.zeros((epoch_num*train_num, 2))\n",
"for i in range(len(train_tnn)):\n",
" train_data[i*train_num:(i+1)*train_num][:,0] = train_tnn[i]\n",
" train_data[i*train_num:(i+1)*train_num][:,1] = i+1\n",
"test_data = np.zeros((epoch_num*test_num, 2))\n",
"for i in range(len(test_tnn)):\n",
" test_data[i*test_num:(i+1)*test_num][:,0] = test_tnn[i]\n",
" test_data[i*test_num:(i+1)*test_num][:,1] = i+1\n",
"data = np.concatenate((train_data, test_data), axis=0)\n",
"type = [\"Train\" for _ in range(len(train_data))] + [\"Test\" for _ in range(len(test_data))]\n",
"method = [\"TimeVis\" for _ in range(len(data))]\n",
"\n",
"dvi_train_data = np.zeros((epoch_num*train_num, 2))\n",
"for i in range(len(dvi_train_tnn)):\n",
" dvi_train_data[i*train_num:(i+1)*train_num][:,0] = dvi_train_tnn[i]\n",
" dvi_train_data[i*train_num:(i+1)*train_num][:,1] = i+1\n",
"dvi_test_data = np.zeros((epoch_num*test_num, 2))\n",
"for i in range(len(dvi_test_tnn)):\n",
" dvi_test_data[i*test_num:(i+1)*test_num][:,0] = dvi_test_tnn[i]\n",
" dvi_test_data[i*test_num:(i+1)*test_num][:,1] = i+1\n",
"dvi_data = np.concatenate((dvi_train_data, dvi_test_data), axis=0)\n",
"dvi_type = [\"Train\" for _ in range(len(dvi_train_data))]+[\"Test\" for _ in range(len(dvi_test_data))]\n",
"dvi_method = [\"DVI\" for _ in range(len(dvi_data))]\n",
"\n",
"data = np.concatenate((data, dvi_data), axis=0)\n",
"type = type + dvi_type\n",
"method = method + dvi_method\n",
"\n",
"df = pd.DataFrame(data,columns=[\"tnn\", \"epoch\"])\n",
"df2 = df.assign(type = type)\n",
"df3 = df2.assign(method = method)\n",
"df3[[\"epoch\"]] = df[[\"epoch\"]].astype(int)\n",
"plt.rcParams['figure.dpi'] = 100 # 图形分辨率\n",
"sns.set_theme(style='darkgrid')\n",
"plt.style.use('ggplot')\n",
"plt.title(\"MNIST\")\n",
"fg = sns.lineplot(x=\"epoch\", y=\"tnn\", hue=\"method\", style=\"type\", markers=False, ci=95, data=df3)\n",
"plt.savefig(\n",
" \"./plot_results/tnn_{}.{}\".format(\"mnist\", output_type),\n",
" dpi=300,\n",
" bbox_inches=\"tight\",\n",
" pad_inches=0.0,\n",
")\n",
"plt.clf()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"DATASET = \"fmnist\"\n",
"CONTENT_PATH = \"/home/xianglin/projects/DVI_data/resnet18_{}\".format(DATASET)\n",
"content_path = CONTENT_PATH"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"train_corrs = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_3_corrs.npy\".format(timevis)))\n",
"train_ps = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_3_ps.npy\".format(timevis)))\n",
"train_tnn = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_3_5_tnn.npy\".format(timevis)))\n",
"test_corrs = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_test_3_corrs.npy\".format(timevis)))\n",
"test_ps = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_test_3_ps.npy\".format(timevis)))\n",
"test_tnn = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_test_3_5_tnn.npy\".format(timevis)))\n",
"\n",
"\n",
"dvi_train_corrs = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_train_{}_3_corrs.npy\".format(dvi)))\n",
"dvi_train_ps = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_train_{}_3_ps.npy\".format(dvi)))\n",
"dvi_train_tnn = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_train_{}_3_5_tnn.npy\".format(dvi)))\n",
"dvi_test_corrs = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_test_{}_3_corrs.npy\".format(dvi)))\n",
"dvi_test_ps = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_test_{}_3_ps.npy\".format(dvi)))\n",
"dvi_test_tnn = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_test_{}_3_5_tnn.npy\".format(dvi)))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"epoch_num = train_corrs.shape[0]\n",
"train_num = train_corrs.shape[1]\n",
"test_num = test_corrs.shape[1]\n",
"\n",
"train_data = np.zeros((epoch_num*train_num, 2))\n",
"for i in range(len(train_corrs)):\n",
" train_data[i*train_num:(i+1)*train_num][:,0] = train_corrs[i]\n",
" train_data[i*train_num:(i+1)*train_num][:,1] = i+1\n",
"test_data = np.zeros((epoch_num*test_num, 2))\n",
"for i in range(len(test_corrs)):\n",
" test_data[i*test_num:(i+1)*test_num][:,0] = test_corrs[i]\n",
" test_data[i*test_num:(i+1)*test_num][:,1] = i+1\n",
"data = np.concatenate((train_data, test_data), axis=0)\n",
"type = [\"Train\" for _ in range(len(train_data))] + [\"Test\" for _ in range(len(test_data))]\n",
"method = [\"TimeVis\" for _ in range(len(data))]\n",
"\n",
"dvi_train_data = np.zeros((epoch_num*train_num, 2))\n",
"for i in range(len(dvi_train_corrs)):\n",
" dvi_train_data[i*train_num:(i+1)*train_num][:,0] = dvi_train_corrs[i]\n",
" dvi_train_data[i*train_num:(i+1)*train_num][:,1] = i+1\n",
"dvi_test_data = np.zeros((epoch_num*test_num, 2))\n",
"for i in range(len(dvi_test_corrs)):\n",
" dvi_test_data[i*test_num:(i+1)*test_num][:,0] = dvi_test_corrs[i]\n",
" dvi_test_data[i*test_num:(i+1)*test_num][:,1] = i+1\n",
"dvi_data = np.concatenate((dvi_train_data, dvi_test_data), axis=0)\n",
"dvi_type = [\"Train\" for _ in range(len(dvi_train_data))]+[\"Test\" for _ in range(len(dvi_test_data))]\n",
"dvi_method = [\"DVI\" for _ in range(len(dvi_data))]\n",
"\n",
"data = np.concatenate((data, dvi_data), axis=0)\n",
"type = type + dvi_type\n",
"method = method + dvi_method\n",
"\n",
"df = pd.DataFrame(data,columns=[\"corr\", \"epoch\"])\n",
"df2 = df.assign(type = type)\n",
"df3 = df2.assign(method = method)\n",
"df3[[\"epoch\"]] = df[[\"epoch\"]].astype(int)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 600x400 with 0 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.rcParams['figure.dpi'] = 100 # 图形分辨率\n",
"sns.set_theme(style='darkgrid')\n",
"plt.style.use('ggplot')\n",
"plt.title(\"FMNIST\")\n",
"sns.lineplot(x=\"epoch\", y=\"corr\", hue=\"method\", style=\"type\", markers=False, ci=95, data=df3)\n",
"\n",
"plt.savefig(\n",
" \"./plot_results/corr_3_{}.{}\".format(\"fmnist\", output_type),\n",
" dpi=300,\n",
" bbox_inches=\"tight\",\n",
" pad_inches=0.0,\n",
")\n",
"plt.clf()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 600x400 with 0 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"epoch_num = train_tnn.shape[0]\n",
"train_num = train_tnn.shape[1]\n",
"test_num = test_tnn.shape[1]\n",
"\n",
"train_data = np.zeros((epoch_num*train_num, 2))\n",
"for i in range(len(train_tnn)):\n",
" train_data[i*train_num:(i+1)*train_num][:,0] = train_tnn[i]\n",
" train_data[i*train_num:(i+1)*train_num][:,1] = i+1\n",
"test_data = np.zeros((epoch_num*test_num, 2))\n",
"for i in range(len(test_tnn)):\n",
" test_data[i*test_num:(i+1)*test_num][:,0] = test_tnn[i]\n",
" test_data[i*test_num:(i+1)*test_num][:,1] = i+1\n",
"data = np.concatenate((train_data, test_data), axis=0)\n",
"type = [\"Train\" for _ in range(len(train_data))] + [\"Test\" for _ in range(len(test_data))]\n",
"method = [\"TimeVis\" for _ in range(len(data))]\n",
"\n",
"dvi_train_data = np.zeros((epoch_num*train_num, 2))\n",
"for i in range(len(dvi_train_tnn)):\n",
" dvi_train_data[i*train_num:(i+1)*train_num][:,0] = dvi_train_tnn[i]\n",
" dvi_train_data[i*train_num:(i+1)*train_num][:,1] = i+1\n",
"dvi_test_data = np.zeros((epoch_num*test_num, 2))\n",
"for i in range(len(dvi_test_tnn)):\n",
" dvi_test_data[i*test_num:(i+1)*test_num][:,0] = dvi_test_tnn[i]\n",
" dvi_test_data[i*test_num:(i+1)*test_num][:,1] = i+1\n",
"dvi_data = np.concatenate((dvi_train_data, dvi_test_data), axis=0)\n",
"dvi_type = [\"Train\" for _ in range(len(dvi_train_data))]+[\"Test\" for _ in range(len(dvi_test_data))]\n",
"dvi_method = [\"DVI\" for _ in range(len(dvi_data))]\n",
"\n",
"data = np.concatenate((data, dvi_data), axis=0)\n",
"type = type + dvi_type\n",
"method = method + dvi_method\n",
"\n",
"df = pd.DataFrame(data,columns=[\"tnn\", \"epoch\"])\n",
"df2 = df.assign(type = type)\n",
"df3 = df2.assign(method = method)\n",
"df3[[\"epoch\"]] = df[[\"epoch\"]].astype(int)\n",
"plt.rcParams['figure.dpi'] = 100 # 图形分辨率\n",
"sns.set_theme(style='darkgrid')\n",
"plt.style.use('ggplot')\n",
"plt.title(\"FMNIST\")\n",
"sns.lineplot(x=\"epoch\", y=\"tnn\", hue=\"method\", style=\"type\", markers=False, ci=95, data=df3)\n",
"\n",
"plt.savefig(\n",
" \"./plot_results/tnn_{}.{}\".format(\"fmnist\", output_type),\n",
" dpi=300,\n",
" bbox_inches=\"tight\",\n",
" pad_inches=0.0,\n",
")\n",
"plt.clf()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"DATASET = \"cifar10\"\n",
"CONTENT_PATH = \"/home/xianglin/projects/DVI_data/resnet18_{}\".format(DATASET)\n",
"content_path = CONTENT_PATH"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"train_corrs = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_3_corrs.npy\".format(timevis)))\n",
"train_ps = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_3_ps.npy\".format(timevis)))\n",
"train_tnn = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_3_5_tnn.npy\".format(timevis)))\n",
"test_corrs = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_test_3_corrs.npy\".format(timevis)))\n",
"test_ps = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_test_3_ps.npy\".format(timevis)))\n",
"test_tnn = np.load(os.path.join(CONTENT_PATH, \"Model\", \"{}_test_3_5_tnn.npy\".format(timevis)))\n",
"\n",
"\n",
"dvi_train_corrs = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_train_{}_3_corrs.npy\".format(dvi)))\n",
"dvi_train_ps = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_train_{}_3_ps.npy\".format(dvi)))\n",
"dvi_train_tnn = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_train_{}_3_5_tnn.npy\".format(dvi)))\n",
"dvi_test_corrs = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_test_{}_3_corrs.npy\".format(dvi)))\n",
"dvi_test_ps = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_test_{}_3_ps.npy\".format(dvi)))\n",
"dvi_test_tnn = np.load(os.path.join(CONTENT_PATH, \"Model\", \"DVI_test_{}_3_5_tnn.npy\".format(dvi)))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"epoch_num = train_corrs.shape[0]\n",
"train_num = train_corrs.shape[1]\n",
"test_num = test_corrs.shape[1]\n",
"\n",
"train_data = np.zeros((epoch_num*train_num, 2))\n",
"for i in range(len(train_corrs)):\n",
" train_data[i*train_num:(i+1)*train_num][:,0] = train_corrs[i]\n",
" train_data[i*train_num:(i+1)*train_num][:,1] = i+1\n",
"test_data = np.zeros((epoch_num*test_num, 2))\n",
"for i in range(len(test_corrs)):\n",
" test_data[i*test_num:(i+1)*test_num][:,0] = test_corrs[i]\n",
" test_data[i*test_num:(i+1)*test_num][:,1] = i+1\n",
"data = np.concatenate((train_data, test_data), axis=0)\n",
"type = [\"Train\" for _ in range(len(train_data))] + [\"Test\" for _ in range(len(test_data))]\n",
"method = [\"TimeVis\" for _ in range(len(data))]\n",
"\n",
"dvi_train_data = np.zeros((epoch_num*train_num, 2))\n",
"for i in range(len(dvi_train_corrs)):\n",
" dvi_train_data[i*train_num:(i+1)*train_num][:,0] = dvi_train_corrs[i]\n",
" dvi_train_data[i*train_num:(i+1)*train_num][:,1] = i+1\n",
"dvi_test_data = np.zeros((epoch_num*test_num, 2))\n",
"for i in range(len(dvi_test_corrs)):\n",
" dvi_test_data[i*test_num:(i+1)*test_num][:,0] = dvi_test_corrs[i]\n",
" dvi_test_data[i*test_num:(i+1)*test_num][:,1] = i+1\n",
"dvi_data = np.concatenate((dvi_train_data, dvi_test_data), axis=0)\n",
"dvi_type = [\"Train\" for _ in range(len(dvi_train_data))]+[\"Test\" for _ in range(len(dvi_test_data))]\n",
"dvi_method = [\"DVI\" for _ in range(len(dvi_data))]\n",
"\n",
"data = np.concatenate((data, dvi_data), axis=0)\n",
"type = type + dvi_type\n",
"method = method + dvi_method\n",
"\n",
"df = pd.DataFrame(data,columns=[\"corr\", \"epoch\"])\n",
"df2 = df.assign(type = type)\n",
"df3 = df2.assign(method = method)\n",
"df3[[\"epoch\"]] = df[[\"epoch\"]].astype(int)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 600x400 with 0 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.rcParams['figure.dpi'] = 100\n",
"plt.style.use('ggplot')\n",
"plt.title(\"CIFAR10\")\n",
"sns.lineplot(x=\"epoch\", y=\"corr\", hue=\"method\", style=\"type\", markers=False, ci=95, data=df3)\n",
"plt.savefig(\n",
" \"./plot_results/corr_3_{}.{}\".format(\"cifar10\", output_type),\n",
" dpi=300,\n",
" bbox_inches=\"tight\",\n",
" pad_inches=0.0,\n",
")\n",
"plt.clf()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 600x400 with 0 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"epoch_num = train_tnn.shape[0]\n",
"train_num = train_tnn.shape[1]\n",
"test_num = test_tnn.shape[1]\n",
"\n",
"train_data = np.zeros((epoch_num*train_num, 2))\n",
"for i in range(len(train_tnn)):\n",
" train_data[i*train_num:(i+1)*train_num][:,0] = train_tnn[i]\n",
" train_data[i*train_num:(i+1)*train_num][:,1] = i+1\n",
"test_data = np.zeros((epoch_num*test_num, 2))\n",
"for i in range(len(test_tnn)):\n",
" test_data[i*test_num:(i+1)*test_num][:,0] = test_tnn[i]\n",
" test_data[i*test_num:(i+1)*test_num][:,1] = i+1\n",
"data = np.concatenate((train_data, test_data), axis=0)\n",
"type = [\"Train\" for _ in range(len(train_data))] + [\"Test\" for _ in range(len(test_data))]\n",
"method = [\"TimeVis\" for _ in range(len(data))]\n",
"\n",
"dvi_train_data = np.zeros((epoch_num*train_num, 2))\n",
"for i in range(len(dvi_train_tnn)):\n",
" dvi_train_data[i*train_num:(i+1)*train_num][:,0] = dvi_train_tnn[i]\n",
" dvi_train_data[i*train_num:(i+1)*train_num][:,1] = i+1\n",
"dvi_test_data = np.zeros((epoch_num*test_num, 2))\n",
"for i in range(len(dvi_test_tnn)):\n",
" dvi_test_data[i*test_num:(i+1)*test_num][:,0] = dvi_test_tnn[i]\n",
" dvi_test_data[i*test_num:(i+1)*test_num][:,1] = i+1\n",
"dvi_data = np.concatenate((dvi_train_data, dvi_test_data), axis=0)\n",
"dvi_type = [\"Train\" for _ in range(len(dvi_train_data))]+[\"Test\" for _ in range(len(dvi_test_data))]\n",
"dvi_method = [\"DVI\" for _ in range(len(dvi_data))]\n",
"\n",
"data = np.concatenate((data, dvi_data), axis=0)\n",
"type = type + dvi_type\n",
"method = method + dvi_method\n",
"\n",
"df = pd.DataFrame(data,columns=[\"tnn\", \"epoch\"])\n",
"df2 = df.assign(type = type)\n",
"df3 = df2.assign(method = method)\n",
"df3[[\"epoch\"]] = df[[\"epoch\"]].astype(int)\n",
"plt.rcParams['figure.dpi'] = 100\n",
"plt.style.use('ggplot')\n",
"plt.title(\"CIFAR10\")\n",
"sns.lineplot(x=\"epoch\", y=\"tnn\", hue=\"method\", style=\"type\", markers=False, ci=95, data=df3)\n",
"plt.savefig(\n",
" \"./plot_results/tnn_{}.{}\".format(\"cifar10\", output_type),\n",
" dpi=300,\n",
" bbox_inches=\"tight\",\n",
" pad_inches=0.0,\n",
")\n",
"plt.clf()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# simple\n",
"def draw(corrs, ps, corrs2, ps2, title):\n",
" fig, axs = plt.subplots(2)\n",
" fig.suptitle(title)\n",
"\n",
" epochs = [i for i in range(1, len(corrs)+1, 1)]\n",
" mean_corr1 = np.mean(corrs, axis=1)\n",
" var_corr1 = np.std(corrs, axis=1)\n",
" mean_p1 = np.mean(ps, axis=1)\n",
" var_p1 = np.std(ps, axis=1)\n",
"\n",
" mean_corr2 = np.mean(corrs2, axis=1)\n",
" var_corr2 = np.std(corrs2, axis=1)\n",
" mean_p2 = np.mean(ps2, axis=1)\n",
" var_p2 = np.std(ps2, axis=1)\n",
"\n",
" a11 = axs[0].plot(epochs, mean_corr1, \"b.-\", epochs, mean_p1, \"r+-\")\n",
" a12 = axs[0].fill_between(epochs, mean_corr1-var_corr1, mean_corr1+var_corr1)\n",
" a13 = axs[0].fill_between(epochs, mean_p1-var_p1, mean_p1+var_p1)\n",
"\n",
" a21 = axs[1].plot(epochs, mean_corr2, \"b.-\", epochs, mean_p2, \"r+-\")\n",
" a22 = axs[1].fill_between(epochs, mean_corr2-var_corr2, mean_corr2+var_corr2)\n",
" a23 = axs[1].fill_between(epochs, mean_p2-var_p2, mean_p2+var_p2)\n",
"\n",
"\n",
" plt.show()\n",
" plt.clf()"
]
}
],
"metadata": {
"interpreter": {
"hash": "aa7a9f36e1a1e240450dbe9cc8f6d8df1d5301f36681fb271c44fdd883236b60"
},
"kernelspec": {
"display_name": "Python 3.7.11 ('SV': conda)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
|