Spaces:
Runtime error
Runtime error
File size: 4,662 Bytes
6724ca0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
from diffusers import StableDiffusionInpaintPipeline
import os
from tqdm import tqdm
from PIL import Image
import numpy as np
import cv2
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from data.base_dataset import Normalize_image
from utils.saving_utils import load_checkpoint_mgpu
from networks import U2NET
import argparse
from enum import Enum
from rembg import remove
class Parts:
UPPER = 1
LOWER = 2
def parse_arguments():
parser = argparse.ArgumentParser(
description="Stable Fashion API, allows you to picture yourself in any cloth your imagination can think of!"
)
parser.add_argument('--image', type=str, required=True, help='path to image')
parser.add_argument('--part', choices=['upper', 'lower'], default='upper', type=str)
parser.add_argument('--resolution', choices=[256, 512, 1024, 2048], default=256, type=int)
parser.add_argument('--prompt', type=str, default="A pink cloth")
parser.add_argument('--num_steps', type=int, default=5)
parser.add_argument('--guidance_scale', type=float, default=7.5)
parser.add_argument('--rembg', action='store_true')
parser.add_argument('--output', default='output.jpg', type=str)
args, _ = parser.parse_known_args()
return args
def load_u2net():
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint_path = os.path.join("trained_checkpoint", "cloth_segm_u2net_latest.pth")
net = U2NET(in_ch=3, out_ch=4)
net = load_checkpoint_mgpu(net, checkpoint_path)
net = net.to(device)
net = net.eval()
return net
def change_bg_color(rgba_image, color):
new_image = Image.new("RGBA", rgba_image.size, color)
new_image.paste(rgba_image, (0, 0), rgba_image)
return new_image.convert("RGB")
def load_inpainting_pipeline():
device = "cuda" if torch.cuda.is_available() else "cpu"
inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
revision="fp16",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
).to(device)
return inpainting_pipeline
def process_image(args, inpainting_pipeline, net):
device = "cuda" if torch.cuda.is_available() else "cpu"
image_path = args.image
transforms_list = []
transforms_list += [transforms.ToTensor()]
transforms_list += [Normalize_image(0.5, 0.5)]
transform_rgb = transforms.Compose(transforms_list)
img = Image.open(image_path)
img = img.convert("RGB")
img = img.resize((args.resolution, args.resolution))
if args.rembg:
img_with_green_bg = remove(img)
img_with_green_bg = change_bg_color(img_with_green_bg, color="GREEN")
img_with_green_bg = img_with_green_bg.convert("RGB")
else:
img_with_green_bg = img
image_tensor = transform_rgb(img_with_green_bg)
image_tensor = image_tensor.unsqueeze(0)
output_tensor = net(image_tensor.to(device))
output_tensor = F.log_softmax(output_tensor[0], dim=1)
output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
output_tensor = torch.squeeze(output_tensor, dim=0)
output_tensor = torch.squeeze(output_tensor, dim=0)
output_arr = output_tensor.cpu().numpy()
mask_code = eval(f"Parts.{args.part.upper()}")
mask = (output_arr == mask_code)
output_arr[mask] = 1
output_arr[~mask] = 0
output_arr *= 255
mask_PIL = Image.fromarray(output_arr.astype("uint8"), mode="L")
clothed_image_from_pipeline = inpainting_pipeline(prompt=args.prompt,
image=img_with_green_bg,
mask_image=mask_PIL,
width=args.resolution,
height=args.resolution,
guidance_scale=args.guidance_scale,
num_inference_steps=args.num_steps).images[0]
clothed_image_from_pipeline = remove(clothed_image_from_pipeline)
clothed_image_from_pipeline = change_bg_color(clothed_image_from_pipeline, "WHITE")
return clothed_image_from_pipeline.convert("RGB")
if __name__ == '__main__':
args = parse_arguments()
net = load_u2net()
inpainting_pipeline = load_inpainting_pipeline()
result_image = process_image(args, inpainting_pipeline, net)
result_image.save(args.output)
|