yourusername commited on
Commit
d3a3d62
1 Parent(s): 0d5c2ba

:beers: cheers

Browse files
Files changed (2) hide show
  1. app.py +76 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+
3
+ import matplotlib.pyplot as plt
4
+ import requests
5
+ import streamlit as st
6
+ import torch
7
+ from PIL import Image
8
+ from transformers import DetrFeatureExtractor, DetrForObjectDetection
9
+
10
+ # colors for visualization
11
+ COLORS = [
12
+ [0.000, 0.447, 0.741],
13
+ [0.850, 0.325, 0.098],
14
+ [0.929, 0.694, 0.125],
15
+ [0.494, 0.184, 0.556],
16
+ [0.466, 0.674, 0.188],
17
+ [0.301, 0.745, 0.933],
18
+ ]
19
+
20
+ @st.cache(allow_output_mutation=True)
21
+ def get_hf_components(model_name_or_path):
22
+ feature_extractor = DetrFeatureExtractor.from_pretrained(model_name_or_path)
23
+ model = DetrForObjectDetection.from_pretrained(model_name_or_path)
24
+ model.eval()
25
+ return feature_extractor, model
26
+
27
+ @st.cache
28
+ def get_img_from_url(url):
29
+ return Image.open(requests.get(url, stream=True).raw)
30
+
31
+ def fig2img(fig):
32
+ buf = io.BytesIO()
33
+ fig.savefig(buf)
34
+ buf.seek(0)
35
+ img = Image.open(buf)
36
+ return img
37
+
38
+
39
+ def visualize_prediction(pil_img, output_dict, threshold=0.7, id2label=None):
40
+ keep = output_dict["scores"] > threshold
41
+ boxes = output_dict["boxes"][keep].tolist()
42
+ scores = output_dict["scores"][keep].tolist()
43
+ labels = output_dict["labels"][keep].tolist()
44
+ if id2label is not None:
45
+ labels = [id2label[x] for x in labels]
46
+
47
+ plt.figure(figsize=(16, 10))
48
+ plt.imshow(pil_img)
49
+ ax = plt.gca()
50
+ colors = COLORS * 100
51
+ for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors):
52
+ ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=3))
53
+ ax.text(xmin, ymin, f"{label}: {score:0.2f}", fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
54
+ plt.axis("off")
55
+ return fig2img(plt.gcf())
56
+
57
+ def make_prediction(img, feature_extractor, model):
58
+ inputs = feature_extractor(img, return_tensors="pt")
59
+ outputs = model(**inputs)
60
+ img_size = torch.tensor([tuple(reversed(img.size))])
61
+ processed_outputs = feature_extractor.post_process(outputs, img_size)
62
+ return processed_outputs[0]
63
+
64
+ def main():
65
+ option = st.selectbox("Which model should we use?", ("facebook/detr-resnet-50", "facebook/detr-resnet-101"))
66
+ feature_extractor, model = get_hf_components(option)
67
+ url = st.text_input("URL to some image", "http://images.cocodataset.org/val2017/000000039769.jpg")
68
+ img = get_img_from_url(url)
69
+ processed_outputs = make_prediction(img, feature_extractor, model)
70
+ threshold = st.slider("Prediction Threshold", 0.0, 1.0, 0.7)
71
+ viz_img = visualize_prediction(img, processed_outputs, threshold, model.config.id2label)
72
+ st.image(viz_img)
73
+
74
+
75
+ if __name__ == '__main__':
76
+ main()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
1
+ streamlit
2
+ https://download.pytorch.org/whl/cpu/torch-1.8.1%2Bcpu-cp38-cp38-linux_x86_64.whl
3
+ git+https://github.com/huggingface/transformers.git
4
+ Pillow
5
+ matplotlib
6
+ timm