Junfeng5 commited on
Commit
943258e
1 Parent(s): 5c34bad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -10
app.py CHANGED
@@ -114,6 +114,20 @@ GLEEmodel_swin = GLEE_Model(cfg_swin, None, device, None, True).to(device)
114
  GLEEmodel_swin.load_state_dict(checkpoints_swin, strict=False)
115
  GLEEmodel_swin.eval()
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  pixel_mean = torch.Tensor( [123.675, 116.28, 103.53]).to(device).view(3, 1, 1)
118
  pixel_std = torch.Tensor([58.395, 57.12, 57.375]).to(device).view(3, 1, 1)
119
  normalizer = lambda x: (x - pixel_mean) / pixel_std
@@ -130,16 +144,26 @@ TEXT_Y_OFFSET_SCALE = 1e-2
130
  if inference_type != 'LSJ':
131
  resizer = torchvision.transforms.Resize(inference_size,antialias=True)
132
  videoresizer = torchvision.transforms.Resize(video_inference_size,antialias=True)
 
 
 
 
133
 
134
 
135
  def segment_image(img, prompt_mode, categoryname, custom_category, expressiong, results_select, num_inst_select, threshold_select, mask_image_mix_ration, model_selection):
136
  torch.cuda.empty_cache()
137
  if model_selection == 'GLEE-Plus (SwinL)':
138
  GLEEmodel = GLEEmodel_swin
 
139
  print('use GLEE-Plus')
140
- else:
 
141
  GLEEmodel = GLEEmodel_r50
142
  print('use GLEE-Lite')
 
 
 
 
143
 
144
  copyed_img = img['background'][:,:,:3].copy()
145
 
