jhj0517 commited on
Commit
11d7b39
1 Parent(s): 4af6ba2

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ models/
2
+ modules/__pycache__/
3
+ outputs/
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from modules import sam
4
+ from modules.ui_utils import *
5
+ from modules.html_constants import *
6
+
7
+
8
+ class App:
9
+ def __init__(self):
10
+ #download_sam_model_url()
11
+ self.app = gr.Blocks(css=CSS)
12
+ self.sam = sam.SamInference()
13
+
14
+ def launch(self):
15
+ with self.app:
16
+ with gr.Row():
17
+ gr.Markdown(MARKDOWN_NOTE, elem_id="md_pgroject")
18
+ with gr.Row().style(equal_height=True): # bug https://github.com/gradio-app/gradio/issues/3202
19
+ with gr.Column(scale=5):
20
+ img_input = gr.Image(label="Input image here")
21
+ with gr.Column(scale=5):
22
+ # Tuable Params
23
+ nb_points_per_side = gr.Number(label="points_per_side", value=32)
24
+ sld_pred_iou_thresh = gr.Slider(label="pred_iou_thresh", value=0.88, minimum=0, maximum=1)
25
+ sld_stability_score_thresh = gr.Slider(label="stability_score_thresh", value=0.95, minimum=0,
26
+ maximum=1)
27
+ nb_crop_n_layers = gr.Number(label="crop_n_layers", value=0)
28
+ nb_crop_n_points_downscale_factor = gr.Number(label="crop_n_points_downscale_factor", value=1)
29
+ nb_min_mask_region_area = gr.Number(label="min_mask_region_area", value=0)
30
+ html_param_explain = gr.HTML(PARAMS_EXPLANATION, elem_id="html_param_explain")
31
+
32
+ with gr.Row():
33
+ btn_generate = gr.Button("GENERATE", variant="primary")
34
+ with gr.Row():
35
+ gallery_output = gr.Gallery(label="Output will be shown here", show_label=True).style(grid=5,
36
+ height="auto")
37
+ btn_open_folder = gr.Button("📁\n(PSD)").style(full_width=False)
38
+
39
+ params = [nb_points_per_side, sld_pred_iou_thresh, sld_stability_score_thresh, nb_crop_n_layers,
40
+ nb_crop_n_points_downscale_factor, nb_min_mask_region_area]
41
+ btn_generate.click(fn=self.sam.generate_mask_app, inputs=[img_input] + params, outputs=gallery_output)
42
+ btn_open_folder.click(fn=lambda: open_folder("outputs\psd"), inputs=None, outputs=None)
43
+
44
+ self.app.queue(api_open=False).launch()
45
+
46
+
47
+ if __name__ == "__main__":
48
+ app = App()
49
+ app.launch()
models/model file will be saved here.txt ADDED
File without changes
modules/__init__.py ADDED
File without changes
modules/html_constants.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CSS = """
2
+ #md_project a {
3
+ color: black;
4
+ text-decoration: none;
5
+ }
6
+ #md_project a:hover {
7
+ text-decoration: underline;
8
+ }
9
+ """
10
+
11
+
12
+ PROJECT_NAME = """
13
+ # [Layer-Divider-WebUI](https://github.com/jhj0517/Layer-Divider-WebUI)
14
+ """
15
+
16
+ MARKDOWN_NOTE = """
17
+ ## This space only support CPU because it's free huggingface space.
18
+ ## If you want to run CUDA version , check this [repository](https://github.com/jhj0517/Layer-Divider-WebUI)
19
+ """
20
+
21
+ PARAMS_EXPLANATION = """
22
+ <!DOCTYPE html>
23
+ <html lang="en">
24
+ <head>
25
+ <meta charset="UTF-8">
26
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
27
+ <style>
28
+ table {
29
+ border-collapse: collapse;
30
+ width: 100%;
31
+ }
32
+ th, td {
33
+ border: 1px solid #dddddd;
34
+ text-align: left;
35
+ padding: 8px;
36
+ }
37
+ th {
38
+ background-color: #f2f2f2;
39
+ }
40
+ </style>
41
+ </head>
42
+ <body>
43
+
44
+ <details>
45
+ <summary>Explanation of Each Parameter</summary>
46
+ <table>
47
+ <thead>
48
+ <tr>
49
+ <th>Parameter</th>
50
+ <th>Description</th>
51
+ </tr>
52
+ </thead>
53
+ <tbody>
54
+ <tr>
55
+ <td>points_per_side</td>
56
+ <td>The number of points to be sampled along one side of the image. The total number of points is points_per_side**2. If None, 'point_grids' must provide explicit point sampling.</td>
57
+ </tr>
58
+ <tr>
59
+ <td>pred_iou_thresh</td>
60
+ <td>A filtering threshold in [0,1], using the model's predicted mask quality.</td>
61
+ </tr>
62
+ <tr>
63
+ <td>stability_score_thresh</td>
64
+ <td>A filtering threshold in [0,1], using the stability of the mask under changes to the cutoff used to binarize the model's mask predictions.</td>
65
+ </tr>
66
+ <tr>
67
+ <td>crops_n_layers</td>
68
+ <td>If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where each layer has 2**i_layer number of image crops.</td>
69
+ </tr>
70
+ <tr>
71
+ <td>crop_n_points_downscale_factor</td>
72
+ <td>The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.</td>
73
+ </tr>
74
+ <tr>
75
+ <td>min_mask_region_area</td>
76
+ <td>If >0, postprocessing will be applied to remove disconnected regions and holes in masks with area smaller than min_mask_region_area. Requires opencv.</td>
77
+ </tr>
78
+ </tbody>
79
+ </table>
80
+ </details>
81
+
82
+ </body>
83
+ </html>
84
+ """
modules/mask_utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from pycocotools import mask as coco_mask
4
+ from pytoshop import layers
5
+ import pytoshop
6
+ from pytoshop.enums import BlendMode
7
+ from datetime import datetime
8
+
9
+
10
+ def generate_random_color():
11
+ return np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)
12
+
13
+
14
+ def create_base_layer(image):
15
+ rgba_image = cv2.cvtColor(image, cv2.COLOR_RGB2RGBA)
16
+ return [rgba_image]
17
+
18
+
19
+ def create_mask_layers(image, masks):
20
+ layer_list = []
21
+
22
+ for result in masks:
23
+ rle = result['segmentation']
24
+ mask = coco_mask.decode(rle).astype(np.uint8)
25
+ rgba_image = cv2.cvtColor(image, cv2.COLOR_RGB2RGBA)
26
+ rgba_image[..., 3] = cv2.bitwise_and(rgba_image[..., 3], rgba_image[..., 3], mask=mask)
27
+
28
+ layer_list.append(rgba_image)
29
+
30
+ return layer_list
31
+
32
+
33
+ def create_mask_gallery(image, masks):
34
+ mask_array_list = []
35
+ label_list = []
36
+
37
+ for index, result in enumerate(masks):
38
+ rle = result['segmentation']
39
+ mask = coco_mask.decode(rle).astype(np.uint8)
40
+
41
+ rgba_image = cv2.cvtColor(image, cv2.COLOR_RGB2RGBA)
42
+ rgba_image[..., 3] = cv2.bitwise_and(rgba_image[..., 3], rgba_image[..., 3], mask=mask)
43
+
44
+ mask_array_list.append(rgba_image)
45
+ label_list.append(f'Part {index}')
46
+
47
+ return [[img, label] for img, label in zip(mask_array_list, label_list)]
48
+
49
+
50
+ def create_mask_combined_images(image, masks):
51
+ final_result = np.zeros_like(image)
52
+
53
+ for result in masks:
54
+ rle = result['segmentation']
55
+ mask = coco_mask.decode(rle).astype(np.uint8)
56
+
57
+ color = generate_random_color()
58
+ colored_mask = np.zeros_like(image)
59
+ colored_mask[mask == 1] = color
60
+
61
+ final_result = cv2.addWeighted(final_result, 1, colored_mask, 0.5, 0)
62
+
63
+ combined_image = cv2.addWeighted(image, 1, final_result, 0.5, 0)
64
+ return [combined_image, "masked"]
65
+
66
+
67
+ def insert_psd_layer(psd, image_data, layer_name, blending_mode):
68
+ channel_data = [layers.ChannelImageData(image=image_data[:, :, i], compression=1) for i in range(4)]
69
+
70
+ layer_record = layers.LayerRecord(
71
+ channels={-1: channel_data[3], 0: channel_data[0], 1: channel_data[1], 2: channel_data[2]},
72
+ top=0, bottom=image_data.shape[0], left=0, right=image_data.shape[1],
73
+ blend_mode=blending_mode,
74
+ name=layer_name,
75
+ opacity=255,
76
+ )
77
+ psd.layer_and_mask_info.layer_info.layer_records.append(layer_record)
78
+ return psd
79
+
80
+
81
+ def save_psd(input_image_data, layer_data, layer_names, blending_modes):
82
+ psd_file = pytoshop.core.PsdFile(num_channels=3, height=input_image_data.shape[0], width=input_image_data.shape[1])
83
+
84
+ for index, layer in enumerate(layer_data):
85
+ psd_file = insert_psd_layer(psd_file, layer, layer_names[index], blending_modes[index])
86
+
87
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
88
+ with open(f"outputs/psd/result-{timestamp}.psd", 'wb') as output_file:
89
+ psd_file.write(output_file)
90
+
91
+
92
+ def save_psd_with_masks(image, masks):
93
+ original_layer = create_base_layer(image)
94
+ mask_layers = create_mask_layers(image, masks)
95
+ names = [f'Part {i}' for i in range(len(mask_layers))]
96
+ modes = [BlendMode.normal] * (len(mask_layers)+1)
97
+ save_psd(image, original_layer+mask_layers, ['Original_Image']+names, modes)
98
+
99
+
modules/model_downloader.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ AVAILABLE_MODELS = {
4
+ "ViT-H SAM model": ["sam_vit_h_4b8939.pth", "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"],
5
+ "ViT-L SAM model": ["sam_vit_l_0b3195.pth", "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth"],
6
+ "ViT-B SAM model": ["sam_vit_b_01ec64.pth", "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"],
7
+ }
8
+
9
+
10
+ def download_sam_model_url():
11
+ torch.hub.download_url_to_file(AVAILABLE_MODELS["ViT-H SAM model"][1],
12
+ f'models/{AVAILABLE_MODELS["ViT-H SAM model"][0]}')
modules/sam.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
2
+ import os
3
+
4
+ from modules.mask_utils import *
5
+ from modules.model_downloader import *
6
+
7
+
8
+ class SamInference:
9
+ def __init__(self):
10
+ self.model = None
11
+ self.model_path = f"models/sam_vit_h_4b8939.pth"
12
+ self.device = "cuda"
13
+ self.mask_generator = None
14
+
15
+ # Tuable Parameters , All default values
16
+ self.tunable_params = {
17
+ 'points_per_side': 32,
18
+ 'pred_iou_thresh': 0.88,
19
+ 'stability_score_thresh': 0.95,
20
+ 'crop_n_layers': 0,
21
+ 'crop_n_points_downscale_factor': 1,
22
+ 'min_mask_region_area': 0
23
+ }
24
+
25
+ def set_mask_generator(self):
26
+ print("applying configs to model..")
27
+ if not os.path.exists(self.model_path):
28
+ print("No needed SAM model detected. downloading VIT H SAM model....")
29
+ download_sam_model_url()
30
+
31
+ self.model = sam_model_registry["default"](checkpoint=self.model_path)
32
+ self.model.to(device=self.device)
33
+ self.mask_generator = SamAutomaticMaskGenerator(
34
+ self.model,
35
+ points_per_side=self.tunable_params['points_per_side'],
36
+ pred_iou_thresh=self.tunable_params['pred_iou_thresh'],
37
+ stability_score_thresh=self.tunable_params['stability_score_thresh'],
38
+ crop_n_layers=self.tunable_params['crop_n_layers'],
39
+ crop_n_points_downscale_factor=self.tunable_params['crop_n_points_downscale_factor'],
40
+ min_mask_region_area=self.tunable_params['min_mask_region_area'],
41
+ output_mode="coco_rle",
42
+ )
43
+
44
+ def generate_mask(self, image):
45
+ return [self.mask_generator.generate(image)]
46
+
47
+ def generate_mask_app(self, image, *params):
48
+ tunable_params = {
49
+ 'points_per_side': int(params[0]),
50
+ 'pred_iou_thresh': float(params[1]),
51
+ 'stability_score_thresh': float(params[2]),
52
+ 'crop_n_layers': int(params[3]),
53
+ 'crop_n_points_downscale_factor': int(params[4]),
54
+ 'min_mask_region_area': int(params[5]),
55
+ }
56
+
57
+ if self.model is None or self.mask_generator is None or self.tunable_params != tunable_params:
58
+ self.tunable_params = tunable_params
59
+ self.set_mask_generator()
60
+
61
+ masks = self.mask_generator.generate(image)
62
+ save_psd_with_masks(image, masks)
63
+ combined_image = create_mask_combined_images(image, masks)
64
+ gallery = create_mask_gallery(image, masks)
65
+ return [combined_image] + gallery
modules/ui_utils.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ def open_folder(folder_path):
5
+ if os.path.exists(folder_path):
6
+ os.system(f"start {folder_path}")
7
+ else:
8
+ print(f"The folder {folder_path} does not exist.")
outputs/psd/psd file will be saved here.txt ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu117
2
+ torch
3
+ --extra-index-url https://download.pytorch.org/whl/cu117
4
+ torchvision
5
+ git+https://github.com/facebookresearch/segment-anything.git
6
+ opencv-python
7
+ pycocotools
8
+ matplotlib
9
+ onnxruntime
10
+ onnx
11
+ gradio
12
+ pytoshop==1.2.0
screenshot.png ADDED