Spaces:
Runtime error
Runtime error
File size: 5,445 Bytes
6724ca0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import streamlit as st
from diffusers import StableDiffusionInpaintPipeline
import os
from tqdm import tqdm
from PIL import Image
import numpy as np
import cv2
import warnings
from huggingface_hub import hf_hub_download
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from data.base_dataset import Normalize_image
from utils.saving_utils import load_checkpoint_mgpu
from networks import U2NET
import argparse
from enum import Enum
from rembg import remove
from dataclasses import dataclass
@dataclass
class StableFashionCLIArgs:
image
part
resolution
promt
num_steps
guidance_scale
rembg
class Parts:
UPPER = 1
LOWER = 2
def load_u2net():
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint_path = hf_hub_download(repo_id="maiti/cloth-segmentation", filename="cloth_segm_u2net_latest.pth")
net = U2NET(in_ch=3, out_ch=4)
net = load_checkpoint_mgpu(net, checkpoint_path)
net = net.to(device)
net = net.eval()
return net
def change_bg_color(rgba_image, color):
new_image = Image.new("RGBA", rgba_image.size, color)
new_image.paste(rgba_image, (0, 0), rgba_image)
return new_image.convert("RGB")
def load_inpainting_pipeline():
device = "cuda" if torch.cuda.is_available() else "cpu"
inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
revision="fp16",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
).to(device)
return inpainting_pipeline
def process_image(args, inpainting_pipeline, net):
device = "cuda" if torch.cuda.is_available() else "cpu"
image_path = args.image
transforms_list = []
transforms_list += [transforms.ToTensor()]
transforms_list += [Normalize_image(0.5, 0.5)]
transform_rgb = transforms.Compose(transforms_list)
img = Image.open(image_path)
img = img.convert("RGB")
img = img.resize((args.resolution, args.resolution))
if args.rembg:
img_with_green_bg = remove(img)
img_with_green_bg = change_bg_color(img_with_green_bg, color="GREEN")
img_with_green_bg = img_with_green_bg.convert("RGB")
else:
img_with_green_bg = img
image_tensor = transform_rgb(img_with_green_bg)
image_tensor = image_tensor.unsqueeze(0)
output_tensor = net(image_tensor.to(device))
output_tensor = F.log_softmax(output_tensor[0], dim=1)
output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
output_tensor = torch.squeeze(output_tensor, dim=0)
output_tensor = torch.squeeze(output_tensor, dim=0)
output_arr = output_tensor.cpu().numpy()
mask_code = eval(f"Parts.{args.part.upper()}")
mask = (output_arr == mask_code)
output_arr[mask] = 1
output_arr[~mask] = 0
output_arr *= 255
mask_PIL = Image.fromarray(output_arr.astype("uint8"), mode="L")
clothed_image_from_pipeline = inpainting_pipeline(prompt=args.prompt,
image=img_with_green_bg,
mask_image=mask_PIL,
width=args.resolution,
height=args.resolution,
guidance_scale=args.guidance_scale,
num_inference_steps=args.num_steps).images[0]
clothed_image_from_pipeline = remove(clothed_image_from_pipeline)
clothed_image_from_pipeline = change_bg_color(clothed_image_from_pipeline, "WHITE")
return clothed_image_from_pipeline.convert("RGB")
st.title("Stable Fashion Huggingface Spaces")
file_name = st.file_uploader("Upload a clear full length picture of yourself, preferably in a less noisy background")
net = load_u2net()
inpainting_pipeline = load_inpainting_pipeline()
if file_name is not None:
image = Image.open(file_name)
stable_fashion_args = StableFashionCLIArgs()
stable_fashion_args.image = image
body_part = st.radio("Would you like to try clothes on your upper body (such as shirts, kurtas etc) or lower (Jeans, Pants etc)? ", ('Upper', 'Lower'))
stable_fashion_args.part = body_part
resolution = st.radio("Which resolution would you like to get the resulting picture in? (Keep in mind, higher the resolution, higher the queue times)", (128, 256, 512))
stable_fashion_args.resolution = resolution
rembg_status = st.radio("Would you like to remove background in your image before putting new clothes on you? (Sometimes it results in better images)", ("Yes", "No"))
stable_fashion_args.rembg = (rembg_status == "Yes")
guidance_scale = st.slider("Select a guidance scale. 7.5 gives the best results.", 1.0, 15.0, value=7.5)
stable_fashion_args.guidance_scale = guidance_scale
prompt = st.text_input('Write the description of cloth you want to try', 'a bright yellow t shirt')
stable_fashion_args.prompt = guidance_scale
num_steps = st.slider("No. of inference steps for the diffusion process", 5, 50, value=25)
result_image = process_image(stable_fashion_args, inpainting_pipeline, net)
st.image(result_image, caption='Sunrise by the mountains')
|