yourusername commited on
Commit
bfb8cb9
1 Parent(s): d3a3d62

:lipstick: style

Browse files
Files changed (1) hide show
  1. app.py +6 -1
app.py CHANGED
@@ -17,6 +17,7 @@ COLORS = [
17
  [0.301, 0.745, 0.933],
18
  ]
19
 
 
20
  @st.cache(allow_output_mutation=True)
21
  def get_hf_components(model_name_or_path):
22
  feature_extractor = DetrFeatureExtractor.from_pretrained(model_name_or_path)
@@ -24,10 +25,12 @@ def get_hf_components(model_name_or_path):
24
  model.eval()
25
  return feature_extractor, model
26
 
 
27
  @st.cache
28
  def get_img_from_url(url):
29
  return Image.open(requests.get(url, stream=True).raw)
30
 
 
31
  def fig2img(fig):
32
  buf = io.BytesIO()
33
  fig.savefig(buf)
@@ -54,6 +57,7 @@ def visualize_prediction(pil_img, output_dict, threshold=0.7, id2label=None):
54
  plt.axis("off")
55
  return fig2img(plt.gcf())
56
 
 
57
  def make_prediction(img, feature_extractor, model):
58
  inputs = feature_extractor(img, return_tensors="pt")
59
  outputs = model(**inputs)
@@ -61,6 +65,7 @@ def make_prediction(img, feature_extractor, model):
61
  processed_outputs = feature_extractor.post_process(outputs, img_size)
62
  return processed_outputs[0]
63
 
 
64
  def main():
65
  option = st.selectbox("Which model should we use?", ("facebook/detr-resnet-50", "facebook/detr-resnet-101"))
66
  feature_extractor, model = get_hf_components(option)
@@ -72,5 +77,5 @@ def main():
72
  st.image(viz_img)
73
 
74
 
75
- if __name__ == '__main__':
76
  main()
 
17
  [0.301, 0.745, 0.933],
18
  ]
19
 
20
+
21
  @st.cache(allow_output_mutation=True)
22
  def get_hf_components(model_name_or_path):
23
  feature_extractor = DetrFeatureExtractor.from_pretrained(model_name_or_path)
 
25
  model.eval()
26
  return feature_extractor, model
27
 
28
+
29
  @st.cache
30
  def get_img_from_url(url):
31
  return Image.open(requests.get(url, stream=True).raw)
32
 
33
+
34
  def fig2img(fig):
35
  buf = io.BytesIO()
36
  fig.savefig(buf)
 
57
  plt.axis("off")
58
  return fig2img(plt.gcf())
59
 
60
+
61
  def make_prediction(img, feature_extractor, model):
62
  inputs = feature_extractor(img, return_tensors="pt")
63
  outputs = model(**inputs)
 
65
  processed_outputs = feature_extractor.post_process(outputs, img_size)
66
  return processed_outputs[0]
67
 
68
+
69
  def main():
70
  option = st.selectbox("Which model should we use?", ("facebook/detr-resnet-50", "facebook/detr-resnet-101"))
71
  feature_extractor, model = get_hf_components(option)
 
77
  st.image(viz_img)
78
 
79
 
80
+ if __name__ == "__main__":
81
  main()