|
import streamlit as st |
|
from PIL import Image, ImageOps |
|
import cv2 |
|
import numpy as np |
|
import random |
|
import time |
|
import seaborn as sns |
|
|
|
from cv_funcs import * |
|
from torchvision_funcs import * |
|
|
|
|
|
from backgroundremover.utilities import download_download_files_from_github |
|
import torch |
|
import os |
|
from hsh.library.hash import Hasher |
|
from torchvision import transforms |
|
|
|
def load_model(model_name: str = "u2net"): |
|
hasher = Hasher() |
|
|
|
model = { |
|
'u2netp': (u2net.U2NETP, |
|
'e4f636406ca4e2af789941e7f139ee2e', |
|
'1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy', |
|
'U2NET_PATH'), |
|
'u2net': (u2net.U2NET, |
|
'09fb4e49b7f785c9f855baf94916840a', |
|
'1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ', |
|
'U2NET_PATH'), |
|
'u2net_human_seg': (u2net.U2NET, |
|
'347c3d51b01528e5c6c071e3cff1cb55', |
|
'1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P', |
|
'U2NET_PATH') |
|
}[model_name] |
|
|
|
if model_name == "u2net": |
|
net = u2net.U2NET(3, 1) |
|
path = os.environ.get( |
|
"U2NET_PATH", |
|
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")), |
|
) |
|
if ( |
|
not os.path.exists(path) |
|
or hasher.md5(path) != "09fb4e49b7f785c9f855baf94916840a" |
|
): |
|
download_downloadfiles_from_github( |
|
path, model_name |
|
) |
|
else: |
|
print("Choose between u2net, u2net_human_seg or u2netp", file=sys.stderr) |
|
|
|
try: |
|
if torch.cuda.is_available(): |
|
net.load_state_dict(torch.load(path)) |
|
net.to(torch.device("cuda")) |
|
else: |
|
net.load_state_dict( |
|
torch.load( |
|
path, |
|
map_location="cpu", |
|
) |
|
) |
|
except FileNotFoundError: |
|
raise FileNotFoundError( |
|
errno.ENOENT, os.strerror(errno.ENOENT), model_name + ".pth" |
|
) |
|
|
|
net.eval() |
|
|
|
return net |
|
|
|
def norm_pred(d): |
|
ma = torch.max(d) |
|
mi = torch.min(d) |
|
dn = (d - mi) / (ma - mi) |
|
|
|
return dn |
|
|
|
|
|
def preprocess(image): |
|
label_3 = np.zeros(image.shape) |
|
label = np.zeros(label_3.shape[0:2]) |
|
|
|
if 3 == len(label_3.shape): |
|
label = label_3[:, :, 0] |
|
elif 2 == len(label_3.shape): |
|
label = label_3 |
|
|
|
if 3 == len(image.shape) and 2 == len(label.shape): |
|
label = label[:, :, np.newaxis] |
|
elif 2 == len(image.shape) and 2 == len(label.shape): |
|
image = image[:, :, np.newaxis] |
|
label = label[:, :, np.newaxis] |
|
|
|
transform = transforms.Compose( |
|
[data_loader.RescaleT(320), data_loader.ToTensorLab(flag=0)] |
|
) |
|
sample = transform({"imidx": np.array([0]), "image": image, "label": label}) |
|
|
|
return sample |
|
|
|
|
|
def predict(net, item): |
|
sample = preprocess(item) |
|
|
|
with torch.no_grad(): |
|
|
|
if torch.cuda.is_available(): |
|
inputs_test = torch.cuda.FloatTensor( |
|
sample["image"].unsqueeze(0).cuda().float() |
|
) |
|
else: |
|
inputs_test = torch.FloatTensor(sample["image"].unsqueeze(0).float()) |
|
|
|
d1, d2, d3, d4, d5, d6, d7 = net(inputs_test) |
|
|
|
pred = d1[:, 0, :, :] |
|
predict = norm_pred(pred) |
|
|
|
predict = predict.squeeze() |
|
predict_np = predict.cpu().detach().numpy() |
|
img = Image.fromarray(predict_np * 255).convert("RGB") |
|
|
|
del d1, d2, d3, d4, d5, d6, d7, pred, predict, predict_np, inputs_test, sample |
|
|
|
return img |
|
|
|
def remove_bg(img): |
|
img_arry = np.array(img) |
|
model = load_model(model_name="u2net") |
|
mask = predict(model, img_arry) |
|
mask = mask.resize(img.size) |
|
|
|
mask_arry = np.array(mask) |
|
mask_arry[mask_arry>0] = 1 |
|
img_masked = Image.fromarray(cv2.multiply(img_arry, mask_arry)) |
|
index_masked = np.where(np.array(mask)==0) |
|
return img_masked, index_masked |
|
|
|
@st.cache |
|
def show_generated_image(image): |
|
st.image(image) |
|
|
|
@st.cache(suppress_st_warning=True) |
|
def randomize_palette_colors(n_rows, n_cols, palettes=['Set1', 'Set3', 'Spectral'], seed=time.time(), n_times=10): |
|
random.seed(seed) |
|
colors = [sns.color_palette(palette, n_rows*n_cols*n_times) for palette in palettes] |
|
_ = [random.shuffle(color) for color in colors] |
|
return colors |
|
|
|
@st.cache(suppress_st_warning=True) |
|
def remove_image_background(image): |
|
|
|
return remove_bg(img) |
|
|
|
title = 'Andy Warhol like Image Generator' |
|
st.set_page_config(page_title=title, page_icon='favicon.jpeg', layout='centered') |
|
st.title(title) |
|
uploaded_file = st.file_uploader('Choose an image file') |
|
if uploaded_file is None: uploaded_file = './sample.jpg' |
|
|
|
if uploaded_file is not None: |
|
im = Image.open(uploaded_file) |
|
im.thumbnail((1000, 1000),resample=Image.BICUBIC) |
|
|
|
is_masked = st.checkbox('With background masking? (3 colors)') |
|
if is_masked: |
|
im_masked, index_masked = remove_image_background(im) |
|
st.image(im_masked, caption='Masked image') |
|
else: st.image(im, caption='Original') |
|
|
|
im_gray = np.array(im.convert('L')) |
|
thresh, _img = cv2.threshold(im_gray, 0, 255, cv2.THRESH_OTSU) |
|
|
|
n_rows, n_cols = st.number_input('Rows', value=3), st.number_input('Columns', value=3) |
|
|
|
thresh = st.slider('Threshold', value=thresh, min_value=0.0, max_value=255.0) |
|
colors = randomize_palette_colors(n_rows, n_cols, seed=0) |
|
|
|
if st.button('Shuffle colors'): |
|
colors = randomize_palette_colors(n_rows, n_cols, seed=time.time()) |
|
|
|
if True or st.button('Generate'): |
|
ims_generated = [] |
|
|
|
for row in range(n_rows): |
|
for col in range(n_cols): |
|
i_color = n_cols * row + col |
|
rgbs = [np.array(color[i_color])*np.array([255, 255, 255]).tolist() for color in colors] |
|
ims_col = np.empty((*im_gray.shape, 3)) |
|
for i in range(3): |
|
ims_col[:, :, i] = (im_gray <= thresh) * rgbs[0][i] + (im_gray > thresh) * rgbs[1][i] |
|
if is_masked: ims_col[:, :, i][index_masked] = rgbs[2][i] |
|
if col == 0: |
|
im_col_concat = Image.fromarray(ims_col.astype(np.uint8)) |
|
else: |
|
im_col_concat = get_concat_h(im_col_concat, Image.fromarray(ims_col.astype(np.uint8))) |
|
if row == 0: |
|
im_generated = im_col_concat |
|
else: |
|
im_generated = get_concat_v(im_generated, im_col_concat) |
|
|
|
st.image(im_generated) |
|
|