loooooong commited on
Commit
2524aed
1 Parent(s): 2c587e6

remove all optimization, add safety checker

Browse files
Files changed (1) hide show
  1. app.py +9 -26
app.py CHANGED
@@ -12,12 +12,12 @@ import numpy as np
12
  from torchvision import transforms
13
 
14
  from transformers import CLIPTextModel, CLIPTokenizer
 
15
 
16
  from diffusers import UniPCMultistepScheduler
17
  from diffusers import AutoencoderKL
18
  from diffusers import StableDiffusionPipeline
19
- from diffusers.loaders import LoraLoaderMixin
20
- import intel_extension_for_pytorch as ipex
21
 
22
  from stablegarment.models import GarmentEncoderModel,ControlNetModel
23
  from stablegarment.piplines import StableGarmentPipeline,StableGarmentControlNetPipeline
@@ -38,27 +38,8 @@ garment_encoder = garment_encoder.to(device=device,dtype=torch_dtype)
38
  pipeline_t2i = StableGarmentPipeline.from_pretrained(base_model_path, vae=vae, torch_dtype=torch_dtype, use_safetensors=True,).to(device=device) # variant="fp16"
39
  # pipeline = StableDiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V4.0_noVAE", vae=vae, torch_dtype=torch_dtype).to(device=device)
40
  pipeline_t2i.scheduler = scheduler
41
-
42
- if device=="cpu":
43
- # speed up for cpu
44
- # to channels last
45
- pipeline_t2i.unet = pipeline_t2i.unet.to(memory_format=torch.channels_last)
46
- pipeline_t2i.vae = pipeline_t2i.vae.to(memory_format=torch.channels_last)
47
- pipeline_t2i.text_encoder = pipeline_t2i.text_encoder.to(memory_format=torch.channels_last)
48
- # pipeline_t2i.safety_checker = pipeline_t2i.safety_checker.to(memory_format=torch.channels_last)
49
-
50
- # Create random input to enable JIT compilation
51
- sample = torch.randn(2,4,64,48).type(torch_dtype)
52
- timestep = torch.rand(1)*999
53
- encoder_hidden_status = torch.randn(2,77,768).type(torch_dtype)
54
- input_example = (sample, timestep, encoder_hidden_status)
55
-
56
- # optimize with IPEX
57
- pipeline_t2i.unet = ipex.optimize(pipeline_t2i.unet.eval(), dtype=torch.bfloat16, inplace=True, sample_input=input_example)
58
- pipeline_t2i.vae = ipex.optimize(pipeline_t2i.vae.eval(), dtype=torch.bfloat16, inplace=True)
59
- pipeline_t2i.text_encoder = ipex.optimize(pipeline_t2i.text_encoder.eval(), dtype=torch.bfloat16, inplace=True)
60
- # pipeline_t2i.safety_checker = ipex.optimize(pipeline_t2i.safety_checker.eval(), dtype=torch.bfloat16, inplace=True)
61
-
62
 
63
  pipeline_tryon = None
64
  '''
@@ -77,7 +58,6 @@ pipeline_tryon = StableGarmentControlNetPipeline(
77
  ).to(device=device,dtype=torch_dtype)
78
  '''
79
 
80
-
81
  def prepare_controlnet_inputs(agn_mask_list,densepose_list):
82
  for i,agn_mask_img in enumerate(agn_mask_list):
83
  agn_mask_img = np.array(agn_mask_img.convert("L"))
@@ -101,7 +81,7 @@ def tryon(prompt,init_image,garment_top,garment_down,):
101
  garment_images = [garment_top,]
102
  prompt = [prompt,]
103
  cloth_prompt = ["",]
104
- controlnet_condition = prepare_controlnet_inputs([image_agn_mask],[densepose_image])
105
 
106
  images = pipeline_tryon(prompt, negative_prompt="",cloth_prompt=cloth_prompt, # negative_cloth_prompt = n_prompt,
107
  height=height,width=width,num_inference_steps=25,guidance_scale=1.5,eta=0.0,
@@ -128,7 +108,7 @@ def text2image(prompt,init_image,garment_top,garment_down,style_fidelity=1.):
128
  garment_encoder=garment_encoder,garment_image=garment_images,).images
129
  return images[0]
130
 
131
- # def text2image(prompt,init_image,garment_top,garment_down,):
132
  # return pipeline(prompt).images[0]
133
 
134
  def infer(prompt,init_image,garment_top,garment_down,t2i_only,style_fidelity):
