Commit
Β·
e82f64b
1
Parent(s):
3ff958a
new version of the app with lightweight models
Browse files
app.py
CHANGED
|
@@ -5,7 +5,7 @@ import torch
|
|
| 5 |
import json
|
| 6 |
import string
|
| 7 |
import gradio as gr
|
| 8 |
-
from model import create_vitbase_model, create_effnetb0
|
| 9 |
from timeit import default_timer as timer
|
| 10 |
from typing import Tuple, Dict
|
| 11 |
from torchvision.transforms import v2
|
|
@@ -56,10 +56,27 @@ vitbase_model_101 = create_vitbase_model(
|
|
| 56 |
compile=True
|
| 57 |
)
|
| 58 |
|
| 59 |
-
vitbase_model_102 = create_vitbase_model(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
model_weights_dir=".",
|
| 61 |
-
model_weights_name="
|
| 62 |
-
image_size=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
num_classes=num_classes_102,
|
| 64 |
compile=True
|
| 65 |
)
|
|
@@ -84,12 +101,23 @@ transforms_vit = v2.Compose([
|
|
| 84 |
std=[0.229, 0.224, 0.225])
|
| 85 |
])
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
# Put models into evaluation mode and turn on inference mode
|
| 89 |
effnetb0_model_1.eval()
|
| 90 |
effnetb0_model_2.eval()
|
| 91 |
vitbase_model_101.eval()
|
| 92 |
-
vitbase_model_102.eval()
|
|
|
|
|
|
|
| 93 |
|
| 94 |
# Set thresdholds
|
| 95 |
BINARY_CLASSIF_THR_1 = 0.8310611844062805
|
|
@@ -99,8 +127,8 @@ MULTICLASS_CLASSIF_THR = 0.5
|
|
| 99 |
ENTROPY_THR = 2.7
|
| 100 |
|
| 101 |
# Set model names
|
| 102 |
-
lite_model = "β‘
|
| 103 |
-
pro_model = "π
|
| 104 |
|
| 105 |
# Computes the entropy
|
| 106 |
def entropy(pred_probs):
|
|
@@ -172,18 +200,18 @@ def classify_food(image, model=pro_model) -> Tuple[Dict, str, str]:
|
|
| 172 |
# If the picture is food
|
| 173 |
if predict(image_eff, effnetb0_model_1)[:,1] >= BINARY_CLASSIF_THR_1:
|
| 174 |
|
| 175 |
-
# π
|
| 176 |
if model == pro_model:
|
| 177 |
|
| 178 |
# If the image is likely to be an known category
|
| 179 |
if predict(image_eff, effnetb0_model_2)[:,1] >= BINARY_CLASSIF_THR_2:
|
| 180 |
|
| 181 |
# Preproces the image for the ViTs
|
| 182 |
-
|
| 183 |
|
| 184 |
# Pass the transformed image through the model and turn the prediction logits into prediction probabilities
|
| 185 |
-
pred_probs_102 = predict(
|
| 186 |
-
pred_probs_101 = predict(
|
| 187 |
|
| 188 |
# Calculate entropy
|
| 189 |
entropy_102 = entropy(pred_probs_102)
|
|
@@ -239,7 +267,7 @@ def classify_food(image, model=pro_model) -> Tuple[Dict, str, str]:
|
|
| 239 |
# Get the top predicted class
|
| 240 |
top_class = "unknown"
|
| 241 |
|
| 242 |
-
# β‘
|
| 243 |
else:
|
| 244 |
|
| 245 |
# Preproces the image for the ViTs
|
|
@@ -301,7 +329,7 @@ def classify_food(image, model=pro_model) -> Tuple[Dict, str, str]:
|
|
| 301 |
# Create title, description, and examples
|
| 302 |
title = "Transform-Eats Large<br>π₯ͺπ₯π₯£π₯©ππ£π°"
|
| 303 |
description = f"""
|
| 304 |
-
A cutting-edge
|
| 305 |
"""
|
| 306 |
|
| 307 |
# Group food items alphabetically
|
|
@@ -351,12 +379,14 @@ with gr.Blocks(theme="ocean") as demo:
|
|
| 351 |
mirror_webcam=False
|
| 352 |
)
|
| 353 |
|
| 354 |
-
model_radio = gr.Radio(
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
)
|
|
|
|
|
|
|
| 360 |
|
| 361 |
# Define the status message output field to display error messages
|
| 362 |
status_output = gr.HTML(label="Status:")
|
|
@@ -370,15 +400,16 @@ with gr.Blocks(theme="ocean") as demo:
|
|
| 370 |
# Create the Gradio demo
|
| 371 |
gr.Interface(
|
| 372 |
fn=classify_food, # mapping function from input to outputs
|
| 373 |
-
inputs=[upload_input
|
| 374 |
outputs=[gr.Label(num_top_classes=3, label="Prediction"),
|
| 375 |
gr.Textbox(label="Prediction time:"),
|
| 376 |
gr.Textbox(label="Food Description:"),
|
| 377 |
status_output], # outputs
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
|
|
|
| 382 |
|
| 383 |
# Launch the demo!
|
| 384 |
demo.launch()
|
|
|
|
| 5 |
import json
|
| 6 |
import string
|
| 7 |
import gradio as gr
|
| 8 |
+
from model import create_vitbase_model, create_swin_tiny_model, create_effnetb0
|
| 9 |
from timeit import default_timer as timer
|
| 10 |
from typing import Tuple, Dict
|
| 11 |
from torchvision.transforms import v2
|
|
|
|
| 56 |
compile=True
|
| 57 |
)
|
| 58 |
|
| 59 |
+
#vitbase_model_102 = create_vitbase_model(
|
| 60 |
+
# model_weights_dir=".",
|
| 61 |
+
# model_weights_name="vitbase16_102_2025-01-27_epoch19.pth",
|
| 62 |
+
# image_size=384,
|
| 63 |
+
# num_classes=num_classes_102,
|
| 64 |
+
# compile=True
|
| 65 |
+
#)
|
| 66 |
+
|
| 67 |
+
# Load the Swin-V2-Tiny transformer with input image of 384x384 pixels and 101 + unknown classes
|
| 68 |
+
swint_model_101 = create_swin_tiny_model(
|
| 69 |
model_weights_dir=".",
|
| 70 |
+
model_weights_name="swinv2tiny_101_2025-02-05_epoch25.pth",
|
| 71 |
+
image_size=256,
|
| 72 |
+
num_classes=num_classes_101,
|
| 73 |
+
compile=True
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
swint_model_102 = create_swin_tiny_model(
|
| 77 |
+
model_weights_dir=".",
|
| 78 |
+
model_weights_name="swinv2tiny_102_2025-02-08_acc_epoch28.pth",
|
| 79 |
+
image_size=256,
|
| 80 |
num_classes=num_classes_102,
|
| 81 |
compile=True
|
| 82 |
)
|
|
|
|
| 101 |
std=[0.229, 0.224, 0.225])
|
| 102 |
])
|
| 103 |
|
| 104 |
+
# Specify manual transforms for Swins
|
| 105 |
+
transforms_swint = v2.Compose([
|
| 106 |
+
v2.Resize(260),
|
| 107 |
+
v2.CenterCrop((256, 256)),
|
| 108 |
+
v2.ToImage(),
|
| 109 |
+
v2.ToDtype(torch.float32, scale=True),
|
| 110 |
+
v2.Normalize(mean=[0.485, 0.456, 0.406],
|
| 111 |
+
std=[0.229, 0.224, 0.225])
|
| 112 |
+
])
|
| 113 |
|
| 114 |
# Put models into evaluation mode and turn on inference mode
|
| 115 |
effnetb0_model_1.eval()
|
| 116 |
effnetb0_model_2.eval()
|
| 117 |
vitbase_model_101.eval()
|
| 118 |
+
#vitbase_model_102.eval()
|
| 119 |
+
swint_model_101.eval()
|
| 120 |
+
swint_model_102.eval()
|
| 121 |
|
| 122 |
# Set thresdholds
|
| 123 |
BINARY_CLASSIF_THR_1 = 0.8310611844062805
|
|
|
|
| 127 |
ENTROPY_THR = 2.7
|
| 128 |
|
| 129 |
# Set model names
|
| 130 |
+
lite_model = "β‘ Lite β‘ less accurate prediction"
|
| 131 |
+
pro_model = "π Pro π more accurate prediction"
|
| 132 |
|
| 133 |
# Computes the entropy
|
| 134 |
def entropy(pred_probs):
|
|
|
|
| 200 |
# If the picture is food
|
| 201 |
if predict(image_eff, effnetb0_model_1)[:,1] >= BINARY_CLASSIF_THR_1:
|
| 202 |
|
| 203 |
+
# π Pro π
|
| 204 |
if model == pro_model:
|
| 205 |
|
| 206 |
# If the image is likely to be an known category
|
| 207 |
if predict(image_eff, effnetb0_model_2)[:,1] >= BINARY_CLASSIF_THR_2:
|
| 208 |
|
| 209 |
# Preproces the image for the ViTs
|
| 210 |
+
image_swint = transforms_swint(image).unsqueeze(0)
|
| 211 |
|
| 212 |
# Pass the transformed image through the model and turn the prediction logits into prediction probabilities
|
| 213 |
+
pred_probs_102 = predict(image_swint, swint_model_102)
|
| 214 |
+
pred_probs_101 = predict(image_swint, swint_model_101)
|
| 215 |
|
| 216 |
# Calculate entropy
|
| 217 |
entropy_102 = entropy(pred_probs_102)
|
|
|
|
| 267 |
# Get the top predicted class
|
| 268 |
top_class = "unknown"
|
| 269 |
|
| 270 |
+
# β‘ Lite β‘
|
| 271 |
else:
|
| 272 |
|
| 273 |
# Preproces the image for the ViTs
|
|
|
|
| 329 |
# Create title, description, and examples
|
| 330 |
title = "Transform-Eats Large<br>π₯ͺπ₯π₯£π₯©ππ£π°"
|
| 331 |
description = f"""
|
| 332 |
+
A cutting-edge, leightweight Transformer model to classify 101 delicious food types. Discover the power of AI in culinary recognition.
|
| 333 |
"""
|
| 334 |
|
| 335 |
# Group food items alphabetically
|
|
|
|
| 379 |
mirror_webcam=False
|
| 380 |
)
|
| 381 |
|
| 382 |
+
#model_radio = gr.Radio(
|
| 383 |
+
# choices=[lite_model, pro_model],
|
| 384 |
+
# value=pro_model,
|
| 385 |
+
# label="Select the AI algorithm:",
|
| 386 |
+
# info="ViT Pro is selected by default if none is chosen."
|
| 387 |
+
#)
|
| 388 |
+
|
| 389 |
+
food_vision_examples = [["examples/" + example] for example in os.listdir("examples")]
|
| 390 |
|
| 391 |
# Define the status message output field to display error messages
|
| 392 |
status_output = gr.HTML(label="Status:")
|
|
|
|
| 400 |
# Create the Gradio demo
|
| 401 |
gr.Interface(
|
| 402 |
fn=classify_food, # mapping function from input to outputs
|
| 403 |
+
inputs=[upload_input], # inputs
|
| 404 |
outputs=[gr.Label(num_top_classes=3, label="Prediction"),
|
| 405 |
gr.Textbox(label="Prediction time:"),
|
| 406 |
gr.Textbox(label="Food Description:"),
|
| 407 |
status_output], # outputs
|
| 408 |
+
examples=food_vision_examples, # Create examples list from "examples/" directory
|
| 409 |
+
article=article, # Created by...
|
| 410 |
+
flagging_mode=flagging_mode, # Only For debugging
|
| 411 |
+
flagging_options=["correct", "incorrect"], # Only For debugging
|
| 412 |
+
)
|
| 413 |
|
| 414 |
# Launch the demo!
|
| 415 |
demo.launch()
|
model.py
CHANGED
|
@@ -77,6 +77,44 @@ def create_vitbase_model(
|
|
| 77 |
|
| 78 |
return vitbase16_model
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
# Create an EfficientNet-B0 Model
|
| 81 |
def create_effnetb0(
|
| 82 |
model_weights_dir: Path,
|
|
|
|
| 77 |
|
| 78 |
return vitbase16_model
|
| 79 |
|
| 80 |
+
def create_swin_tiny_model(
|
| 81 |
+
model_weights_dir:Path,
|
| 82 |
+
model_weights_name:str,
|
| 83 |
+
image_size:int=224,
|
| 84 |
+
num_classes:int=101,
|
| 85 |
+
compile:bool=False
|
| 86 |
+
):
|
| 87 |
+
|
| 88 |
+
"""
|
| 89 |
+
Creates a Swin-V2-Tiny model with the specified number of classes.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
model_weights_dir: A directory where the model is located.
|
| 93 |
+
model_weights_name: The name of the model to load.
|
| 94 |
+
image_size: The size of the input image.
|
| 95 |
+
num_classes: The number of classes for the classification task.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
The created ViT-B/16 model.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
# Instantiate the model
|
| 102 |
+
swint_model = torchvision.models.swin_v2_t().to("cpu")
|
| 103 |
+
swint_model.head = torch.nn.Linear(in_features=768, out_features=num_classes).to("cpu")
|
| 104 |
+
|
| 105 |
+
# Compile the model
|
| 106 |
+
if compile:
|
| 107 |
+
swint_model = torch.compile(swint_model, backend="aot_eager")
|
| 108 |
+
|
| 109 |
+
# Load the trained weights
|
| 110 |
+
swint_model = load_model(
|
| 111 |
+
model=swint_model,
|
| 112 |
+
model_weights_dir=model_weights_dir,
|
| 113 |
+
model_weights_name=model_weights_name
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
return swint_model
|
| 117 |
+
|
| 118 |
# Create an EfficientNet-B0 Model
|
| 119 |
def create_effnetb0(
|
| 120 |
model_weights_dir: Path,
|