ysharma HF staff commited on
Commit
420a964
1 Parent(s): 4b1b707

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -28
app.py CHANGED
@@ -3,14 +3,12 @@ from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, Auto
3
  from diffusers.utils import load_image
4
  from transformers import DPTImageProcessor, DPTForDepthEstimation
5
  import torch
6
- import mediapy
7
  import sa_handler
8
  import pipeline_calls
9
 
10
 
11
 
12
- # init models
13
-
14
  depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
15
  feature_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
16
 
@@ -29,9 +27,11 @@ pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
29
  use_safetensors=True,
30
  torch_dtype=torch.float16,
31
  ).to("cuda")
 
32
  pipeline.enable_model_cpu_offload()
33
  pipeline.enable_vae_slicing()
34
 
 
35
  sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
36
  share_layer_norm=False,
37
  share_attention=True,
@@ -43,50 +43,56 @@ handler = sa_handler.Handler(pipeline)
43
  handler.register(sa_args, )
44
 
45
 
46
-
47
-
48
- # run ControlNet depth with StyleAligned
49
  def style_aligned_controlnet(ref_style_prompt, depth_map, ref_image, img_generation_prompt):
50
- if depth_map == True:
51
- image = load_image(ref_image)
52
- depth_image = pipeline_calls.get_depth_map(image, feature_processor, depth_estimator)
53
- else:
54
- depth_image = load_image(ref_image).resize((1024, 1024))
55
- controlnet_conditioning_scale = 0.8
56
- num_images_per_prompt = 3 # adjust according to VRAM size
57
- latents = torch.randn(1 + num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)
58
- latents[1:] = torch.randn(num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)
59
- images = pipeline_calls.controlnet_call(pipeline, [ref_style_prompt, img_generation_prompt],
60
- image=depth_image,
61
- num_inference_steps=50,
62
- controlnet_conditioning_scale=controlnet_conditioning_scale,
63
- num_images_per_prompt=num_images_per_prompt,
64
- latents=latents)
65
- #mediapy.show_images([images[0], depth_image2] + images[1:], titles=["reference", "depth"] + [f'result {i}' for i in range(1, len(images))])
66
- return [images[0], depth_image] + images[1:], gr.Image(value=images[0], visible=True)
67
-
68
-
 
 
69
  with gr.Blocks() as demo:
70
  gr.HTML('<h1 style="text-align: center;">Style-aligned with ControlNet Depth</h1>')
71
  with gr.Row():
72
 
73
  with gr.Column(variant='panel'):
 
74
  ref_style_prompt = gr.Textbox(
75
  label='Reference style prompt',
76
  info="Enter a Prompt to generate the reference image", placeholder='a poster in <style name> style'
77
  )
 
78
  depth_map = gr.Checkbox(label='Depth-map',)
 
79
  ref_style_image = gr.Image(visible=False, label='Reference style image')
80
 
81
  with gr.Column(variant='panel'):
 
82
  ref_image = gr.Image(label="Upload the reference image",
83
  type='filepath' )
 
84
  img_generation_prompt = gr.Textbox(
85
  label='ControlNet Prompt',
86
  info="Enter a Prompt to generate images using ControlNet and Style-aligned",
87
  )
88
-
89
  btn = gr.Button("Generate", size='sm')
 
90
  gallery = gr.Gallery(label="Style-Aligned ControlNet - Generated images",
91
  elem_id="gallery",
92
  columns=5,
@@ -101,7 +107,7 @@ with gr.Blocks() as demo:
101
  api_name="style_aligned_controlnet")
102
 
103
 