@@ -166,6 +146,8 @@ model = opj(model_dir, "13987_00.jpg")
166
  all_person = [opj(model_dir,fname) for fname in os.listdir(model_dir) if fname.endswith(".jpg")]
167
  with gr.Blocks(css = ".output-image, .input-image, .image-preview {height: 400px !important} ", ) as gradio_app:
168
  gr.Markdown("# StableGarment")
 
 
169
  with gr.Row():
170
  with gr.Column():
171
  init_image = gr.Image(sources='clipboard', type="filepath", label="model", value=None, interactive=False)
@@ -207,6 +189,7 @@ with gr.Blocks(css = ".output-image, .input-image, .image-preview {height: 400px
207
  style_fidelity,
208
  ],
209
  outputs=[gallery],)
 
210
 
211
  if __name__ == "__main__":
212
  gradio_app.launch()
 
12
  from torchvision import transforms
13
 
14
  from transformers import CLIPTextModel, CLIPTokenizer
15
+ from transformers.models.clip.image_processing_clip import CLIPImageProcessor
16
 
17
  from diffusers import UniPCMultistepScheduler
18
  from diffusers import AutoencoderKL
19
  from diffusers import StableDiffusionPipeline
20
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
 
21
 
22
  from stablegarment.models import GarmentEncoderModel,ControlNetModel
23
  from stablegarment.piplines import StableGarmentPipeline,StableGarmentControlNetPipeline
 
38
  pipeline_t2i = StableGarmentPipeline.from_pretrained(base_model_path, vae=vae, torch_dtype=torch_dtype, use_safetensors=True,).to(device=device) # variant="fp16"
39
  # pipeline = StableDiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V4.0_noVAE", vae=vae, torch_dtype=torch_dtype).to(device=device)
40
  pipeline_t2i.scheduler = scheduler
41
+ pipeline_t2i.safety_checker = StableDiffusionSafetyChecker.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch_dtype, subfolder="safety_checker").to(device=device)
42
+ pipeline_t2i.feature_extractor = CLIPImageProcessor.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch_dtype, subfolder="feature_extractor")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  pipeline_tryon = None
45
  '''
 
58
  ).to(device=device,dtype=torch_dtype)
59
  '''
60
 
 
61
  def prepare_controlnet_inputs(agn_mask_list,densepose_list):
62
  for i,agn_mask_img in enumerate(agn_mask_list):
63
  agn_mask_img = np.array(agn_mask_img.convert("L"))
 
81
  garment_images = [garment_top,]
82
  prompt = [prompt,]
83
  cloth_prompt = ["",]
84
+ controlnet_condition = prepare_controlnet_inputs([image_agn_mask],[densepose_image]).type(torch_dtype)
85
 
86
  images = pipeline_tryon(prompt, negative_prompt="",cloth_prompt=cloth_prompt, # negative_cloth_prompt = n_prompt,
87
  height=height,width=width,num_inference_steps=25,guidance_scale=1.5,eta=0.0,
 
108
  garment_encoder=garment_encoder,garment_image=garment_images,).images
109
  return images[0]
110
 
111
+ # def text2image(prompt,init_image,garment_top,garment_down,*args,**kwargs):
112
  # return pipeline(prompt).images[0]
113
 
114
  def infer(prompt,init_image,garment_top,garment_down,t2i_only,style_fidelity):
 
146
  all_person = [opj(model_dir,fname) for fname in os.listdir(model_dir) if fname.endswith(".jpg")]
147
  with gr.Blocks(css = ".output-image, .input-image, .image-preview {height: 400px !important} ", ) as gradio_app:
148
  gr.Markdown("# StableGarment")
149
+ gr.Markdown("Demo for [StableGarment: Garment-Centric Generation via Stable Diffusion](https://arxiv.org/abs/2403.10783).")
150
+ gr.Markdown("*Running on cpu, so it is super slow. Feel free to duplicate the space or visit [StableGarment](https://github.com/logn-2024/StableGarment) for more info.*")
151
  with gr.Row():
152
  with gr.Column():
153
  init_image = gr.Image(sources='clipboard', type="filepath", label="model", value=None, interactive=False)
 
189
  style_fidelity,
190
  ],
191
  outputs=[gallery],)
192
+ gr.Markdown("We borrow some code from [OutfitAnyone](https://huggingface.co/spaces/HumanAIGC/OutfitAnyone), thanks. This demo is not safe for all audiences, which may reflect implicit bias and other defects of base model.")
193
 
194
  if __name__ == "__main__":
195
  gradio_app.launch()