File size: 4,077 Bytes
0c43c79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a97974
0c43c79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a97974
0c43c79
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.cluster import KMeans, MiniBatchKMeans

from scripts.convertor import rgb2df, df2rgba

import gradio as gr
import huggingface_hub
import onnxruntime as rt
import copy
from PIL import Image

import segmentation_refinement as refine


# Declare Execution Providers
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']

# Download and host the model
model_path = huggingface_hub.hf_hub_download(
    "skytnt/anime-seg", "isnetis.onnx")
rmbg_model = rt.InferenceSession(model_path, providers=providers)

def get_mask(img, s=1024):
    img = (img / 255).astype(np.float32)
    dim = img.shape[2]
    if dim == 4:
        img = img[..., :3]
        dim = 3
    h, w = h0, w0 = img.shape[:-1]
    h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
    ph, pw = s - h, s - w
    img_input = np.zeros([s, s, dim], dtype=np.float32)
    img_input[ph // 2:ph // 2 + h, pw //
              2:pw // 2 + w] = cv2.resize(img, (w, h))
    img_input = np.transpose(img_input, (2, 0, 1))
    img_input = img_input[np.newaxis, :]
    mask = rmbg_model.run(None, {'img': img_input})[0][0]
    mask = np.transpose(mask, (1, 2, 0))
    mask = mask[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
    mask = cv2.resize(mask, (w0, h0))[:, :, np.newaxis]
    return mask

def assign_tile(row, tile_width, tile_height):
    tile_x = row['x_l'] // tile_width
    tile_y = row['y_l'] // tile_height
    return f"tile_{tile_y}_{tile_x}"

def rmbg_fn(img):
    mask = get_mask(img)
    img = (mask * img + 255 * (1 - mask)).astype(np.uint8)
    mask = (mask * 255).astype(np.uint8)
    img = np.concatenate([img, mask], axis=2, dtype=np.uint8)
    mask = mask.repeat(3, axis=2)
    return mask, img

def refinement(img, mask, fast, psp_L):
    mask =  cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
    refiner = refine.Refiner(device='cpu') # device can also be 'cpu'

    # Fast - Global step only.
    # Smaller L -> Less memory usage; faster in fast mode.
    mask = refiner.refine(img, mask, fast=fast, L=psp_L) 

    return mask


def get_foreground(img, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L):
    if td_abg_enabled == True:
        mask = get_mask(img)
        mask = (mask * 255).astype(np.uint8)
        mask = mask.repeat(3, axis=2)
        if cascadePSP_enabled == True:
            mask =  refinement(img, mask, fast, psp_L)
            mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB)
        df = rgb2df(img)

        image_width = img.shape[1] 
        image_height = img.shape[0] 

        num_horizontal_splits = h_split
        num_vertical_splits = v_split
        tile_width = image_width // num_horizontal_splits
        tile_height = image_height // num_vertical_splits

        df['tile'] = df.apply(assign_tile, args=(tile_width, tile_height), axis=1)

        cls = MiniBatchKMeans(n_clusters=n_cluster, batch_size=100)
        cls.fit(df[["r","g","b"]])
        df["label"] = cls.labels_

        mask_df = rgb2df(mask)
        mask_df['bg_label'] = (mask_df['r'] > alpha) & (mask_df['g'] > alpha) & (mask_df['b'] > alpha)

        img_df = df.copy()
        img_df["bg_label"] = mask_df["bg_label"]
        img_df["label"] = img_df["label"].astype(str) + "-" + img_df["tile"]
        bg_rate = img_df.groupby("label").sum()["bg_label"]/img_df.groupby("label").count()["bg_label"]
        img_df['bg_cls'] = (img_df['label'].isin(bg_rate[bg_rate > th_rate].index)).astype(int)
        img_df.loc[img_df['bg_cls'] == 0, ['a']] = 0
        img_df.loc[img_df['bg_cls'] != 0, ['a']] = 255
        img = df2rgba(img_df)

    if cascadePSP_enabled == True and td_abg_enabled == False:
        mask = get_mask(img)
        mask = (mask * 255).astype(np.uint8)
        refiner = refine.Refiner(device='cpu')
        mask = refiner.refine(img, mask, fast=fast, L=psp_L)
        img = np.dstack((img, mask))
    
    if cascadePSP_enabled == False and td_abg_enabled == False:
        mask, img = rmbg_fn(img)

    return mask, img