stable-fashion / app.py
ovshake
add mask pil
ae6d3bd
raw
history blame
6.66 kB
import streamlit as st
from diffusers import StableDiffusionInpaintPipeline
import os
from tqdm import tqdm
from PIL import Image
import numpy as np
import cv2
import warnings
from huggingface_hub import hf_hub_download
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from data.base_dataset import Normalize_image
from utils.saving_utils import load_checkpoint_mgpu
from networks import U2NET
import argparse
from enum import Enum
from rembg import remove
from dataclasses import dataclass
@dataclass
class StableFashionCLIArgs:
image = None
part = None
resolution = None
promt = None
num_steps = None
guidance_scale = None
rembg = None
class Parts:
UPPER = 1
LOWER = 2
@st.cache(allow_output_mutation=True)
def load_u2net():
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint_path = hf_hub_download(repo_id="maiti/cloth-segmentation", filename="cloth_segm_u2net_latest.pth")
net = U2NET(in_ch=3, out_ch=4)
net = load_checkpoint_mgpu(net, checkpoint_path)
net = net.to(device)
net = net.eval()
return net
def change_bg_color(rgba_image, color):
new_image = Image.new("RGBA", rgba_image.size, color)
new_image.paste(rgba_image, (0, 0), rgba_image)
return new_image.convert("RGB")
@st.cache(allow_output_mutation=True)
def load_inpainting_pipeline():
device = "cuda" if torch.cuda.is_available() else "cpu"
inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
revision="fp16",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
use_auth_token=os.environ["hf_auth_token"]
).to(device)
return inpainting_pipeline
def process_image(args, inpainting_pipeline, net):
device = "cuda" if torch.cuda.is_available() else "cpu"
image_path = args.image
transforms_list = []
transforms_list += [transforms.ToTensor()]
transforms_list += [Normalize_image(0.5, 0.5)]
transform_rgb = transforms.Compose(transforms_list)
img = Image.open(image_path)
img = img.convert("RGB")
img = img.resize((args.resolution, args.resolution))
if args.rembg:
img_with_green_bg = remove(img)
img_with_green_bg = change_bg_color(img_with_green_bg, color="GREEN")
img_with_green_bg = img_with_green_bg.convert("RGB")
else:
img_with_green_bg = img
image_tensor = transform_rgb(img_with_green_bg)
image_tensor = image_tensor.unsqueeze(0)
with torch.autocast(device_type=device):
output_tensor = net(image_tensor.to(device))
output_tensor = F.log_softmax(output_tensor[0], dim=1)
output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
output_tensor = torch.squeeze(output_tensor, dim=0)
output_tensor = torch.squeeze(output_tensor, dim=0)
output_arr = output_tensor.cpu().numpy()
mask_code = eval(f"Parts.{args.part.upper()}")
mask = (output_arr == mask_code)
output_arr[mask] = 1
output_arr[~mask] = 0
output_arr *= 255
mask_PIL = Image.fromarray(output_arr.astype("uint8"), mode="L")
clothed_image_from_pipeline = inpainting_pipeline(prompt=args.prompt,
image=img_with_green_bg,
mask_image=mask_PIL,
width=args.resolution,
height=args.resolution,
guidance_scale=args.guidance_scale,
num_inference_steps=args.num_steps).images[0]
clothed_image_from_pipeline = remove(clothed_image_from_pipeline)
clothed_image_from_pipeline = change_bg_color(clothed_image_from_pipeline, "WHITE")
return clothed_image_from_pipeline.convert("RGB"), mask_PIL
net = load_u2net()
inpainting_pipeline = load_inpainting_pipeline()
st.markdown(
"""
<p style='text-align: center'>
<a href='https://github.com/ovshake' target='_blank'>ovshake Github</a> | <a href='https://github.com/ovshake/stable-fashion' target='_blank'>Stable Fashion Github</a> | <a href='https://huggingface.co/spaces/maiti/stable-fashion' target='_blank'>Stable Fashion Demo</a>
<br />
Follow me for more! <a href='https://twitter.com/o_v_shake' target='_blank'> <img src="https://img.icons8.com/color/48/000000/twitter--v1.png" height="30"></a><a href='https://github.com/ovshake' target='_blank'><img src="https://img.icons8.com/fluency/48/000000/github.png" height="27"></a><a href='https://www.linkedin.com/in/ovshake/' target='_blank'><img src="https://img.icons8.com/fluency/48/000000/linkedin.png" height="30"></a>
</p>
""",
unsafe_allow_html=True,
)
st.title("Stable Fashion Huggingface Spaces")
file_name = st.file_uploader("Upload a clear full length picture of yourself, preferably in a less noisy background")
stable_fashion_args = StableFashionCLIArgs()
stable_fashion_args.image = file_name
body_part = st.radio("Would you like to try clothes on your upper body (such as shirts, kurtas etc) or lower (Jeans, Pants etc)? ", ('Upper', 'Lower'))
stable_fashion_args.part = body_part
resolution = st.radio("Which resolution would you like to get the resulting picture in? (Keep in mind, higher the resolution, higher the queue times)", (128, 256, 512), index=2)
stable_fashion_args.resolution = resolution
rembg_status = st.radio("Would you like to remove background in your image before putting new clothes on you? (Sometimes it results in better images)", ("Yes", "No"), index=0)
stable_fashion_args.rembg = (rembg_status == "Yes")
guidance_scale = st.slider("Select a guidance scale. 7.5 gives the best results.", 1.0, 15.0, value=7.5)
stable_fashion_args.guidance_scale = guidance_scale
prompt = st.text_input('Write the description of cloth you want to try', 'a bright yellow t shirt')
stable_fashion_args.prompt = prompt
num_steps = st.slider("No. of inference steps for the diffusion process", 5, 50, value=25)
stable_fashion_args.num_steps = num_steps
if file_name is not None:
result_image, mask_PIL = process_image(stable_fashion_args, inpainting_pipeline, net)
print(np.unique(np.asarray(mask_PIL)))
st.image(result_image, caption='Result')
st.image(mask_PIL, caption='Mask')
else:
stock_image = Image.open('assets/abhishek_yellow.jpg')
st.image(stock_image, caption='Result')