Arekku21 commited on
Commit
e41aa53
1 Parent(s): baffa9d

Updated app.py with confidence threshold of 0.8

Browse files
Files changed (1) hide show
  1. app.py +13 -1
app.py CHANGED
@@ -6,6 +6,8 @@ import torch
6
  import torchvision.models as models
7
  from torch import nn
8
 
 
 
9
  from albumentations import (
10
  HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
11
  Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
@@ -98,11 +100,21 @@ def predict_image(image):
98
  # Assuming the output is a tensor representing class probabilities
99
  probabilities = torch.nn.functional.softmax(output[0], dim=0).numpy()
100
 
 
 
101
  # Get the class with the highest probability
102
  predicted_class = np.argmax(probabilities)
103
 
 
 
 
 
 
 
 
 
104
  # Return the class label
105
- return "Predicted Class: " + classes[predicted_class+1]
106
 
107
  # create a gradio interface
108
  gr.Interface(fn=predict_image, inputs="image", outputs="text").launch()
 
6
  import torchvision.models as models
7
  from torch import nn
8
 
9
+ from model import *
10
+
11
  from albumentations import (
12
  HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
13
  Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
 
100
  # Assuming the output is a tensor representing class probabilities
101
  probabilities = torch.nn.functional.softmax(output[0], dim=0).numpy()
102
 
103
+ conf_threshold = 0.8
104
+
105
  # Get the class with the highest probability
106
  predicted_class = np.argmax(probabilities)
107
 
108
+ # If the probability is less than the threshold, return unknown else return the class
109
+ if probabilities[predicted_class] < conf_threshold:
110
+ string_to_return = "Predicted Class: Unknown"
111
+
112
+ else:
113
+ # Return the class label
114
+ string_to_return = f"Predicted Class: {classes[predicted_class+1]} with probability: {probabilities[predicted_class] * 100:.2f}%"
115
+
116
  # Return the class label
117
+ return string_to_return
118
 
119
  # create a gradio interface
120
  gr.Interface(fn=predict_image, inputs="image", outputs="text").launch()