mrdbourke commited on
Commit
fa9a065
β€’
1 Parent(s): 802aea5

update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -16
app.py CHANGED
@@ -1,4 +1,4 @@
1
- ### 1. Imports and class names setup ###
2
  import gradio as gr
3
  import os
4
  import torch
@@ -8,18 +8,18 @@ from timeit import default_timer as timer
8
  from typing import Tuple, Dict
9
 
10
  # Setup class names
11
- with open("class_names.txt", "r") as f: # reading them in from class_names.txt
12
- class_names = [food_name.strip() for food_name in f.readlines()]
13
-
14
- ### 2. Model and transforms preparation ###
15
 
16
  # Create model
17
- model, transforms = create_effnetb2_model(
18
- num_classes=101, # could also use len(class_names)
19
  )
20
 
21
  # Load saved weights
22
- model.load_state_dict(
23
  torch.load(
24
  f="09_pretrained_effnetb2_feature_extractor_food101_20_percent.pth",
25
  map_location=torch.device("cpu"), # load to CPU
@@ -34,25 +34,28 @@ def predict(img) -> Tuple[Dict, float]:
34
  """
35
  # Start the timer
36
  start_time = timer()
37
-
38
  # Transform the target image and add a batch dimension
39
  img = effnetb2_transforms(img).unsqueeze(0)
40
-
41
  # Put model into evaluation mode and turn on inference mode
42
  effnetb2.eval()
43
  with torch.inference_mode():
44
  # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
45
  pred_probs = torch.softmax(effnetb2(img), dim=1)
46
-
47
  # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
48
- pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
49
-
 
 
50
  # Calculate the prediction time
51
  pred_time = round(timer() - start_time, 5)
52
-
53
- # Return the prediction dictionary and prediction time
54
  return pred_labels_and_probs, pred_time
55
 
 
56
  ### 4. Gradio app ###
57
 
58
  # Create title, description and article strings
@@ -63,7 +66,7 @@ article = "Created at [09. PyTorch Model Deployment](https://www.learnpytorch.io
63
  # Create examples list from "examples/" directory
64
  example_list = [["examples/" + example] for example in os.listdir("examples")]
65
 
66
- # Create Gradio interface
67
  demo = gr.Interface(
68
  fn=predict,
69
  inputs=gr.Image(type="pil"),
 
1
+ ### 1. Imports and class names setup ###
2
  import gradio as gr
3
  import os
4
  import torch
 
8
  from typing import Tuple, Dict
9
 
10
  # Setup class names
11
+ with open("class_names.txt", "r") as f: # reading them in from class_names.txt
12
+ class_names = [food_name.strip() for food_name in f.readlines()]
13
+
14
+ ### 2. Model and transforms preparation ###
15
 
16
  # Create model
17
+ effnetb2, effnetb2_transforms = create_effnetb2_model(
18
+ num_classes=101, # could also use len(class_names)
19
  )
20
 
21
  # Load saved weights
22
+ effnetb2.load_state_dict(
23
  torch.load(
24
  f="09_pretrained_effnetb2_feature_extractor_food101_20_percent.pth",
25
  map_location=torch.device("cpu"), # load to CPU
 
34
  """
35
  # Start the timer
36
  start_time = timer()
37
+
38
  # Transform the target image and add a batch dimension
39
  img = effnetb2_transforms(img).unsqueeze(0)
40
+
41
  # Put model into evaluation mode and turn on inference mode
42
  effnetb2.eval()
43
  with torch.inference_mode():
44
  # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
45
  pred_probs = torch.softmax(effnetb2(img), dim=1)
46
+
47
  # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
48
+ pred_labels_and_probs = {
49
+ class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))
50
+ }
51
+
52
  # Calculate the prediction time
53
  pred_time = round(timer() - start_time, 5)
54
+
55
+ # Return the prediction dictionary and prediction time
56
  return pred_labels_and_probs, pred_time
57
 
58
+
59
  ### 4. Gradio app ###
60
 
61
  # Create title, description and article strings
 
66
  # Create examples list from "examples/" directory
67
  example_list = [["examples/" + example] for example in os.listdir("examples")]
68
 
69
+ # Create Gradio interface
70
  demo = gr.Interface(
71
  fn=predict,
72
  inputs=gr.Image(type="pil"),