kepler296e commited on
Commit
74c2772
·
1 Parent(s): f3eb48f

improve ui

Browse files
Files changed (1) hide show
  1. main.py +31 -21
main.py CHANGED
@@ -17,8 +17,16 @@ models = model_loader.from_pretrained(weights_url, device)
17
  pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_safetensors=True)
18
  pipe = pipe.to(device)
19
 
20
- def generate_image(prompt, negative_prompt, seed, guidance_scale, num_inference_steps, model):
 
21
 
 
 
 
 
 
 
 
22
  generator = torch.Generator(device=device).manual_seed(seed)
23
 
24
  if model == "from-scratch":
@@ -29,8 +37,8 @@ def generate_image(prompt, negative_prompt, seed, guidance_scale, num_inference_
29
  strength=0.9,
30
  cfg_scale=guidance_scale,
31
  n_inference_steps=num_inference_steps,
32
- width=512,
33
- height=512,
34
  generator=generator,
35
  device=device,
36
  idle_device="cpu",
@@ -43,8 +51,8 @@ def generate_image(prompt, negative_prompt, seed, guidance_scale, num_inference_
43
  negative_prompt=negative_prompt,
44
  guidance_scale=guidance_scale,
45
  num_inference_steps=num_inference_steps,
46
- width=512,
47
- height=512,
48
  generator=generator,
49
  ).images[0]
50
 
@@ -96,30 +104,33 @@ with gr.Blocks(css=css) as demo:
96
  interactive=True,
97
  )
98
 
 
 
99
  seed = gr.Slider(
100
  label="Seed",
101
  minimum=0,
102
- maximum=2147483647, # 2^31 - 1
103
  step=1,
104
  value=42,
105
  )
 
 
 
 
106
 
107
- negative_prompt = gr.Text(
108
- label="Negative prompt",
109
- max_lines=1,
110
- placeholder="Enter a negative prompt",
111
- visible=False,
112
- value="",
113
- )
114
-
115
- # randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
116
-
117
- """
118
  with gr.Row():
119
 
120
  width = gr.Slider(
121
  label="Width",
122
- minimum=256,
123
  maximum=MAX_IMAGE_SIZE,
124
  step=32,
125
  value=512,
@@ -127,12 +138,11 @@ with gr.Blocks(css=css) as demo:
127
 
128
  height = gr.Slider(
129
  label="Height",
130
- minimum=256,
131
  maximum=MAX_IMAGE_SIZE,
132
  step=32,
133
  value=512,
134
  )
135
- """
136
 
137
  with gr.Row():
138
 
@@ -154,7 +164,7 @@ with gr.Blocks(css=css) as demo:
154
 
155
  run_button.click(
156
  fn = generate_image,
157
- inputs = [prompt, negative_prompt, seed, guidance_scale, num_inference_steps, model],
158
  outputs = [result]
159
  )
160
 
 
17
  pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_safetensors=True)
18
  pipe = pipe.to(device)
19
 
20
+ MIN_IMAGE_SIZE = 256
21
+ MAX_IMAGE_SIZE = 1024
22
 
23
+ MAX_SEED = 2147483647 # 2^31 - 1
24
+
25
+ def generate_image(prompt, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, model, width, height):
26
+
27
+ if randomize_seed:
28
+ seed = torch.randint(0, MAX_SEED, (1,)).item()
29
+
30
  generator = torch.Generator(device=device).manual_seed(seed)
31
 
32
  if model == "from-scratch":
 
37
  strength=0.9,
38
  cfg_scale=guidance_scale,
39
  n_inference_steps=num_inference_steps,
40
+ width=width,
41
+ height=height,
42
  generator=generator,
43
  device=device,
44
  idle_device="cpu",
 
51
  negative_prompt=negative_prompt,
52
  guidance_scale=guidance_scale,
53
  num_inference_steps=num_inference_steps,
54
+ width=width,
55
+ height=height,
56
  generator=generator,
57
  ).images[0]
58
 
 
104
  interactive=True,
105
  )
106
 
107
+ with gr.Row():
108
+
109
  seed = gr.Slider(
110
  label="Seed",
111
  minimum=0,
112
+ maximum=MAX_SEED,
113
  step=1,
114
  value=42,
115
  )
116
+
117
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
118
+
119
+ with gr.Row():
120
 
121
+ negative_prompt = gr.Text(
122
+ label="Negative prompt",
123
+ max_lines=1,
124
+ placeholder="Enter a negative prompt",
125
+ visible=False,
126
+ value="",
127
+ )
128
+
 
 
 
129
  with gr.Row():
130
 
131
  width = gr.Slider(
132
  label="Width",
133
+ minimum=MIN_IMAGE_SIZE,
134
  maximum=MAX_IMAGE_SIZE,
135
  step=32,
136
  value=512,
 
138
 
139
  height = gr.Slider(
140
  label="Height",
141
+ minimum=MIN_IMAGE_SIZE,
142
  maximum=MAX_IMAGE_SIZE,
143
  step=32,
144
  value=512,
145
  )
 
146
 
147
  with gr.Row():
148
 
 
164
 
165
  run_button.click(
166
  fn = generate_image,
167
+ inputs = [prompt, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, model, width, height],
168
  outputs = [result]
169
  )
170