sergio-sanz-rodriguez commited on
Commit
e82f64b
Β·
1 Parent(s): 3ff958a

new version of the app with lightweight models

Browse files
Files changed (2) hide show
  1. app.py +55 -24
  2. model.py +38 -0
app.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  import json
6
  import string
7
  import gradio as gr
8
- from model import create_vitbase_model, create_effnetb0
9
  from timeit import default_timer as timer
10
  from typing import Tuple, Dict
11
  from torchvision.transforms import v2
@@ -56,10 +56,27 @@ vitbase_model_101 = create_vitbase_model(
56
  compile=True
57
  )
58
 
59
- vitbase_model_102 = create_vitbase_model(
 
 
 
 
 
 
 
 
 
60
  model_weights_dir=".",
61
- model_weights_name="vitbase16_102_2025-01-27_epoch19.pth",
62
- image_size=384,
 
 
 
 
 
 
 
 
63
  num_classes=num_classes_102,
64
  compile=True
65
  )
@@ -84,12 +101,23 @@ transforms_vit = v2.Compose([
84
  std=[0.229, 0.224, 0.225])
85
  ])
86
 
 
 
 
 
 
 
 
 
 
87
 
88
  # Put models into evaluation mode and turn on inference mode
89
  effnetb0_model_1.eval()
90
  effnetb0_model_2.eval()
91
  vitbase_model_101.eval()
92
- vitbase_model_102.eval()
 
 
93
 
94
  # Set thresdholds
95
  BINARY_CLASSIF_THR_1 = 0.8310611844062805
@@ -99,8 +127,8 @@ MULTICLASS_CLASSIF_THR = 0.5
99
  ENTROPY_THR = 2.7
100
 
101
  # Set model names
102
- lite_model = "⚑ ViT Lite ⚑ faster, less accurate prediction"
103
- pro_model = "πŸ’Ž ViT Pro πŸ’Ž slower, more accurate prediction"
104
 
105
  # Computes the entropy
106
  def entropy(pred_probs):
@@ -172,18 +200,18 @@ def classify_food(image, model=pro_model) -> Tuple[Dict, str, str]:
172
  # If the picture is food
173
  if predict(image_eff, effnetb0_model_1)[:,1] >= BINARY_CLASSIF_THR_1:
174
 
175
- # πŸ’Ž ViT Pro πŸ’Ž
176
  if model == pro_model:
177
 
178
  # If the image is likely to be an known category
179
  if predict(image_eff, effnetb0_model_2)[:,1] >= BINARY_CLASSIF_THR_2:
180
 
181
  # Preproces the image for the ViTs
182
- image_vit = transforms_vit(image).unsqueeze(0)
183
 
184
  # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
185
- pred_probs_102 = predict(image_vit, vitbase_model_102)
186
- pred_probs_101 = predict(image_vit, vitbase_model_101)
187
 
188
  # Calculate entropy
189
  entropy_102 = entropy(pred_probs_102)
@@ -239,7 +267,7 @@ def classify_food(image, model=pro_model) -> Tuple[Dict, str, str]:
239
  # Get the top predicted class
240
  top_class = "unknown"
241
 
242
- # ⚑ ViT Lite ⚑
243
  else:
244
 
245
  # Preproces the image for the ViTs
@@ -301,7 +329,7 @@ def classify_food(image, model=pro_model) -> Tuple[Dict, str, str]:
301
  # Create title, description, and examples
302
  title = "Transform-Eats Large<br>πŸ₯ͺπŸ₯—πŸ₯£πŸ₯©πŸπŸ£πŸ°"
303
  description = f"""
304
- A cutting-edge Vision Transformer (ViT) model to classify 101 delicious food types. Discover the power of AI in culinary recognition.
305
  """
306
 
307
  # Group food items alphabetically
@@ -351,12 +379,14 @@ with gr.Blocks(theme="ocean") as demo:
351
  mirror_webcam=False
352
  )
353
 
