kittendev commited on
Commit
8274714
1 Parent(s): 3fec50d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -178
app.py CHANGED
@@ -1,178 +1,178 @@
1
- # Copyright (C) 2020 * Ltd. All rights reserved.
2
- # author : Sanghyeon Jo <josanghyeokn@gmail.com>
3
-
4
- import gradio as gr
5
-
6
- import os
7
- import sys
8
- import copy
9
- import shutil
10
- import random
11
- import argparse
12
- import numpy as np
13
-
14
- import imageio
15
-
16
- import torch
17
- import torch.nn as nn
18
- import torch.nn.functional as F
19
-
20
- from torchvision import transforms
21
- from torch.utils.tensorboard import SummaryWriter
22
-
23
- from torch.utils.data import DataLoader
24
-
25
- from core.puzzle_utils import *
26
- from core.networks import *
27
- from core.datasets import *
28
-
29
- from tools.general.io_utils import *
30
- from tools.general.time_utils import *
31
- from tools.general.json_utils import *
32
-
33
- from tools.ai.log_utils import *
34
- from tools.ai.demo_utils import *
35
- from tools.ai.optim_utils import *
36
- from tools.ai.torch_utils import *
37
- from tools.ai.evaluate_utils import *
38
-
39
- from tools.ai.augment_utils import *
40
- from tools.ai.randaugment import *
41
-
42
- parser = argparse.ArgumentParser()
43
-
44
- ###############################################################################
45
- # Dataset
46
- ###############################################################################
47
- parser.add_argument('--seed', default=2606, type=int)
48
- parser.add_argument('--num_workers', default=4, type=int)
49
- parser.add_argument('--data_dir', default='../VOCtrainval_11-May-2012/', type=str)
50
-
51
- ###############################################################################
52
- # Network
53
- ###############################################################################
54
- parser.add_argument('--architecture', default='DeepLabv3+', type=str)
55
- parser.add_argument('--backbone', default='resnet50', type=str)
56
- parser.add_argument('--mode', default='fix', type=str)
57
- parser.add_argument('--use_gn', default=True, type=str2bool)
58
-
59
- ###############################################################################
60
- # Inference parameters
61
- ###############################################################################
62
- parser.add_argument('--tag', default='', type=str)
63
-
64
- parser.add_argument('--domain', default='val', type=str)
65
-
66
- parser.add_argument('--scales', default='0.5,1.0,1.5,2.0', type=str)
67
- parser.add_argument('--iteration', default=10, type=int)
68
-
69
- if __name__ == '__main__':
70
- ###################################################################################
71
- # Arguments
72
- ###################################################################################
73
- args = parser.parse_args()
74
-
75
- model_dir = create_directory('./experiments/models/')
76
- model_path = model_dir + f'DeepLabv3+@ResNeSt-101@Fix@GN.pth'
77
-
78
- if 'train' in args.domain:
79
- args.tag += '@train'
80
- else:
81
- args.tag += '@' + args.domain
82
-
83
- args.tag += '@scale=%s' % args.scales
84
- args.tag += '@iteration=%d' % args.iteration
85
-
86
- set_seed(args.seed)
87
- log_func = lambda string='': print(string)
88
-
89
- ###################################################################################
90
- # Transform, Dataset, DataLoader
91
- ###################################################################################
92
- imagenet_mean = [0.485, 0.456, 0.406]
93
- imagenet_std = [0.229, 0.224, 0.225]
94
-
95
- normalize_fn = Normalize(imagenet_mean, imagenet_std)
96
-
97
- # for mIoU
98
- meta_dic = read_json('./data/VOC_2012.json')
99
-
100
- ###################################################################################
101
- # Network
102
- ###################################################################################
103
- if args.architecture == 'DeepLabv3+':
104
- model = DeepLabv3_Plus(args.backbone, num_classes=meta_dic['classes'] + 1, mode=args.mode,
105
- use_group_norm=args.use_gn)
106
- elif args.architecture == 'Seg_Model':
107
- model = Seg_Model(args.backbone, num_classes=meta_dic['classes'] + 1)
108
- elif args.architecture == 'CSeg_Model':
109
- model = CSeg_Model(args.backbone, num_classes=meta_dic['classes'] + 1)
110
-
111
- model = model.cuda()
112
- model.eval()
113
-
114
- log_func('[i] Architecture is {}'.format(args.architecture))
115
- log_func('[i] Total Params: %.2fM' % (calculate_parameters(model)))
116
- log_func()
117
-
118
- load_model(model, model_path, parallel=False)
119
-
120
- #################################################################################################
121
- # Evaluation
122
- #################################################################################################
123
- eval_timer = Timer()
124
- scales = [float(scale) for scale in args.scales.split(',')]
125
-
126
- model.eval()
127
- eval_timer.tik()
128
-
129
-
130
- def inference(images, image_size):
131
- images = images.cuda()
132
-
133
- logits = model(images)
134
- logits = resize_for_tensors(logits, image_size)
135
-
136
- logits = logits[0] + logits[1].flip(-1)
137
- logits = get_numpy_from_tensor(logits).transpose((1, 2, 0))
138
- return logits
139
-
140
-
141
- def predict_image(ori_image):
142
- with torch.no_grad():
143
- ori_w, ori_h = ori_image.size
144
-
145
- cams_list = []
146
-
147
- for scale in scales:
148
- image = copy.deepcopy(ori_image)
149
- image = image.resize((round(ori_w * scale), round(ori_h * scale)), resample=PIL.Image.BICUBIC)
150
-
151
- image = normalize_fn(image)
152
- image = image.transpose((2, 0, 1))
153
-
154
- image = torch.from_numpy(image)
155
- flipped_image = image.flip(-1)
156
-
157
- images = torch.stack([image, flipped_image])
158
-
159
- cams = inference(images, (ori_h, ori_w))
160
- cams_list.append(cams)
161
-
162
- preds = np.sum(cams_list, axis=0)
163
- preds = F.softmax(torch.from_numpy(preds), dim=-1).numpy()
164
-
165
- if args.iteration > 0:
166
- preds = crf_inference(np.asarray(ori_image), preds.transpose((2, 0, 1)), t=args.iteration)
167
- pred_mask = np.argmax(preds, axis=0)
168
- else:
169
- pred_mask = np.argmax(preds, axis=-1)
170
-
171
- return pred_mask.astype(np.uint8)
172
-
173
-
174
- demo = gr.Interface(
175
- fn=predict_image,
176
- inputs="image",
177
- outputs="image"
178
- )
 
