ramirjf commited on
Commit
e771aac
1 Parent(s): 9d02e76
Files changed (2) hide show
  1. .gitignore +3 -0
  2. app.py +8 -2
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ /.DS_Store
2
+ /.venv
3
+ /__pycache__
app.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  import torch
5
 
6
  from model import createVITModel
 
7
  from timeit import default_timer as timer
8
  from typing import Tuple, Dict
9
 
@@ -11,10 +12,15 @@ from typing import Tuple, Dict
11
  with open("classes.txt", "r") as f: # reading them in from class_names.txt
12
  class_names = [food_name.strip() for food_name in f.readlines()]
13
 
14
- model, vit_transform = createVITModel()
15
  model.load_state_dict(torch.load('VIT_32_20_003.pth'))
16
  model = model.to('cpu')
17
 
 
 
 
 
 
18
  def predict(img) -> Tuple[Dict, float]:
19
  """Transforms and performs a prediction on img and returns prediction and time taken.
20
  """
@@ -56,4 +62,4 @@ demo = gr.Interface(fn=predict, # mapping function from input to output
56
  article=article)
57
 
58
  # Launch the demo!
59
- demo.launch(debug=False) # generate a publically shareable URL?
 
4
  import torch
5
 
6
  from model import createVITModel
7
+ from pathlib import Path
8
  from timeit import default_timer as timer
9
  from typing import Tuple, Dict
10
 
 
12
  with open("classes.txt", "r") as f: # reading them in from class_names.txt
13
  class_names = [food_name.strip() for food_name in f.readlines()]
14
 
15
+ model, vit_transform = createVITModel(out_features=len(class_names))
16
  model.load_state_dict(torch.load('VIT_32_20_003.pth'))
17
  model = model.to('cpu')
18
 
19
+ # Create a list of example inputs to our Gradio demo
20
+ examples_source_dir = Path("examples")
21
+ examples_source_paths = list(examples_source_dir.glob("*.jpg"))
22
+ example_list = [str(filepath) for filepath in examples_source_paths]
23
+
24
  def predict(img) -> Tuple[Dict, float]:
25
  """Transforms and performs a prediction on img and returns prediction and time taken.
26
  """
 
62
  article=article)
63
 
64
  # Launch the demo!
65
+ demo.launch(debug=False) # generate a publically shareable URL?