stefan-it commited on
Commit
70d5e84
1 Parent(s): 91f0453

app: rewrite in Gradio Blocks with multi-model support

Browse files
Files changed (1) hide show
  1. app.py +86 -33
app.py CHANGED
@@ -22,21 +22,40 @@ from detectron2.utils.visualizer import Visualizer
22
  from detectron2.data import MetadataCatalog
23
 
24
 
25
- model_path = "https://huggingface.co/dbmdz/detectron2-model/resolve/main/model_final.pth"
26
-
27
- cfg = get_cfg()
28
- cfg.merge_from_file("./configs/detectron2/faster_rcnn_R_50_FPN_3x.yaml")
29
- cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2
30
- cfg.MODEL.WEIGHTS = model_path
31
-
32
- my_metadata = MetadataCatalog.get("dbmdz_coco_all")
33
- my_metadata.thing_classes = ["Illumination", "Illustration"]
34
-
35
- if not torch.cuda.is_available():
36
- cfg.MODEL.DEVICE = "cpu"
37
-
38
-
39
- def inference(image_url, image, min_score):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  if image_url:
41
  r = requests.get(image_url)
42
  if r:
@@ -46,29 +65,63 @@ def inference(image_url, image, min_score):
46
  # Model expect BGR!
47
  im = image[:,:,::-1]
48
 
49
- cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = min_score
50
- predictor = DefaultPredictor(cfg)
 
 
51
 
52
  outputs = predictor(im)
53
 
54
- v = Visualizer(im, my_metadata, scale=1.2)
55
  out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
56
 
57
  return out.get_image()
58
 
59
 
60
- title = "DBMDZ Detectron2 Model Demo"
61
- description = "This demo introduces an interactive playground for our trained Detectron2 model. <br>The model was trained on manually annotated segments from digitized books to detect Illustration or Illumination segments on a given page."
62
- article = '<p>Detectron model is available from our repository <a href="">here</a> on the Hugging Face Model Hub.</p>'
63
-
64
- gr.Interface(
65
- inference,
66
- [gr.inputs.Textbox(label="Image URL", placeholder="https://api.digitale-sammlungen.de/iiif/image/v2/bsb10483966_00008/full/500,/0/default.jpg"),
67
- gr.inputs.Image(type="numpy", label="Input Image"),
68
- gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Minimum score"),
69
- ],
70
- gr.outputs.Image(type="pil", label="Output"),
71
- title=title,
72
- description=description,
73
- article=article,
74
- examples=[]).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  from detectron2.data import MetadataCatalog
23
 
24
 
25
+ models = [
26
+ {
27
+ "name": "Version 1 (2-class)",
28
+ "model_path": "https://huggingface.co/dbmdz/detectron2-model/resolve/main/model_final.pth",
29
+ "classes": ["Illumination", "Illustration"],
30
+ "cfg": None,
31
+ "metadata": None
32
+ },
33
+ {
34
+ "name": "Version 2 (4-class)",
35
+ "model_path": "https://huggingface.co/dbmdz/detectron2-v2-model/resolve/main/model_final.pth",
36
+ "classes": ["ILLUSTRATION", "OTHER", "STAMP", "INITIAL"],
37
+ "cfg": None,
38
+ "metadata": None
39
+ },
40
+ ]
41
+
42
+ model_name_to_id = {model["name"] : id_ for id_, model in enumerate(models)}
43
+
44
+ for model in models:
45
+
46
+ model["cfg"] = get_cfg()
47
+ model["cfg"].merge_from_file("./configs/detectron2/faster_rcnn_R_50_FPN_3x.yaml")
48
+ model["cfg"].MODEL.ROI_HEADS.NUM_CLASSES = len(model["classes"])
49
+ model["cfg"].MODEL.WEIGHTS = model["model_path"]
50
+
51
+ model["metadata"] = MetadataCatalog.get(model["name"])
52
+ model["metadata"].thing_classes = model["classes"]
53
+
54
+ if not torch.cuda.is_available():
55
+ model["cfg"].MODEL.DEVICE = "cpu"
56
+
57
+
58
+ def inference(image_url, image, min_score, model_name):
59
  if image_url:
60
  r = requests.get(image_url)
61
  if r:
65
  # Model expect BGR!
66
  im = image[:,:,::-1]
67
 
68
+ model_id = model_name_to_id[model_name]
69
+
70
+ models[model_id]["cfg"].MODEL.ROI_HEADS.SCORE_THRESH_TEST = min_score
71
+ predictor = DefaultPredictor(models[model_id]["cfg"])
72
 
73
  outputs = predictor(im)
74
 
75
+ v = Visualizer(im, models[model_id]["metadata"], scale=1.2)
76
  out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
77
 
78
  return out.get_image()
79
 
80
 
81
+ title = "# DBMDZ Detectron2 Model Demo"
82
+ description = """
83
+ This demo introduces an interactive playground for our trained Detectron2 model.
84
+
85
+ Currently, two models are supported that were trained on manually annotated segments from digitized books:
86
+
87
+ * [Version 1 (2-class)](https://huggingface.co/dbmdz/detectron2-model): This model can detect *Illustration* or *Illumination* segments on a given page.
88
+ * [Version 2 (4-class)](https://huggingface.co/dbmdz/detectron2-v2-model): This model is more powerful and can detect *Illustration*, *Stamp*, *Initial* or *Other* segments on a given page.
89
+ """
90
+ footer = "Made in Munich with ❤️ and 🥨."
91
+
92
+ with gr.Blocks() as demo:
93
+ gr.Markdown(title)
94
+ gr.Markdown(description)
95
+
96
+ with gr.Tab("From URL"):
97
+ url_input = gr.Textbox(label="Image URL", placeholder="https://api.digitale-sammlungen.de/iiif/image/v2/bsb10483966_00008/full/500,/0/default.jpg")
98
+
99
+ with gr.Tab("From Image"):
100
+ image_input = gr.Image(type="numpy", label="Input Image")
101
+
102
+ min_score = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Minimum score")
103
+
104
+ model_name = gr.Radio(choices=[model["name"] for model in models], value=models[0]["name"], label="Select Detectron2 model")
105
+
106
+ output_image = gr.Image(type="pil", label="Output")
107
+
108
+ inference_button = gr.Button("Submit")
109
+
110
+ inference_button.click(fn=inference, inputs=[url_input, image_input, min_score, model_name], outputs=output_image)
111
+
112
+ gr.Markdown(footer)
113
+
114
+ demo.launch()
115
+
116
+ #gr.Interface(
117
+ # inference,
118
+ # [gr.inputs.Textbox(label="Image URL", placeholder="https://api.digitale-sammlungen.de/iiif/image/v2/bsb10483966_00008/full/500,/0/default.jpg"),
119
+ # gr.inputs.Image(type="numpy", label="Input Image"),
120
+ # gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Minimum score"),
121
+ # gr.Radio(choices=[model["name"] for model in models], value=models[0]["name"], label="Select Detectron2 model"),
122
+ # ],
123
+ # gr.outputs.Image(type="pil", label="Output"),
124
+ # title=title,
125
+ # description=description,
126
+ # article=article,
127
+ # examples=[]).launch()