354
- model_radio = gr.Radio(
355
- choices=[lite_model, pro_model],
356
- value=pro_model,
357
- label="Select the AI algorithm:",
358
- info="ViT Pro is selected by default if none is chosen."
359
- )
 
 
360
 
361
  # Define the status message output field to display error messages
362
  status_output = gr.HTML(label="Status:")
@@ -370,15 +400,16 @@ with gr.Blocks(theme="ocean") as demo:
370
  # Create the Gradio demo
371
  gr.Interface(
372
  fn=classify_food, # mapping function from input to outputs
373
- inputs=[upload_input, model_radio], # inputs
374
  outputs=[gr.Label(num_top_classes=3, label="Prediction"),
375
  gr.Textbox(label="Prediction time:"),
376
  gr.Textbox(label="Food Description:"),
377
  status_output], # outputs
378
- article=article, # Created by...
379
- flagging_mode=flagging_mode, # Only For debugging
380
- flagging_options=["correct", "incorrect"], # Only For debugging
381
- )
 
382
 
383
  # Launch the demo!
384
  demo.launch()
 
5
  import json
6
  import string
7
  import gradio as gr
8
+ from model import create_vitbase_model, create_swin_tiny_model, create_effnetb0
9
  from timeit import default_timer as timer
10
  from typing import Tuple, Dict
11
  from torchvision.transforms import v2
 
56
  compile=True
57
  )
58
 
59
+ #vitbase_model_102 = create_vitbase_model(
60
+ # model_weights_dir=".",
61
+ # model_weights_name="vitbase16_102_2025-01-27_epoch19.pth",
62
+ # image_size=384,
63
+ # num_classes=num_classes_102,
64
+ # compile=True
65
+ #)
66
+
67
+ # Load the Swin-V2-Tiny transformer with input image of 384x384 pixels and 101 + unknown classes
68
+ swint_model_101 = create_swin_tiny_model(
69
  model_weights_dir=".",
70
+ model_weights_name="swinv2tiny_101_2025-02-05_epoch25.pth",
71
+ image_size=256,
72
+ num_classes=num_classes_101,
73
+ compile=True
74
+ )
75
+
76
+ swint_model_102 = create_swin_tiny_model(
77
+ model_weights_dir=".",
78
+ model_weights_name="swinv2tiny_102_2025-02-08_acc_epoch28.pth",
79
+ image_size=256,
80
  num_classes=num_classes_102,
81
  compile=True
82
  )
 
101
  std=[0.229, 0.224, 0.225])
102
  ])
103
 
104
+ # Specify manual transforms for Swins
105
+ transforms_swint = v2.Compose([
106
+ v2.Resize(260),
107
+ v2.CenterCrop((256, 256)),
108
+ v2.ToImage(),
109
+ v2.ToDtype(torch.float32, scale=True),
110
+ v2.Normalize(mean=[0.485, 0.456, 0.406],
111
+ std=[0.229, 0.224, 0.225])
112
+ ])
113
 
114
  # Put models into evaluation mode and turn on inference mode
115
  effnetb0_model_1.eval()
116
  effnetb0_model_2.eval()
117
  vitbase_model_101.eval()
118
+ #vitbase_model_102.eval()
119
+ swint_model_101.eval()
120
+ swint_model_102.eval()
121
 
122
  # Set thresdholds
123
  BINARY_CLASSIF_THR_1 = 0.8310611844062805
 
127
  ENTROPY_THR = 2.7
128
 
129
  # Set model names
130
+ lite_model = "⚑ Lite ⚑ less accurate prediction"
131
+ pro_model = "πŸ’Ž Pro πŸ’Ž more accurate prediction"
132
 
133
  # Computes the entropy
134
  def entropy(pred_probs):
 
200
  # If the picture is food
201
  if predict(image_eff, effnetb0_model_1)[:,1] >= BINARY_CLASSIF_THR_1:
202
 
203
+ # πŸ’Ž Pro πŸ’Ž
204
  if model == pro_model:
205
 
206
  # If the image is likely to be an known category
