import streamlit as st from diffusers import StableDiffusionInpaintPipeline import os from tqdm import tqdm from PIL import Image import numpy as np import cv2 import warnings from huggingface_hub import hf_hub_download warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=DeprecationWarning) import torch import torch.nn.functional as F import torchvision.transforms as transforms from data.base_dataset import Normalize_image from utils.saving_utils import load_checkpoint_mgpu from networks import U2NET import argparse from enum import Enum from rembg import remove from dataclasses import dataclass @dataclass class StableFashionCLIArgs: image = None part = None resolution = None promt = None num_steps = None guidance_scale = None rembg = None class Parts: UPPER = 1 LOWER = 2 @st.cache(allow_output_mutation=True) def load_u2net(): device = "cuda" if torch.cuda.is_available() else "cpu" checkpoint_path = hf_hub_download(repo_id="maiti/cloth-segmentation", filename="cloth_segm_u2net_latest.pth") net = U2NET(in_ch=3, out_ch=4) net = load_checkpoint_mgpu(net, checkpoint_path) net = net.to(device) net = net.eval() return net def change_bg_color(rgba_image, color): new_image = Image.new("RGBA", rgba_image.size, color) new_image.paste(rgba_image, (0, 0), rgba_image) return new_image.convert("RGB") @st.cache(allow_output_mutation=True) def load_inpainting_pipeline(): device = "cuda" if torch.cuda.is_available() else "cpu" inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-inpainting", revision="fp16", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, use_auth_token=os.environ["hf_auth_token"] ).to(device) return inpainting_pipeline def process_image(args, inpainting_pipeline, net): device = "cuda" if torch.cuda.is_available() else "cpu" image_path = args.image transforms_list = [] transforms_list += [transforms.ToTensor()] transforms_list += [Normalize_image(0.5, 0.5)] transform_rgb = transforms.Compose(transforms_list) img = Image.open(image_path) img = img.convert("RGB") img = img.resize((args.resolution, args.resolution)) if args.rembg: img_with_green_bg = remove(img) img_with_green_bg = change_bg_color(img_with_green_bg, color="GREEN") img_with_green_bg = img_with_green_bg.convert("RGB") else: img_with_green_bg = img image_tensor = transform_rgb(img_with_green_bg) image_tensor = image_tensor.unsqueeze(0) with torch.autocast(device_type=device): output_tensor = net(image_tensor.to(device)) output_tensor = F.log_softmax(output_tensor[0], dim=1) output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1] output_tensor = torch.squeeze(output_tensor, dim=0) output_tensor = torch.squeeze(output_tensor, dim=0) output_arr = output_tensor.cpu().numpy() mask_code = eval(f"Parts.{args.part.upper()}") mask = (output_arr == mask_code) output_arr[mask] = 1 output_arr[~mask] = 0 output_arr *= 255 mask_PIL = Image.fromarray(output_arr.astype("uint8"), mode="L") clothed_image_from_pipeline = inpainting_pipeline(prompt=args.prompt, image=img_with_green_bg, mask_image=mask_PIL, width=args.resolution, height=args.resolution, guidance_scale=args.guidance_scale, num_inference_steps=args.num_steps).images[0] clothed_image_from_pipeline = remove(clothed_image_from_pipeline) clothed_image_from_pipeline = change_bg_color(clothed_image_from_pipeline, "WHITE") return clothed_image_from_pipeline.convert("RGB"). mask_PIL net = load_u2net() inpainting_pipeline = load_inpainting_pipeline() st.markdown( """
ovshake Github | Stable Fashion Github | Stable Fashion Demo
Follow me for more!