hysts HF staff commited on
Commit
c81e62d
1 Parent(s): aa8c79c

Remove the tab for semi-supervised models

Browse files
Files changed (2) hide show
  1. app.py +22 -72
  2. model.py +34 -63
app.py CHANGED
@@ -4,86 +4,36 @@ import pathlib
4
 
5
  import gradio as gr
6
 
7
- from model import FULLY_SUPERVISED_MODELS, SEMI_SUPERVISED_MODELS, Model
8
 
9
  DESCRIPTION = '''# CutLER
10
 
11
  This is an unofficial demo for [https://github.com/facebookresearch/CutLER](https://github.com/facebookresearch/CutLER).
12
  '''
13
 
14
- model = Model()
15
  paths = sorted(pathlib.Path('CutLER/cutler/demo/imgs').glob('*.jpg'))
16
 
17
-
18
- def create_unsupervised_demo():
19
- with gr.Blocks() as demo:
20
- with gr.Row():
21
- with gr.Column():
22
- image = gr.Image(label='Input image', type='filepath')
23
- model_name = gr.Text(label='Model',
24
- value='Unsupervised',
25
- visible=False)
26
- score_threshold = gr.Slider(label='Score threshold',
27
- minimum=0,
28
- maximum=1,
29
- value=0.5,
30
- step=0.05)
31
- run_button = gr.Button('Run')
32
- with gr.Column():
33
- result = gr.Image(label='Result', type='numpy')
34
- with gr.Row():
35
- gr.Examples(examples=[[path.as_posix()] for path in paths],
36
- inputs=[image])
37
-
38
- run_button.click(fn=model,
39
- inputs=[
40
- image,
41
- model_name,
42
- score_threshold,
43
- ],
44
- outputs=result)
45
-
46
- return demo
47
-
48
-
49
- def create_supervised_demo():
50
- model_names = list(SEMI_SUPERVISED_MODELS.keys()) + list(
51
- FULLY_SUPERVISED_MODELS.keys())
52
- with gr.Blocks() as demo:
53
- with gr.Row():
54
- with gr.Column():
55
- image = gr.Image(label='Input image', type='filepath')
56
- model_name = gr.Dropdown(label='Model',
57
- choices=model_names,
58
- value=model_names[-1])
59
- score_threshold = gr.Slider(label='Score threshold',
60
- minimum=0,
61
- maximum=1,
62
- value=0.5,
63
- step=0.05)
64
- run_button = gr.Button('Run')
65
- with gr.Column():
66
- result = gr.Image(label='Result', type='numpy')
67
- with gr.Row():
68
- gr.Examples(examples=[[path.as_posix()] for path in paths],
69
- inputs=[image])
70
-
71
- run_button.click(fn=model,
72
- inputs=[
73
- image,
74
- model_name,
75
- score_threshold,
76
- ],
77
- outputs=result)
78
-
79
- return demo
80
-
81
-
82
  with gr.Blocks(css='style.css') as demo:
83
  gr.Markdown(DESCRIPTION)
84
- with gr.Tabs():
85
- with gr.TabItem('Zero-shot unsupervised'):
86
- create_unsupervised_demo()
87
- with gr.TabItem('Semi/Fully-supervised'):
88
- create_supervised_demo()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  demo.queue().launch()
4
 
5
  import gradio as gr
6
 
7
+ from model import run_model
8
 
9
  DESCRIPTION = '''# CutLER
10
 
11
  This is an unofficial demo for [https://github.com/facebookresearch/CutLER](https://github.com/facebookresearch/CutLER).
12
  '''
13
 
 
14
  paths = sorted(pathlib.Path('CutLER/cutler/demo/imgs').glob('*.jpg'))
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  with gr.Blocks(css='style.css') as demo:
17
  gr.Markdown(DESCRIPTION)
18
+ with gr.Row():
19
+ with gr.Column():
20
+ image = gr.Image(label='Input image', type='filepath')
21
+ score_threshold = gr.Slider(label='Score threshold',
22
+ minimum=0,
23
+ maximum=1,
24
+ value=0.5,
25
+ step=0.05)
26
+ run_button = gr.Button('Run')
27
+ with gr.Column():
28
+ result = gr.Image(label='Result', type='numpy')
29
+ with gr.Row():
30
+ gr.Examples(examples=[[path.as_posix()] for path in paths],
31
+ inputs=[image])
32
+
33
+ run_button.click(fn=run_model,
34
+ inputs=[
35
+ image,
36
+ score_threshold,
37
+ ],
38
+ outputs=result)
39
  demo.queue().launch()
model.py CHANGED
@@ -21,32 +21,6 @@ from predictor import VisualizationDemo
21
 
22
  mp.set_start_method('spawn', force=True)
23
 
