Xu Ma commited on
Commit
c89c010
·
1 Parent(s): 1b90f20
Files changed (3) hide show
  1. app.py +207 -138
  2. config/base.yaml +19 -3
  3. main.py +339 -0
app.py CHANGED
@@ -9,124 +9,193 @@ import torch
9
  import yaml
10
  from PIL import Image
11
  from subprocess import call
12
-
13
- ROOT_PATH = sys.path[0] # 根目录
14
-
15
- # 模型路径
16
- model_path = "ultralytics/yolov5"
17
- # 模型名称临时变量
18
- model_name_tmp = ""
19
- # 设备临时变量
20
- device_tmp = ""
21
- # 文件后缀
22
- suffix_list = [".csv", ".yaml"]
23
- def parse_args(known=False):
24
- parser = argparse.ArgumentParser(description="Gradio LIVE")
25
- parser.add_argument(
26
- "--model_name", "-mn", default="yolov5s", type=str, help="model name"
27
- )
28
- parser.add_argument(
29
- "--model_cfg",
30
- "-mc",
31
- default="./model_config/model_name_p5_all.yaml",
32
- type=str,
33
- help="model config",
34
- )
35
- parser.add_argument(
36
- "--cls_name",
37
- "-cls",
38
- default="./cls_name/cls_name.yaml",
39
- type=str,
40
- help="cls name",
41
- )
42
- parser.add_argument(
43
- "--nms_conf",
44
- "-conf",
45
- default=0.5,
46
- type=float,
47
- help="model NMS confidence threshold",
48
- )
49
- parser.add_argument(
50
- "--nms_iou", "-iou", default=0.45, type=float, help="model NMS IoU threshold"
51
- )
52
-
53
- parser.add_argument(
54
- "--label_dnt_show",
55
- "-lds",
56
- action="store_false",
57
- default=True,
58
- help="label show",
59
- )
60
- parser.add_argument(
61
- "--device",
62
- "-dev",
63
- default="cpu",
64
- type=str,
65
- help="cuda or cpu, hugging face only cpu",
66
- )
67
- parser.add_argument(
68
- "--inference_size", "-isz", default=640, type=int, help="model inference size"
69
- )
70
-
71
- args = parser.parse_known_args()[0] if known else parser.parse_args()
72
- return args
73
- # 模型加载
74
- def model_loading(model_name, device):
75
-
76
- # 加载本地模型
77
- model = torch.hub.load(model_path, model_name, force_reload=True, device=device)
78
-
79
- return model
80
- # 检测信息
81
- def export_json(results, model, img_size):
82
-
83
- return [
84
- [
85
- {
86
- "id": int(i),
87
- "class": int(result[i][5]),
88
- "class_name": model.model.names[int(result[i][5])],
89
- "normalized_box": {
90
- "x0": round(result[i][:4].tolist()[0], 6),
91
- "y0": round(result[i][:4].tolist()[1], 6),
92
- "x1": round(result[i][:4].tolist()[2], 6),
93
- "y1": round(result[i][:4].tolist()[3], 6),
94
- },
95
- "confidence": round(float(result[i][4]), 2),
96
- "fps": round(1000 / float(results.t[1]), 2),
97
- "width": img_size[0],
98
- "height": img_size[1],
99
- }
100
- for i in range(len(result))
101
- ]
102
- for result in results.xyxyn
103
- ]
104
- def yolo_det(img, experiment_id, device=None, model_name=None, inference_size=None, conf=None, iou=None, label_opt=None, model_cls=None):
105
-
106
- global model, model_name_tmp, device_tmp
107
-
108
- if model_name_tmp != model_name:
109
- # 模型判断,避免反复加载
110
- model_name_tmp = model_name
111
- model = model_loading(model_name_tmp, device)
112
- elif device_tmp != device:
113
- device_tmp = device
114
- model = model_loading(model_name_tmp, device)
115
-
116
- # -----------模型调参-----------
117
- model.conf = conf # NMS 置信度阈值
118
- model.iou = iou # NMS IOU阈值
119
- model.max_det = 1000 # 最大检测框数
120
- model.classes = model_cls # 模型类别
121
-
122
- results = model(img, size=inference_size) # 检测
123
- results.render(labels=label_opt) # 渲染
124
-
125
- det_img = Image.fromarray(results.imgs[0]) # 检测图片
126
-
127
- det_json = export_json(results, model, img.size)[0] # 检测信息
128
-
129
- return det_img, det_json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
 
