|  | import torch | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  | import math | 
					
						
						|  |  | 
					
						
						|  | from utils.gaussian_splatting import generate_2D_gaussian_splatting_step, generate_2D_gaussian_splatting_step_buffer | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def split_and_joint_image(lq, scale_factor, split_size, | 
					
						
						|  | overlap_size, model_g, model_fea2gs, | 
					
						
						|  | scale_modify, crop_size = 2, | 
					
						
						|  | default_step_size = 1.2, mode = 'scale_modify', | 
					
						
						|  | cuda_rendering = True, | 
					
						
						|  | if_dmax = False, | 
					
						
						|  | dmax_mode = 'fix', | 
					
						
						|  | dmax = 25): | 
					
						
						|  | h_lq, w_lq = lq.shape[-2:] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | assert overlap_size > 0 and overlap_size < split_size // 2, f"overlap size is wrong" | 
					
						
						|  |  | 
					
						
						|  | tile_nums_h = math.ceil((h_lq - overlap_size) / (split_size - overlap_size)) | 
					
						
						|  | tile_nums_w = math.ceil((w_lq - overlap_size) / (split_size - overlap_size)) | 
					
						
						|  |  | 
					
						
						|  | pad_h_lq = tile_nums_h * (split_size - overlap_size) + overlap_size - h_lq | 
					
						
						|  | pad_w_lq = tile_nums_w * (split_size - overlap_size) + overlap_size - w_lq | 
					
						
						|  |  | 
					
						
						|  | assert pad_h_lq < h_lq, f'pad_h_lq-{pad_h_lq} should be smaller than h_lq-{h_lq}, please decrease the split_size-{split_size}' | 
					
						
						|  | assert pad_w_lq < w_lq, f'pad_w_lq-{pad_w_lq} should be smaller than w_lq-{w_lq}, please decrease the split_size-{split_size}' | 
					
						
						|  |  | 
					
						
						|  | lq_pad = F.pad(input=lq, pad=(0, pad_w_lq, 0, pad_h_lq), mode='reflect') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | split_size_sr = math.ceil(split_size * scale_factor) | 
					
						
						|  | sr_tile_list = [] | 
					
						
						|  | for h_num in range(tile_nums_h): | 
					
						
						|  | for w_num in range(tile_nums_w): | 
					
						
						|  | tile_lq_position_start_h = h_num * (split_size - overlap_size) | 
					
						
						|  | tile_lq_position_start_w = w_num * (split_size - overlap_size) | 
					
						
						|  | tile_lq_position_end_h = tile_lq_position_start_h + split_size | 
					
						
						|  | tile_lq_position_end_w = tile_lq_position_start_w + split_size | 
					
						
						|  |  | 
					
						
						|  | input_tile = lq_pad[:,:, tile_lq_position_start_h:tile_lq_position_end_h, tile_lq_position_start_w:tile_lq_position_end_w] | 
					
						
						|  |  | 
					
						
						|  | model_g_output = model_g(input_tile) | 
					
						
						|  |  | 
					
						
						|  | scale_vector = scale_modify[0].unsqueeze(0).to(model_g_output.device) | 
					
						
						|  | batch_gs_parameters = model_fea2gs(model_g_output, scale_vector) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | gs_parameters = batch_gs_parameters[0, :] | 
					
						
						|  | b_output = generate_2D_gaussian_splatting_step(sr_size=torch.tensor([split_size_sr, split_size_sr]), gs_parameters=gs_parameters, | 
					
						
						|  | scale=scale_factor, sample_coords=None, | 
					
						
						|  | scale_modify = scale_modify, | 
					
						
						|  | default_step_size = default_step_size, mode = mode, | 
					
						
						|  | cuda_rendering = cuda_rendering, | 
					
						
						|  | if_dmax = if_dmax, | 
					
						
						|  | dmax_mode = dmax_mode, | 
					
						
						|  | dmax = dmax) | 
					
						
						|  | sr_tile_list.append(b_output.unsqueeze(0)) | 
					
						
						|  |  | 
					
						
						|  | tile_sr_h = sr_tile_list[0].shape[2] | 
					
						
						|  | tile_sr_w = sr_tile_list[0].shape[3] | 
					
						
						|  |  | 
					
						
						|  | assert tile_sr_w == split_size_sr and tile_sr_h == split_size_sr, \ | 
					
						
						|  | f'tile_sr_h-{tile_sr_w}, tile_sr_w-{tile_sr_w}, split_size_sr-{split_size_sr} is not the same' | 
					
						
						|  |  | 
					
						
						|  | overlap_sr = math.ceil(overlap_size * scale_factor) | 
					
						
						|  |  | 
					
						
						|  | sr_pad = torch.zeros(lq.shape[0], lq.shape[1], | 
					
						
						|  | (tile_nums_h - 1) * (split_size_sr - overlap_sr) + split_size_sr, | 
					
						
						|  | (tile_nums_w - 1) * (split_size_sr - overlap_sr) + split_size_sr, | 
					
						
						|  | device=lq.device) | 
					
						
						|  |  | 
					
						
						|  | idx = 0 | 
					
						
						|  |  | 
					
						
						|  | if scale_factor != int(scale_factor): | 
					
						
						|  | for h_num in range(tile_nums_h): | 
					
						
						|  | for w_num in range(tile_nums_w): | 
					
						
						|  | tile_sr_position_start_w = w_num * (split_size_sr - overlap_sr) | 
					
						
						|  | tile_sr_position_end_w = tile_sr_position_start_w + split_size_sr | 
					
						
						|  | tile_sr_position_start_h = h_num * (split_size_sr - overlap_sr) | 
					
						
						|  | tile_sr_position_end_h = tile_sr_position_start_h + split_size_sr | 
					
						
						|  | if h_num == 0 and w_num == 0: | 
					
						
						|  | sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h, | 
					
						
						|  | tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx] | 
					
						
						|  | elif h_num == 0 and w_num !=0: | 
					
						
						|  | if w_num != tile_nums_w - 1: | 
					
						
						|  | sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h, | 
					
						
						|  | tile_sr_position_start_w+crop_size:tile_sr_position_end_w] = sr_tile_list[idx][:,:,:,crop_size:] | 
					
						
						|  | else: | 
					
						
						|  | sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h, | 
					
						
						|  | tile_sr_position_start_w+crop_size:sr_pad.shape[3]] = sr_tile_list[idx][:,:,:,crop_size:sr_pad.shape[3] - tile_sr_position_start_w] | 
					
						
						|  | elif h_num != 0 and w_num ==0: | 
					
						
						|  | if h_num != tile_nums_h - 1: | 
					
						
						|  | sr_pad[:, :, tile_sr_position_start_h+crop_size:tile_sr_position_end_h, | 
					
						
						|  | tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:,:] | 
					
						
						|  | else: | 
					
						
						|  | sr_pad[:, :, tile_sr_position_start_h+crop_size:sr_pad.shape[2], | 
					
						
						|  | tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:sr_pad.shape[2] - tile_sr_position_start_h,:] | 
					
						
						|  | else: | 
					
						
						|  | if w_num != tile_nums_w - 1 and h_num != tile_nums_h - 1: | 
					
						
						|  | sr_pad[:,:,tile_sr_position_start_h+crop_size:tile_sr_position_end_h, | 
					
						
						|  | tile_sr_position_start_w+crop_size:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:,crop_size:] | 
					
						
						|  | elif w_num == tile_nums_w - 1 and h_num != tile_nums_h - 1: | 
					
						
						|  | sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h, | 
					
						
						|  | tile_sr_position_start_w+crop_size:sr_pad.shape[3]] = sr_tile_list[idx][:,:,:,crop_size:sr_pad.shape[3] - tile_sr_position_start_w] | 
					
						
						|  | elif w_num != tile_nums_w - 1 and h_num == tile_nums_h - 1: | 
					
						
						|  | sr_pad[:, :, tile_sr_position_start_h+crop_size:sr_pad.shape[2], | 
					
						
						|  | tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:sr_pad.shape[2] - tile_sr_position_start_h,:] | 
					
						
						|  | elif w_num == tile_nums_w - 1 and h_num == tile_nums_h - 1: | 
					
						
						|  | sr_pad[:,:,tile_sr_position_start_h+crop_size:sr_pad.shape[2], | 
					
						
						|  | tile_sr_position_start_w+crop_size:sr_pad.shape[3]] = sr_tile_list[idx][:,:,crop_size:sr_pad.shape[2] - tile_sr_position_start_h,crop_size:sr_pad.shape[3] - tile_sr_position_start_w] | 
					
						
						|  | idx = idx + 1 | 
					
						
						|  | else: | 
					
						
						|  | for h_num in range(tile_nums_h): | 
					
						
						|  | for w_num in range(tile_nums_w): | 
					
						
						|  | tile_sr_position_start_w = w_num * (split_size_sr - overlap_sr) | 
					
						
						|  | tile_sr_position_end_w = tile_sr_position_start_w + split_size_sr | 
					
						
						|  | tile_sr_position_start_h = h_num * (split_size_sr - overlap_sr) | 
					
						
						|  | tile_sr_position_end_h = tile_sr_position_start_h + split_size_sr | 
					
						
						|  | if h_num == 0 and w_num == 0: | 
					
						
						|  | sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h, | 
					
						
						|  | tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx] | 
					
						
						|  | elif h_num == 0 and w_num !=0: | 
					
						
						|  | sr_pad[:, :, tile_sr_position_start_h:tile_sr_position_end_h, | 
					
						
						|  | tile_sr_position_start_w+crop_size:tile_sr_position_end_w] = sr_tile_list[idx][:,:,:,crop_size:] | 
					
						
						|  | elif h_num != 0 and w_num ==0: | 
					
						
						|  | sr_pad[:, :, tile_sr_position_start_h+crop_size:tile_sr_position_end_h, | 
					
						
						|  | tile_sr_position_start_w:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:,:] | 
					
						
						|  | else: | 
					
						
						|  | sr_pad[:,:,tile_sr_position_start_h+crop_size:tile_sr_position_end_h, | 
					
						
						|  | tile_sr_position_start_w+crop_size:tile_sr_position_end_w] = sr_tile_list[idx][:,:,crop_size:,crop_size:] | 
					
						
						|  | idx = idx + 1 | 
					
						
						|  |  | 
					
						
						|  | print(f"sr_pad shape is {sr_pad.shape}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | sr_final = sr_pad | 
					
						
						|  |  | 
					
						
						|  | return sr_final |