Thomas J. Trebat commited on
Commit
8713bc4
1 Parent(s): 92e317c

Made labels global

Browse files
Files changed (1) hide show
  1. app.py +14 -12
app.py CHANGED
@@ -8,17 +8,14 @@ from timm.data.transforms_factory import create_transform
8
 
9
 
10
  class ImageClassifier(object):
11
- def __init__(self, model_name):
12
- self.model = timm.create_model(
13
- model_name,
14
- pretrained=True
15
- ).eval()
16
 
17
  def get_top_5_predictions(self, image):
18
  values, indices = torch.topk(self.get_output_probabilities(image), 5)
19
- labels = self.get_labels()
20
  return [
21
- {'label': labels[i], 'score': v.item()}
22
  for i, v in zip(indices, values)
23
  ]
24
 
@@ -27,6 +24,7 @@ class ImageClassifier(object):
27
  return torch.nn.functional.softmax(output[0], dim=0)
28
 
29
  def classify_image(self, image):
 
30
  transform = self.create_image_transform()
31
  return self.model(transform(image).unsqueeze(0))
32
 
@@ -34,13 +32,11 @@ class ImageClassifier(object):
34
  return create_transform(**resolve_data_config(
35
  self.model.pretrained_cfg, model=self.model))
36
 
37
- def get_labels(self):
38
- return self.model.pretrained_cfg['label_names']
39
 
40
  class ImageClassificationApp(object):
41
- def __init__(self, title, model_name):
42
  self.title = title
43
- self.classifier = ImageClassifier(model_name)
44
 
45
  def render(self):
46
  st.title(self.title)
@@ -71,7 +67,13 @@ class ImageClassificationApp(object):
71
 
72
 
73
  if __name__ == '__main__':
 
 
 
 
 
 
74
  ImageClassificationApp(
75
  'Pet Image Classification App',
76
- 'hf-hub:nateraw/resnet50-oxford-iiit-pet'
77
  ).render()
 
8
 
9
 
10
  class ImageClassifier(object):
11
+ def __init__(self, model, labels):
12
+ self.model = model
13
+ self.labels = labels
 
 
14
 
15
  def get_top_5_predictions(self, image):
16
  values, indices = torch.topk(self.get_output_probabilities(image), 5)
 
17
  return [
18
+ {'label': self.labels[i], 'score': v.item()}
19
  for i, v in zip(indices, values)
20
  ]
21
 
 
24
  return torch.nn.functional.softmax(output[0], dim=0)
25
 
26
  def classify_image(self, image):
27
+ self.model.eval()
28
  transform = self.create_image_transform()
29
  return self.model(transform(image).unsqueeze(0))
30
 
 
32
  return create_transform(**resolve_data_config(
33
  self.model.pretrained_cfg, model=self.model))
34
 
 
 
35
 
36
  class ImageClassificationApp(object):
37
+ def __init__(self, title, classifier):
38
  self.title = title
39
+ self.classifier = classifier
40
 
41
  def render(self):
42
  st.title(self.title)
 
67
 
68
 
69
  if __name__ == '__main__':
70
+ model = timm.create_model(
71
+ 'hf-hub:nateraw/resnet50-oxford-iiit-pet',
72
+ pretrained=True
73
+ )
74
+ labels = model.pretrained_cfg['label_names']
75
+ classifier = ImageClassifier(model, labels)
76
  ImageClassificationApp(
77
  'Pet Image Classification App',
78
+ classifier
79
  ).render()