Spaces:
Runtime error
Runtime error
stable garment app
Browse files- app.py +179 -12
- assets/images/garment/00126_00.jpg +0 -0
- assets/images/garment/04743_00.jpg +0 -0
- assets/images/garment/13987_00.jpg +0 -0
- assets/images/image_parse/01163_00_agn.jpg +0 -0
- assets/images/image_parse/01163_00_densepose.png +0 -0
- assets/images/image_parse/01163_00_mask.png +0 -0
- assets/images/image_parse/01827_00_agn.jpg +0 -0
- assets/images/image_parse/01827_00_densepose.png +0 -0
- assets/images/image_parse/01827_00_mask.png +0 -0
- assets/images/image_parse/13987_00_agn.jpg +0 -0
- assets/images/image_parse/13987_00_densepose.png +0 -0
- assets/images/image_parse/13987_00_mask.png +0 -0
- assets/images/model/01163_00.jpg +0 -0
- assets/images/model/01827_00.jpg +0 -0
- assets/images/model/13987_00.jpg +0 -0
- requirements.txt +7 -1
app.py
CHANGED
@@ -1,18 +1,185 @@
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
-
from
|
|
|
|
|
3 |
|
4 |
-
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
inputs=gr.Image(label="Select hot dog candidate", sources=['upload', 'webcam'], type="pil"),
|
13 |
-
outputs=[gr.Image(label="Processed Image"), gr.Label(label="Result", num_top_classes=2)],
|
14 |
-
title="Hot Dog? Or Not?",
|
15 |
-
)
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
if __name__ == "__main__":
|
18 |
-
gradio_app.launch()
|
|
|
1 |
+
# adapted from https://huggingface.co/spaces/HumanAIGC/OutfitAnyone/blob/main/app.py
|
2 |
+
import torch
|
3 |
+
import spaces
|
4 |
import gradio as gr
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
from torchvision import transforms
|
8 |
|
9 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
10 |
|
11 |
+
from diffusers import UniPCMultistepScheduler
|
12 |
+
from diffusers import AutoencoderKL
|
13 |
+
from diffusers import StableDiffusionPipeline
|
14 |
+
from diffusers.loaders import LoraLoaderMixin
|
15 |
|
16 |
+
from stablegarment.models import AppearanceEncoderModel,ControlNetModel
|
17 |
+
from stablegarment.piplines import StableGarmentPipeline,StableGarmentControlNetPipeline
|
|
|
|
|
|
|
|
|
18 |
|
19 |
+
import os
|
20 |
+
from os.path import join as opj
|
21 |
+
|
22 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
23 |
+
torch_dtype = torch.float16 if "cuda"==device else torch.float32
|
24 |
+
height = 512
|
25 |
+
width = 384
|
26 |
+
|
27 |
+
base_model_path = "SG161222/Realistic_Vision_V4.0_noVAE"
|
28 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(dtype=torch_dtype,device=device)
|
29 |
+
scheduler = UniPCMultistepScheduler.from_pretrained("runwayml/stable-diffusion-v1-5",subfolder="scheduler")
|
30 |
+
|
31 |
+
pretrained_garment_encoder_path = "StableGarment_text2img"
|
32 |
+
garment_encoder = AppearanceEncoderModel.from_pretrained(pretrained_garment_encoder_path,torch_dtype=torch_dtype,subfolder="garment_encoder")
|
33 |
+
garment_encoder = garment_encoder.to(device=device,dtype=torch_dtype)
|
34 |
+
|
35 |
+
pipeline_t2i = StableGarmentPipeline.from_pretrained(base_model_path, vae=vae, torch_dtype=torch_dtype, variant="fp16").to(device=device)
|
36 |
+
# pipeline = StableDiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V4.0_noVAE", vae=vae, torch_dtype=torch_dtype, variant="fp16").to(device=device)
|
37 |
+
pipeline_t2i.scheduler = scheduler
|
38 |
+
|
39 |
+
pipeline_tryon = None
|
40 |
+
'''
|
41 |
+
# not ready
|
42 |
+
pretrained_model_path = "part_module_controlnet_imp2"
|
43 |
+
controlnet = ControlNetModel.from_pretrained(pretrained_model_path,subfolder="controlnet")
|
44 |
+
text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder='text_encoder')
|
45 |
+
tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder='tokenizer')
|
46 |
+
pipeline_tryon = StableGarmentControlNetPipeline(
|
47 |
+
vae,
|
48 |
+
text_encoder,
|
49 |
+
tokenizer,
|
50 |
+
pipeline_t2i.unet,
|
51 |
+
controlnet,
|
52 |
+
scheduler,
|
53 |
+
).to(device=device,dtype=torch_dtype)
|
54 |
+
'''
|
55 |
+
|
56 |
+
|
57 |
+
def prepare_controlnet_inputs(agn_mask_list,densepose_list):
|
58 |
+
for i,agn_mask_img in enumerate(agn_mask_list):
|
59 |
+
agn_mask_img = np.array(agn_mask_img.convert("L"))
|
60 |
+
agn_mask_img = np.expand_dims(agn_mask_img, axis=-1)
|
61 |
+
agn_mask_img = (agn_mask_img >= 128).astype(np.float32) # 0 or 1
|
62 |
+
agn_mask_list[i] = 1. - agn_mask_img
|
63 |
+
densepose_list = [np.array(img)/255. for img in densepose_list]
|
64 |
+
controlnet_inputs = []
|
65 |
+
for mask,pose in zip(agn_mask_list,densepose_list):
|
66 |
+
controlnet_inputs.append(torch.tensor(np.concatenate([mask, pose], axis=-1)).permute(2,0,1))
|
67 |
+
controlnet_inputs = torch.stack(controlnet_inputs)
|
68 |
+
return controlnet_inputs
|
69 |
+
|
70 |
+
@spaces.GPU(enable_queue=True)
|
71 |
+
def tryon(prompt,init_image,garment_top,garment_down,):
|
72 |
+
basename = os.path.splitext(os.path.basename(init_image))[0]
|
73 |
+
image_agn = Image.open(opj(parse_dir,basename+"_agn.jpg")).resize((width,height))
|
74 |
+
image_agn_mask = Image.open(opj(parse_dir,basename+"_mask.png")).resize((width,height))
|
75 |
+
densepose_image = Image.open(opj(parse_dir,basename+"_densepose.png")).resize((width,height))
|
76 |
+
garment_top = Image.open(garment_top).resize((width,height))
|
77 |
+
|
78 |
+
garment_images = [garment_top,]
|
79 |
+
prompt = [prompt,]
|
80 |
+
cloth_prompt = ["",]
|
81 |
+
controlnet_condition = prepare_controlnet_inputs([image_agn_mask],[densepose_image])
|
82 |
+
|
83 |
+
images = pipeline_tryon(prompt, negative_prompt="",cloth_prompt=cloth_prompt, # negative_cloth_prompt = n_prompt,
|
84 |
+
height=height,width=width,num_inference_steps=25,guidance_scale=1.5,eta=0.0,
|
85 |
+
controlnet_condition=controlnet_condition,reference_image=garment_images,
|
86 |
+
garment_encoder=garment_encoder,condition_extra=image_agn,
|
87 |
+
generator=None,).images
|
88 |
+
return images[0]
|
89 |
+
|
90 |
+
@spaces.GPU(enable_queue=True)
|
91 |
+
def text2image(prompt,init_image,garment_top,garment_down,style_fidelity=1.):
|
92 |
+
|
93 |
+
garment_top = Image.open(garment_top).resize((width,height))
|
94 |
+
garment_top = transforms.CenterCrop((height,width))(transforms.Resize(max(height, width))(garment_top))
|
95 |
+
|
96 |
+
garment_images = [garment_top,]
|
97 |
+
prompt = [prompt,]
|
98 |
+
cloth_prompt = ["",]
|
99 |
+
n_prompt = "nsfw, unsaturated, abnormal, unnatural, artifact"
|
100 |
+
negative_prompt = [n_prompt]
|
101 |
+
images = pipeline_t2i(prompt,negative_prompt=negative_prompt,cloth_prompt=cloth_prompt,height=height,width=width,
|
102 |
+
num_inference_steps=30,guidance_scale=4,num_images_per_prompt=1,style_fidelity=style_fidelity,
|
103 |
+
garment_encoder=garment_encoder,garment_image=garment_images,).images
|
104 |
+
return images[0]
|
105 |
+
|
106 |
+
# def text2image(prompt,init_image,garment_top,garment_down,):
|
107 |
+
# return pipeline(prompt).images[0]
|
108 |
+
|
109 |
+
def infer(prompt,init_image,garment_top,garment_down,t2i_only,style_fidelity):
|
110 |
+
if t2i_only:
|
111 |
+
return text2image(prompt,init_image,garment_top,garment_down,style_fidelity)
|
112 |
+
else:
|
113 |
+
return tryon(prompt,init_image,garment_top,garment_down)
|
114 |
+
|
115 |
+
init_state,prompt_state = None,""
|
116 |
+
t2i_only_state = True
|
117 |
+
def set_mode(t2i_only,person_condition,prompt):
|
118 |
+
global init_state, prompt_state, t2i_only_state
|
119 |
+
t2i_only_state = not t2i_only_state
|
120 |
+
init_state, prompt_state = person_condition or init_state, prompt_state or prompt
|
121 |
+
if t2i_only:
|
122 |
+
return [gr.Image(sources='clipboard', type="filepath", label="model",value=None, interactive=False),
|
123 |
+
gr.Textbox(placeholder="", label="prompt(for t2i)", value=prompt_state, interactive=True),
|
124 |
+
]
|
125 |
+
else:
|
126 |
+
return [gr.Image(sources='clipboard', type="filepath", label="model",value=init_state, interactive=False),
|
127 |
+
gr.Textbox(placeholder="", label="prompt(for t2i)", value="", interactive=False),
|
128 |
+
]
|
129 |
+
|
130 |
+
def example_fn(inputs,):
|
131 |
+
if t2i_only_state:
|
132 |
+
return gr.Image(sources='clipboard', type="filepath", label="model", value=None, interactive=False)
|
133 |
+
return gr.Image(sources='clipboard', type="filepath", label="model",value=inputs, interactive=False)
|
134 |
+
|
135 |
+
gr.set_static_paths(paths=["assets/images/model"])
|
136 |
+
model_dir = opj(os.path.dirname(__file__), "assets/images/model")
|
137 |
+
garment_dir = opj(os.path.dirname(__file__), "assets/images/garment")
|
138 |
+
parse_dir = opj(os.path.dirname(__file__), "assets/images/image_parse")
|
139 |
+
|
140 |
+
model = opj(model_dir, "13987_00.jpg")
|
141 |
+
all_person = [opj(model_dir,fname) for fname in os.listdir(model_dir) if fname.endswith(".jpg")]
|
142 |
+
with gr.Blocks(css = ".output-image, .input-image, .image-preview {height: 400px !important} ", ) as gradio_app:
|
143 |
+
gr.Markdown("# StableGarment")
|
144 |
+
with gr.Row():
|
145 |
+
with gr.Column():
|
146 |
+
init_image = gr.Image(sources='clipboard', type="filepath", label="model", value=None, interactive=False)
|
147 |
+
example = gr.Examples(inputs=gr.Image(visible=False), #init_image,
|
148 |
+
examples_per_page=4,
|
149 |
+
examples=all_person,
|
150 |
+
run_on_click=True,
|
151 |
+
outputs=init_image,
|
152 |
+
fn=example_fn,)
|
153 |
+
with gr.Column():
|
154 |
+
with gr.Row():
|
155 |
+
images_top = [opj(garment_dir,fname) for fname in os.listdir(garment_dir) if fname.endswith(".jpg")]
|
156 |
+
garment_top = gr.Image(sources='upload', type="filepath", label="top garment",value=images_top[0]) # ,interactive=False
|
157 |
+
example_top = gr.Examples(inputs=garment_top,
|
158 |
+
examples_per_page=4,
|
159 |
+
examples=images_top)
|
160 |
+
images_down = []
|
161 |
+
garment_down = gr.Image(sources='upload', type="filepath", label="lower garment",interactive=False, visible=False)
|
162 |
+
example_down = gr.Examples(inputs=garment_down,
|
163 |
+
examples_per_page=4,
|
164 |
+
examples=images_down)
|
165 |
+
prompt = gr.Textbox(placeholder="", label="prompt(for t2i)",) # interactive=False
|
166 |
+
with gr.Row():
|
167 |
+
t2i_only = gr.Checkbox(label="t2i with garment", info="Only text and garment.", elem_id="t2i_switch", value=True, interactive=False,)
|
168 |
+
run_button = gr.Button(value="Run")
|
169 |
+
style_fidelity = gr.Slider(0, 1, value=1, label="fidelity(for t2i)") # , info=""
|
170 |
+
t2i_only.change(fn=set_mode,inputs=[t2i_only,init_image,prompt],outputs=[init_image,prompt,])
|
171 |
+
with gr.Column():
|
172 |
+
gallery = gr.Image()
|
173 |
+
run_button.click(fn=infer,
|
174 |
+
inputs=[
|
175 |
+
prompt,
|
176 |
+
init_image,
|
177 |
+
garment_top,
|
178 |
+
garment_down,
|
179 |
+
t2i_only,
|
180 |
+
style_fidelity,
|
181 |
+
],
|
182 |
+
outputs=[gallery],)
|
183 |
+
|
184 |
if __name__ == "__main__":
|
185 |
+
gradio_app.launch()
|
assets/images/garment/00126_00.jpg
ADDED
assets/images/garment/04743_00.jpg
ADDED
assets/images/garment/13987_00.jpg
ADDED
assets/images/image_parse/01163_00_agn.jpg
ADDED
assets/images/image_parse/01163_00_densepose.png
ADDED
assets/images/image_parse/01163_00_mask.png
ADDED
assets/images/image_parse/01827_00_agn.jpg
ADDED
assets/images/image_parse/01827_00_densepose.png
ADDED
assets/images/image_parse/01827_00_mask.png
ADDED
assets/images/image_parse/13987_00_agn.jpg
ADDED
assets/images/image_parse/13987_00_densepose.png
ADDED
assets/images/image_parse/13987_00_mask.png
ADDED
assets/images/model/01163_00.jpg
ADDED
assets/images/model/01827_00.jpg
ADDED
assets/images/model/13987_00.jpg
ADDED
requirements.txt
CHANGED
@@ -1,2 +1,8 @@
|
|
1 |
transformers
|
2 |
-
torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
transformers
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
+
diffusers
|
5 |
+
spaces
|
6 |
+
Pillow
|
7 |
+
numpy
|
8 |
+
git+https://$ACCESS_TOKEN@github.com/logn-2024/StableGarment.git
|