AlekseyCalvin commited on
Commit
e8d61e0
1 Parent(s): 868ec74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -104,7 +104,7 @@ def update_selection(evt: gr.SelectData, width, height):
104
  )
105
 
106
  @spaces.GPU(duration=70)
107
- def generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, negative_prompt, lora_scale, progress):
108
  pipe.to("cuda")
109
  generator = torch.Generator(device="cuda").manual_seed(seed)
110
 
@@ -118,11 +118,12 @@ def generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height,
118
  height=height,
119
  generator=generator,
120
  negative_prompt=negative_prompt,
 
121
  joint_attention_kwargs={"scale": lora_scale},
122
  ).images[0]
123
  return image
124
 
125
- def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, negative_prompt, lora_scale, progress=gr.Progress(track_tqdm=True)):
126
  if selected_index is None:
127
  raise gr.Error("You must select a LoRA before proceeding.")
128
 
@@ -152,7 +153,7 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, wid
152
  if randomize_seed:
153
  seed = random.randint(0, MAX_SEED)
154
 
155
- image = generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, negative_prompt, lora_scale, progress)
156
  pipe.to("cpu")
157
  pipe.unload_lora_weights()
158
  return image, seed
@@ -206,8 +207,8 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
206
  with gr.Accordion("Advanced Settings", open=True):
207
  with gr.Column():
208
  with gr.Row():
209
- cfg_scale = gr.Slider(label="CFG Scale", minimum=0, maximum=20, step=0.5, value=3.0)
210
- steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=12)
211
 
212
  with gr.Row():
213
  width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
@@ -216,7 +217,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
216
  with gr.Row():
217
  randomize_seed = gr.Checkbox(True, label="Randomize seed")
218
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
219
- lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=2.0, step=0.01, value=1.05)
220
 
221
  gallery.select(
222
  update_selection,
 
104
  )
105
 
106
  @spaces.GPU(duration=70)
107
+ def generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, negative_prompt, no_cfg_until_timestep, lora_scale, progress):
108
  pipe.to("cuda")
109
  generator = torch.Generator(device="cuda").manual_seed(seed)
110
 
 
118
  height=height,
119
  generator=generator,
120
  negative_prompt=negative_prompt,
121
+ no_cfg_until_timestep=2,
122
  joint_attention_kwargs={"scale": lora_scale},
123
  ).images[0]
124
  return image
125
 
126
+ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, negative_prompt, no_cfg_until_timestep=2, lora_scale, progress=gr.Progress(track_tqdm=True)):
127
  if selected_index is None:
128
  raise gr.Error("You must select a LoRA before proceeding.")
129
 
 
153
  if randomize_seed:
154
  seed = random.randint(0, MAX_SEED)
155
 
156
+ image = generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, negative_prompt, no_cfg_until_timestep, lora_scale, progress)
157
  pipe.to("cpu")
158
  pipe.unload_lora_weights()
159
  return image, seed
 
207
  with gr.Accordion("Advanced Settings", open=True):
208
  with gr.Column():
209
  with gr.Row():
210
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=0, maximum=20, step=0.5, value=2.5)
211
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=20)
212
 
213
  with gr.Row():
214
  width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
 
217
  with gr.Row():
218
  randomize_seed = gr.Checkbox(True, label="Randomize seed")
219
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
220
+ lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=2.0, step=0.01, value=0.8)
221
 
222
  gallery.select(
223
  update_selection,