raja5259 commited on
Commit
a9f1949
1 Parent(s): 03380ef

update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -5
app.py CHANGED
@@ -27,11 +27,26 @@ from pytorch_grad_cam.utils.image import show_cam_on_image
27
  import gradio as gr
28
  import misclas_helper
29
  import gradcam_helper
 
30
  from misclas_helper import display_cifar_misclassified_data
31
  from gradcam_helper import display_gradcam_output
 
 
32
 
33
  fileName = None
34
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def hello(DoYouWantToShowMisClassifiedImages, HowManyImages):
36
  if(DoYouWantToShowMisClassifiedImages.lower() == "yes"):
37
  fileName = misclas_helper.display_cifar_misclassified_data(misclassified_data, classes, inv_normalize, number_of_samples=HowManyImages)
@@ -49,11 +64,6 @@ misClass_demo = gr.Interface(
49
 
50
  ############
51
 
52
- targets = None
53
- device = torch.device("cpu")
54
- classes = ('plane', 'car', 'bird', 'cat', 'deer',
55
- 'dog', 'frog', 'horse', 'ship', 'truck')
56
-
57
 
58
  def inference(DoYouWantToShowGradCAMMedImages, HowManyImages, WhichLayer, transparency):
59
  if(DoYouWantToShowGradCAMMedImages.lower() == "yes"):
 
27
  import gradio as gr
28
  import misclas_helper
29
  import gradcam_helper
30
+ import lightningmodel
31
  from misclas_helper import display_cifar_misclassified_data
32
  from gradcam_helper import display_gradcam_output
33
+ from misclas_helper import get_misclassified_data2
34
+ from lightningmodel import LitResnet
35
 
36
  fileName = None
37
 
38
+ targets = None
39
+ device = torch.device("cpu")
40
+ classes = ('plane', 'car', 'bird', 'cat', 'deer',
41
+ 'dog', 'frog', 'horse', 'ship', 'truck')
42
+
43
+ model = LitResnet(lr=0.05).load_from_checkpoint("/content/weights.ckpt")
44
+
45
+ device = torch.device("cpu")
46
+
47
+ # Get the misclassified data from test dataset
48
+ misclassified_data = get_misclassified_data2(model, device, 20)
49
+
50
  def hello(DoYouWantToShowMisClassifiedImages, HowManyImages):
51
  if(DoYouWantToShowMisClassifiedImages.lower() == "yes"):
52
  fileName = misclas_helper.display_cifar_misclassified_data(misclassified_data, classes, inv_normalize, number_of_samples=HowManyImages)
 
64
 
65
  ############
66
 
 
 
 
 
 
67
 
68
  def inference(DoYouWantToShowGradCAMMedImages, HowManyImages, WhichLayer, transparency):
69
  if(DoYouWantToShowGradCAMMedImages.lower() == "yes"):