132
  def run_cmd(command):
@@ -150,25 +219,25 @@ run_cmd("python main.py --config config/base.yaml --experiment experiment_5x1 --
150
 
151
 
152
 
153
- # yaml文件解析
154
- def yaml_parse(file_path):
155
- return yaml.safe_load(open(file_path, "r", encoding="utf-8").read())
156
-
157
-
158
- # yaml csv 文件解析
159
- def yaml_csv(file_path, file_tag):
160
- file_suffix = Path(file_path).suffix
161
- if file_suffix == suffix_list[0]:
162
- # 模型名称
163
- file_names = [i[0] for i in list(csv.reader(open(file_path)))] # csv版
164
- elif file_suffix == suffix_list[1]:
165
- # 模型名称
166
- file_names = yaml_parse(file_path).get(file_tag) # yaml版
167
- else:
168
- print(f"{file_path}格式不正确!程序退出!")
169
- sys.exit()
170
-
171
- return file_names
172
 
173
 
174
  def main(args):
@@ -223,7 +292,7 @@ def main(args):
223
 
224
  # Interface
225
  gr.Interface(
226
- fn=yolo_det,
227
  inputs=inputs,
228
  outputs=[outputs, outputs02],
229
  title=title,
 
9
  import yaml
10
  from PIL import Image
11
  from subprocess import call
12
+ import pydiffvg
13
+ import torch
14
+ import cv2
15
+ import matplotlib.pyplot as plt
16
+ import random
17
+ import argparse
18
+ import math
19
+ import errno
20
+ from tqdm import tqdm
21
+ import yaml
22
+ from easydict import EasyDict as edict
23
+ from main import main_func
24
+
25
+ def parse_args():
26
+ parser = argparse.ArgumentParser()
27
+ parser.add_argument('--debug', action='store_true', default=False)
28
+ parser.add_argument("--config", default="config/base.yaml", type=str)
29
+ parser.add_argument("--experiment", type=str)
30
+ parser.add_argument("--seed", type=int)
31
+ parser.add_argument("--target", type=str, help="target image path")
32
+ parser.add_argument('--log_dir', metavar='DIR', default="log/debug")
33
+ parser.add_argument('--initial', type=str, default="random", choices=['random', 'circle'])
34
+ parser.add_argument('--signature', nargs='+', type=str)
35
+ parser.add_argument('--seginit', nargs='+', type=str)
36
+ parser.add_argument("--num_segments", type=int, default=4)
37
+ # parser.add_argument("--num_paths", type=str, default="1,1,1")
38
+ # parser.add_argument("--num_iter", type=int, default=500)
39
+ # parser.add_argument('--free', action='store_true')
40
+ # Please ensure that image resolution is divisible by pool_size; otherwise the performance would drop a lot.
41
+ # parser.add_argument('--pool_size', type=int, default=40, help="the pooled image size for next path initialization")
42
+ # parser.add_argument('--save_loss', action='store_true')
43
+ # parser.add_argument('--save_init', action='store_true')
44
+ # parser.add_argument('--save_image', action='store_true')
45
+ # parser.add_argument('--save_video', action='store_true')
46
+ # parser.add_argument('--print_weight', action='store_true')
47
+ # parser.add_argument('--circle_init_radius', type=float)
48
+ cfg = edict()
49
+ args = parser.parse_args()
50
+ cfg.debug = args.debug
51
+ cfg.config = args.config
52
+ cfg.experiment = args.experiment
53
+ cfg.seed = args.seed
54
+ cfg.target = args.target
55
+ cfg.log_dir = args.log_dir
56
+ cfg.initial = args.initial
57
+ cfg.signature = args.signature
58
+ # set cfg num_segments in command
59
+ cfg.num_segments = args.num_segments
60
+ if args.seginit is not None:
61
+ cfg.seginit = edict()
62
+ cfg.seginit.type = args.seginit[0]
63
+ if cfg.seginit.type == 'circle':
64
+ cfg.seginit.radius = float(args.seginit[1])
65
+ return cfg
66
+
67
+
68
+
69
+
70
+
71
+ def run_live(img, experiment_id):
72
+ main_func(img, experiment_id)
73
+ return 0, 1
74
+
75
+
76
+
77
+
78
+
79
+
80
+
81
+
82
+
83
+ # ROOT_PATH = sys.path[0] # 根目录
84
+ # # 模型路径
85
+ # model_path = "ultralytics/yolov5"
86
+ # # 模型名称临时变量
87
+ # model_name_tmp = ""
88
+ # # 设备临时变量
89
+ # device_tmp = ""
90
+ # # 文件后缀
91
+ # suffix_list = [".csv", ".yaml"]
92
+ # def parse_args(known=False):
93
+ # parser = argparse.ArgumentParser(description="Gradio LIVE")
94
+ # parser.add_argument(
95
+ # "--model_name", "-mn", default="yolov5s", type=str, help="model name"
96
+ # )
97
+ # parser.add_argument(
98
+ # "--model_cfg",
99
+ # "-mc",
100
+ # default="./model_config/model_name_p5_all.yaml",
101
+ # type=str,
102
+ # help="model config",
103
+ # )
104
+ # parser.add_argument(
105
+ # "--cls_name",
106
+ # "-cls",
107
+ # default="./cls_name/cls_name.yaml",
108
+ # type=str,
109
+ # help="cls name",
110
+ # )
111
+ # parser.add_argument(
112
+ # "--nms_conf",
113
+ # "-conf",
114
+ # default=0.5,
115
+ # type=float,
116
+ # help="model NMS confidence threshold",
117
+ # )
118
+ # parser.add_argument(
119
+ # "--nms_iou", "-iou", default=0.45, type=float, help="model NMS IoU threshold"
120
+ # )
121
+ #
122
+ # parser.add_argument(
123
+ # "--label_dnt_show",
124
+ # "-lds",
125
+ # action="store_false",
126
+ # default=True,
127
+ # help="label show",
128
+ # )
129
+ # parser.add_argument(
130
+ # "--device",
131
+ # "-dev",
132
+ # default="cpu",
133
+ # type=str,
134
+ # help="cuda or cpu, hugging face only cpu",
135
+ # )
136
+ # parser.add_argument(
137
+ # "--inference_size", "-isz", default=640, type=int, help="model inference size"
138
+ # )
139
+ #
140
+ # args = parser.parse_known_args()[0] if known else parser.parse_args()
141
+ # return args
142
+ # # 模型加载
143
+ # def model_loading(model_name, device):
144
+ #
145
+ # # 加载本地模型
146
+ # model = torch.hub.load(model_path, model_name, force_reload=True, device=device)
147
+ #
148
+ # return model
149
+ # # 检测信息
150
+ # def export_json(results, model, img_size):
151
+ #
152
+ # return [
153
+ # [
154
+ # {
155
+ # "id": int(i),
156
+ # "class": int(result[i][5]),
157
+ # "class_name": model.model.names[int(result[i][5])],
158
+ # "normalized_box": {
159
+ # "x0": round(result[i][:4].tolist()[0], 6),
160
+ # "y0": round(result[i][:4].tolist()[1], 6),
161
+ # "x1": round(result[i][:4].tolist()[2], 6),
162
+ # "y1": round(result[i][:4].tolist()[3], 6),
163
+ # },
164
+ # "confidence": round(float(result[i][4]), 2),
165
+ # "fps": round(1000 / float(results.t[1]), 2),
166
+ # "width": img_size[0],
167
+ # "height": img_size[1],
168
+ # }
169
+ # for i in range(len(result))
170
+ # ]
171
+ # for result in results.xyxyn
172
+ # ]
173
+ # def yolo_det(img, experiment_id, device=None, model_name=None, inference_size=None, conf=None, iou=None, label_opt=None, model_cls=None):
174
+ #
175
+ # global model, model_name_tmp, device_tmp
176
+ #
177
+ # if model_name_tmp != model_name:
178
+ # # 模型判断,避免反复加载
179
+ # model_name_tmp = model_name
180
+ # model = model_loading(model_name_tmp, device)
181
+ # elif device_tmp != device:
182
+ # device_tmp = device
183
+ # model = model_loading(model_name_tmp, device)
184
+ #
185
+ # # -----------模型调参-----------
186
+ # model.conf = conf # NMS 置信度阈值
187
+ # model.iou = iou # NMS IOU阈值
188
+ # model.max_det = 1000 # 最大检测框数
189
+ # model.classes = model_cls # 模型类别
190
+ #
191
+ # results = model(img, size=inference_size) # 检测
192
+ # results.render(labels=label_opt) # 渲染
193
+ #
194
+ # det_img = Image.fromarray(results.imgs[0]) # 检测图片
195
+ #
196
+ # det_json = export_json(results, model, img.size)[0] # 检测信息
197
+ #
198
+ # return det_img, det_json
199
 
200
 
201
  def run_cmd(command):
 
219
 
220
 
221
 
222
+ # # yaml文件解析
223
+ # def yaml_parse(file_path):
224
+ # return yaml.safe_load(open(file_path, "r", encoding="utf-8").read())
225
+ #
226
+ #
227
+ # # yaml csv 文件解析
228
+ # def yaml_csv(file_path, file_tag):
229
+ # file_suffix = Path(file_path).suffix
230
+ # if file_suffix == suffix_list[0]:
231
+ # # 模型名称
232
+ # file_names = [i[0] for i in list(csv.reader(open(file_path)))] # csv版
233
+ # elif file_suffix == suffix_list[1]:
234
+ # # 模型名称
235
+ # file_names = yaml_parse(file_path).get(file_tag) # yaml版
236
+ # else:
237
+ # print(f"{file_path}格式不正确!程序退出!")
238
+ # sys.exit()
239
+ #
240
+ # return file_names
241
 
242
 
243
  def main(args):
 
292
 
293
  # Interface
294
  gr.Interface(
295
+ fn=run_live,
296
  inputs=inputs,
297
  outputs=[outputs, outputs02],
298
  title=title,
config/base.yaml CHANGED
@@ -5,10 +5,10 @@ default:
5
  type: circle
6
  radius: 5
7
  save:
8
- init: true
9
- image: true
10
  output: true
11
- video: true
12
  loss: false
13
  trainable:
14
  bg: False
@@ -66,3 +66,19 @@ experiment_1357:
66
  type: list
67
  schedule: [1, 3, 5, 7]
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  type: circle
6
  radius: 5
7
  save:
8
+ init: false
9
+ image: false
10
  output: true
11
+ video: false
12
  loss: false
13
  trainable:
14
  bg: False
 
66
  type: list
67
  schedule: [1, 3, 5, 7]
68
 
69
+
70
+ experiment_exp2_256:
71
+ path_schedule:
72
+ type: exp
73
+ base: 2
74
+ max_path: 256
75
+ max_path_per_iter: 32
76
+
77
+
78
+ experiment_exp2_128:
79
+ path_schedule:
80
+ type: exp
81
+ base: 2
82
+ max_path: 128
83
+ max_path_per_iter: 32
84
+
main.py CHANGED
@@ -344,6 +344,345 @@ class linear_decay_lrlambda_f(object):
344
  lr = lr_s * (1-r) + lr_e * r
345
  return lr
346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
 
348
  if __name__ == "__main__":
349
 
 
344
  lr = lr_s * (1-r) + lr_e * r
345
  return lr
346
 
347
+ def main_func(target, experiment):
348
+ cfg_arg = parse_args()
349
+ with open(cfg_arg.config, 'r') as f:
350
+ cfg = yaml.load(f, Loader=yaml.FullLoader)
351
+ cfg_default = edict(cfg['default'])
352
+ cfg = edict(cfg[cfg_arg.experiment])
353
+ cfg.update(cfg_default)
354
+ cfg.update(cfg_arg)
355
+ cfg.exid = get_experiment_id(cfg.debug)
356
+
357
+ cfg.experiment_dir = \
358
+ osp.join(cfg.log_dir, '{}_{}'.format(cfg.exid, '_'.join(cfg.signature)))
359
+ cfg.target = target
360
+ cfg.experiment = experiment
361
+
362
+ configfile = osp.join(cfg.experiment_dir, 'config.yaml')
363
+ check_and_create_dir(configfile)
364
+ with open(osp.join(configfile), 'w') as f:
365
+ yaml.dump(edict_2_dict(cfg), f)
366
+
367
+ # Use GPU if available
368
+ pydiffvg.set_use_gpu(torch.cuda.is_available())
369
+ device = pydiffvg.get_device()
370
+
371
+ gt = np.array(PIL.Image.open(cfg.target))
372
+ print(f"Input image shape is: {gt.shape}")
373
+ if len(gt.shape) == 2:
374
+ print("Converting the gray-scale image to RGB.")
375
+ gt = gt.unsqueeze(dim=-1).repeat(1,1,3)
376
+ if gt.shape[2] == 4:
377
+ print("Input image includes alpha channel, simply dropout alpha channel.")
378
+ gt = gt[:, :, :3]
379
+ gt = (gt/255).astype(np.float32)
380
+ gt = torch.FloatTensor(gt).permute(2, 0, 1)[None].to(device)
381
+ if cfg.use_ycrcb:
382
+ gt = ycrcb_conversion(gt)
383
+ h, w = gt.shape[2:]
384
+
385
+ path_schedule = get_path_schedule(**cfg.path_schedule)
386
+
387
+ if cfg.seed is not None:
388
+ random.seed(cfg.seed)
389
+ npr.seed(cfg.seed)
390
+ torch.manual_seed(cfg.seed)
391
+ render = pydiffvg.RenderFunction.apply
392
+
393
+ shapes_record, shape_groups_record = [], []
394
+
395
+ region_loss = None
396
+ loss_matrix = []
397
+
398
+ para_point, para_color = {}, {}
399
+ if cfg.trainable.stroke:
400
+ para_stroke_width, para_stroke_color = {}, {}
401
+
402
+ pathn_record = []
403
+ # Background
404
+ if cfg.trainable.bg:
405
+ # meancolor = gt.mean([2, 3])[0]
406
+ para_bg = torch.tensor([1., 1., 1.], requires_grad=True, device=device)
407
+ else:
408
+ if cfg.use_ycrcb:
409
+ para_bg = torch.tensor([219/255, 0, 0], requires_grad=False, device=device)
410
+ else:
411
+ para_bg = torch.tensor([1., 1., 1.], requires_grad=False, device=device)
412
+
413
+ ##################
414
+ # start_training #
415
+ ##################
416
+
417
+ loss_weight = None
418
+ loss_weight_keep = 0
419
+ if cfg.coord_init.type == 'naive':
420
+ pos_init_method = naive_coord_init(
421
+ para_bg.view(1, -1, 1, 1).repeat(1, 1, h, w), gt)
422
+ elif cfg.coord_init.type == 'sparse':
423
+ pos_init_method = sparse_coord_init(
424
+ para_bg.view(1, -1, 1, 1).repeat(1, 1, h, w), gt)
425
+ elif cfg.coord_init.type == 'random':
426
+ pos_init_method = random_coord_init([h, w])
427
+ else:
428
+ raise ValueError
429
+
430
+ lrlambda_f = linear_decay_lrlambda_f(cfg.num_iter, 0.4)
431
+ optim_schedular_dict = {}
432
+
433
+ for path_idx, pathn in enumerate(path_schedule):
434
+ loss_list = []
435
+ print("=> Adding [{}] paths, [{}] ...".format(pathn, cfg.seginit.type))
436
+ pathn_record.append(pathn)
437
+ pathn_record_str = '-'.join([str(i) for i in pathn_record])
438
+
439
+ # initialize new shapes related stuffs.
440
+ if cfg.trainable.stroke:
441
+ shapes, shape_groups, point_var, color_var, stroke_width_var, stroke_color_var = init_shapes(
442
+ pathn, cfg.num_segments, (h, w),
443
+ cfg.seginit, len(shapes_record),
444
+ pos_init_method,
445
+ trainable_stroke=True,
446
+ gt=gt, )
447
+ para_stroke_width[path_idx] = stroke_width_var
448
+ para_stroke_color[path_idx] = stroke_color_var
449
+ else:
450
+ shapes, shape_groups, point_var, color_var = init_shapes(
451
+ pathn, cfg.num_segments, (h, w),
452
+ cfg.seginit, len(shapes_record),
453
+ pos_init_method,
454
+ trainable_stroke=False,
455
+ gt=gt, )
456
+
457
+ shapes_record += shapes
458
+ shape_groups_record += shape_groups
459
+
460
+ if cfg.save.init:
461
+ filename = os.path.join(
462
+ cfg.experiment_dir, "svg-init",
463
+ "{}-init.svg".format(pathn_record_str))
464
+ check_and_create_dir(filename)
465
+ pydiffvg.save_svg(
466
+ filename, w, h,
467
+ shapes_record, shape_groups_record)
468
+
469
+ para = {}
470
+ if (cfg.trainable.bg) and (path_idx == 0):
471
+ para['bg'] = [para_bg]
472
+ para['point'] = point_var
473
+ para['color'] = color_var
474
+ if cfg.trainable.stroke:
475
+ para['stroke_width'] = stroke_width_var
476
+ para['stroke_color'] = stroke_color_var
477
+
478
+ pg = [{'params' : para[ki], 'lr' : cfg.lr_base[ki]} for ki in sorted(para.keys())]
479
+ optim = torch.optim.Adam(pg)
480
+
481
+ if cfg.trainable.record:
482
+ scheduler = LambdaLR(
483
+ optim, lr_lambda=lrlambda_f, last_epoch=-1)
484
+ else:
485
+ scheduler = LambdaLR(
486
+ optim, lr_lambda=lrlambda_f, last_epoch=cfg.num_iter)
487
+ optim_schedular_dict[path_idx] = (optim, scheduler)
488
+
489
+ # Inner loop training
490
+ t_range = tqdm(range(cfg.num_iter))
491
+ for t in t_range:
492
+
493
+ for _, (optim, _) in optim_schedular_dict.items():
494
+ optim.zero_grad()
495
+
496
+ # Forward pass: render the image.
497
+ scene_args = pydiffvg.RenderFunction.serialize_scene(
498
+ w, h, shapes_record, shape_groups_record)
499
+ img = render(w, h, 2, 2, t, None, *scene_args)
500
+
501
+ # Compose img with white background
502
+ img = img[:, :, 3:4] * img[:, :, :3] + \
503
+ para_bg * (1 - img[:, :, 3:4])
504
+
505
+ if cfg.save.video:
506
+ filename = os.path.join(
507
+ cfg.experiment_dir, "video-png",
508
+ "{}-iter{}.png".format(pathn_record_str, t))
509
+ check_and_create_dir(filename)
510
+ if cfg.use_ycrcb:
511
+ imshow = ycrcb_conversion(
512
+ img, format='[2D x 3]', reverse=True).detach().cpu()
513
+ else:
514
+ imshow = img.detach().cpu()
515
+ pydiffvg.imwrite(imshow, filename, gamma=gamma)
516
+
517
+ x = img.unsqueeze(0).permute(0, 3, 1, 2) # HWC -> NCHW
518
+
519
+ if cfg.use_ycrcb:
520
+ color_reweight = torch.FloatTensor([255/219, 255/224, 255/255]).to(device)
521
+ loss = ((x-gt)*(color_reweight.view(1, -1, 1, 1)))**2
522
+ else:
523
+ loss = ((x-gt)**2)
524
+
525
+ if cfg.loss.use_l1_loss:
526
+ loss = abs(x-gt)
527
+
528
+ if cfg.loss.use_distance_weighted_loss:
529
+ if cfg.use_ycrcb:
530
+ raise ValueError
531
+ shapes_forsdf = copy.deepcopy(shapes)
532
+ shape_groups_forsdf = copy.deepcopy(shape_groups)
533
+ for si in shapes_forsdf:
534
+ si.stroke_width = torch.FloatTensor([0]).to(device)
535
+ for sg_idx, sgi in enumerate(shape_groups_forsdf):
536
+ sgi.fill_color = torch.FloatTensor([1, 1, 1, 1]).to(device)
537
+ sgi.shape_ids = torch.LongTensor([sg_idx]).to(device)
538
+
539
+ sargs_forsdf = pydiffvg.RenderFunction.serialize_scene(
540
+ w, h, shapes_forsdf, shape_groups_forsdf)
541
+ with torch.no_grad():
542
+ im_forsdf = render(w, h, 2, 2, 0, None, *sargs_forsdf)
543
+ # use alpha channel is a trick to get 0-1 image
544
+ im_forsdf = (im_forsdf[:, :, 3]).detach().cpu().numpy()
545
+ loss_weight = get_sdf(im_forsdf, normalize='to1')
546
+ loss_weight += loss_weight_keep
547
+ loss_weight = np.clip(loss_weight, 0, 1)
548
+ loss_weight = torch.FloatTensor(loss_weight).to(device)
549
+
550
+ if cfg.save.loss:
551
+ save_loss = loss.squeeze(dim=0).mean(dim=0,keepdim=False).cpu().detach().numpy()
552
+ save_weight = loss_weight.cpu().detach().numpy()
553
+ save_weighted_loss = save_loss*save_weight
554
+ # normalize to [0,1]
555
+ save_loss = (save_loss - np.min(save_loss))/np.ptp(save_loss)
556
+ save_weight = (save_weight - np.min(save_weight))/np.ptp(save_weight)
557
+ save_weighted_loss = (save_weighted_loss - np.min(save_weighted_loss))/np.ptp(save_weighted_loss)
558
+
559
+ # save
560
+ plt.imshow(save_loss, cmap='Reds')
561
+ plt.axis('off')
562
+ # plt.colorbar()
563
+ filename = os.path.join(cfg.experiment_dir, "loss", "{}-iter{}-mseloss.png".format(pathn_record_str, t))
564
+ check_and_create_dir(filename)
565
+ plt.savefig(filename, dpi=800)
566
+ plt.close()
567
+
568
+ plt.imshow(save_weight, cmap='Greys')
569
+ plt.axis('off')
570
+ # plt.colorbar()
571
+ filename = os.path.join(cfg.experiment_dir, "loss", "{}-iter{}-sdfweight.png".format(pathn_record_str, t))
572
+ plt.savefig(filename, dpi=800)
573
+ plt.close()
574
+
575
+ plt.imshow(save_weighted_loss, cmap='Reds')
576
+ plt.axis('off')
577
+ # plt.colorbar()
578
+ filename = os.path.join(cfg.experiment_dir, "loss", "{}-iter{}-weightedloss.png".format(pathn_record_str, t))
579
+ plt.savefig(filename, dpi=800)
580
+ plt.close()
581
+
582
+
583
+
584
+
585
+
586
+ if loss_weight is None:
587
+ loss = loss.sum(1).mean()
588
+ else:
589
+ loss = (loss.sum(1)*loss_weight).mean()
590
+
591
+ # if (cfg.loss.bis_loss_weight is not None) and (cfg.loss.bis_loss_weight > 0):
592
+ # loss_bis = bezier_intersection_loss(point_var[0]) * cfg.loss.bis_loss_weight
593
+ # loss = loss + loss_bis
594
+ if (cfg.loss.xing_loss_weight is not None) \
595
+ and (cfg.loss.xing_loss_weight > 0):
596
+ loss_xing = xing_loss(point_var) * cfg.loss.xing_loss_weight
597
+ loss = loss + loss_xing
598
+
599
+
600
+ loss_list.append(loss.item())
601
+ t_range.set_postfix({'loss': loss.item()})
602
+ loss.backward()
603
+
604
+ # step
605
+ for _, (optim, scheduler) in optim_schedular_dict.items():
606
+ optim.step()
607
+ scheduler.step()
608
+
609
+ for group in shape_groups_record:
610
+ group.fill_color.data.clamp_(0.0, 1.0)
611
+
612
+ if cfg.loss.use_distance_weighted_loss:
613
+ loss_weight_keep = loss_weight.detach().cpu().numpy() * 1
614
+
615
+ if not cfg.trainable.record:
616
+ for _, pi in pg.items():
617
+ for ppi in pi:
618
+ pi.require_grad = False
619
+ optim_schedular_dict = {}
620
+
621
+ if cfg.save.image:
622
+ filename = os.path.join(
623
+ cfg.experiment_dir, "demo-png", "{}.png".format(pathn_record_str))
624
+ check_and_create_dir(filename)
625
+ if cfg.use_ycrcb:
626
+ imshow = ycrcb_conversion(
627
+ img, format='[2D x 3]', reverse=True).detach().cpu()
628
+ else:
629
+ imshow = img.detach().cpu()
630
+ pydiffvg.imwrite(imshow, filename, gamma=gamma)
631
+
632
+ if cfg.save.output:
633
+ filename = os.path.join(
634
+ cfg.experiment_dir, "output-svg", "{}.svg".format(pathn_record_str))
635
+ check_and_create_dir(filename)
636
+ pydiffvg.save_svg(filename, w, h, shapes_record, shape_groups_record)
637
+
638
+ loss_matrix.append(loss_list)
639
+
640
+ # calculate the pixel loss
641
+ # pixel_loss = ((x-gt)**2).sum(dim=1, keepdim=True).sqrt_() # [N,1,H, W]
642
+ # region_loss = adaptive_avg_pool2d(pixel_loss, cfg.region_loss_pool_size)
643
+ # loss_weight = torch.softmax(region_loss.reshape(1, 1, -1), dim=-1)\
644
+ # .reshape_as(region_loss)
645
+
646
+ pos_init_method = naive_coord_init(x, gt)
647
+
648
+ if cfg.coord_init.type == 'naive':
649
+ pos_init_method = naive_coord_init(x, gt)
650
+ elif cfg.coord_init.type == 'sparse':
651
+ pos_init_method = sparse_coord_init(x, gt)
652
+ elif cfg.coord_init.type == 'random':
653
+ pos_init_method = random_coord_init([h, w])
654
+ else:
655
+ raise ValueError
656
+
657
+ if cfg.save.video:
658
+ print("saving iteration video...")
659
+ img_array = []
660
+ for ii in range(0, cfg.num_iter):
661
+ filename = os.path.join(
662
+ cfg.experiment_dir, "video-png",
663
+ "{}-iter{}.png".format(pathn_record_str, ii))
664
+ img = cv2.imread(filename)
665
+ # cv2.putText(
666
+ # img, "Path:{} \nIteration:{}".format(pathn_record_str, ii),
667
+ # (10, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
668
+ img_array.append(img)
669
+
670
+ videoname = os.path.join(
671
+ cfg.experiment_dir, "video-avi",
672
+ "{}.avi".format(pathn_record_str))
673
+ check_and_create_dir(videoname)
674
+ out = cv2.VideoWriter(
675
+ videoname,
676
+ # cv2.VideoWriter_fourcc(*'mp4v'),
677
+ cv2.VideoWriter_fourcc(*'FFV1'),
678
+ 20.0, (w, h))
679
+ for iii in range(len(img_array)):
680
+ out.write(img_array[iii])
681
+ out.release()
682
+ # shutil.rmtree(os.path.join(cfg.experiment_dir, "video-png"))
683
+
684
+ print("The last loss is: {}".format(loss.item()))
685
+
686
 
687
  if __name__ == "__main__":
688