hanquansanren commited on
Commit
125b486
·
1 Parent(s): 3a8784c

Add application file

Browse files
Files changed (2) hide show
  1. .gitignore +1 -1
  2. run_gradio.py +553 -0
.gitignore CHANGED
@@ -3,7 +3,7 @@ vis_hp
3
  assets
4
  images
5
  backup
6
- run_gradio.py
7
  run_foward.py
8
 
9
 
 
3
  assets
4
  images
5
  backup
6
+ # run_gradio.py
7
  run_foward.py
8
 
9
 
run_gradio.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ from datetime import date
4
+ from shutil import copyfile
5
+ import cv2 as cv
6
+ import numpy as np
7
+ import torch
8
+ import torch.backends.cudnn
9
+ import admin.settings as ws_settings
10
+ import os
11
+ import torch
12
+ import torch.distributed as dist
13
+ import torchvision.transforms as transforms
14
+ from torch.utils.data import DataLoader
15
+ import datasets
16
+ from utils_data.image_transforms import ArrayToTensor
17
+ from train_settings.dvd.improved_diffusion import dist_util, logger
18
+ from train_settings.dvd.improved_diffusion.script_util import args_to_dict, create_model_and_diffusion,model_and_diffusion_defaults
19
+ from train_settings.models.geotr.geotr_core import GeoTr_Seg_Inf, reload_segmodel, reload_model, Seg
20
+ from train_settings.models.geotr.unet_model import UNet
21
+ from PIL import Image
22
+ from tqdm import tqdm
23
+ import torch.nn.functional as F
24
+ import torch as th
25
+ from train_settings.dvd.improved_diffusion.gaussian_diffusion import GaussianDiffusion
26
+ from train_settings.dvd.feature_backbones.VGG_features import VGGPyramid
27
+ from train_settings.dvd.eval_utils import extract_raw_features_single,extract_raw_features_single2
28
+ from datasets.utils.warping import register_model2
29
+
30
+ import gradio as gr
31
+
32
+
33
+
34
+
35
+ reg_model_bilin = register_model2((512,512), 'bilinear')
36
+
37
+ def coords_grid_tensor(perturbed_img_shape):
38
+ im_x, im_y = np.mgrid[0:perturbed_img_shape[0]-1:complex(perturbed_img_shape[0]), 0:perturbed_img_shape[1]-1:complex(perturbed_img_shape[1])]
39
+ coords = np.stack((im_y,im_x), axis=2) # 先x后y,行序优先
40
+ coords = th.from_numpy(coords).float().permute(2,0,1).to(dist_util.dev()) # (2, 512, 512)
41
+ return coords.unsqueeze(0) # [2, 512, 512]
42
+
43
+ def run_sample_lr_dewarping(
44
+ settings, logger, diffusion, model, radius, source, feature_size,
45
+ raw_corr, init_flow, c20, source_64, pyramid, doc_mask,
46
+ seg_map_all=None, textline_map=None, init_feat=None
47
+ ):
48
+ model_kwsettings = {'init_flow': init_flow, 'src_feat': c20, 'src_64':None,
49
+ 'y512':source, 'tmode':settings.env.train_mode,
50
+ 'mask_cat': doc_mask,
51
+ 'init_feat': init_feat,
52
+ 'iter': settings.env.iter} # 'trg_feat': trg_feat
53
+ # [1, 81, 64, 64] [1, 2, 64, 64] [1, 64, 64, 64]
54
+ if settings.env.use_gt_mask == False:
55
+ model_kwsettings['mask_y512'] = seg_map_all # [b, 384, 64, 64]
56
+ if settings.env.use_line_mask == True:
57
+ model_kwsettings['line_msk'] = textline_map #
58
+ image_size_h, image_size_w = feature_size, feature_size
59
+
60
+ logger.info(f"\nStarting sampling")
61
+
62
+ sample, _ = diffusion.ddim_sample_loop(
63
+ model,
64
+ (1, 2, image_size_h, image_size_w), # 1,2,64,64
65
+ noise=None,
66
+ clip_denoised=settings.env.clip_denoised, # false
67
+ model_kwargs=model_kwsettings,
68
+ eta=0.0,
69
+ progress=True,
70
+ denoised_fn=None,
71
+ sampling_kwargs={'src_img': source}, # 'trg_img': target
72
+ logger=logger,
73
+ n_batch=settings.env.n_batch,
74
+ time_variant = settings.env.time_variant,
75
+ pyramid=pyramid
76
+ )
77
+
78
+ sample = th.clamp(sample, min=-1, max=1)
79
+ return sample
80
+
81
+ def visualize_dewarping(settings, sample, data, i, source_vis, data_path, ref_flow=None):
82
+ os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}/dewarped_pred', exist_ok=True) # pred dewarped
83
+ # warped_src = warp(source_vis.to(sample.device).float(), sample) # [1, 3, 1629, 981]
84
+ warped_src = reg_model_bilin([source_vis.to(sample.device).float(), sample])
85
+ warped_src = warped_src[0].permute(1, 2, 0).detach().cpu().numpy()#*255. # (1873, 1353, 3)
86
+ warped_src = Image.fromarray((warped_src).astype(np.uint8))
87
+
88
+ return warped_src
89
+
90
+ def visualize_dewarping_single(settings, sample, source_vis):
91
+ os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}/dewarped_pred', exist_ok=True) # pred dewarped
92
+ # warped_src = warp(source_vis.to(sample.device).float(), sample) # [1, 3, 1629, 981]
93
+ warped_src = reg_model_bilin([source_vis.to(sample.device).float(), sample])
94
+ warped_src = warped_src[0].permute(1, 2, 0).detach().cpu().numpy()#*255. # (1873, 1353, 3)
95
+ warped_src = Image.fromarray((warped_src).astype(np.uint8))
96
+
97
+ return warped_src
98
+
99
+
100
+
101
+
102
+
103
+
104
+ def prepare_data(settings, batch_preprocessing, SIZE, data):
105
+ if 'source_image_ori' in data:
106
+ source_vis = data['source_image_ori'] # B, C, 512, 512 torch.uint8 cpu
107
+ else:
108
+ source_vis = data['source_image']
109
+ if 'target_image' in data:
110
+ target_vis = data['target_image']
111
+ else:
112
+ target_vis = None
113
+
114
+ _, _, H_ori, W_ori = source_vis.shape
115
+
116
+ source = data['source_image'].to(dist_util.dev()) # [1, 3, 914, 1380] torch.float32
117
+ if 'source_image_0' in data:
118
+ source_0 = data['source_image_0'].to(dist_util.dev())
119
+ else:
120
+ source_0 = None
121
+ if 'target_image' in data:
122
+ target = data['target_image'] # [1, 3, 914, 1380] torch.float32
123
+ else:
124
+ target = None
125
+ if 'flow_map' in data:
126
+ batch_ori = data['flow_map'] # [1, 2, 914, 1380] torch.float32
127
+ else:
128
+ batch_ori = None
129
+ if 'flow_map_inter' in data:
130
+ batch_ori_inter = data['flow_map_inter'] # [1, 2, 914, 1380] torch.float32
131
+ else:
132
+ batch_ori_inter = None
133
+ if target is not None:
134
+ target = F.interpolate(target, size=512, mode='bilinear', align_corners=False) # [1, 3, 512, 512]
135
+ target_256 = data['target_image_256'].to(dist_util.dev()) # [1, 3, 256, 256]
136
+ else:
137
+ target = None
138
+ target_256 = None
139
+
140
+ if settings.env.eval_dataset == 'hp-240':# false
141
+ source_256 = source
142
+ target_256 = target
143
+
144
+ else: # true
145
+ data['source_image_256'] = torch.nn.functional.interpolate(input=source.float(), size=(256, 256), mode='area')
146
+ source_256 = data['source_image_256'].to(dist_util.dev())
147
+
148
+ if 'target_image_256' in data:
149
+ target_256 = data['target_image_256']
150
+ else:
151
+ target_256 = None
152
+ if 'correspondence_mask' in data:
153
+ mask = data['correspondence_mask'] # torch.bool [1, 914, 1380]
154
+ else:
155
+ mask = torch.ones((1, 512, 512), dtype=torch.bool).to(dist_util.dev()) # None
156
+
157
+ return data, H_ori, W_ori, source, target, batch_ori, batch_ori_inter, source_256, target_256, source_vis, target_vis, mask, source_0
158
+
159
+ def prepare_data_single(input_image, input_image_ori):
160
+ source_vis = input_image_ori
161
+ target_vis = None
162
+ _, _, H_ori, W_ori = source_vis.shape
163
+ source = input_image.to(dist_util.dev()) # [1, 3, 914, 1380] torch.float32
164
+ source_0 = None
165
+ target = None
166
+ batch_ori = None
167
+ batch_ori_inter = None
168
+ target = None
169
+ target_256 = None
170
+ source_256 = torch.nn.functional.interpolate(input=source.float(), size=(256, 256), mode='area').to(dist_util.dev())
171
+ target_256 = None
172
+ mask = torch.ones((1, 512, 512), dtype=torch.bool).to(dist_util.dev()) # None
173
+
174
+ return input_image, H_ori, W_ori, source, target, batch_ori, batch_ori_inter, source_256, target_256, source_vis, target_vis, mask, source_0
175
+
176
+
177
+
178
+ def run_evaluation_docunet(
179
+ settings, logger, val_loader, diffusion: GaussianDiffusion, model,
180
+ pretrained_dewarp_model,pretrained_line_seg_model=None,pretrained_seg_model=None
181
+ ):
182
+ os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}', exist_ok=True)
183
+ batch_preprocessing = None
184
+ pbar = tqdm(enumerate(val_loader), total=len(val_loader))
185
+ pyramid = VGGPyramid(train=False).to(dist_util.dev())
186
+ SIZE = None
187
+
188
+ # for each document image
189
+
190
+ for i, data in pbar:
191
+ radius = 4
192
+ raw_corr = None
193
+ data_path = data['path']
194
+ source_288 = F.interpolate(data['source_image'], size=(288), mode='bilinear', align_corners=True).to(dist_util.dev())
195
+
196
+ if settings.env.time_variant == True:
197
+ init_feat = torch.zeros((data['source_image'].shape[0], 256, 64, 64), dtype=torch.float32).to(dist_util.dev())
198
+ else:
199
+ init_feat = None
200
+
201
+
202
+ with torch.inference_mode():
203
+ ref_bm, mask_x = pretrained_dewarp_model(source_288) # [1,2,288,288] 0~288 0~1
204
+ ref_flow = ref_bm/287.0 # [-1, 1] # [1,2,288,288]
205
+ if settings.env.use_init_flow:
206
+ init_flow = F.interpolate(ref_flow, size=(64), mode='bilinear', align_corners=True) # [24, 2, 64, 64]
207
+ else:
208
+ init_flow = torch.zeros((data['source_image'].shape[0], 2, 64, 64), dtype=torch.float32).to(dist_util.dev())
209
+
210
+
211
+ (
212
+ data,
213
+ H_ori, # 512
214
+ W_ori, # 512
215
+ source, # [1, 3, 512, 512] 0-1
216
+ target, # None
217
+ batch_ori, # None
218
+ batch_ori_inter, # None
219
+ source_256,# [1, 3, 256, 256] 0-1
220
+ target_256, # None
221
+ source_vis, # [1, 3, H, W] cpu仅用于可视化
222
+ target_vis, # None
223
+ mask, # [1, 512, 512] 全白
224
+ source_0
225
+ ) = prepare_data(settings, batch_preprocessing, SIZE, data)
226
+
227
+
228
+
229
+ with torch.no_grad():
230
+ if settings.env.use_gt_mask == False:
231
+ # ref_bm, mask_x = self.pretrained_dewarp_model(source_288) # [1,2,288,288] bm 0~288 mskx0-256
232
+ mskx, d0, hx6, hx5d, hx4d, hx3d, hx2d, hx1d = pretrained_seg_model(source_288)
233
+ hx6 = F.interpolate(hx6, size=64, mode='bilinear', align_corners=False)
234
+ hx5d = F.interpolate(hx5d, size=64, mode='bilinear', align_corners=False)
235
+ hx4d = F.interpolate(hx4d, size=64, mode='bilinear', align_corners=False)
236
+ hx3d = F.interpolate(hx3d, size=64, mode='bilinear', align_corners=False)
237
+ hx2d = F.interpolate(hx2d, size=64, mode='bilinear', align_corners=False)
238
+ hx1d = F.interpolate(hx1d, size=64, mode='bilinear', align_corners=False)
239
+
240
+ seg_map_all = torch.cat((hx6, hx5d, hx4d, hx3d, hx2d, hx1d), dim=1) # [b, 384, 64, 64]
241
+ # tv_save_image(mskx,"vis_hp/debug_vis/mskx.png")
242
+ if settings.env.use_line_mask:
243
+ textline_map, textline_mask = pretrained_line_seg_model(mskx) # [3, 64, 256, 256]
244
+ textline_map = F.interpolate(textline_map, size=64, mode='bilinear', align_corners=False) # [3, 64, 64, 64]
245
+ else:
246
+ seg_map_all = None
247
+ textline_map = None
248
+
249
+
250
+ if settings.env.train_VGG:
251
+ c20 = None
252
+ feature_size = 64
253
+ else:
254
+ feature_size = 64
255
+ if settings.env.train_mode == 'stage_1_dit_cat' or settings.env.train_mode =='stage_1_dit_cross':
256
+ with th.no_grad():
257
+ c20 = extract_raw_features_single2(pyramid, source, source_256, feature_size) # [24, 1, 64, 64, 64, 64]
258
+ # 平均互相关,VGG最浅层特征的下采样(512*512->64*64)
259
+ else:
260
+ with th.no_grad():
261
+ c20 = extract_raw_features_single(pyramid, source, source_256, feature_size) # [24, 1, 64, 64, 64, 64]
262
+ # 平均互相关,VGG最浅层特征的下采样(512*512->64*64)
263
+
264
+ source_64 = None # F.interpolate(source, size=(feature_size), mode='bilinear', align_corners=True)
265
+ logger.info(f"Starting sampling with VGG Features")
266
+
267
+ sample = run_sample_lr_dewarping(
268
+ settings,
269
+ logger,
270
+ diffusion,
271
+ model,
272
+ radius, # 4
273
+ source, # [B, 3, 512, 512] 0~1
274
+ feature_size, # 64
275
+ raw_corr, # None
276
+ init_flow, # [B, 2, 64, 64] -1~1
277
+ c20, # # [B, 64, 64, 64]
278
+ source_64, # None
279
+ pyramid,
280
+ mask_x, #mask_x, # F.interpolate(mskx, size=(512), mode='bilinear', align_corners=True)[:,:1,:,:] , # mask_x
281
+ seg_map_all,
282
+ textline_map,
283
+ init_feat
284
+ ) # sample: [1, 2, 64, 64] 偏移量 [-1,1]范围 五步DDIM的结果
285
+
286
+
287
+ if settings.env.use_sr_net == False:
288
+ sample = F.interpolate(sample, size=(H_ori, W_ori), mode='bilinear', align_corners=True) # [-1,+1] 偏移场
289
+ # sample[:, 0, :, :] = sample[:, 0, :, :] * W_ori
290
+ # sample[:, 1, :, :] = sample[:, 1, :, :] * H_ori
291
+ base = F.interpolate(coords_grid_tensor((512,512))/511., size=(H_ori, W_ori), mode='bilinear', align_corners=True)
292
+ # sample = ( ((sample + base.to(sample.device)) )*2 - 1 )
293
+ sample = ( ((sample + base.to(sample.device))*1 )*2 - 1 )*0.987 # (2 * (bm / 286.8) - 1) * 0.99
294
+ ref_flow = None
295
+ if ref_flow is not None:
296
+ ref_flow = F.interpolate(ref_flow, size=(H_ori, W_ori), mode='bilinear', align_corners=True) # [-1,+1] 偏移场
297
+ # ref_flow[:, 0, :, :] = ref_flow[:, 0, :, :] * W_ori
298
+ # ref_flow[:, 1, :, :] = ref_flow[:, 1, :, :] * H_ori
299
+ ref_flow = (ref_flow + base.to(ref_flow.device))*2 -1
300
+ # init_flow = F.interpolate(init_flow, size=(H_ori, W_ori), mode='bilinear', align_corners=True)
301
+ else:
302
+ raise ValueError("Invalid value")
303
+
304
+
305
+ if settings.env.visualize:
306
+ output = visualize_dewarping(settings, sample, data, i, source_vis, data_path, ref_flow)
307
+
308
+
309
+
310
+ def run_single_docunet(input_image_ori):
311
+ input_image_ori = np.array(input_image_ori, dtype=np.uint8) # [x, y, 3]
312
+
313
+ # resize to 512x512
314
+ input_image_resized = cv.resize(input_image_ori, (512, 512)) # [512, 512, 3]
315
+
316
+ # transpose to [3, 512, 512]
317
+ input_image_ori = np.transpose(input_image_ori, (2, 0, 1)) # [3, 512, 512]
318
+ input_image = np.transpose(input_image_resized, (2, 0, 1)) # [3, 512, 512]
319
+
320
+ input_image = input_image / 255
321
+
322
+ input_image_ori = torch.tensor(input_image_ori).unsqueeze(0) # [1, 3, 512, 512]
323
+ input_image = torch.tensor(input_image).unsqueeze(0).float() # [1, 3, 512, 512]
324
+
325
+ os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}', exist_ok=True)
326
+ batch_preprocessing = None
327
+ pyramid = VGGPyramid(train=False).to(dist_util.dev())
328
+ SIZE = None
329
+
330
+
331
+ radius = 4
332
+ raw_corr = None
333
+ source_288 = F.interpolate(input_image, size=(288), mode='bilinear', align_corners=True).to(dist_util.dev())
334
+
335
+ if settings.env.time_variant == True:
336
+ init_feat = torch.zeros((input_image.shape[0], 256, 64, 64), dtype=torch.float32).to(dist_util.dev())
337
+ else:
338
+ init_feat = None
339
+
340
+ with torch.inference_mode():
341
+ ref_bm, mask_x = pretrained_dewarp_model(source_288) # [1,2,288,288] 0~288 0~1
342
+ ref_flow = ref_bm/287.0 # [-1, 1] # [1,2,288,288]
343
+ if settings.env.use_init_flow:
344
+ init_flow = F.interpolate(ref_flow, size=(64), mode='bilinear', align_corners=True) # [24, 2, 64, 64]
345
+ else:
346
+ init_flow = torch.zeros((input_image.shape[0], 2, 64, 64), dtype=torch.float32).to(dist_util.dev())
347
+
348
+ (
349
+ data,
350
+ H_ori, # 512
351
+ W_ori, # 512
352
+ source, # [1, 3, 512, 512] 0-1
353
+ target, # None
354
+ batch_ori, # None
355
+ batch_ori_inter, # None
356
+ source_256,# [1, 3, 256, 256] 0-1
357
+ target_256, # None
358
+ source_vis, # [1, 3, H, W] cpu仅用于可视化
359
+ target_vis, # None
360
+ mask, # [1, 512, 512] 全白
361
+ source_0
362
+ ) = prepare_data_single(input_image, input_image_ori)
363
+
364
+
365
+
366
+ with torch.no_grad():
367
+ if settings.env.use_gt_mask == False:
368
+ # ref_bm, mask_x = self.pretrained_dewarp_model(source_288) # [1,2,288,288] bm 0~288 mskx0-256
369
+ mskx, d0, hx6, hx5d, hx4d, hx3d, hx2d, hx1d = pretrained_seg_model(source_288)
370
+ hx6 = F.interpolate(hx6, size=64, mode='bilinear', align_corners=False)
371
+ hx5d = F.interpolate(hx5d, size=64, mode='bilinear', align_corners=False)
372
+ hx4d = F.interpolate(hx4d, size=64, mode='bilinear', align_corners=False)
373
+ hx3d = F.interpolate(hx3d, size=64, mode='bilinear', align_corners=False)
374
+ hx2d = F.interpolate(hx2d, size=64, mode='bilinear', align_corners=False)
375
+ hx1d = F.interpolate(hx1d, size=64, mode='bilinear', align_corners=False)
376
+
377
+ seg_map_all = torch.cat((hx6, hx5d, hx4d, hx3d, hx2d, hx1d), dim=1) # [b, 384, 64, 64]
378
+ # tv_save_image(mskx,"vis_hp/debug_vis/mskx.png")
379
+ if settings.env.use_line_mask:
380
+ textline_map, textline_mask = pretrained_line_seg_model(mskx) # [3, 64, 256, 256]
381
+ textline_map = F.interpolate(textline_map, size=64, mode='bilinear', align_corners=False) # [3, 64, 64, 64]
382
+ else:
383
+ seg_map_all = None
384
+ textline_map = None
385
+
386
+ if settings.env.train_VGG:
387
+ c20 = None
388
+ feature_size = 64
389
+ else:
390
+ feature_size = 64
391
+ if settings.env.train_mode == 'stage_1_dit_cat' or settings.env.train_mode =='stage_1_dit_cross':
392
+ with th.no_grad():
393
+ c20 = extract_raw_features_single2(pyramid, source, source_256, feature_size) # [24, 1, 64, 64, 64, 64]
394
+ # 平均互相关,VGG最浅层特征的下采样(512*512->64*64)
395
+ else:
396
+ with th.no_grad():
397
+ c20 = extract_raw_features_single(pyramid, source, source_256, feature_size) # [24, 1, 64, 64, 64, 64]
398
+ # 平均互相关,VGG最浅层特征的下采样(512*512->64*64)
399
+
400
+ source_64 = None # F.interpolate(source, size=(feature_size), mode='bilinear', align_corners=True)
401
+ logger.info(f"Starting sampling with VGG Features")
402
+
403
+ sample = run_sample_lr_dewarping(
404
+ settings,
405
+ logger,
406
+ diffusion,
407
+ model,
408
+ radius, # 4
409
+ source, # [B, 3, 512, 512] 0~1
410
+ feature_size, # 64
411
+ raw_corr, # None
412
+ init_flow, # [B, 2, 64, 64] -1~1
413
+ c20, # # [B, 64, 64, 64]
414
+ source_64, # None
415
+ pyramid,
416
+ mask_x, #mask_x, # F.interpolate(mskx, size=(512), mode='bilinear', align_corners=True)[:,:1,:,:] , # mask_x
417
+ seg_map_all,
418
+ textline_map,
419
+ init_feat
420
+ ) # sample: [1, 2, 64, 64] 偏移量 [-1,1]范围 五步DDIM的结果
421
+
422
+ if settings.env.use_sr_net == False:
423
+ sample = F.interpolate(sample, size=(H_ori, W_ori), mode='bilinear', align_corners=True) # [-1,+1] 偏移场
424
+ # sample[:, 0, :, :] = sample[:, 0, :, :] * W_ori
425
+ # sample[:, 1, :, :] = sample[:, 1, :, :] * H_ori
426
+ base = F.interpolate(coords_grid_tensor((512,512))/511., size=(H_ori, W_ori), mode='bilinear', align_corners=True)
427
+ # sample = ( ((sample + base.to(sample.device)) )*2 - 1 )
428
+ sample = ( ((sample + base.to(sample.device))*1 )*2 - 1 )*0.987 # (2 * (bm / 286.8) - 1) * 0.99
429
+ ref_flow = None
430
+ if ref_flow is not None:
431
+ ref_flow = F.interpolate(ref_flow, size=(H_ori, W_ori), mode='bilinear', align_corners=True) # [-1,+1] 偏移场
432
+ # ref_flow[:, 0, :, :] = ref_flow[:, 0, :, :] * W_ori
433
+ # ref_flow[:, 1, :, :] = ref_flow[:, 1, :, :] * H_ori
434
+ ref_flow = (ref_flow + base.to(ref_flow.device))*2 -1
435
+ # init_flow = F.interpolate(init_flow, size=(H_ori, W_ori), mode='bilinear', align_corners=True)
436
+ else:
437
+ raise ValueError("Invalid value")
438
+
439
+
440
+ output = visualize_dewarping_single(settings, sample, source_vis)
441
+
442
+ return output
443
+
444
+
445
+
446
+
447
+
448
+
449
+
450
+
451
+
452
+
453
+ parser = argparse.ArgumentParser(description='Run a sampling scripts in train_settings.')
454
+ parser.add_argument('--train_module', type=str, default='dvd', help='Name of module in the "train_settings/" folder.')
455
+ parser.add_argument('--train_name', type=str, default='val_TDiff', help='Name of the train settings file.')
456
+ parser.add_argument('--cudnn_benchmark', type=bool, default=True, help='Set cudnn benchmark on (1) or off (0) (default is on).')
457
+ parser.add_argument('--seed', type=int, default=1992, help='Pseudo-RNG seed')
458
+ parser.add_argument('--name', type=str, default="gradio", help='Name of the experiment')
459
+ parser.add_argument('--corruption', action='store_true') # 默认为false,触发则为true
460
+
461
+ args = parser.parse_args()
462
+
463
+ args.seed = random.randint(0, 3000000)
464
+ args.seed = torch.initial_seed() & (2 ** 32 - 1)
465
+ print('Seed is {}'.format(args.seed))
466
+ random.seed(int(args.seed))
467
+ np.random.seed(args.seed)
468
+
469
+ cudnn_benchmark=args.cudnn_benchmark
470
+ seed=args.seed
471
+ corruption=args.corruption
472
+ name=args.name
473
+
474
+ # This is needed to avoid strange crashes related to opencv
475
+ cv.setNumThreads(0)
476
+
477
+ torch.backends.cudnn.benchmark = cudnn_benchmark
478
+
479
+ # dd/mm/YY
480
+ today = date.today()
481
+ d1 = today.strftime("%d/%m/%Y")
482
+ print('Sampling: {} {}\nDate: {}'.format(args.train_module, args.train_name, d1))
483
+
484
+ settings = ws_settings.Settings()
485
+ settings.module_name = args.train_module
486
+ settings.script_name = args.train_name
487
+ settings.project_path = 'train_settings/{}/{}'.format(args.train_module, args.train_name) # 'train_settings/DiffMatch/val_DiffMatch'
488
+ settings.seed = seed
489
+ settings.name = name
490
+
491
+ save_dir = os.path.join(settings.env.workspace_dir, settings.project_path) # 'checkpoints+train_settings/DiffMatch/val_DiffMatch'
492
+ if not os.path.exists(save_dir):
493
+ os.makedirs(save_dir)
494
+ copyfile(settings.project_path + '.py', os.path.join(save_dir, settings.script_name + '.py'))
495
+
496
+
497
+ settings.severity = 0
498
+ settings.corruption_number = 0
499
+
500
+
501
+ dist_util.setup_dist()
502
+ logger.configure(dir=f"SAMPLING_{settings.env.eval_dataset}_{settings.name}")
503
+ logger.log(f"Corruption Disabled. Evaluating on Original {settings.env.eval_dataset}")
504
+ logger.log("Loading model and diffusion...")
505
+
506
+ model, diffusion = create_model_and_diffusion(
507
+ device=dist_util.dev(),
508
+ train_mode=settings.env.train_mode, # stage 1
509
+ tv=settings.env.time_variant,
510
+ **args_to_dict(settings, model_and_diffusion_defaults().keys()),
511
+ )
512
+ setattr(diffusion, "settings", settings)
513
+
514
+ pretrained_dewarp_model = GeoTr_Seg_Inf()
515
+ reload_segmodel(pretrained_dewarp_model.msk, settings.env.seg_model_path)
516
+ # reload_model(pretrained_dewarp_model.GeoTr, settings.env.dewarping_model_path)
517
+ pretrained_dewarp_model.to(dist_util.dev())
518
+ pretrained_dewarp_model.eval()
519
+
520
+ if settings.env.use_line_mask:
521
+ pretrained_line_seg_model = UNet(n_channels=3, n_classes=1)
522
+ pretrained_seg_model = Seg()
523
+ line_model_ckpt = dist_util.load_state_dict(settings.env.line_seg_model_path, map_location='cpu')['model']
524
+ pretrained_line_seg_model.load_state_dict(line_model_ckpt, strict=True)
525
+ pretrained_line_seg_model.to(dist_util.dev())
526
+ pretrained_line_seg_model.eval()
527
+
528
+ seg_model_ckpt = dist_util.load_state_dict(settings.env.new_seg_model_path, map_location='cpu')['model']
529
+ pretrained_seg_model.load_state_dict(seg_model_ckpt, strict=True)
530
+ pretrained_seg_model.to(dist_util.dev())
531
+ pretrained_seg_model.eval()
532
+
533
+ model.cpu().load_state_dict(dist_util.load_state_dict(settings.env.model_path, map_location="cpu"), strict=False)
534
+ logger.log(f"Model loaded with {settings.env.model_path}")
535
+
536
+ model.to(dist_util.dev())
537
+ model.eval()
538
+
539
+
540
+ if __name__ == '__main__':
541
+ demo = gr.Interface(
542
+ fn=run_single_docunet,
543
+ inputs=[
544
+ gr.Image(type="numpy", label="Input Image"),
545
+ ],
546
+ outputs=[
547
+ gr.Image(type="numpy", label="Output Image"),
548
+ ],
549
+ title="Document Image Dewarping",
550
+ description="This is a demo for document image dewarping using a trained model.",
551
+ )
552
+
553
+ demo.launch(share=True, debug=True, server_name="10.7.88.77")