loooooong commited on
Commit
60e529d
1 Parent(s): 22d51f0

stable garment app

Browse files
app.py CHANGED
@@ -1,18 +1,185 @@
 
 
 
1
  import gradio as gr
2
- from transformers import pipeline
 
 
3
 
4
- pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
5
 
6
- def predict(input_img):
7
- predictions = pipeline(input_img)
8
- return input_img, {p["label"]: p["score"] for p in predictions}
 
9
 
10
- gradio_app = gr.Interface(
11
- predict,
12
- inputs=gr.Image(label="Select hot dog candidate", sources=['upload', 'webcam'], type="pil"),
13
- outputs=[gr.Image(label="Processed Image"), gr.Label(label="Result", num_top_classes=2)],
14
- title="Hot Dog? Or Not?",
15
- )
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  if __name__ == "__main__":
18
- gradio_app.launch()
 
1
+ # adapted from https://huggingface.co/spaces/HumanAIGC/OutfitAnyone/blob/main/app.py
2
+ import torch
3
+ import spaces
4
  import gradio as gr
5
+ from PIL import Image
6
+ import numpy as np
7
+ from torchvision import transforms
8
 
9
+ from transformers import CLIPTextModel, CLIPTokenizer
10
 
11
+ from diffusers import UniPCMultistepScheduler
12
+ from diffusers import AutoencoderKL
13
+ from diffusers import StableDiffusionPipeline
14
+ from diffusers.loaders import LoraLoaderMixin
15
 
16
+ from stablegarment.models import AppearanceEncoderModel,ControlNetModel
17
+ from stablegarment.piplines import StableGarmentPipeline,StableGarmentControlNetPipeline
 
 
 
 
18
 
