| | import torch |
| | import numpy as np |
| | from PIL import Image |
| | import torch.nn.functional as F |
| |
|
| | def pad_to_multiple_of_16(latent, pad_value, patch_size=16): |
| | h, w = latent.size(2), latent.size(3) |
| | target_h = ((h - 1) // patch_size + 1) * patch_size |
| | target_w = ((w - 1) // patch_size + 1) * patch_size |
| | pad_h = (target_h - h) // 2 |
| | pad_w = (target_w - w) // 2 |
| | |
| | pad_h_extra = (target_h - h) % 2 |
| | pad_w_extra = (target_w - w) % 2 |
| | padded_latent = F.pad(latent, (pad_w, pad_w + pad_w_extra, pad_h, pad_h + pad_h_extra), mode='constant', value=pad_value) |
| | |
| | return padded_latent |
| |
|
| | def split_into_blocks(latent, patch_size=16): |
| | b, c, h, w = latent.size() |
| | blocks = latent.view(b, c, h // patch_size, patch_size, w // patch_size, patch_size) |
| | blocks = blocks.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, c, patch_size, patch_size) |
| | |
| | return blocks |
| |
|
| | def merge_blocks(blocks, original_shape, patch_size=16): |
| | b, c, h, w = original_shape |
| | num_blocks_per_row = w // patch_size |
| | num_blocks_per_col = h // patch_size |
| |
|
| | |
| | blocks = blocks.view(b, num_blocks_per_col, num_blocks_per_row, c, patch_size, patch_size) |
| | blocks = blocks.permute(0, 3, 1, 4, 2, 5).contiguous() |
| | blocks = blocks.view(b, c, h, w) |
| | |
| | return blocks |
| |
|
| | def crop_to_original_shape(blocks, original_shape): |
| | _, _, padded_height, padded_width = blocks.shape |
| | original_height, original_width = original_shape[2], original_shape[3] |
| | start_h = (padded_height - original_height) // 2 |
| | end_h = start_h + original_height |
| | start_w = (padded_width - original_width) // 2 |
| | end_w = start_w + original_width |
| | cropped_blocks = blocks[:, :, start_h:end_h, start_w:end_w] |
| | |
| | return cropped_blocks |
| |
|
| | def adaptively_split_and_pad(image_tensor, pad_value, target_patch_size=16): |
| | """ |
| | return: |
| | patches_tensor: (N * num_blocks_h * num_blocks_w, c, target_patch_size, target_patch_size) patched tensors after spilt |
| | patch_sizes: a list, ori size of each blocks |
| | num_blocks_h, num_blocks_w |
| | """ |
| | c, h, w = image_tensor.size(1), image_tensor.size(2), image_tensor.size(3) |
| | |
| | num_blocks_h = h // target_patch_size if h % target_patch_size == 0 else h // target_patch_size + 1 |
| | num_blocks_w = w // target_patch_size if w % target_patch_size == 0 else w // target_patch_size + 1 |
| |
|
| | |
| | block_h = h // num_blocks_h |
| | block_w = w // num_blocks_w |
| | patches = [] |
| | patch_sizes = [] |
| |
|
| | for i in range(num_blocks_h): |
| | for j in range(num_blocks_w): |
| | |
| | start_h = i * block_h |
| | start_w = j * block_w |
| | end_h = start_h + block_h if i < num_blocks_h - 1 else h |
| | end_w = start_w + block_w if j < num_blocks_w - 1 else w |
| | |
| | patch = image_tensor[:, :, start_h:end_h, start_w:end_w] |
| | |
| | |
| |
|
| | |
| | pad_top = (target_patch_size - (end_h - start_h)) // 2 |
| | pad_bottom = target_patch_size - (end_h - start_h) - pad_top |
| | pad_left = (target_patch_size - (end_w - start_w)) // 2 |
| | pad_right = target_patch_size - (end_w - start_w) - pad_left |
| |
|
| | |
| | patch_padded = F.pad(patch, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=pad_value) |
| |
|
| | patches.append(patch_padded) |
| | patch_sizes.append((end_h - start_h, end_w - start_w)) |
| | |
| | patches_tensor = torch.cat(patches, dim=0) |
| | return patches_tensor, patch_sizes, num_blocks_h, num_blocks_w |
| |
|
| |
|
| | def crop_and_reconstruct(patches, patch_sizes, num_blocks_h, num_blocks_w, target_patch_size=16): |
| | """ |
| | inverse operation of adaptively_split_and_pad |
| | """ |
| | index = 0 |
| | reconstructed_rows = [] |
| |
|
| | for i in range(num_blocks_h): |
| | row_patches = [] |
| | for j in range(num_blocks_w): |
| | patch = patches[index] |
| | patch_height, patch_width = patch_sizes[index] |
| |
|
| | valid_h_start = (target_patch_size - patch_height) // 2 |
| | valid_w_start = (target_patch_size - patch_width) // 2 |
| | valid_h_end = valid_h_start + patch_height |
| | valid_w_end = valid_w_start + patch_width |
| |
|
| | cropped_patch = patch[:, valid_h_start:valid_h_end, valid_w_start:valid_w_end] |
| | row_patches.append(cropped_patch) |
| | index += 1 |
| | row_tensor = torch.cat(row_patches, dim=2) |
| | reconstructed_rows.append(row_tensor) |
| |
|
| | reconstructed_image = torch.cat(reconstructed_rows, dim=1) |
| | return reconstructed_image |
| |
|
| | def save_image(tensor, file_path): |
| | |
| | image = tensor.to('cpu').clone().detach() |
| | image = image.squeeze(0) |
| | image = torch.clamp(image, 0, 1) |
| | image = Image.fromarray((image.permute(1, 2, 0).numpy() * 255).astype(np.uint8)) |
| | image.save(file_path) |
| | print(f"Image saved to {file_path}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | N, C, H, W = 1, 3, 36, 33 |
| | image_tensor = torch.rand(N, C, H, W) |
| |
|
| | |
| | target_patch_size = 16 |
| | pad_value = 0 |
| | patches_tensor, patch_sizes, num_blocks_h, num_blocks_w = adaptively_split_and_pad(image_tensor, pad_value, target_patch_size) |
| |
|
| | |
| | for i, patch in enumerate(patches_tensor): |
| | save_image(patch, f"patch_{i}.png") |
| |
|
| | |
| | reconstructed_image = crop_and_reconstruct(patches_tensor, patch_sizes, num_blocks_h, num_blocks_w, target_patch_size) |
| |
|
| | |
| | save_image(reconstructed_image, "reconstructed_image.png") |