shlomoc commited on
Commit
fca455f
β€’
1 Parent(s): 44fc24d

update description

Browse files
Files changed (1) hide show
  1. app.py +13 -8
app.py CHANGED
@@ -14,7 +14,7 @@ class_names = ["pizza", "steak", "sushi"]
14
 
15
  # Create VIT model
16
  vit, vit_transforms = create_vit_model(
17
- num_classes=3, # len(class_names) would also work
18
  )
19
 
20
  # Load saved weights
@@ -28,6 +28,8 @@ vit.load_state_dict(
28
  ### 3. Predict function ###
29
 
30
  # Create predict function
 
 
31
  def predict(img) -> Tuple[Dict, float]:
32
  """Transforms and performs a prediction on img and returns prediction and time taken.
33
  """
@@ -44,7 +46,8 @@ def predict(img) -> Tuple[Dict, float]:
44
  pred_probs = torch.softmax(vit(img), dim=1)
45
 
46
  # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
47
- pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
 
48
 
49
  # Calculate the prediction time
50
  pred_time = round(timer() - start_time, 5)
@@ -54,19 +57,21 @@ def predict(img) -> Tuple[Dict, float]:
54
 
55
  ### 4. Gradio app ###
56
 
 
57
  # Create title, description and article strings
58
  title = "FoodVision Mini πŸ•πŸ₯©πŸ£"
59
- description = "An vit feature extractor computer vision model to classify images of food as pizza, steak or sushi."
60
- article = "Created at [09. PyTorch Model Deployment](https://www.learnpytorch.io/09_pytorch_model_deployment/)."
61
 
62
  # Create examples list from "examples/" directory
63
  example_list = [["examples/" + example] for example in os.listdir("examples")]
64
 
65
  # Create the Gradio demo
66
- demo = gr.Interface(fn=predict, # mapping function from input to output
67
- inputs=gr.Image(type="pil"), # what are the inputs?
68
- outputs=[gr.Label(num_top_classes=3, label="Predictions"), # what are the outputs?
69
- gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
 
70
  # Create examples list from "examples/" directory
71
  examples=example_list,
72
  title=title,
 
14
 
15
  # Create VIT model
16
  vit, vit_transforms = create_vit_model(
17
+ num_classes=3, # len(class_names) would also work
18
  )
19
 
20
  # Load saved weights
 
28
  ### 3. Predict function ###
29
 
30
  # Create predict function
31
+
32
+
33
  def predict(img) -> Tuple[Dict, float]:
34
  """Transforms and performs a prediction on img and returns prediction and time taken.
35
  """
 
46
  pred_probs = torch.softmax(vit(img), dim=1)
47
 
48
  # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
49
+ pred_labels_and_probs = {class_names[i]: float(
50
+ pred_probs[0][i]) for i in range(len(class_names))}
51
 
52
  # Calculate the prediction time
53
  pred_time = round(timer() - start_time, 5)
 
57
 
58
  ### 4. Gradio app ###
59
 
60
+
61
  # Create title, description and article strings
62
  title = "FoodVision Mini πŸ•πŸ₯©πŸ£"
63
+ description = "A ViT-B/16 feature extractor (ViT for short) computer vision model to classify images of food as pizza, steak or sushi."
64
+ article = "Created based on [09. PyTorch Model Deployment](https://www.learnpytorch.io/09_pytorch_model_deployment/)."
65
 
66
  # Create examples list from "examples/" directory
67
  example_list = [["examples/" + example] for example in os.listdir("examples")]
68
 
69
  # Create the Gradio demo
70
+ demo = gr.Interface(fn=predict, # mapping function from input to output
71
+ inputs=gr.Image(type="pil"), # what are the inputs?
72
+ outputs=[gr.Label(num_top_classes=3, label="Predictions"), # what are the outputs?
73
+ # our fn has two outputs, therefore we have two outputs
74
+ gr.Number(label="Prediction time (s)")],
75
  # Create examples list from "examples/" directory
76
  examples=example_list,
77
  title=title,