1
+ # Copyright (C) 2020 * Ltd. All rights reserved.
2
+ # author : Sanghyeon Jo <josanghyeokn@gmail.com>
3
+
4
+ import gradio as gr
5
+
6
+ import os
7
+ import sys
8
+ import copy
9
+ import shutil
10
+ import random
11
+ import argparse
12
+ import numpy as np
13
+
14
+ import imageio
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+
20
+ from torchvision import transforms
21
+ from torch.utils.tensorboard import SummaryWriter
22
+
23
+ from torch.utils.data import DataLoader
24
+
25
+ from core.puzzle_utils import *
26
+ from core.networks import *
27
+ from core.datasets import *
28
+
29
+ from tools.general.io_utils import *
30
+ from tools.general.time_utils import *
31
+ from tools.general.json_utils import *
32
+
33
+ from tools.ai.log_utils import *
34
+ from tools.ai.demo_utils import *
35
+ from tools.ai.optim_utils import *
36
+ from tools.ai.torch_utils import *
37
+ from tools.ai.evaluate_utils import *
38
+
39
+ from tools.ai.augment_utils import *
40
+ from tools.ai.randaugment import *
41
+
42
+ parser = argparse.ArgumentParser()
43
+
44
+ ###############################################################################
45
+ # Dataset
46
+ ###############################################################################
47
+ parser.add_argument('--seed', default=2606, type=int)
48
+ parser.add_argument('--num_workers', default=4, type=int)
49
+ parser.add_argument('--data_dir', default='../VOCtrainval_11-May-2012/', type=str)
50
+
51
+ ###############################################################################
52
+ # Network
53
+ ###############################################################################
54
+ parser.add_argument('--architecture', default='DeepLabv3+', type=str)
55
+ parser.add_argument('--backbone', default='resnet50', type=str)
56
+ parser.add_argument('--mode', default='fix', type=str)
57
+ parser.add_argument('--use_gn', default=True, type=str2bool)
58
+
59
+ ###############################################################################
60
+ # Inference parameters
61
+ ###############################################################################
62
+ parser.add_argument('--tag', default='', type=str)
63
+
64
+ parser.add_argument('--domain', default='val', type=str)
65
+
66
+ parser.add_argument('--scales', default='0.5,1.0,1.5,2.0', type=str)
67
+ parser.add_argument('--iteration', default=10, type=int)
68
+
69
+ if __name__ == '__main__':
70
+ ###################################################################################
71
+ # Arguments
72
+ ###################################################################################
73
+ args = parser.parse_args()
74
+
75
+ model_dir = create_directory('./experiments/models/')
76
+ model_path = model_dir + f'DeepLabv3+@ResNet-50@Fix@GN.pth'
77
+
78
+ if 'train' in args.domain:
79
+ args.tag += '@train'
80
+ else:
81
+ args.tag += '@' + args.domain
82
+
83
+ args.tag += '@scale=%s' % args.scales
84
+ args.tag += '@iteration=%d' % args.iteration
85
+
86
+ set_seed(args.seed)
87
+ log_func = lambda string='': print(string)
88
+
89
+ ###################################################################################
90
+ # Transform, Dataset, DataLoader
91
+ ###################################################################################
92
+ imagenet_mean = [0.485, 0.456, 0.406]
93
+ imagenet_std = [0.229, 0.224, 0.225]
94
+
95
+ normalize_fn = Normalize(imagenet_mean, imagenet_std)
96
+
97
+ # for mIoU
98
+ meta_dic = read_json('./data/VOC_2012.json')
99
+
100
+ ###################################################################################
101
+ # Network
102
+ ###################################################################################
103
+ if args.architecture == 'DeepLabv3+':
104
+ model = DeepLabv3_Plus(args.backbone, num_classes=meta_dic['classes'] + 1, mode=args.mode,
105
+ use_group_norm=args.use_gn)
106
+ elif args.architecture == 'Seg_Model':
107
+ model = Seg_Model(args.backbone, num_classes=meta_dic['classes'] + 1)
108
+ elif args.architecture == 'CSeg_Model':
109
+ model = CSeg_Model(args.backbone, num_classes=meta_dic['classes'] + 1)
110
+
111
+ model = model.cuda()
112
+ model.eval()
113
+
114
+ log_func('[i] Architecture is {}'.format(args.architecture))
115
+ log_func('[i] Total Params: %.2fM' % (calculate_parameters(model)))
116
+ log_func()
117
+
118
+ load_model(model, model_path, parallel=False)
119
+
120
+ #################################################################################################
121
+ # Evaluation
122
+ #################################################################################################
123
+ eval_timer = Timer()
124
+ scales = [float(scale) for scale in args.scales.split(',')]
125
+
126
+ model.eval()
127
+ eval_timer.tik()
128
+
129
+
130
+ def inference(images, image_size):
131
+ images = images.cuda()
132
+
133
+ logits = model(images)
134
+ logits = resize_for_tensors(logits, image_size)
135
+
136
+ logits = logits[0] + logits[1].flip(-1)
137
+ logits = get_numpy_from_tensor(logits).transpose((1, 2, 0))
138
+ return logits
139
+
140
+
141
+ def predict_image(ori_image):
142
+ with torch.no_grad():
143
+ ori_w, ori_h = ori_image.size
144
+
145
+ cams_list = []
146
+
147
+ for scale in scales:
148
+ image = copy.deepcopy(ori_image)
149
+ image = image.resize((round(ori_w * scale), round(ori_h * scale)), resample=PIL.Image.BICUBIC)
150
+
151
+ image = normalize_fn(image)
152
+ image = image.transpose((2, 0, 1))
153
+
154
+ image = torch.from_numpy(image)
155
+ flipped_image = image.flip(-1)
156
+
157
+ images = torch.stack([image, flipped_image])
158
+
159
+ cams = inference(images, (ori_h, ori_w))
160
+ cams_list.append(cams)
161
+
162
+ preds = np.sum(cams_list, axis=0)
163
+ preds = F.softmax(torch.from_numpy(preds), dim=-1).numpy()
164
+
165
+ if args.iteration > 0:
166
+ preds = crf_inference(np.asarray(ori_image), preds.transpose((2, 0, 1)), t=args.iteration)
167
+ pred_mask = np.argmax(preds, axis=0)
168
+ else:
169
+ pred_mask = np.argmax(preds, axis=-1)
170
+
171
+ return pred_mask.astype(np.uint8)
172
+
173
+
174
+ demo = gr.Interface(
175
+ fn=predict_image,
176
+ inputs="image",
177
+ outputs="image"
178
+ )