ash11sh's picture
Update app.py
9bf2d4d
import io
import os
import sys
import cv2
import requests
import numpy as np
from io import BytesIO
from pathlib import Path
from PIL import Image, ImageFile, ImageFilter, ImageEnhance, ImageOps
# from misc import get_potrait
import torch
import contextlib
from data.base_dataset import get_transform
from models.cut_model import CUTModel
from util.util import tensor2im
from argparse import Namespace
from pathlib import Path
from copy import deepcopy
from codeformer.app import inference_app
from rembg import remove, new_session
import gradio as gr
# CUTGAN input options
OPT = Namespace(
batch_size=1,
checkpoints_dir="cyclegan",
crop_size=256,
# dataroot=".",
dataset_mode="unaligned",
direction="AtoB",
display_id=-1,
display_winsize=256,
epoch="latest",
eval=False,
gpu_ids=[],
nce_layers="0,4,8,12,16",
nce_idt=False,
lambda_NCE=10.0,
lambda_GAN=1.0,
init_gain=0.02,
nce_includes_all_negatives_from_minibatch=False,
init_type="xavier",
normG="instance",
no_antialias=False,
no_antialias_up=False,
netF="mlp_sample",
netF_nc=256,
nce_T=0.07,
num_patches=256,
CUT_mode="FastCUT",
input_nc=3,
isTrain=False,
load_iter=0,
load_size=256,
max_dataset_size=float("inf"),
model="CUT",
n_layers_D=3,
name=None,
ndf=64,
netD="basic",
netG="resnet_9blocks",
ngf=64,
no_dropout=True,
no_flip=True,
num_test=50,
num_threads=4,
output_nc=3,
phase="test",
preprocess="scale_width",
random_scale_max=3.0,
results_dir="./results/",
serial_batches=True,
suffix="",
verbose=False,
)
class SingleImageDataset(torch.utils.data.Dataset):
"""dataset with precisely one image"""
def __init__(self, img, preprocess):
img = preprocess(img)
self.img = img
def __getitem__(self, i):
return self.img
def __len__(self):
return 1
fp = "cyclegan/EyeFastcut/latest_net_G.pth"
opt = deepcopy(OPT)
model_name = "EyeFastcut"
opt.name = model_name
if opt.verbose:
# model = load_model(opt, model_fp)
model = CUTModel(opt).netG
model.load_state_dict(torch.load(fp))
else:
with contextlib.redirect_stdout(io.StringIO()):
# model = load_model(opt, model_fp)
model = CUTModel(opt).netG
model.load_state_dict(torch.load(fp))
# inference code for single image - cutgan
"""reference inference code:
https://www.jeremyafisher.com/running-cyclegan-programmatically.html
"""
def cutgan(img: Image) -> Image:
img = img.convert("RGB")
data_loader = torch.utils.data.DataLoader(
SingleImageDataset(img, get_transform(opt)), batch_size=1
)
data = next(iter(data_loader))
with torch.no_grad():
pred = model(data)
pred_arr = tensor2im(pred)
pred_img = Image.fromarray(pred_arr)
return pred_img
# image resize function
def imsize(img, max_size=512, maintain_aspect_ratio=True):
# calculate desired dimensions
if maintain_aspect_ratio:
if img.height > max_size or img.width > max_size:
# if width > height:
if img.width > img.height:
desired_width = max_size
desired_height = int(img.height / (img.width / max_size))
# if height > width:
elif img.height > img.width:
desired_height = max_size
desired_width = int(img.width / (img.height / max_size))
else:
desired_height = max_size
desired_width = max_size
else:
desired_width = img.width
desired_height = img.height
else:
desired_width = max_size
desired_height = max_size
# round desired dimensions to nearest multiple of 8
desired_width = (desired_width // 8) * 8
desired_height = (desired_height // 8) * 8
# resize image
desired_dimensions = (desired_width, desired_height)
transition_image = img.resize(desired_dimensions)
return transition_image
def rem_glass(input_img):
w_0, h_0 = input_img.size
#resizing
im = imsize(input_img, max_size=256, maintain_aspect_ratio=False)
width, height = im.size
ori_im = im.copy()
# get cutout and mask
session = new_session("u2net_human_seg")
im = remove(ori_im,False,240,10,20, session, only_mask=False)
mask = remove(ori_im,False,240,10,20, session, only_mask=True)
# send image to model to remove glasses
im = cutgan(im)
# composite original image and output based on mask
w, h = im.size
ori_im = ori_im.resize((w, h))
mask = mask.resize((w, h))
img = Image.composite(im, ori_im, mask)
# upscale the image
# img = upscale(img, model_cran_v2)
# scale image to original size
img = img.resize((w_0, h_0))
img.save("removal.png")
inference_app(
image="removal.png",
background_enhance=False,
face_upsample=False,
upscale=2,
codeformer_fidelity=0.5,)
return Image.open('output/out.png')
demo = gr.Interface(rem_glass, gr.inputs.Image(type="pil"), gr.outputs.Image(type="pil"),)
if __name__ == "__main__":
demo.launch()