ysharma HF staff commited on
Commit
6d9201f
·
1 Parent(s): 2783d9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -17
app.py CHANGED
@@ -1,17 +1,17 @@
1
  import gradio as gr
2
  import torch
3
  from diffusers import StableDiffusionPanoramaPipeline, DDIMScheduler
4
- import mediapy
5
  import sa_handler
6
  import pipeline_calls
7
 
8
 
9
  # init models
10
- model_ckpt = "stabilityai/stable-diffusion-2-base"
11
  scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
12
  pipeline = StableDiffusionPanoramaPipeline.from_pretrained(
13
  model_ckpt, scheduler=scheduler, torch_dtype=torch.float16
14
  ).to("cuda")
 
15
  pipeline.enable_model_cpu_offload()
16
  pipeline.enable_vae_slicing()
17
  sa_args = sa_handler.StyleAlignedArgs(share_group_norm=True,
@@ -21,42 +21,50 @@ sa_args = sa_handler.StyleAlignedArgs(share_group_norm=True,
21
  adain_keys=True,
22
  adain_values=False,
23
  )
 
24
  handler = sa_handler.Handler(pipeline)
25
  handler.register(sa_args)
26
 
27
 
28
- # run MultiDiffusion with StyleAligned
29
  def style_aligned_multidiff(ref_style_prompt, img_generation_prompt):
30
- view_batch_size = 25 # adjust according to VRAM size
31
- reference_latent = torch.randn(1, 4, 64, 64,)
32
- images = pipeline_calls.panorama_call(pipeline,
33
- [ref_style_prompt, img_generation_prompt],
34
- reference_latent=reference_latent,
35
- view_batch_size=view_batch_size)
36
-
37
- return images, gr.Image(value=images[0], visible=True)
38
-
 
 
39
 
 
40
  with gr.Blocks() as demo:
 
41
  with gr.Row():
42
  with gr.Column(variant='panel'):
 
43
  ref_style_prompt = gr.Textbox(
44
  label='Reference style prompt',
45
  info='Enter a Prompt to generate the reference image',
46
  placeholder='A poster in a papercut art style.'
47
  )
 
48
  ref_style_image = gr.Image(visible=False, label='Reference style image')
49
 
50
  with gr.Column(variant='panel'):
 
51
  img_generation_prompt = gr.Textbox(
52
  label='MultiDiffusion Prompt',
53
- info='Enter a Prompt to generate panaromic images using Style-aligned combined with MultiDiffusion',
54
  placeholder= 'A village in a papercut art style.'
55
  )
56
 
57
-
58
  btn = gr.Button('Style-aligned MultiDiffusion - Generate', size='sm')
59
-
60
  gallery = gr.Gallery(label='Style-Aligned ControlNet - Generated images',
61
  elem_id='gallery',
62
  columns=5,
@@ -66,13 +74,13 @@ with gr.Blocks() as demo:
66
  allow_preview=True,
67
  preview=True,
68
  )
69
-
70
  btn.click(fn=style_aligned_multidiff,
71
  inputs=[ref_style_prompt, img_generation_prompt],
72
  outputs=[gallery, ref_style_image],
73
  api_name='style_aligned_multidiffusion')
74
 
75
-
76
  gr.Examples(
77
  examples=[
78
  ['A poster in a papercut art style.', 'A village in a papercut art style.'],
@@ -87,4 +95,5 @@ with gr.Blocks() as demo:
87
  fn=style_aligned_multidiff,
88
  )
89
 
 
90
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
  from diffusers import StableDiffusionPanoramaPipeline, DDIMScheduler
 
4
  import sa_handler
5
  import pipeline_calls
6
 
7
 
8
  # init models
9
+ model_ckpt = "stability/stable-diffusion-2-base"
10
  scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
11
  pipeline = StableDiffusionPanoramaPipeline.from_pretrained(
12
  model_ckpt, scheduler=scheduler, torch_dtype=torch.float16
13
  ).to("cuda")
14
+ # Configure the pipeline for CPU offloading and VAE slicing
15
  pipeline.enable_model_cpu_offload()
16
  pipeline.enable_vae_slicing()
17
  sa_args = sa_handler.StyleAlignedArgs(share_group_norm=True,
 
21
  adain_keys=True,
22
  adain_values=False,
23
  )
24
+ # Initialize the style-aligned handler
25
  handler = sa_handler.Handler(pipeline)
26
  handler.register(sa_args)
27
 
28
 
29
+ # Define the function to run MultiDiffusion with StyleAligned
30
  def style_aligned_multidiff(ref_style_prompt, img_generation_prompt):
31
+ try:
32
+ view_batch_size = 25 # adjust according to VRAM size
33
+ reference_latent = torch.randn(1, 4, 64, 64,)
34
+ images = pipeline_calls.panorama_call(pipeline,
35
+ [ref_style_prompt, img_generation_prompt],
36
+ reference_latent=reference_latent,
37
+ view_batch_size=view_batch_size)
38
+
39
+ return images, gr.Image(value=images[0], visible=True)
40
+ except Exception as e:
41
+ raise gr.Error(f"Error in generating images:{e}")
42
 
43
+ # Create a Gradio UI
44
  with gr.Blocks() as demo:
45
+ gr.HTML('<h1 style="text-align: center;">Style-aligned with MultiDiffusion</h1>')
46
  with gr.Row():
47
  with gr.Column(variant='panel'):
48
+ # Textbox for reference style prompt
49
  ref_style_prompt = gr.Textbox(
50
  label='Reference style prompt',
51
  info='Enter a Prompt to generate the reference image',
52
  placeholder='A poster in a papercut art style.'
53
  )
54
+ # Image display for the reference style image
55
  ref_style_image = gr.Image(visible=False, label='Reference style image')
56
 
57
  with gr.Column(variant='panel'):
58
+ # Textbox for prompt for MultiDiffusion panoramas
59
  img_generation_prompt = gr.Textbox(
60
  label='MultiDiffusion Prompt',
61
+ info='Enter a Prompt to generate panoramic images using Style-aligned combined with MultiDiffusion',
62
  placeholder= 'A village in a papercut art style.'
63
  )
64
 
65
+ # Button to trigger image generation
66
  btn = gr.Button('Style-aligned MultiDiffusion - Generate', size='sm')
67
+ # Gallery to display generated style image and the panorama
68
  gallery = gr.Gallery(label='Style-Aligned ControlNet - Generated images',
69
  elem_id='gallery',
70
  columns=5,
 
74
  allow_preview=True,
75
  preview=True,
76
  )
77
+ # Button click event
78
  btn.click(fn=style_aligned_multidiff,
79
  inputs=[ref_style_prompt, img_generation_prompt],
80
  outputs=[gallery, ref_style_image],
81
  api_name='style_aligned_multidiffusion')
82
 
83
+ # Example inputs for the Gradio demo
84
  gr.Examples(
85
  examples=[
86
  ['A poster in a papercut art style.', 'A village in a papercut art style.'],
 
95
  fn=style_aligned_multidiff,
96
  )
97
 
98
+ # Launch the Gradio demo
99
  demo.launch()