File size: 5,445 Bytes
6724ca0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import streamlit as st
from diffusers import StableDiffusionInpaintPipeline
import os

from tqdm import tqdm
from PIL import Image
import numpy as np
import cv2
import warnings
from huggingface_hub import hf_hub_download

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)

import torch
import torch.nn.functional as F
import torchvision.transforms as transforms

from data.base_dataset import Normalize_image
from utils.saving_utils import load_checkpoint_mgpu
from networks import U2NET
import argparse
from enum import Enum
from rembg import remove
from dataclasses import dataclass


@dataclass
class StableFashionCLIArgs:
    image
    part
    resolution
    promt
    num_steps
    guidance_scale
    rembg


class Parts:
    UPPER = 1
    LOWER = 2


def load_u2net():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    checkpoint_path = hf_hub_download(repo_id="maiti/cloth-segmentation", filename="cloth_segm_u2net_latest.pth")
    net = U2NET(in_ch=3, out_ch=4)
    net = load_checkpoint_mgpu(net, checkpoint_path)
    net = net.to(device)
    net = net.eval()
    return net

def change_bg_color(rgba_image, color):
    new_image = Image.new("RGBA", rgba_image.size, color)
    new_image.paste(rgba_image, (0, 0), rgba_image)
    return new_image.convert("RGB")

def load_inpainting_pipeline():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
            "runwayml/stable-diffusion-inpainting",
            revision="fp16",
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        ).to(device)
    return inpainting_pipeline


def process_image(args, inpainting_pipeline, net):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    image_path = args.image
    transforms_list = []
    transforms_list += [transforms.ToTensor()]
    transforms_list += [Normalize_image(0.5, 0.5)]
    transform_rgb = transforms.Compose(transforms_list)
    img = Image.open(image_path)
    img = img.convert("RGB")
    img = img.resize((args.resolution, args.resolution))
    if args.rembg:
        img_with_green_bg = remove(img)
        img_with_green_bg = change_bg_color(img_with_green_bg, color="GREEN")
        img_with_green_bg = img_with_green_bg.convert("RGB")
    else:
        img_with_green_bg = img
    image_tensor = transform_rgb(img_with_green_bg)
    image_tensor = image_tensor.unsqueeze(0)
    output_tensor = net(image_tensor.to(device))
    output_tensor = F.log_softmax(output_tensor[0], dim=1)
    output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
    output_tensor = torch.squeeze(output_tensor, dim=0)
    output_tensor = torch.squeeze(output_tensor, dim=0)
    output_arr = output_tensor.cpu().numpy()
    mask_code = eval(f"Parts.{args.part.upper()}")
    mask = (output_arr == mask_code)
    output_arr[mask] = 1
    output_arr[~mask] = 0
    output_arr *= 255
    mask_PIL = Image.fromarray(output_arr.astype("uint8"), mode="L")
    clothed_image_from_pipeline = inpainting_pipeline(prompt=args.prompt,
                                                    image=img_with_green_bg,
                                                    mask_image=mask_PIL,
                                                    width=args.resolution,
                                                    height=args.resolution,
                                                    guidance_scale=args.guidance_scale,
                                                    num_inference_steps=args.num_steps).images[0]
    clothed_image_from_pipeline = remove(clothed_image_from_pipeline)
    clothed_image_from_pipeline = change_bg_color(clothed_image_from_pipeline, "WHITE")
    return clothed_image_from_pipeline.convert("RGB")


st.title("Stable Fashion Huggingface Spaces")
file_name = st.file_uploader("Upload a clear full length picture of yourself, preferably in a less noisy background")
net = load_u2net()
inpainting_pipeline = load_inpainting_pipeline()

if file_name is not None:
    image = Image.open(file_name)
    stable_fashion_args = StableFashionCLIArgs()
    stable_fashion_args.image = image
    body_part = st.radio("Would you like to try clothes on your upper body (such as shirts, kurtas etc) or lower (Jeans, Pants etc)? ", ('Upper', 'Lower'))
    stable_fashion_args.part = body_part
    resolution = st.radio("Which resolution would you like to get the resulting picture in? (Keep in mind, higher the resolution, higher the queue times)", (128, 256, 512))
    stable_fashion_args.resolution = resolution
    rembg_status = st.radio("Would you like to remove background in your image before putting new clothes on you? (Sometimes it results in better images)", ("Yes", "No"))
    stable_fashion_args.rembg = (rembg_status == "Yes")
    guidance_scale = st.slider("Select a guidance scale. 7.5 gives the best results.", 1.0, 15.0, value=7.5)
    stable_fashion_args.guidance_scale = guidance_scale
    prompt = st.text_input('Write the description of cloth you want to try', 'a bright yellow t shirt')
    stable_fashion_args.prompt = guidance_scale

    num_steps = st.slider("No. of inference steps for the diffusion process", 5, 50, value=25)


    result_image = process_image(stable_fashion_args, inpainting_pipeline, net)
    st.image(result_image, caption='Sunrise by the mountains')