AlekseyCalvin commited on
Commit
83c9489
·
verified ·
1 Parent(s): 2c4abd3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -99,7 +99,7 @@ def update_selection(evt: gr.SelectData, width, height):
99
  )
100
 
101
  @spaces.GPU(duration=70)
102
- def generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, progress):
103
  pipe.to("cuda")
104
  generator = torch.Generator(device="cuda").manual_seed(seed)
105
 
@@ -107,6 +107,7 @@ def generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height,
107
  # Generate image
108
  image = pipe(
109
  prompt=f"{prompt} {trigger_word}",
 
110
  num_inference_steps=steps,
111
  guidance_scale=cfg_scale,
112
  width=width,
@@ -116,7 +117,9 @@ def generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height,
116
  ).images[0]
117
  return image
118
 
119
- def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
 
 
120
  if selected_index is None:
121
  raise gr.Error("You must select a LoRA before proceeding.")
122
 
@@ -136,7 +139,7 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, wid
136
  if randomize_seed:
137
  seed = random.randint(0, MAX_SEED)
138
 
139
- image = generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, progress)
140
  pipe.to("cpu")
141
  pipe.unload_lora_weights()
142
  return image, seed
@@ -167,6 +170,9 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
167
  with gr.Row():
168
  with gr.Column(scale=3):
169
  prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Select LoRa/Style & type prompt!")
 
 
 
170
  with gr.Column(scale=1, elem_id="gen_column"):
171
  generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
172
  with gr.Row():
@@ -208,7 +214,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
208
  gr.on(
209
  triggers=[generate_button.click, prompt.submit],
210
  fn=run_lora,
211
- inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale],
212
  outputs=[result, seed]
213
  )
214
 
 
99
  )
100
 
101
  @spaces.GPU(duration=70)
102
+ def generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, negative_prompt, lora_scale, progress):
103
  pipe.to("cuda")
104
  generator = torch.Generator(device="cuda").manual_seed(seed)
105
 
 
107
  # Generate image
108
  image = pipe(
109
  prompt=f"{prompt} {trigger_word}",
110
+ negative_prompt=negative_prompt,
111
  num_inference_steps=steps,
112
  guidance_scale=cfg_scale,
113
  width=width,
 
117
  ).images[0]
118
  return image
119
 
120
+ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, negative_prompt, lora_scale, progress=gr.Progress(track_tqdm=True)):
121
+ if negative_prompt == "":
122
+ negative_prompt = None
123
  if selected_index is None:
124
  raise gr.Error("You must select a LoRA before proceeding.")
125
 
 
139
  if randomize_seed:
140
  seed = random.randint(0, MAX_SEED)
141
 
142
+ image = generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, negative_prompt, lora_scale, progress)
143
  pipe.to("cpu")
144
  pipe.unload_lora_weights()
145
  return image, seed
 
170
  with gr.Row():
171
  with gr.Column(scale=3):
172
  prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Select LoRa/Style & type prompt!")
173
+ with gr.Row():
174
+ with gr.Column(scale=3):
175
+ negative_prompt = gr.Textbox(label="Negative Prompt", lines=1, placeholder="List unwanted conditions, open-fluxedly!")
176
  with gr.Column(scale=1, elem_id="gen_column"):
177
  generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
178
  with gr.Row():
 
214
  gr.on(
215
  triggers=[generate_button.click, prompt.submit],
216
  fn=run_lora,
217
+ inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, negative_prompt, lora_scale],
218
  outputs=[result, seed]
219
  )
220