24
- UNSUPERVISED_MODELS = {
25
- 'Unsupervised': {
26
- 'config_path':
27
- 'CutLER/cutler/model_zoo/configs/CutLER-ImageNet/cascade_mask_rcnn_R_50_FPN.yaml',
28
- 'weight_url':
29
- 'http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth',
30
- }
31
- }
32
- SEMI_SUPERVISED_MODELS = {
33
- f'Semi-supervised with COCO ({perc}%)': {
34
- 'config_path':
35
- f'CutLER/cutler/model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_{perc}perc.yaml',
36
- 'weight_url':
37
- f'http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_semi_{perc}perc.pth',
38
- }
39
- for perc in [1, 2, 5, 10, 20, 30, 40, 50, 60, 80]
40
- }
41
- FULLY_SUPERVISED_MODELS = {
42
- 'Fully-supervised with COCO': {
43
- 'config_path':
44
- f'CutLER/cutler/model_zoo/configs/COCO-Semisupervised/cascade_mask_rcnn_R_50_FPN_100perc.yaml',
45
- 'weight_url':
46
- f'http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_fully_100perc.pth',
47
- }
48
- }
49
-
50
 
51
  def setup_cfg(args):
52
  # load config from file and command-line arguments
@@ -108,40 +82,37 @@ def get_parser():
108
  return parser
109
 
110
 
111
- class Model:
112
- MODEL_DICT = UNSUPERVISED_MODELS | SEMI_SUPERVISED_MODELS | FULLY_SUPERVISED_MODELS
113
-
114
- def __init__(self):
115
- self.model_dir = pathlib.Path('checkpoints')
116
- self.model_dir.mkdir(exist_ok=True)
117
-
118
- def load_model(self, model_name: str,
119
- score_threshold: float) -> VisualizationDemo:
120
- model_info = self.MODEL_DICT[model_name]
121
- weight_url = model_info['weight_url']
122
- weight_path = self.model_dir / weight_url.split('/')[-1]
123
- if not weight_path.exists():
124
- weight_path.parent.mkdir(exist_ok=True)
125
- subprocess.run(shlex.split(f'wget {weight_url} -O {weight_path}'))
126
-
127
- arg_list = [
128
- '--config-file', model_info['config_path'],
129
- '--confidence-threshold',
130
- str(score_threshold), '--opts', 'MODEL.WEIGHTS',
131
- weight_path.as_posix(), 'MODEL.DEVICE',
132
- 'cuda:0' if torch.cuda.is_available() else 'cpu'
133
- ]
134
- if model_name in UNSUPERVISED_MODELS:
135
- arg_list += ['DATASETS.TEST', '()']
136
- args = get_parser().parse_args(arg_list)
137
- cfg = setup_cfg(args)
138
- return VisualizationDemo(cfg)
139
-
140
- def __call__(self,
141
- image_path: str,
142
- model_name: str,
143
- score_threshold: float = 0.5) -> np.ndarray:
144
- model = self.load_model(model_name, score_threshold)
145
- image = read_image(image_path, format='BGR')
146
- _, res = model.run_on_image(image)
147
- return res.get_image()
21
 
22
  mp.set_start_method('spawn', force=True)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def setup_cfg(args):
26
  # load config from file and command-line arguments
82
  return parser
83
 
84
 
85
+ CONFIG_PATH = 'CutLER/cutler/model_zoo/configs/CutLER-ImageNet/cascade_mask_rcnn_R_50_FPN.yaml'
86
+ WEIGHT_URL = 'http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_final.pth'
87
+
88
+
89
+ def load_model(score_threshold: float) -> VisualizationDemo:
90
+ model_dir = pathlib.Path('checkpoints')
91
+ model_dir.mkdir(exist_ok=True)
92
+ weight_path = model_dir / WEIGHT_URL.split('/')[-1]
93
+ if not weight_path.exists():
94
+ subprocess.run(shlex.split(f'wget {WEIGHT_URL} -O {weight_path}'))
95
+
96
+ arg_list = [
97
+ '--config-file',
98
+ CONFIG_PATH,
99
+ '--confidence-threshold',
100
+ str(score_threshold),
101
+ '--opts',
102
+ 'MODEL.WEIGHTS',
103
+ weight_path.as_posix(),
104
+ 'MODEL.DEVICE',
105
+ 'cuda:0' if torch.cuda.is_available() else 'cpu',
106
+ 'DATASETS.TEST',
107
+ '()',
108
+ ]
109
+ args = get_parser().parse_args(arg_list)
110
+ cfg = setup_cfg(args)
111
+ return VisualizationDemo(cfg)
112
+
113
+
114
+ def run_model(image_path: str, score_threshold: float = 0.5) -> np.ndarray:
115
+ model = load_model(score_threshold)
116
+ image = read_image(image_path, format='BGR')
117
+ _, res = model.run_on_image(image)
118
+ return res.get_image()