jhj0517 commited on
Commit
e00f944
1 Parent(s): b76be88

Refactor to add type hint

Browse files
Files changed (1) hide show
  1. modules/mask_utils.py +47 -17
modules/mask_utils.py CHANGED
@@ -5,23 +5,34 @@ from pycocotools import mask as coco_mask
5
  from pytoshop import layers
6
  import pytoshop
7
  from pytoshop.enums import BlendMode
 
 
 
 
 
8
 
9
 
10
  def generate_random_color():
11
  return np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)
12
 
13
 
14
- def create_base_layer(image):
15
  rgba_image = cv2.cvtColor(image, cv2.COLOR_RGB2RGBA)
16
  return [rgba_image]
17
 
18
 
19
- def create_mask_layers(image, masks):
 
 
 
20
  layer_list = []
21
 
22
- for result in masks:
23
- rle = result['segmentation']
24
- mask = coco_mask.decode(rle).astype(np.uint8)
 
 
 
25
  rgba_image = cv2.cvtColor(image, cv2.COLOR_RGB2RGBA)
26
  rgba_image[..., 3] = cv2.bitwise_and(rgba_image[..., 3], rgba_image[..., 3], mask=mask)
27
 
@@ -30,13 +41,18 @@ def create_mask_layers(image, masks):
30
  return layer_list
31
 
32
 
33
- def create_mask_gallery(image, masks):
 
 
 
34
  mask_array_list = []
35
  label_list = []
36
 
37
- for index, result in enumerate(masks):
38
- rle = result['segmentation']
39
- mask = coco_mask.decode(rle).astype(np.uint8)
 
 
40
 
41
  rgba_image = cv2.cvtColor(image, cv2.COLOR_RGB2RGBA)
42
  rgba_image[..., 3] = cv2.bitwise_and(rgba_image[..., 3], rgba_image[..., 3], mask=mask)
@@ -47,12 +63,15 @@ def create_mask_gallery(image, masks):
47
  return [[img, label] for img, label in zip(mask_array_list, label_list)]
48
 
49
 
50
- def create_mask_combined_images(image, masks):
 
 
 
51
  final_result = np.zeros_like(image)
52
 
53
- for result in masks:
54
- rle = result['segmentation']
55
- mask = coco_mask.decode(rle).astype(np.uint8)
56
 
57
  color = generate_random_color()
58
  colored_mask = np.zeros_like(image)
@@ -64,7 +83,12 @@ def create_mask_combined_images(image, masks):
64
  return [combined_image, "masked"]
65
 
66
 
67
- def insert_psd_layer(psd, image_data, layer_name, blending_mode):
 
 
 
 
 
68
  channel_data = [layers.ChannelImageData(image=image_data[:, :, i], compression=1) for i in range(4)]
69
 
70
  layer_record = layers.LayerRecord(
@@ -78,8 +102,14 @@ def insert_psd_layer(psd, image_data, layer_name, blending_mode):
78
  return psd
79
 
80
 
81
- def save_psd(input_image_data, layer_data, layer_names, blending_modes, output_path):
82
- psd_file = pytoshop.core.PsdFile(num_channels=3, height=input_image_data.shape[0], width=input_image_data.shape[1])
 
 
 
 
 
 
83
  psd_file.layer_and_mask_info.layer_info.layer_records.clear()
84
 
85
  for index, layer in enumerate(layer_data):
@@ -91,7 +121,7 @@ def save_psd(input_image_data, layer_data, layer_names, blending_modes, output_p
91
 
92
  def save_psd_with_masks(
93
  image: np.ndarray,
94
- masks: Dict,
95
  output_path: str
96
  ):
97
  original_layer = create_base_layer(image)
 
5
  from pytoshop import layers
6
  import pytoshop
7
  from pytoshop.enums import BlendMode
8
+ from pytoshop.core import PsdFile
9
+
10
+
11
+ def decode_to_mask(seg: np.ndarray[np.bool_]) -> np.ndarray[np.uint8]:
12
+ return seg.astype(np.uint8) * 255
13
 
14
 
15
  def generate_random_color():
16
  return np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)
17
 
18
 
19
+ def create_base_layer(image: np.ndarray):
20
  rgba_image = cv2.cvtColor(image, cv2.COLOR_RGB2RGBA)
21
  return [rgba_image]
22
 
23
 
24
+ def create_mask_layers(
25
+ image: np.ndarray,
26
+ masks: List
27
+ ):
28
  layer_list = []
29
 
30
+ sorted_masks = sorted(masks, key=lambda x: x['area'], reverse=True)
31
+
32
+ for info in sorted_masks:
33
+ rle = info['segmentation']
34
+ mask = decode_to_mask(rle)
35
+
36
  rgba_image = cv2.cvtColor(image, cv2.COLOR_RGB2RGBA)
37
  rgba_image[..., 3] = cv2.bitwise_and(rgba_image[..., 3], rgba_image[..., 3], mask=mask)
38
 
 
41
  return layer_list
42
 
43
 
44
+ def create_mask_gallery(
45
+ image: np.ndarray,
46
+ masks: List
47
+ ):
48
  mask_array_list = []
49
  label_list = []
50
 
51
+ sorted_masks = sorted(masks, key=lambda x: x['area'], reverse=True)
52
+
53
+ for index, info in enumerate(sorted_masks):
54
+ rle = info['segmentation']
55
+ mask = decode_to_mask(rle)
56
 
57
  rgba_image = cv2.cvtColor(image, cv2.COLOR_RGB2RGBA)
58
  rgba_image[..., 3] = cv2.bitwise_and(rgba_image[..., 3], rgba_image[..., 3], mask=mask)
 
63
  return [[img, label] for img, label in zip(mask_array_list, label_list)]
64
 
65
 
66
+ def create_mask_combined_images(
67
+ image: np.ndarray,
68
+ masks: List
69
+ ):
70
  final_result = np.zeros_like(image)
71
 
72
+ for info in masks:
73
+ rle = info['segmentation']
74
+ mask = decode_to_mask(rle)
75
 
76
  color = generate_random_color()
77
  colored_mask = np.zeros_like(image)
 
83
  return [combined_image, "masked"]
84
 
85
 
86
+ def insert_psd_layer(
87
+ psd: PsdFile,
88
+ image_data: np.ndarray,
89
+ layer_name: str,
90
+ blending_mode: BlendMode
91
+ ):
92
  channel_data = [layers.ChannelImageData(image=image_data[:, :, i], compression=1) for i in range(4)]
93
 
94
  layer_record = layers.LayerRecord(
 
102
  return psd
103
 
104
 
105
+ def save_psd(
106
+ input_image_data: np.ndarray,
107
+ layer_data: List,
108
+ layer_names: List,
109
+ blending_modes: List,
110
+ output_path: str
111
+ ):
112
+ psd_file = PsdFile(num_channels=3, height=input_image_data.shape[0], width=input_image_data.shape[1])
113
  psd_file.layer_and_mask_info.layer_info.layer_records.clear()
114
 
115
  for index, layer in enumerate(layer_data):
 
121
 
122
  def save_psd_with_masks(
123
  image: np.ndarray,
124
+ masks: List,
125
  output_path: str
126
  ):
127
  original_layer = create_base_layer(image)