@@ -148,8 +172,12 @@ def segment_image(img, prompt_mode, categoryname, custom_category, expressiong,
148
  _,_, ori_height, ori_width = ori_image.shape
149
 
150
  if inference_type == 'LSJ':
151
- infer_image = torch.zeros(1,3,1024,1024).to(ori_image)
152
- infer_image[:,:,:inference_size,:inference_size] = ori_image
 
 
 
 
153
  else:
154
  resize_image = resizer(ori_image)
155
  image_size = torch.as_tensor((resize_image.shape[-2],resize_image.shape[-1]))
@@ -309,8 +337,9 @@ def segment_image(img, prompt_mode, categoryname, custom_category, expressiong,
309
 
310
  fakemask = torch.from_numpy(fakemask).unsqueeze(0).to(ori_image)
311
  if inference_type == 'LSJ':
312
- infer_visual_prompt = torch.zeros(1,1024,1024).to(ori_image)
313
- infer_visual_prompt[:,:inference_size,:inference_size] = fakemask
 
314
  else:
315
  resize_fakemask = resizer(fakemask)
316
  if size_divisibility > 1:
@@ -377,8 +406,12 @@ def process_frames(frame_list):
377
  _,_, ori_height, ori_width = ori_image.shape
378
 
379
  if inference_type == 'LSJ':
380
- infer_image = torch.zeros(1,3,1024,1024).to(ori_image)
381
- infer_image[:,:,:inference_size,:inference_size] = ori_image
 
 
 
 
382
  else:
383
  resize_image = videoresizer(ori_image)
384
  image_size = torch.as_tensor((resize_image.shape[-2],resize_image.shape[-1]))
@@ -414,14 +447,23 @@ def match_from_embds(tgt_embds, cur_embds):
414
  def segment_video(video, prompt_mode, categoryname, custom_category, expressiong, results_select, num_inst_select, threshold_select, mask_image_mix_ration, model_selection,video_frames_select, prompter):
415
  torch.cuda.empty_cache()
416
  ### model selection
 
 
417
  if model_selection == 'GLEE-Plus (SwinL)':
418
  GLEEmodel = GLEEmodel_swin
 
419
  print('use GLEE-Plus')
420
  clip_length = 2 #batchsize
421
- else:
 
422
  GLEEmodel = GLEEmodel_r50
423
  print('use GLEE-Lite')
424
  clip_length = 4 #batchsize
 
 
 
 
 
425
 
426
  # read video and get sparse frames
427
  cap = cv2.VideoCapture(video)
@@ -678,8 +720,9 @@ def segment_video(video, prompt_mode, categoryname, custom_category, expressiong
678
 
679
  fakemask = torch.from_numpy(fakemask).unsqueeze(0).to(ori_image)
680
  if inference_type == 'LSJ':
681
- infer_visual_prompt = torch.zeros(1,1024,1024).to(ori_image)
682
- infer_visual_prompt[:,:inference_size,:inference_size] = fakemask
 
683
  else:
684
  resize_fakemask = videoresizer(fakemask)
685
  if size_divisibility > 1:
 
114
  GLEEmodel_swin.load_state_dict(checkpoints_swin, strict=False)
115
  GLEEmodel_swin.eval()
116
 
117
+
118
+ cfg_eva02 = get_cfg()
119
+ add_deeplab_config(cfg_eva02)
120
+ add_glee_config(cfg_eva02)
121
+ conf_files_swin = 'GLEE/configs/EVA02.yaml'
122
+ checkpoints_eva = torch.load('GLEE/GLEE_{}.pth'.format(args.version))
123
+ cfg_eva02.merge_from_file(conf_files_swin)
124
+ GLEEmodel_eva02 = GLEE_Model(cfg_eva02, None, device, None, True).to(device)
125
+ GLEEmodel_eva02.load_state_dict(checkpoints_eva, strict=False)
126
+ GLEEmodel_eva02.eval()
127
+ # inference_type = 'LSJ'
128
+
129
+
130
+
131
  pixel_mean = torch.Tensor( [123.675, 116.28, 103.53]).to(device).view(3, 1, 1)
132
  pixel_std = torch.Tensor([58.395, 57.12, 57.375]).to(device).view(3, 1, 1)
133
  normalizer = lambda x: (x - pixel_mean) / pixel_std
 
144
  if inference_type != 'LSJ':
145
  resizer = torchvision.transforms.Resize(inference_size,antialias=True)
146
  videoresizer = torchvision.transforms.Resize(video_inference_size,antialias=True)
147
+ else:
148
+ resizer = torchvision.transforms.Resize(size = 1535, max_size=1536, antialias=True)
149
+ videoresizer = torchvision.transforms.Resize(size = 1535, max_size=1536, antialias=True)
150
+
151
 
152
 
153
  def segment_image(img, prompt_mode, categoryname, custom_category, expressiong, results_select, num_inst_select, threshold_select, mask_image_mix_ration, model_selection):
154
  torch.cuda.empty_cache()
155
  if model_selection == 'GLEE-Plus (SwinL)':
156
  GLEEmodel = GLEEmodel_swin
157
+ inference_type = 'resize_shot'
158
  print('use GLEE-Plus')
159
+ elif model_selection == 'GLEE-Lite (R50)':
160
+ inference_type = 'resize_shot'
161
  GLEEmodel = GLEEmodel_r50
162
  print('use GLEE-Lite')
163
+ else:
164
+ GLEEmodel = GLEEmodel_eva02
165
+ print('use GLEE-Pro')
166
+ inference_type = 'LSJ'
167
 
168
  copyed_img = img['background'][:,:,:3].copy()
169
 
 
172
  _,_, ori_height, ori_width = ori_image.shape
173
 
174
  if inference_type == 'LSJ':
175
+ resize_image = resizer(ori_image)
176
+ image_size = torch.as_tensor((resize_image.shape[-2],resize_image.shape[-1]))
177
+ re_size = resize_image.shape[-2:]
178
+ infer_image = torch.zeros(1,3,1536,1536).to(ori_image)
179
+ infer_image[:,:,:image_size[0],:image_size[1]] = resize_image
180
+ padding_size = (1536,1536)
181
  else:
182
  resize_image = resizer(ori_image)
183
  image_size = torch.as_tensor((resize_image.shape[-2],resize_image.shape[-1]))
 
337
 
338
  fakemask = torch.from_numpy(fakemask).unsqueeze(0).to(ori_image)
339
  if inference_type == 'LSJ':
340
+ resize_fakemask = resizer(fakemask)
341
+ infer_visual_prompt = torch.zeros(1,1536,1536).to(resize_fakemask)
342
+ infer_visual_prompt[:,:image_size[0],:image_size[1]] = resize_fakemask
343
  else:
344
  resize_fakemask = resizer(fakemask)
345
  if size_divisibility > 1:
 
406
  _,_, ori_height, ori_width = ori_image.shape
407
 
408
  if inference_type == 'LSJ':
409
+ resize_image = resizer(ori_image)
410
+ image_size = torch.as_tensor((resize_image.shape[-2],resize_image.shape[-1]))
411
+ re_size = resize_image.shape[-2:]
412
+ infer_image = torch.zeros(1,3,1536,1536).to(ori_image)
413
+ infer_image[:,:,:image_size[0],:image_size[1]] = resize_image
414
+ padding_size = (1536,1536)
415
  else:
416
  resize_image = videoresizer(ori_image)
417
  image_size = torch.as_tensor((resize_image.shape[-2],resize_image.shape[-1]))
 
447
  def segment_video(video, prompt_mode, categoryname, custom_category, expressiong, results_select, num_inst_select, threshold_select, mask_image_mix_ration, model_selection,video_frames_select, prompter):
448
  torch.cuda.empty_cache()
449
  ### model selection
450
+
451
+
452
  if model_selection == 'GLEE-Plus (SwinL)':
453
  GLEEmodel = GLEEmodel_swin
454
+ inference_type = 'resize_shot'
455
  print('use GLEE-Plus')
456
  clip_length = 2 #batchsize
457
+ elif model_selection == 'GLEE-Lite (R50)':
458
+ inference_type = 'resize_shot'
459
  GLEEmodel = GLEEmodel_r50
460
  print('use GLEE-Lite')
461
  clip_length = 4 #batchsize
462
+ else:
463
+ GLEEmodel = GLEEmodel_eva02
464
+ print('use GLEE-Pro')
465
+ inference_type = 'LSJ'
466
+ clip_length = 1 #batchsize
467
 
468
  # read video and get sparse frames
469
  cap = cv2.VideoCapture(video)
 
720
 
721
  fakemask = torch.from_numpy(fakemask).unsqueeze(0).to(ori_image)
722
  if inference_type == 'LSJ':
723
+ resize_fakemask = resizer(fakemask)
724
+ infer_visual_prompt = torch.zeros(1,1536,1536).to(resize_fakemask)
725
+ infer_visual_prompt[:,:image_size[0],:image_size[1]] = resize_fakemask
726
  else:
727
  resize_fakemask = videoresizer(fakemask)
728
  if size_divisibility > 1: