import streamlit as st from diffusers import StableDiffusionInpaintPipeline import os from tqdm import tqdm from PIL import Image import numpy as np import cv2 import warnings from huggingface_hub import hf_hub_download warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=DeprecationWarning) import torch import torch.nn.functional as F import torchvision.transforms as transforms from data.base_dataset import Normalize_image from utils.saving_utils import load_checkpoint_mgpu from networks import U2NET import argparse from enum import Enum from rembg import remove from dataclasses import dataclass @dataclass class StableFashionCLIArgs: image = None part = None resolution = None promt = None num_steps = None guidance_scale = None rembg = None class Parts: UPPER = 1 LOWER = 2 @st.cache(allow_output_mutation=True) def load_u2net(): device = "cuda" if torch.cuda.is_available() else "cpu" checkpoint_path = hf_hub_download(repo_id="maiti/cloth-segmentation", filename="cloth_segm_u2net_latest.pth") net = U2NET(in_ch=3, out_ch=4) net = load_checkpoint_mgpu(net, checkpoint_path) net = net.to(device) net = net.eval() return net def change_bg_color(rgba_image, color): new_image = Image.new("RGBA", rgba_image.size, color) new_image.paste(rgba_image, (0, 0), rgba_image) return new_image.convert("RGB") @st.cache(allow_output_mutation=True) def load_inpainting_pipeline(): device = "cuda" if torch.cuda.is_available() else "cpu" inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-inpainting", revision="fp16", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, use_auth_token=os.environ["hf_auth_token"] ).to(device) return inpainting_pipeline def process_image(args, inpainting_pipeline, net): device = "cuda" if torch.cuda.is_available() else "cpu" image_path = args.image transforms_list = [] transforms_list += [transforms.ToTensor()] transforms_list += [Normalize_image(0.5, 0.5)] transform_rgb = transforms.Compose(transforms_list) img = Image.open(image_path) img = img.convert("RGB") img = img.resize((args.resolution, args.resolution)) if args.rembg: img_with_green_bg = remove(img) img_with_green_bg = change_bg_color(img_with_green_bg, color="GREEN") img_with_green_bg = img_with_green_bg.convert("RGB") else: img_with_green_bg = img image_tensor = transform_rgb(img_with_green_bg) image_tensor = image_tensor.unsqueeze(0) with torch.autocast(device_type=device): output_tensor = net(image_tensor.to(device)) output_tensor = F.log_softmax(output_tensor[0], dim=1) output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1] output_tensor = torch.squeeze(output_tensor, dim=0) output_tensor = torch.squeeze(output_tensor, dim=0) output_arr = output_tensor.cpu().numpy() mask_code = eval(f"Parts.{args.part.upper()}") mask = (output_arr == mask_code) print(f"Numbers in output_rr") print(np.unique(output_arr)) print(f"mask code {mask_code}") output_arr[mask] = 1 output_arr[~mask] = 0 output_arr *= 255 mask_PIL = Image.fromarray(output_arr.astype("uint8"), mode="L") clothed_image_from_pipeline = inpainting_pipeline(prompt=args.prompt, image=img_with_green_bg, mask_image=mask_PIL, width=args.resolution, height=args.resolution, guidance_scale=args.guidance_scale, num_inference_steps=args.num_steps).images[0] clothed_image_from_pipeline = remove(clothed_image_from_pipeline) clothed_image_from_pipeline = change_bg_color(clothed_image_from_pipeline, "WHITE") return clothed_image_from_pipeline.convert("RGB"), mask_PIL net = load_u2net() inpainting_pipeline = load_inpainting_pipeline() st.markdown( """

ovshake Github | Stable Fashion Github | Stable Fashion Demo
Follow me for more!

""", unsafe_allow_html=True, ) st.title("Stable Fashion Huggingface Spaces") file_name = st.file_uploader("Upload a clear full length picture of yourself, preferably in a less noisy background") stable_fashion_args = StableFashionCLIArgs() stable_fashion_args.image = file_name body_part = st.radio("Would you like to try clothes on your upper body (such as shirts, kurtas etc) or lower (Jeans, Pants etc)? ", ('Upper', 'Lower')) stable_fashion_args.part = body_part resolution = st.radio("Which resolution would you like to get the resulting picture in? (Keep in mind, higher the resolution, higher the queue times)", (128, 256, 512), index=2) stable_fashion_args.resolution = resolution rembg_status = st.radio("Would you like to remove background in your image before putting new clothes on you? (Sometimes it results in better images)", ("Yes", "No"), index=0) stable_fashion_args.rembg = (rembg_status == "Yes") guidance_scale = st.slider("Select a guidance scale. 7.5 gives the best results.", 1.0, 15.0, value=7.5) stable_fashion_args.guidance_scale = guidance_scale prompt = st.text_input('Write the description of cloth you want to try', 'a bright yellow t shirt') stable_fashion_args.prompt = prompt num_steps = st.slider("No. of inference steps for the diffusion process", 5, 50, value=25) stable_fashion_args.num_steps = num_steps if file_name is not None: result_image, mask_PIL = process_image(stable_fashion_args, inpainting_pipeline, net) print(np.unique(np.asarray(mask_PIL))) st.image(result_image, caption='Result') st.image(mask_PIL, caption='Mask') else: stock_image = Image.open('assets/abhishek_yellow.jpg') st.image(stock_image, caption='Result')