Tonic commited on
Commit
53ad954
β€’
1 Parent(s): 1f46380

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -18
app.py CHANGED
@@ -2,15 +2,10 @@ import gradio as gr
2
  from diffusers import StableDiffusionXLPipeline, DDIMScheduler
3
  import torch
4
  import numpy as np
5
- from diffusers.utils import load_image
6
  from PIL import Image
7
  import io
8
  import sys
9
  import os
10
- current_dir = os.path.dirname(os.path.realpath(__file__))
11
- if current_dir not in sys.path:
12
- sys.path.append(current_dir)
13
-
14
  import sa_handler
15
  import inversion
16
 
@@ -26,10 +21,10 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
26
  ).to("cuda")
27
 
28
  # Function to process the image
29
- def process_image(image, prompt, style):
30
- src_prompt = f'Man laying in a bed, {style}.'
31
-
32
- num_inference_steps = 50
33
  x0 = np.array(Image.fromarray(image).resize((1024, 1024)))
34
  zts = inversion.ddim_inversion(pipeline, x0, src_prompt, num_inference_steps, 2)
35
 
@@ -38,9 +33,6 @@ def process_image(image, prompt, style):
38
  f"{prompt}, {style}."
39
  ]
40
 
41
- shared_score_shift = np.log(2)
42
- shared_score_scale = 1.0
43
-
44
  handler = sa_handler.Handler(pipeline)
45
  sa_args = sa_handler.StyleAlignedArgs(
46
  share_group_norm=True, share_layer_norm=True, share_attention=True,
@@ -59,23 +51,27 @@ def process_image(image, prompt, style):
59
 
60
  images_a = pipeline(prompts, latents=latents,
61
  callback_on_step_end=inversion_callback,
62
- num_inference_steps=num_inference_steps, guidance_scale=10.0).images
63
 
64
  handler.remove()
65
 
66
  return Image.fromarray(images_a[1])
67
 
68
- # Gradio interface
69
  iface = gr.Interface(
70
  fn=process_image,
71
  inputs=[
72
- gr.inputs.Image(type="numpy"),
73
- gr.inputs.Textbox(label="Enter your prompt"),
74
- gr.inputs.Textbox(label="Enter your style", default="medieval painting")
 
 
 
 
 
75
  ],
76
  outputs="image",
77
  title="Stable Diffusion XL with Style Alignment",
78
  description="Generate images in the style of your choice."
79
  )
80
 
81
- iface.launch()
 
2
  from diffusers import StableDiffusionXLPipeline, DDIMScheduler
3
  import torch
4
  import numpy as np
 
5
  from PIL import Image
6
  import io
7
  import sys
8
  import os
 
 
 
 
9
  import sa_handler
10
  import inversion
11
 
 
21
  ).to("cuda")
22
 
23
  # Function to process the image
24
+ def process_image(image, prompt, style, src_description, inference_steps, shared_score_shift, shared_score_scale, guidance_scale):
25
+ src_prompt = f'{src_description}, {style}.'
26
+
27
+ num_inference_steps = inference_steps
28
  x0 = np.array(Image.fromarray(image).resize((1024, 1024)))
29
  zts = inversion.ddim_inversion(pipeline, x0, src_prompt, num_inference_steps, 2)
30
 
 
33
  f"{prompt}, {style}."
34
  ]
35
 
 
 
 
36
  handler = sa_handler.Handler(pipeline)
37
  sa_args = sa_handler.StyleAlignedArgs(
38
  share_group_norm=True, share_layer_norm=True, share_attention=True,
 
51
 
52
  images_a = pipeline(prompts, latents=latents,
53
  callback_on_step_end=inversion_callback,
54
+ num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).images
55
 
56
  handler.remove()
57
 
58
  return Image.fromarray(images_a[1])
59
 
 
60
  iface = gr.Interface(
61
  fn=process_image,
62
  inputs=[
63
+ gr.Image(type="numpy"),
64
+ gr.Textbox(label="Enter your prompt"),
65
+ gr.Textbox(label="Enter your style", default="medieval painting"),
66
+ gr.Textbox(label="Enter source description", default="Man laying in a bed"),
67
+ gr.Slider(minimum=5, maximum=50, step=1, default=50, label="Number of Inference Steps"),
68
+ gr.Slider(minimum=1, maximum=2, step=0.01, default=1.5, label="Shared Score Shift"),
69
+ gr.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Shared Score Scale"),
70
+ gr.Slider(minimum=5, maximum=120, step=1, default=10, label="Guidance Scale")
71
  ],
72
  outputs="image",
73
  title="Stable Diffusion XL with Style Alignment",
74
  description="Generate images in the style of your choice."
75
  )
76
 
77
+ iface.launch()