104
-
105
  gr.Examples(
106
  examples=[
107
  ['A poster in a papercut art style.', False, 'example_image/A.png', 'Letter A in a papercut art style.'],
@@ -116,5 +122,5 @@ with gr.Blocks() as demo:
116
  fn=style_aligned_controlnet,
117
  )
118
 
119
-
120
  demo.launch()
 
3
  from diffusers.utils import load_image
4
  from transformers import DPTImageProcessor, DPTForDepthEstimation
5
  import torch
 
6
  import sa_handler
7
  import pipeline_calls
8
 
9
 
10
 
11
+ # Initialize models
 
12
  depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
13
  feature_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
14
 
 
27
  use_safetensors=True,
28
  torch_dtype=torch.float16,
29
  ).to("cuda")
30
+ # Configure pipeline for CPU offloading and VAE slicing
31
  pipeline.enable_model_cpu_offload()
32
  pipeline.enable_vae_slicing()
33
 
34
+ # Initialize style-aligned handler
35
  sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
36
  share_layer_norm=False,
37
  share_attention=True,
 
43
  handler.register(sa_args, )
44
 
45
 
46
+ # Function to run ControlNet depth with StyleAligned
 
 
47
  def style_aligned_controlnet(ref_style_prompt, depth_map, ref_image, img_generation_prompt):
48
+ try:
49
+ if depth_map == True:
50
+ image = load_image(ref_image)
51
+ depth_image = pipeline_calls.get_depth_map(image, feature_processor, depth_estimator)
52
+ else:
53
+ depth_image = load_image(ref_image).resize((1024, 1024))
54
+ controlnet_conditioning_scale = 0.8
55
+ num_images_per_prompt = 3 # adjust according to VRAM size
56
+ latents = torch.randn(1 + num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)
57
+ latents[1:] = torch.randn(num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)
58
+ images = pipeline_calls.controlnet_call(pipeline, [ref_style_prompt, img_generation_prompt],
59
+ image=depth_image,
60
+ num_inference_steps=50,
61
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
62
+ num_images_per_prompt=num_images_per_prompt,
63
+ latents=latents)
64
+ return [images[0], depth_image] + images[1:], gr.Image(value=images[0], visible=True)
65
+ except Exception as e:
66
+ raise gr.Error(f"Error in generating images:{e}")
67
+
68
+ # Create a Gradio UI
69
  with gr.Blocks() as demo:
70
  gr.HTML('<h1 style="text-align: center;">Style-aligned with ControlNet Depth</h1>')
71
  with gr.Row():
72
 
73
  with gr.Column(variant='panel'):
74
+ # Textbox for reference style prompt
75
  ref_style_prompt = gr.Textbox(
76
  label='Reference style prompt',
77
  info="Enter a Prompt to generate the reference image", placeholder='a poster in <style name> style'
78
  )
79
+ # Checkbox for using controller depth-map
80
  depth_map = gr.Checkbox(label='Depth-map',)
81
+ # Image display for the generated reference style image
82
  ref_style_image = gr.Image(visible=False, label='Reference style image')
83
 
84
  with gr.Column(variant='panel'):
85
+ # Image upload option for uploading a reference image for controlnet
86
  ref_image = gr.Image(label="Upload the reference image",
87
  type='filepath' )
88
+ # Textbox for ControlNet prompt
89
  img_generation_prompt = gr.Textbox(
90
  label='ControlNet Prompt',
91
  info="Enter a Prompt to generate images using ControlNet and Style-aligned",
92
  )
93
+ # Button to trigger image generation
94
  btn = gr.Button("Generate", size='sm')
95
+ # Gallery to display generated images
96
  gallery = gr.Gallery(label="Style-Aligned ControlNet - Generated images",
97
  elem_id="gallery",
98
  columns=5,
 
107
  api_name="style_aligned_controlnet")
108
 
109
 
110
+ # Example inputs for the Gradio interface
111
  gr.Examples(
112
  examples=[
113
  ['A poster in a papercut art style.', False, 'example_image/A.png', 'Letter A in a papercut art style.'],
 
122
  fn=style_aligned_controlnet,
123
  )
124
 
125
+ # Launch the Gradio demo
126
  demo.launch()