shvardhan commited on
Commit
e6ff83d
1 Parent(s): ba8abbe

Add application file

Browse files
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import pathlib
7
+ import subprocess
8
+ import tarfile
9
+
10
+ import cv2
11
+ import gradio as gr
12
+ import numpy as np
13
+
14
+ from model import AppModel
15
+
16
+ DESCRIPTION = '''# MMDetection
17
+ This is an unofficial demo for [https://github.com/open-mmlab/mmdetection](https://github.com/open-mmlab/mmdetection).
18
+ <img id="overview" alt="overview" src="https://user-images.githubusercontent.com/12907710/137271636-56ba1cd2-b110-4812-8221-b4c120320aa9.png" />
19
+ '''
20
+
21
+ DEFAULT_MODEL_TYPE = 'detection'
22
+ DEFAULT_MODEL_NAMES = {
23
+ 'detection': 'YOLOX-l',
24
+ 'instance_segmentation': 'QueryInst (R-50-FPN)',
25
+ 'panoptic_segmentation': 'MaskFormer (R-50)',
26
+ }
27
+ DEFAULT_MODEL_NAME = DEFAULT_MODEL_NAMES[DEFAULT_MODEL_TYPE]
28
+
29
+
30
+
31
+ def update_input_image(image: np.ndarray) -> dict:
32
+ if image is None:
33
+ return gr.Image.update(value=None)
34
+ scale = 1500 / max(image.shape[:2])
35
+ if scale < 1:
36
+ image = cv2.resize(image, None, fx=scale, fy=scale)
37
+ return gr.Image.update(value=image)
38
+
39
+
40
+ def update_model_name(model_type: str) -> dict:
41
+ model_dict = getattr(AppModel, f'{model_type.upper()}_MODEL_DICT')
42
+ model_names = list(model_dict.keys())
43
+ model_name = DEFAULT_MODEL_NAMES[model_type]
44
+ return gr.Dropdown.update(choices=model_names, value=model_name)
45
+
46
+
47
+ def update_visualization_score_threshold(model_type: str) -> dict:
48
+ return gr.Slider.update(visible=model_type != 'panoptic_segmentation')
49
+
50
+
51
+ def update_redraw_button(model_type: str) -> dict:
52
+ return gr.Button.update(visible=model_type != 'panoptic_segmentation')
53
+
54
+
55
+ def set_example_image(example: list) -> dict:
56
+ return gr.Image.update(value=example[0])
57
+
58
+
59
+ model = AppModel(DEFAULT_MODEL_NAME)
60
+
61
+ with gr.Blocks(css='style.css') as demo:
62
+ gr.Markdown(DESCRIPTION)
63
+
64
+ with gr.Row():
65
+ with gr.Column():
66
+ with gr.Row():
67
+ input_image = gr.Image(label='Input Image', type='numpy')
68
+ with gr.Group():
69
+ with gr.Row():
70
+ model_type = gr.Radio(list(DEFAULT_MODEL_NAMES.keys()),
71
+ value=DEFAULT_MODEL_TYPE,
72
+ label='Model Type')
73
+ with gr.Row():
74
+ model_name = gr.Dropdown(list(
75
+ model.DETECTION_MODEL_DICT.keys()),
76
+ value=DEFAULT_MODEL_NAME,
77
+ label='Model')
78
+ with gr.Row():
79
+ run_button = gr.Button(value='Run')
80
+ prediction_results = gr.Variable()
81
+ with gr.Column():
82
+ with gr.Row():
83
+ visualization = gr.Image(label='Result', type='numpy')
84
+ with gr.Row():
85
+ visualization_score_threshold = gr.Slider(
86
+ 0,
87
+ 1,
88
+ step=0.05,
89
+ value=0.3,
90
+ label='Visualization Score Threshold')
91
+ with gr.Row():
92
+ redraw_button = gr.Button(value='Redraw')
93
+
94
+ with gr.Row():
95
+ paths = sorted(pathlib.Path('images').rglob('*.jpg'))
96
+ example_images = gr.Dataset(components=[input_image],
97
+ samples=[[path.as_posix()]
98
+ for path in paths])
99
+
100
+ input_image.change(fn=update_input_image,
101
+ inputs=input_image,
102
+ outputs=input_image)
103
+
104
+ model_type.change(fn=update_model_name,
105
+ inputs=model_type,
106
+ outputs=model_name)
107
+ model_type.change(fn=update_visualization_score_threshold,
108
+ inputs=model_type,
109
+ outputs=visualization_score_threshold)
110
+ model_type.change(fn=update_redraw_button,
111
+ inputs=model_type,
112
+ outputs=redraw_button)
113
+
114
+ model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
115
+ run_button.click(fn=model.run,
116
+ inputs=[
117
+ model_name,
118
+ input_image,
119
+ visualization_score_threshold,
120
+ ],
121
+ outputs=[
122
+ prediction_results,
123
+ visualization,
124
+ ])
125
+ redraw_button.click(fn=model.visualize_detection_results,
126
+ inputs=[
127
+ input_image,
128
+ prediction_results,
129
+ visualization_score_threshold,
130
+ ],
131
+ outputs=visualization)
132
+ example_images.click(fn=set_example_image,
133
+ inputs=example_images,
134
+ outputs=input_image)
135
+
136
+ demo.queue().launch(show_api=False)
configs/_base_/faster-rcnn_r50_fpn_1x_coco.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ model = dict(
3
+ type='FasterRCNN',
4
+ data_preprocessor=dict(
5
+ type='DetDataPreprocessor',
6
+ mean=[123.675, 116.28, 103.53],
7
+ std=[58.395, 57.12, 57.375],
8
+ bgr_to_rgb=True,
9
+ pad_size_divisor=32),
10
+ backbone=dict(
11
+ type='ResNet',
12
+ depth=50,
13
+ num_stages=4,
14
+ out_indices=(0, 1, 2, 3),
15
+ frozen_stages=1,
16
+ norm_cfg=dict(type='BN', requires_grad=True),
17
+ norm_eval=True,
18
+ style='pytorch',
19
+ init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
20
+ neck=dict(
21
+ type='FPN',
22
+ in_channels=[256, 512, 1024, 2048],
23
+ out_channels=256,
24
+ num_outs=5),
25
+ rpn_head=dict(
26
+ type='RPNHead',
27
+ in_channels=256,
28
+ feat_channels=256,
29
+ anchor_generator=dict(
30
+ type='AnchorGenerator',
31
+ scales=[8],
32
+ ratios=[0.5, 1.0, 2.0],
33
+ strides=[4, 8, 16, 32, 64]),
34
+ bbox_coder=dict(
35
+ type='DeltaXYWHBBoxCoder',
36
+ target_means=[.0, .0, .0, .0],
37
+ target_stds=[1.0, 1.0, 1.0, 1.0]),
38
+ loss_cls=dict(
39
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
40
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
41
+ roi_head=dict(
42
+ type='StandardRoIHead',
43
+ bbox_roi_extractor=dict(
44
+ type='SingleRoIExtractor',
45
+ roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
46
+ out_channels=256,
47
+ featmap_strides=[4, 8, 16, 32]),
48
+ bbox_head=dict(
49
+ type='Shared2FCBBoxHead',
50
+ in_channels=256,
51
+ fc_out_channels=1024,
52
+ roi_feat_size=7,
53
+ num_classes=80,
54
+ bbox_coder=dict(
55
+ type='DeltaXYWHBBoxCoder',
56
+ target_means=[0., 0., 0., 0.],
57
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
58
+ reg_class_agnostic=False,
59
+ loss_cls=dict(
60
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
61
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
62
+ # model training and testing settings
63
+ train_cfg=dict(
64
+ rpn=dict(
65
+ assigner=dict(
66
+ type='MaxIoUAssigner',
67
+ pos_iou_thr=0.7,
68
+ neg_iou_thr=0.3,
69
+ min_pos_iou=0.3,
70
+ match_low_quality=True,
71
+ ignore_iof_thr=-1),
72
+ sampler=dict(
73
+ type='RandomSampler',
74
+ num=256,
75
+ pos_fraction=0.5,
76
+ neg_pos_ub=-1,
77
+ add_gt_as_proposals=False),
78
+ allowed_border=-1,
79
+ pos_weight=-1,
80
+ debug=False),
81
+ rpn_proposal=dict(
82
+ nms_pre=2000,
83
+ max_per_img=1000,
84
+ nms=dict(type='nms', iou_threshold=0.7),
85
+ min_bbox_size=0),
86
+ rcnn=dict(
87
+ assigner=dict(
88
+ type='MaxIoUAssigner',
89
+ pos_iou_thr=0.5,
90
+ neg_iou_thr=0.5,
91
+ min_pos_iou=0.5,
92
+ match_low_quality=False,
93
+ ignore_iof_thr=-1),
94
+ sampler=dict(
95
+ type='RandomSampler',
96
+ num=512,
97
+ pos_fraction=0.25,
98
+ neg_pos_ub=-1,
99
+ add_gt_as_proposals=True),
100
+ pos_weight=-1,
101
+ debug=False)),
102
+ test_cfg=dict(
103
+ rpn=dict(
104
+ nms_pre=1000,
105
+ max_per_img=1000,
106
+ nms=dict(type='nms', iou_threshold=0.7),
107
+ min_bbox_size=0),
108
+ rcnn=dict(
109
+ score_thr=0.05,
110
+ nms=dict(type='nms', iou_threshold=0.5),
111
+ max_per_img=100)
112
+ # soft-nms is also supported for rcnn testing
113
+ # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
114
+ ))
configs/faster-rcnn_r50_fpn_organoid_orgaquant.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Inherit and overwrite part of the config based on this config
3
+ _base_ = './faster-rcnn_r50_fpn_1x_coco.py'
4
+
5
+ data_root = 'data/' # dataset root
6
+
7
+ train_batch_size_per_gpu = 16
8
+ train_num_workers = 1
9
+
10
+ max_epochs = 105
11
+ base_lr = 0.00001
12
+
13
+
14
+ metainfo = {
15
+ 'classes': ('orgaquant', ),
16
+ 'palette': [
17
+ (220, 20, 60),
18
+ ]
19
+ }
20
+
21
+ train_dataloader = dict(
22
+ batch_size=train_batch_size_per_gpu,
23
+ num_workers=train_num_workers,
24
+ dataset=dict(
25
+ data_root=data_root,
26
+ metainfo=metainfo,
27
+ data_prefix=dict(img='train/'),
28
+ ann_file='train.json'))
29
+
30
+ val_dataloader = dict(
31
+ dataset=dict(
32
+ data_root=data_root,
33
+ metainfo=metainfo,
34
+ data_prefix=dict(img='val/'),
35
+ ann_file='val.json'))
36
+
37
+ test_dataloader = val_dataloader
38
+
39
+ val_evaluator = dict(ann_file=data_root + 'val.json')
40
+
41
+ test_evaluator = val_evaluator
42
+
43
+ model = dict(
44
+ roi_head=dict(
45
+ bbox_head=dict(num_classes=1)))
46
+
47
+
48
+
49
+ train_pipeline = [
50
+ dict(type='LoadImageFromFile', backend_args=None),
51
+ dict(type='LoadAnnotations', with_bbox=True),
52
+ dict(type='RandomFlip', prob=0.5),
53
+ dict(type = 'RandomShift', prob = 0.5),
54
+ dict(type = 'RandomAffine'),
55
+ dict(type='PhotoMetricDistortion'),
56
+ dict(type='PackDetInputs')
57
+ ]
58
+
59
+
60
+ # optimizer
61
+ optim_wrapper = dict(
62
+ _delete_=True,
63
+ type='OptimWrapper',
64
+ optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
65
+ paramwise_cfg=dict(
66
+ norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
67
+
68
+ default_hooks = dict(
69
+ checkpoint=dict(
70
+ interval=5,
71
+ max_keep_ckpts=2, # only keep latest 2 checkpoints
72
+ save_best='auto'
73
+ ),
74
+ logger=dict(type='LoggerHook', interval=5))
75
+
76
+
77
+ # load COCO pre-trained weight
78
+
79
+ # load_from = './work_dirs/faster-rcnn_r50_fpn_organoid/best_coco_bbox_mAP_epoch_12.pth'
80
+
81
+
82
+ train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)
83
+ visualizer = dict(vis_backends=[dict(type='LocalVisBackend'),dict(type='TensorboardVisBackend')])
images/Subset_1_450x450_001.jpg ADDED
images/Subset_1_450x450_002.jpg ADDED
images/Subset_1_450x450_003.jpg ADDED
images/Subset_1_450x450_004.jpg ADDED
images/Subset_1_450x450_005.jpg ADDED
images/Subset_1_450x450_006.jpg ADDED
images/Subset_1_450x450_007.jpg ADDED
images/Subset_1_450x450_008.jpg ADDED
images/Subset_1_450x450_009.jpg ADDED
images/Subset_1_450x450_010.jpg ADDED
model.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+
5
+ import huggingface_hub
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import yaml # type: ignore
10
+ from mmdet.apis import inference_detector, init_detector
11
+
12
+
13
+
14
+
15
+
16
+ class Model:
17
+
18
+ def __init__(self, model_name: str):
19
+ self.device = torch.device(
20
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
21
+ self.model_name = model_name
22
+ self.model = self._load_model(model_name)
23
+
24
+
25
+ def _load_model(self, name: str) -> nn.Module:
26
+ dic = self.MODEL_DICT[name]
27
+ return init_detector('configs/_base_/faster-rcnn_r50_fpn_1x_coco.py','models/orgaquanT-pretarined.pth' , device=self.device)
28
+
29
+ def set_model(self, name: str) -> None:
30
+ if name == self.model_name:
31
+ return
32
+ self.model_name = name
33
+ self.model = self._load_model(name)
34
+
35
+ def detect_and_visualize(
36
+ self, image: np.ndarray, score_threshold: float
37
+ ) -> tuple[list[np.ndarray] | tuple[list[np.ndarray],
38
+ list[list[np.ndarray]]]
39
+ | dict[str, np.ndarray], np.ndarray]:
40
+ out = self.detect(image)
41
+ vis = self.visualize_detection_results(image, out, score_threshold)
42
+ return out, vis
43
+
44
+ def detect(
45
+ self, image: np.ndarray
46
+ ) -> list[np.ndarray] | tuple[
47
+ list[np.ndarray], list[list[np.ndarray]]] | dict[str, np.ndarray]:
48
+ out = inference_detector(self.model, image)
49
+ return out
50
+
51
+ def visualize_detection_results(
52
+ self,
53
+ image: np.ndarray,
54
+ detection_results: list[np.ndarray]
55
+ | tuple[list[np.ndarray], list[list[np.ndarray]]]
56
+ | dict[str, np.ndarray],
57
+ score_threshold: float = 0.3) -> np.ndarray:
58
+ vis = self.model.show_result(image,
59
+ detection_results,
60
+ score_thr=score_threshold,
61
+ bbox_color=None,
62
+ text_color=(200, 200, 200),
63
+ mask_color=None)
64
+ return vis
65
+
66
+
67
+ class AppModel(Model):
68
+ def run(
69
+ self, model_name: str, image: np.ndarray, score_threshold: float
70
+ ) -> tuple[list[np.ndarray] | tuple[list[np.ndarray],
71
+ list[list[np.ndarray]]]
72
+ | dict[str, np.ndarray], np.ndarray]:
73
+ self.set_model(model_name)
74
+ return self.detect_and_visualize(image, score_threshold)
models/orgaquant_pretrained.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94f9c7f8e33727b7838bb72614b7a3af0c66071e8138708463e1fc1eaac928a2
3
+ size 495354591
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ mmcv-full==1.5.2
2
+ mmdet==2.25.0
3
+ numpy==1.22.4
4
+ opencv-python-headless==4.5.5.64
5
+ openmim==0.1.5
6
+ torch==1.11.0
7
+ torchvision==0.12.0