Spaces:
Runtime error
Runtime error
update app.py
Browse files
app.py
CHANGED
@@ -27,11 +27,26 @@ from pytorch_grad_cam.utils.image import show_cam_on_image
|
|
27 |
import gradio as gr
|
28 |
import misclas_helper
|
29 |
import gradcam_helper
|
|
|
30 |
from misclas_helper import display_cifar_misclassified_data
|
31 |
from gradcam_helper import display_gradcam_output
|
|
|
|
|
32 |
|
33 |
fileName = None
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
def hello(DoYouWantToShowMisClassifiedImages, HowManyImages):
|
36 |
if(DoYouWantToShowMisClassifiedImages.lower() == "yes"):
|
37 |
fileName = misclas_helper.display_cifar_misclassified_data(misclassified_data, classes, inv_normalize, number_of_samples=HowManyImages)
|
@@ -49,11 +64,6 @@ misClass_demo = gr.Interface(
|
|
49 |
|
50 |
############
|
51 |
|
52 |
-
targets = None
|
53 |
-
device = torch.device("cpu")
|
54 |
-
classes = ('plane', 'car', 'bird', 'cat', 'deer',
|
55 |
-
'dog', 'frog', 'horse', 'ship', 'truck')
|
56 |
-
|
57 |
|
58 |
def inference(DoYouWantToShowGradCAMMedImages, HowManyImages, WhichLayer, transparency):
|
59 |
if(DoYouWantToShowGradCAMMedImages.lower() == "yes"):
|
|
|
27 |
import gradio as gr
|
28 |
import misclas_helper
|
29 |
import gradcam_helper
|
30 |
+
import lightningmodel
|
31 |
from misclas_helper import display_cifar_misclassified_data
|
32 |
from gradcam_helper import display_gradcam_output
|
33 |
+
from misclas_helper import get_misclassified_data2
|
34 |
+
from lightningmodel import LitResnet
|
35 |
|
36 |
fileName = None
|
37 |
|
38 |
+
targets = None
|
39 |
+
device = torch.device("cpu")
|
40 |
+
classes = ('plane', 'car', 'bird', 'cat', 'deer',
|
41 |
+
'dog', 'frog', 'horse', 'ship', 'truck')
|
42 |
+
|
43 |
+
model = LitResnet(lr=0.05).load_from_checkpoint("/content/weights.ckpt")
|
44 |
+
|
45 |
+
device = torch.device("cpu")
|
46 |
+
|
47 |
+
# Get the misclassified data from test dataset
|
48 |
+
misclassified_data = get_misclassified_data2(model, device, 20)
|
49 |
+
|
50 |
def hello(DoYouWantToShowMisClassifiedImages, HowManyImages):
|
51 |
if(DoYouWantToShowMisClassifiedImages.lower() == "yes"):
|
52 |
fileName = misclas_helper.display_cifar_misclassified_data(misclassified_data, classes, inv_normalize, number_of_samples=HowManyImages)
|
|
|
64 |
|
65 |
############
|
66 |
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
def inference(DoYouWantToShowGradCAMMedImages, HowManyImages, WhichLayer, transparency):
|
69 |
if(DoYouWantToShowGradCAMMedImages.lower() == "yes"):
|