Thomas J. Trebat commited on
Commit
61ce9e4
1 Parent(s): 83af4db

first commit

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +74 -0
  3. requirements.txt +3 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import torch
3
+ import streamlit as st
4
+ from PIL import Image
5
+ from transformers import AutoImageProcessor, SwinForImageClassification
6
+
7
+
8
+ class ImageClassifier(object):
9
+ def __init__(self, model):
10
+ self.model = model
11
+
12
+ def get_top_5_predictions(self, image):
13
+ values, indices = torch.topk(self.get_output_probabilities(image), 5)
14
+ return [
15
+ {'label': self.model.config.id2label[i.item()], 'score': v.item()}
16
+ for i, v in zip(indices, values)
17
+ ]
18
+
19
+ def get_output_probabilities(self, image):
20
+ output = self.classify_image(image)
21
+ return torch.nn.functional.softmax(output.logits[0], dim=0)
22
+
23
+ def classify_image(self, image):
24
+ image_processor = self.create_image_processor()
25
+ inputs = image_processor(image, return_tensors='pt')
26
+ with torch.no_grad():
27
+ return self.model(**inputs)
28
+
29
+ def create_image_processor(self):
30
+ return AutoImageProcessor.from_pretrained(self.model.name_or_path)
31
+
32
+
33
+ class ImageClassificationApp(object):
34
+ def __init__(self, title, classifier):
35
+ self.title = title
36
+ self.classifier = classifier
37
+
38
+ def render(self):
39
+ st.title(self.title)
40
+ uploaded_image = self.get_uploaded_image()
41
+ if uploaded_image is not None:
42
+ self.show_image_and_results(uploaded_image)
43
+
44
+ def get_uploaded_image(self):
45
+ return st.file_uploader('Choose an image...', type=['jpg', 'png', 'jpeg'])
46
+
47
+ def show_image_and_results(self, uploaded_image):
48
+ self.show_uploaded_image(uploaded_image)
49
+ self.show_classification_results(self.get_image(uploaded_image.read()))
50
+
51
+ def show_uploaded_image(self, uploaded_image):
52
+ st.image(uploaded_image, caption='Uploaded Image', use_column_width=True)
53
+
54
+ def show_classification_results(self, image):
55
+ st.subheader('Classification Results:')
56
+ self.write_top_5_predictions(image)
57
+
58
+ def write_top_5_predictions(self, image):
59
+ for prediction in self.classifier.get_top_5_predictions(image):
60
+ st.write(f"- {prediction['label']}: {prediction['score']:.4f}")
61
+
62
+ def get_image(self, image_data):
63
+ return Image.open(io.BytesIO(image_data))
64
+
65
+
66
+ if __name__ == '__main__':
67
+ model = SwinForImageClassification.from_pretrained(
68
+ 'microsoft/swin-tiny-patch4-window7-224'
69
+ )
70
+ classifier = ImageClassifier(model)
71
+ ImageClassificationApp(
72
+ 'Swin Image Classification App',
73
+ classifier
74
+ ).render()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ streamlit
3
+ transformers