Jsmithwek nateraw commited on
Commit
ac1bcf0
0 Parent(s):

Duplicate from nateraw/detr-object-detection

Browse files

Co-authored-by: Nate Raw <nateraw@users.noreply.huggingface.co>

Files changed (4) hide show
  1. .gitattributes +16 -0
  2. README.md +34 -0
  3. app.py +81 -0
  4. requirements.txt +6 -0
.gitattributes ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
2
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.h5 filter=lfs diff=lfs merge=lfs -text
5
+ *.tflite filter=lfs diff=lfs merge=lfs -text
6
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.ot filter=lfs diff=lfs merge=lfs -text
8
+ *.onnx filter=lfs diff=lfs merge=lfs -text
9
+ *.arrow filter=lfs diff=lfs merge=lfs -text
10
+ *.ftz filter=lfs diff=lfs merge=lfs -text
11
+ *.joblib filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.pb filter=lfs diff=lfs merge=lfs -text
15
+ *.pt filter=lfs diff=lfs merge=lfs -text
16
+ *.pth filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Detr Object Detection
3
+ emoji: 🤯
4
+ colorFrom: pink
5
+ colorTo: green
6
+ sdk: streamlit
7
+ app_file: app.py
8
+ pinned: false
9
+ duplicated_from: nateraw/detr-object-detection
10
+ ---
11
+
12
+ # Configuration
13
+
14
+ `title`: _string_
15
+ Display title for the Space
16
+
17
+ `emoji`: _string_
18
+ Space emoji (emoji-only character allowed)
19
+
20
+ `colorFrom`: _string_
21
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
22
+
23
+ `colorTo`: _string_
24
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
25
+
26
+ `sdk`: _string_
27
+ Can be either `gradio` or `streamlit`
28
+
29
+ `app_file`: _string_
30
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code).
31
+ Path is relative to the root of the repository.
32
+
33
+ `pinned`: _boolean_
34
+ Whether the Space stays on top of your list.
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+
3
+ import matplotlib.pyplot as plt
4
+ import requests
5
+ import streamlit as st
6
+ import torch
7
+ from PIL import Image
8
+ from transformers import DetrFeatureExtractor, DetrForObjectDetection
9
+
10
+ # colors for visualization
11
+ COLORS = [
12
+ [0.000, 0.447, 0.741],
13
+ [0.850, 0.325, 0.098],
14
+ [0.929, 0.694, 0.125],
15
+ [0.494, 0.184, 0.556],
16
+ [0.466, 0.674, 0.188],
17
+ [0.301, 0.745, 0.933]
18
+ ]
19
+
20
+
21
+ @st.cache(allow_output_mutation=True)
22
+ def get_hf_components(model_name_or_path):
23
+ feature_extractor = DetrFeatureExtractor.from_pretrained(model_name_or_path)
24
+ model = DetrForObjectDetection.from_pretrained(model_name_or_path)
25
+ model.eval()
26
+ return feature_extractor, model
27
+
28
+
29
+ @st.cache
30
+ def get_img_from_url(url):
31
+ return Image.open(requests.get(url, stream=True).raw)
32
+
33
+
34
+ def fig2img(fig):
35
+ buf = io.BytesIO()
36
+ fig.savefig(buf)
37
+ buf.seek(0)
38
+ img = Image.open(buf)
39
+ return img
40
+
41
+
42
+ def visualize_prediction(pil_img, output_dict, threshold=0.7, id2label=None):
43
+ keep = output_dict["scores"] > threshold
44
+ boxes = output_dict["boxes"][keep].tolist()
45
+ scores = output_dict["scores"][keep].tolist()
46
+ labels = output_dict["labels"][keep].tolist()
47
+ if id2label is not None:
48
+ labels = [id2label[x] for x in labels]
49
+
50
+ plt.figure(figsize=(16, 10))
51
+ plt.imshow(pil_img)
52
+ ax = plt.gca()
53
+ colors = COLORS * 100
54
+ for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors):
55
+ ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=3))
56
+ ax.text(xmin, ymin, f"{label}: {score:0.2f}", fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
57
+ plt.axis("off")
58
+ return fig2img(plt.gcf())
59
+
60
+
61
+ def make_prediction(img, feature_extractor, model):
62
+ inputs = feature_extractor(img, return_tensors="pt")
63
+ outputs = model(**inputs)
64
+ img_size = torch.tensor([tuple(reversed(img.size))])
65
+ processed_outputs = feature_extractor.post_process(outputs, img_size)
66
+ return processed_outputs[0]
67
+
68
+
69
+ def main():
70
+ option = st.selectbox("Which model should we use?", ("facebook/detr-resnet-50", "facebook/detr-resnet-101"))
71
+ feature_extractor, model = get_hf_components(option)
72
+ url = st.text_input("URL to some image", "http://images.cocodataset.org/val2017/000000039769.jpg")
73
+ img = get_img_from_url(url)
74
+ processed_outputs = make_prediction(img, feature_extractor, model)
75
+ threshold = st.slider("Prediction Threshold", 0.0, 1.0, 0.7)
76
+ viz_img = visualize_prediction(img, processed_outputs, threshold, model.config.id2label)
77
+ st.image(viz_img)
78
+
79
+
80
+ if __name__ == "__main__":
81
+ main()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ streamlit
2
+ https://download.pytorch.org/whl/cpu/torch-1.8.1%2Bcpu-cp38-cp38-linux_x86_64.whl
3
+ git+https://github.com/huggingface/transformers.git
4
+ Pillow
5
+ matplotlib
6
+ timm