ysharma HF staff commited on
Commit
5767af2
1 Parent(s): 1a530c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -21
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  from diffusers import StableDiffusionXLPipeline, DDIMScheduler
3
  import torch
4
- import mediapy
5
  import sa_handler
6
 
7
  # init models
@@ -11,10 +10,10 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
11
  "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True,
12
  scheduler=scheduler
13
  ).to("cuda")
14
- #pipeline.enable_sequential_cpu_offload()
15
  pipeline.enable_model_cpu_offload()
16
  pipeline.enable_vae_slicing()
17
-
18
  handler = sa_handler.Handler(pipeline)
19
  sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
20
  share_layer_norm=False,
@@ -26,43 +25,43 @@ sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
26
 
27
  handler.register(sa_args, )
28
 
29
- # example of set of prompts
30
- sets_of_prompts = [
31
- "a toy train. macro photo. 3d game asset",
32
- "a toy airplane. macro photo. 3d game asset",
33
- "a toy bicycle. macro photo. 3d game asset",
34
- "a toy car. macro photo. 3d game asset",
35
- "a toy boat. macro photo. 3d game asset",
36
- ]
37
-
38
- # run StyleAligned
39
  def style_aligned_sdxl(initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt):
40
- sets_of_prompts = [ prompt + ". " + style_prompt for prompt in [initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5,]]
41
- images = pipeline(sets_of_prompts,).images
42
- #mediapy.show_images(images)
43
- print(images)
44
- return images
 
 
 
45
 
46
  with gr.Blocks() as demo:
47
  with gr.Group():
48
  with gr.Column():
49
  with gr.Accordion(label='Enter upto 5 different initial prompts', open=True):
50
  with gr.Row(variant='panel'):
 
51
  initial_prompt1 = gr.Textbox(label='Initial prompt 1', value='', show_label=False, container=False, placeholder='a toy train')
52
  initial_prompt2 = gr.Textbox(label='Initial prompt 2', value='', show_label=False, container=False, placeholder='a toy airplane')
53
  initial_prompt3 = gr.Textbox(label='Initial prompt 3', value='', show_label=False, container=False, placeholder='a toy bicycle')
54
  initial_prompt4 = gr.Textbox(label='Initial prompt 4', value='', show_label=False, container=False, placeholder='a toy car')
55
  initial_prompt5 = gr.Textbox(label='Initial prompt 5', value='', show_label=False, container=False, placeholder='a toy boat')
56
  with gr.Row():
 
57
  style_prompt = gr.Textbox(label="Enter a style prompt", placeholder='macro photo, 3d game asset')
 
58
  btn = gr.Button("Generate a set of Style-aligned SDXL images",)
 
59
  output = gr.Gallery(label="Style-Aligned SDXL Images", elem_id="gallery",columns=5, rows=1, object_fit="contain", height="auto",)
60
-
 
61
  btn.click(fn=style_aligned_sdxl,
62
  inputs=[initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt],
63
  outputs=output,
64
  api_name="style_aligned_sdxl")
65
 
 
66
  gr.Examples(examples=[
67
  ["a toy train", "a toy airplane", "a toy bicycle", "a toy car", "a toy boat", "macro photo. 3d game asset."],
68
  ["a toy train", "a toy airplane", "a toy bicycle", "a toy car", "a toy boat", "BW logo. high contrast."],
@@ -74,5 +73,5 @@ with gr.Blocks() as demo:
74
  outputs=[output],
75
  fn=style_aligned_sdxl)
76
 
77
- demo.launch()
78
-
 
1
  import gradio as gr
2
  from diffusers import StableDiffusionXLPipeline, DDIMScheduler
3
  import torch
 
4
  import sa_handler
5
 
6
  # init models
 
10
  "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True,
11
  scheduler=scheduler
12
  ).to("cuda")
13
+ # Configure the pipeline for CPU offloading and VAE slicing#pipeline.enable_sequential_cpu_offload()
14
  pipeline.enable_model_cpu_offload()
15
  pipeline.enable_vae_slicing()
16
+ # Initialize the style-aligned handler
17
  handler = sa_handler.Handler(pipeline)
18
  sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
19
  share_layer_norm=False,
 
25
 
26
  handler.register(sa_args, )
27
 
28
+ # Define the function to generate style-aligned images
 
 
 
 
 
 
 
 
 
29
  def style_aligned_sdxl(initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt):
30
+ try:
31
+ # Combine the style prompt with each initial prompt
32
+ sets_of_prompts = [ prompt + ". " + style_prompt for prompt in [initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5,]]
33
+ # Generate images using the pipeline
34
+ images = pipeline(sets_of_prompts,).images
35
+ return images
36
+ except Exception as e:
37
+ raise gr.Error(f"Error in generating images: {e}")
38
 
39
  with gr.Blocks() as demo:
40
  with gr.Group():
41
  with gr.Column():
42
  with gr.Accordion(label='Enter upto 5 different initial prompts', open=True):
43
  with gr.Row(variant='panel'):
44
+ # Textboxes for initial prompts
45
  initial_prompt1 = gr.Textbox(label='Initial prompt 1', value='', show_label=False, container=False, placeholder='a toy train')
46
  initial_prompt2 = gr.Textbox(label='Initial prompt 2', value='', show_label=False, container=False, placeholder='a toy airplane')
47
  initial_prompt3 = gr.Textbox(label='Initial prompt 3', value='', show_label=False, container=False, placeholder='a toy bicycle')
48
  initial_prompt4 = gr.Textbox(label='Initial prompt 4', value='', show_label=False, container=False, placeholder='a toy car')
49
  initial_prompt5 = gr.Textbox(label='Initial prompt 5', value='', show_label=False, container=False, placeholder='a toy boat')
50
  with gr.Row():
51
+ # Textbox for the style prompt
52
  style_prompt = gr.Textbox(label="Enter a style prompt", placeholder='macro photo, 3d game asset')
53
+ # Button to generate images
54
  btn = gr.Button("Generate a set of Style-aligned SDXL images",)
55
+ # Display the generated images
56
  output = gr.Gallery(label="Style-Aligned SDXL Images", elem_id="gallery",columns=5, rows=1, object_fit="contain", height="auto",)
57
+
58
+ # Button click event
59
  btn.click(fn=style_aligned_sdxl,
60
  inputs=[initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt],
61
  outputs=output,
62
  api_name="style_aligned_sdxl")
63
 
64
+ # Providing Example inputs for the demo
65
  gr.Examples(examples=[
66
  ["a toy train", "a toy airplane", "a toy bicycle", "a toy car", "a toy boat", "macro photo. 3d game asset."],
67
  ["a toy train", "a toy airplane", "a toy bicycle", "a toy car", "a toy boat", "BW logo. high contrast."],
 
73
  outputs=[output],
74
  fn=style_aligned_sdxl)
75
 
76
+ # Launch the Gradio demo
77
+ demo.launch()