Spaces:
Sleeping
Sleeping
File size: 3,617 Bytes
6bdded7 ed9b64d 6bdded7 ed9b64d 6bdded7 ed9b64d 6bdded7 ed9b64d 6bdded7 ed9b64d 6bdded7 ed9b64d 6bdded7 ed9b64d 6bdded7 ed9b64d 6bdded7 ed9b64d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import PIL
import numpy as np
from PIL import Image
class Croper:
def __init__(
self,
input_image: PIL.Image,
target_mask: np.ndarray,
mask_size: int = 256,
mask_expansion: int = 20
):
self.input_image = input_image
self.target_mask = target_mask
self.mask_size = mask_size
self.mask_expansion = mask_expansion
def corp_mask_image(self):
target_mask = self.target_mask
input_image = self.input_image
mask_expansion = self.mask_expansion
original_width, original_height = input_image.size
mask_indices = np.where(target_mask)
start_y = np.min(mask_indices[0]) - mask_expansion
if start_y < 0:
start_y = 0
end_y = np.max(mask_indices[0]) + mask_expansion
if end_y > original_height:
end_y = original_height
start_x = np.min(mask_indices[1]) - mask_expansion
if start_x < 0:
start_x = 0
end_x = np.max(mask_indices[1]) + mask_expansion
if end_x > original_width:
end_x = original_width
mask_height = end_y - start_y
mask_width = end_x - start_x
# choose the max side length
max_side_length = max(mask_height, mask_width)
# calculate the crop area
crop_mask = target_mask[start_y:end_y, start_x:end_x]
crop_mask_start_y = (max_side_length - mask_height) // 2
crop_mask_end_y = crop_mask_start_y + mask_height
crop_mask_start_x = (max_side_length - mask_width) // 2
crop_mask_end_x = crop_mask_start_x + mask_width
# create a square mask
square_mask = np.zeros((max_side_length, max_side_length), dtype=target_mask.dtype)
square_mask[crop_mask_start_y:crop_mask_end_y, crop_mask_start_x:crop_mask_end_x] = crop_mask
square_mask_image = Image.fromarray((square_mask * 255).astype(np.uint8))
crop_image = input_image.crop((start_x, start_y, end_x, end_y))
square_image = Image.new("RGB", (max_side_length, max_side_length))
square_image.paste(crop_image, (crop_mask_start_x, crop_mask_start_y))
self.origin_start_x = start_x
self.origin_start_y = start_y
self.origin_end_x = end_x
self.origin_end_y = end_y
self.square_start_x = crop_mask_start_x
self.square_start_y = crop_mask_start_y
self.square_end_x = crop_mask_end_x
self.square_end_y = crop_mask_end_y
self.square_length = max_side_length
self.square_mask_image = square_mask_image
self.square_image = square_image
mask_size = self.mask_size
self.resized_square_mask_image = square_mask_image.resize((mask_size, mask_size))
self.resized_square_image = square_image.resize((mask_size, mask_size))
return self.resized_square_mask_image
def restore_result(self, generated_image):
square_length = self.square_length
generated_image = generated_image.resize((square_length, square_length))
square_mask_image = self.square_mask_image
cropped_generated_image = generated_image.crop((self.square_start_x, self.square_start_y, self.square_end_x, self.square_end_y))
cropped_square_mask_image = square_mask_image.crop((self.square_start_x, self.square_start_y, self.square_end_x, self.square_end_y))
restored_image = self.input_image.copy()
restored_image.paste(cropped_generated_image, (self.origin_start_x, self.origin_start_y), cropped_square_mask_image)
return restored_image
|