207
  if predict(image_eff, effnetb0_model_2)[:,1] >= BINARY_CLASSIF_THR_2:
208
 
209
  # Preproces the image for the ViTs
210
+ image_swint = transforms_swint(image).unsqueeze(0)
211
 
212
  # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
213
+ pred_probs_102 = predict(image_swint, swint_model_102)
214
+ pred_probs_101 = predict(image_swint, swint_model_101)
215
 
216
  # Calculate entropy
217
  entropy_102 = entropy(pred_probs_102)
 
267
  # Get the top predicted class
268
  top_class = "unknown"
269
 
270
+ # ⚑ Lite ⚑
271
  else:
272
 
273
  # Preproces the image for the ViTs
 
329
  # Create title, description, and examples
330
  title = "Transform-Eats Large<br>πŸ₯ͺπŸ₯—πŸ₯£πŸ₯©πŸπŸ£πŸ°"
331
  description = f"""
332
+ A cutting-edge, leightweight Transformer model to classify 101 delicious food types. Discover the power of AI in culinary recognition.
333
  """
334
 
335
  # Group food items alphabetically
 
379
  mirror_webcam=False
380
  )
381
 
382
+ #model_radio = gr.Radio(
383
+ # choices=[lite_model, pro_model],
384
+ # value=pro_model,
385
+ # label="Select the AI algorithm:",
386
+ # info="ViT Pro is selected by default if none is chosen."
387
+ #)
388
+
389
+ food_vision_examples = [["examples/" + example] for example in os.listdir("examples")]
390
 
391
  # Define the status message output field to display error messages
392
  status_output = gr.HTML(label="Status:")
 
400
  # Create the Gradio demo
401
  gr.Interface(
402
  fn=classify_food, # mapping function from input to outputs
403
+ inputs=[upload_input], # inputs
404
  outputs=[gr.Label(num_top_classes=3, label="Prediction"),
405
  gr.Textbox(label="Prediction time:"),
406
  gr.Textbox(label="Food Description:"),
407
  status_output], # outputs
408
+ examples=food_vision_examples, # Create examples list from "examples/" directory
409
+ article=article, # Created by...
410
+ flagging_mode=flagging_mode, # Only For debugging
411
+ flagging_options=["correct", "incorrect"], # Only For debugging
412
+ )
413
 
414
  # Launch the demo!
415
  demo.launch()
model.py CHANGED
@@ -77,6 +77,44 @@ def create_vitbase_model(
77
 
78
  return vitbase16_model
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  # Create an EfficientNet-B0 Model
81
  def create_effnetb0(
82
  model_weights_dir: Path,
 
77
 
78
  return vitbase16_model
79
 
80
+ def create_swin_tiny_model(
81
+ model_weights_dir:Path,
82
+ model_weights_name:str,
83
+ image_size:int=224,
84
+ num_classes:int=101,
85
+ compile:bool=False
86
+ ):
87
+
88
+ """
89
+ Creates a Swin-V2-Tiny model with the specified number of classes.
90
+
91
+ Args:
92
+ model_weights_dir: A directory where the model is located.
93
+ model_weights_name: The name of the model to load.
94
+ image_size: The size of the input image.
95
+ num_classes: The number of classes for the classification task.
96
+
97
+ Returns:
98
+ The created ViT-B/16 model.
99
+ """
100
+
101
+ # Instantiate the model
102
+ swint_model = torchvision.models.swin_v2_t().to("cpu")
103
+ swint_model.head = torch.nn.Linear(in_features=768, out_features=num_classes).to("cpu")
104
+
105
+ # Compile the model
106
+ if compile:
107
+ swint_model = torch.compile(swint_model, backend="aot_eager")
108
+
109
+ # Load the trained weights
110
+ swint_model = load_model(
111
+ model=swint_model,
112
+ model_weights_dir=model_weights_dir,
113
+ model_weights_name=model_weights_name
114
+ )
115
+
116
+ return swint_model
117
+
118
  # Create an EfficientNet-B0 Model
119
  def create_effnetb0(
120
  model_weights_dir: Path,