19
+ import os
20
+ from os.path import join as opj
21
+
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ torch_dtype = torch.float16 if "cuda"==device else torch.float32
24
+ height = 512
25
+ width = 384
26
+
27
+ base_model_path = "SG161222/Realistic_Vision_V4.0_noVAE"
28
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(dtype=torch_dtype,device=device)
29
+ scheduler = UniPCMultistepScheduler.from_pretrained("runwayml/stable-diffusion-v1-5",subfolder="scheduler")
30
+
31
+ pretrained_garment_encoder_path = "StableGarment_text2img"
32
+ garment_encoder = AppearanceEncoderModel.from_pretrained(pretrained_garment_encoder_path,torch_dtype=torch_dtype,subfolder="garment_encoder")
33
+ garment_encoder = garment_encoder.to(device=device,dtype=torch_dtype)
34
+
35
+ pipeline_t2i = StableGarmentPipeline.from_pretrained(base_model_path, vae=vae, torch_dtype=torch_dtype, variant="fp16").to(device=device)
36
+ # pipeline = StableDiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V4.0_noVAE", vae=vae, torch_dtype=torch_dtype, variant="fp16").to(device=device)
37
+ pipeline_t2i.scheduler = scheduler
38
+
39
+ pipeline_tryon = None
40
+ '''
41
+ # not ready
42
+ pretrained_model_path = "part_module_controlnet_imp2"
43
+ controlnet = ControlNetModel.from_pretrained(pretrained_model_path,subfolder="controlnet")
44
+ text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder='text_encoder')
45
+ tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder='tokenizer')
46
+ pipeline_tryon = StableGarmentControlNetPipeline(
47
+ vae,
48
+ text_encoder,
49
+ tokenizer,
50
+ pipeline_t2i.unet,
51
+ controlnet,
52
+ scheduler,
53
+ ).to(device=device,dtype=torch_dtype)
54
+ '''
55
+
56
+
57
+ def prepare_controlnet_inputs(agn_mask_list,densepose_list):
58
+ for i,agn_mask_img in enumerate(agn_mask_list):
59
+ agn_mask_img = np.array(agn_mask_img.convert("L"))
60
+ agn_mask_img = np.expand_dims(agn_mask_img, axis=-1)
61
+ agn_mask_img = (agn_mask_img >= 128).astype(np.float32) # 0 or 1
62
+ agn_mask_list[i] = 1. - agn_mask_img
63
+ densepose_list = [np.array(img)/255. for img in densepose_list]
64
+ controlnet_inputs = []
65
+ for mask,pose in zip(agn_mask_list,densepose_list):
66
+ controlnet_inputs.append(torch.tensor(np.concatenate([mask, pose], axis=-1)).permute(2,0,1))
67
+ controlnet_inputs = torch.stack(controlnet_inputs)
68
+ return controlnet_inputs
69
+
70
+ @spaces.GPU(enable_queue=True)
71
+ def tryon(prompt,init_image,garment_top,garment_down,):
72
+ basename = os.path.splitext(os.path.basename(init_image))[0]
73
+ image_agn = Image.open(opj(parse_dir,basename+"_agn.jpg")).resize((width,height))
74
+ image_agn_mask = Image.open(opj(parse_dir,basename+"_mask.png")).resize((width,height))
75
+ densepose_image = Image.open(opj(parse_dir,basename+"_densepose.png")).resize((width,height))
76
+ garment_top = Image.open(garment_top).resize((width,height))
77
+
78
+ garment_images = [garment_top,]
79
+ prompt = [prompt,]
80
+ cloth_prompt = ["",]
81
+ controlnet_condition = prepare_controlnet_inputs([image_agn_mask],[densepose_image])
82
+
83
+ images = pipeline_tryon(prompt, negative_prompt="",cloth_prompt=cloth_prompt, # negative_cloth_prompt = n_prompt,
84
+ height=height,width=width,num_inference_steps=25,guidance_scale=1.5,eta=0.0,
85
+ controlnet_condition=controlnet_condition,reference_image=garment_images,
86
+ garment_encoder=garment_encoder,condition_extra=image_agn,
87
+ generator=None,).images
88
+ return images[0]
89
+
90
+ @spaces.GPU(enable_queue=True)
91
+ def text2image(prompt,init_image,garment_top,garment_down,style_fidelity=1.):
92
+
93
+ garment_top = Image.open(garment_top).resize((width,height))
94
+ garment_top = transforms.CenterCrop((height,width))(transforms.Resize(max(height, width))(garment_top))
95
+
96
+ garment_images = [garment_top,]
97
+ prompt = [prompt,]
98
+ cloth_prompt = ["",]
99
+ n_prompt = "nsfw, unsaturated, abnormal, unnatural, artifact"
100
+ negative_prompt = [n_prompt]
101
+ images = pipeline_t2i(prompt,negative_prompt=negative_prompt,cloth_prompt=cloth_prompt,height=height,width=width,
102
+ num_inference_steps=30,guidance_scale=4,num_images_per_prompt=1,style_fidelity=style_fidelity,
103
+ garment_encoder=garment_encoder,garment_image=garment_images,).images
104
+ return images[0]
105
+
106
+ # def text2image(prompt,init_image,garment_top,garment_down,):
107
+ # return pipeline(prompt).images[0]
108
+
109
+ def infer(prompt,init_image,garment_top,garment_down,t2i_only,style_fidelity):
110
+ if t2i_only:
111
+ return text2image(prompt,init_image,garment_top,garment_down,style_fidelity)
112
+ else:
113
+ return tryon(prompt,init_image,garment_top,garment_down)
114
+
115
+ init_state,prompt_state = None,""
116
+ t2i_only_state = True
117
+ def set_mode(t2i_only,person_condition,prompt):
118
+ global init_state, prompt_state, t2i_only_state
119
+ t2i_only_state = not t2i_only_state
120
+ init_state, prompt_state = person_condition or init_state, prompt_state or prompt
121
+ if t2i_only:
122
+ return [gr.Image(sources='clipboard', type="filepath", label="model",value=None, interactive=False),
123
+ gr.Textbox(placeholder="", label="prompt(for t2i)", value=prompt_state, interactive=True),
124
+ ]
125
+ else:
126
+ return [gr.Image(sources='clipboard', type="filepath", label="model",value=init_state, interactive=False),
127
+ gr.Textbox(placeholder="", label="prompt(for t2i)", value="", interactive=False),
128
+ ]
129
+
130
+ def example_fn(inputs,):
131
+ if t2i_only_state:
132
+ return gr.Image(sources='clipboard', type="filepath", label="model", value=None, interactive=False)
133
+ return gr.Image(sources='clipboard', type="filepath", label="model",value=inputs, interactive=False)
134
+
135
+ gr.set_static_paths(paths=["assets/images/model"])
136
+ model_dir = opj(os.path.dirname(__file__), "assets/images/model")
137
+ garment_dir = opj(os.path.dirname(__file__), "assets/images/garment")
138
+ parse_dir = opj(os.path.dirname(__file__), "assets/images/image_parse")
139
+
140
+ model = opj(model_dir, "13987_00.jpg")
141
+ all_person = [opj(model_dir,fname) for fname in os.listdir(model_dir) if fname.endswith(".jpg")]
142
+ with gr.Blocks(css = ".output-image, .input-image, .image-preview {height: 400px !important} ", ) as gradio_app:
143
+ gr.Markdown("# StableGarment")
144
+ with gr.Row():
145
+ with gr.Column():
146
+ init_image = gr.Image(sources='clipboard', type="filepath", label="model", value=None, interactive=False)
147
+ example = gr.Examples(inputs=gr.Image(visible=False), #init_image,
148
+ examples_per_page=4,
149
+ examples=all_person,
150
+ run_on_click=True,
151
+ outputs=init_image,
152
+ fn=example_fn,)
153
+ with gr.Column():
154
+ with gr.Row():
155
+ images_top = [opj(garment_dir,fname) for fname in os.listdir(garment_dir) if fname.endswith(".jpg")]
156
+ garment_top = gr.Image(sources='upload', type="filepath", label="top garment",value=images_top[0]) # ,interactive=False
157
+ example_top = gr.Examples(inputs=garment_top,
158
+ examples_per_page=4,
159
+ examples=images_top)
160
+ images_down = []
161
+ garment_down = gr.Image(sources='upload', type="filepath", label="lower garment",interactive=False, visible=False)
162
+ example_down = gr.Examples(inputs=garment_down,
163
+ examples_per_page=4,
164
+ examples=images_down)
165
+ prompt = gr.Textbox(placeholder="", label="prompt(for t2i)",) # interactive=False
166
+ with gr.Row():
167
+ t2i_only = gr.Checkbox(label="t2i with garment", info="Only text and garment.", elem_id="t2i_switch", value=True, interactive=False,)
168
+ run_button = gr.Button(value="Run")
169
+ style_fidelity = gr.Slider(0, 1, value=1, label="fidelity(for t2i)") # , info=""
170
+ t2i_only.change(fn=set_mode,inputs=[t2i_only,init_image,prompt],outputs=[init_image,prompt,])
171
+ with gr.Column():
172
+ gallery = gr.Image()
173
+ run_button.click(fn=infer,
174
+ inputs=[
175
+ prompt,
176
+ init_image,
177
+ garment_top,
178
+ garment_down,
179
+ t2i_only,
180
+ style_fidelity,
181
+ ],
182
+ outputs=[gallery],)
183
+
184
  if __name__ == "__main__":
