Spaces:
Sleeping
Sleeping
Updated app.py with confidence threshold of 0.8
Browse files
app.py
CHANGED
@@ -6,6 +6,8 @@ import torch
|
|
6 |
import torchvision.models as models
|
7 |
from torch import nn
|
8 |
|
|
|
|
|
9 |
from albumentations import (
|
10 |
HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
|
11 |
Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
|
@@ -98,11 +100,21 @@ def predict_image(image):
|
|
98 |
# Assuming the output is a tensor representing class probabilities
|
99 |
probabilities = torch.nn.functional.softmax(output[0], dim=0).numpy()
|
100 |
|
|
|
|
|
101 |
# Get the class with the highest probability
|
102 |
predicted_class = np.argmax(probabilities)
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
# Return the class label
|
105 |
-
return
|
106 |
|
107 |
# create a gradio interface
|
108 |
gr.Interface(fn=predict_image, inputs="image", outputs="text").launch()
|
|
|
6 |
import torchvision.models as models
|
7 |
from torch import nn
|
8 |
|
9 |
+
from model import *
|
10 |
+
|
11 |
from albumentations import (
|
12 |
HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
|
13 |
Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
|
|
|
100 |
# Assuming the output is a tensor representing class probabilities
|
101 |
probabilities = torch.nn.functional.softmax(output[0], dim=0).numpy()
|
102 |
|
103 |
+
conf_threshold = 0.8
|
104 |
+
|
105 |
# Get the class with the highest probability
|
106 |
predicted_class = np.argmax(probabilities)
|
107 |
|
108 |
+
# If the probability is less than the threshold, return unknown else return the class
|
109 |
+
if probabilities[predicted_class] < conf_threshold:
|
110 |
+
string_to_return = "Predicted Class: Unknown"
|
111 |
+
|
112 |
+
else:
|
113 |
+
# Return the class label
|
114 |
+
string_to_return = f"Predicted Class: {classes[predicted_class+1]} with probability: {probabilities[predicted_class] * 100:.2f}%"
|
115 |
+
|
116 |
# Return the class label
|
117 |
+
return string_to_return
|
118 |
|
119 |
# create a gradio interface
|
120 |
gr.Interface(fn=predict_image, inputs="image", outputs="text").launch()
|