amanmibra commited on
Commit
44430bc
1 Parent(s): c79316a

Add assertion in demo

Browse files
Files changed (1) hide show
  1. notebooks/Demo.ipynb +10 -59
notebooks/Demo.ipynb CHANGED
@@ -3,7 +3,7 @@
3
  {
4
  "cell_type": "code",
5
  "execution_count": 19,
6
- "id": "9ef0e433",
7
  "metadata": {},
8
  "outputs": [],
9
  "source": [
@@ -437,17 +437,17 @@
437
  },
438
  {
439
  "cell_type": "code",
440
- "execution_count": 42,
441
  "id": "32621820",
442
  "metadata": {},
443
  "outputs": [
444
  {
445
  "data": {
446
  "text/plain": [
447
- "tensor([[1., 0., 0.]], grad_fn=<SoftmaxBackward0>)"
448
  ]
449
  },
450
- "execution_count": 42,
451
  "metadata": {},
452
  "output_type": "execute_result"
453
  }
@@ -455,79 +455,30 @@
455
  "source": [
456
  "input = wav.unsqueeze(0)\n",
457
  "output = model(input)\n",
458
- "prediction = train_dataset.labels[output]"
 
459
  ]
460
  },
461
  {
462
  "cell_type": "code",
463
- "execution_count": 28,
464
  "id": "9455a818",
465
  "metadata": {},
466
  "outputs": [
467
  {
468
  "data": {
469
  "text/plain": [
470
- "torch.Size([1, 128, 9, 31])"
471
- ]
472
- },
473
- "execution_count": 28,
474
- "metadata": {},
475
- "output_type": "execute_result"
476
- }
477
- ],
478
- "source": [
479
- "prediction = torch.argmax(output, 1)"
480
- ]
481
- },
482
- {
483
- "cell_type": "code",
484
- "execution_count": 34,
485
- "id": "e23f2494",
486
- "metadata": {},
487
- "outputs": [
488
- {
489
- "data": {
490
- "text/plain": [
491
- "torch.Size([1, 60032])"
492
  ]
493
  },
494
- "execution_count": 34,
495
  "metadata": {},
496
  "output_type": "execute_result"
497
  }
498
  ],
499
  "source": [
500
- "torch.nn.Flatten()(wav).shape"
501
  ]
502
- },
503
- {
504
- "cell_type": "code",
505
- "execution_count": 43,
506
- "id": "246fd00b",
507
- "metadata": {},
508
- "outputs": [
509
- {
510
- "data": {
511
- "text/plain": [
512
- "0"
513
- ]
514
- },
515
- "execution_count": 43,
516
- "metadata": {},
517
- "output_type": "execute_result"
518
- }
519
- ],
520
- "source": [
521
- "actual_output"
522
- ]
523
- },
524
- {
525
- "cell_type": "code",
526
- "execution_count": null,
527
- "id": "0dabc10c",
528
- "metadata": {},
529
- "outputs": [],
530
- "source": []
531
  }
532
  ],
533
  "metadata": {
 
3
  {
4
  "cell_type": "code",
5
  "execution_count": 19,
6
+ "id": "27deb847",
7
  "metadata": {},
8
  "outputs": [],
9
  "source": [
 
437
  },
438
  {
439
  "cell_type": "code",
440
+ "execution_count": 46,
441
  "id": "32621820",
442
  "metadata": {},
443
  "outputs": [
444
  {
445
  "data": {
446
  "text/plain": [
447
+ "0"
448
  ]
449
  },
450
+ "execution_count": 46,
451
  "metadata": {},
452
  "output_type": "execute_result"
453
  }
 
455
  "source": [
456
  "input = wav.unsqueeze(0)\n",
457
  "output = model(input)\n",
458
+ "prediction = torch.argmax(output, 1).item()\n",
459
+ "prediction"
460
  ]
461
  },
462
  {
463
  "cell_type": "code",
464
+ "execution_count": 49,
465
  "id": "9455a818",
466
  "metadata": {},
467
  "outputs": [
468
  {
469
  "data": {
470
  "text/plain": [
471
+ "True"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  ]
473
  },
474
+ "execution_count": 49,
475
  "metadata": {},
476
  "output_type": "execute_result"
477
  }
478
  ],
479
  "source": [
480
+ "actual_output == prediction"
481
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
482
  }
483
  ],
484
  "metadata": {