185
+ gradio_app.launch()
assets/images/garment/00126_00.jpg ADDED
assets/images/garment/04743_00.jpg ADDED
assets/images/garment/13987_00.jpg ADDED
assets/images/image_parse/01163_00_agn.jpg ADDED
assets/images/image_parse/01163_00_densepose.png ADDED
assets/images/image_parse/01163_00_mask.png ADDED
assets/images/image_parse/01827_00_agn.jpg ADDED
assets/images/image_parse/01827_00_densepose.png ADDED
assets/images/image_parse/01827_00_mask.png ADDED
assets/images/image_parse/13987_00_agn.jpg ADDED
assets/images/image_parse/13987_00_densepose.png ADDED
assets/images/image_parse/13987_00_mask.png ADDED
assets/images/model/01163_00.jpg ADDED
assets/images/model/01827_00.jpg ADDED
assets/images/model/13987_00.jpg ADDED
requirements.txt CHANGED
@@ -1,2 +1,8 @@
1
  transformers
2
- torch
 
 
 
 
 
 
 
1
  transformers
2
+ torch
3
+ torchvision
4
+ diffusers
5
+ spaces
6
+ Pillow
7
+ numpy
8
+ git+https://$ACCESS_TOKEN@github.com/logn-2024/StableGarment.git