NikeZoldyck's picture
Update utils/shared_utils.py
88bec8b
from pathlib import Path
from rembg import remove
import io
# Apply the transformations needed
from torch import autocast, nn
import torch
import torch.nn as nn
import torch
import torchvision.transforms as transforms
import torchvision.utils as utils
import torch.nn as nn
import pyrootutils
from PIL import Image
import numpy as np
from utils.photo_wct import PhotoWCT
from utils.photo_smooth import Propagator
#from utils.smooth_filter import smooth_filter
# Load models
root = Path.cwd()
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model
p_wct = PhotoWCT().to(device)
p_wct.load_state_dict(torch.load(root/"models/components/photo_wct.pth"))
p_pro = Propagator().to(device)
stylization_module=p_wct
smoothing_module=p_pro
#Dependecies - To be installed -
#!pip install replicate
#Token - To be authenticated -
#API TOKEN - 664474670af075461f85420f7b1d23d18484f826
#To be declared as an environment variable -
#export REPLICATE_API_TOKEN =
import replicate
import os
import requests
def stableDiffusionAPICall(text_prompt):
os.environ['REPLICATE_API_TOKEN'] = 'a9f4c06cb9808f42b29637bb60b7b88f106ad5b8'
model = replicate.models.get("stability-ai/stable-diffusion")
#text_prompt = 'photorealistic, elf fighting Sauron'
gen_bg_img = model.predict(prompt=text_prompt)[0]
img_data = requests.get(gen_bg_img).content
# r_data = binascii.unhexlify(img_data)
stream = io.BytesIO(img_data)
img = Image.open(stream)
del img_data
return img
def memory_limit_image_resize(cont_img):
# prevent too small or too big images
MINSIZE=400
MAXSIZE=800
orig_width = cont_img.width
orig_height = cont_img.height
if max(cont_img.width,cont_img.height) < MINSIZE:
if cont_img.width > cont_img.height:
cont_img.thumbnail((int(cont_img.width*1.0/cont_img.height*MINSIZE), MINSIZE), Image.BICUBIC)
else:
cont_img.thumbnail((MINSIZE, int(cont_img.height*1.0/cont_img.width*MINSIZE)), Image.BICUBIC)
if min(cont_img.width,cont_img.height) > MAXSIZE:
if cont_img.width > cont_img.height:
cont_img.thumbnail((MAXSIZE, int(cont_img.height*1.0/cont_img.width*MAXSIZE)), Image.BICUBIC)
else:
cont_img.thumbnail(((int(cont_img.width*1.0/cont_img.height*MAXSIZE), MAXSIZE)), Image.BICUBIC)
print("Resize image: (%d,%d)->(%d,%d)" % (orig_width, orig_height, cont_img.width, cont_img.height))
return cont_img.width, cont_img.height
def superimpose(input_img,back_img):
matte_img = remove(input_img)
back_img.paste(matte_img, (0, 0), matte_img)
return back_img,input_img
def style_transfer(cont_img,styl_img):
with torch.no_grad():
new_cw, new_ch = memory_limit_image_resize(cont_img)
new_sw, new_sh = memory_limit_image_resize(styl_img)
cont_pilimg = cont_img.copy()
cw = cont_pilimg.width
ch = cont_pilimg.height
cont_img = transforms.ToTensor()(cont_img).unsqueeze(0)
styl_img = transforms.ToTensor()(styl_img).unsqueeze(0)
cont_seg = []
styl_seg = []
if device == 'cuda':
cont_img = cont_img.to(device)
styl_img = styl_img.to(device)
stylization_module.to(device)
cont_seg = np.asarray(cont_seg)
styl_seg = np.asarray(styl_seg)
stylized_img = stylization_module.transform(cont_img, styl_img, cont_seg, styl_seg)
if ch != new_ch or cw != new_cw:
stylized_img = nn.functional.upsample(stylized_img, size=(ch, cw), mode='bilinear')
grid = utils.make_grid(stylized_img.data, nrow=1, padding=0)
ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
stylized_img = Image.fromarray(ndarr)
#final_img = smooth_filter(stylized_img, cont_pilimg, f_radius=15, f_edge=1e-1)
return stylized_img
def smoother(stylized_img, over_img):
if device == 'cuda':
smoothing_module.to(device)
final_img = smoothing_module.process(stylized_img, over_img)
#final_img = smooth_filter(stylized_img, over_img, f_radius=15, f_edge=1e-1)
return final_img
if __name__ == "__main__":
root = pyrootutils.setup_root(__file__, pythonpath=True)
fg_path = root/"notebooks/profile_new.png"
bg_path = root/"notebooks/back_img.png"
ckpt_path = root/"src/models/MODNet/pretrained/modnet_photographic_portrait_matting.ckpt"
#stableDiffusionAPICall("Photorealistic scenery of a concert")
fg_img = Image.open(fg_path).resize((800,800))
bg_img = Image.open(bg_path).resize((800,800))
#img = combined_display(fg_img, bg_img,ckpt_path)
img = superimpose(fg_img,bg_img)
img.save(root/"notebooks/overlay.png")
# bg_img.paste(img, (0, 0), img)
# bg_img.save(root/"notebooks/check.png")