dgoot commited on
Commit
9b97455
·
1 Parent(s): fdbf283

Add LORA weight customization

Browse files
Files changed (1) hide show
  1. app.py +16 -1
app.py CHANGED
@@ -160,7 +160,8 @@ elif model_type == "LORA":
160
  else:
161
  raise ValueError(f"Unsupported base model: {base_model}")
162
 
163
- pipe.load_lora_weights(get_file_name("Model"), adapter_name=slugify(model_name))
 
164
  else:
165
  raise ValueError(f"Unsupported model type: {model_type}")
166
 
@@ -176,6 +177,7 @@ def infer(
176
  strength: float,
177
  num_inference_steps: int,
178
  guidance_scale: float,
 
179
  progress=gr.Progress(track_tqdm=True),
180
  ):
181
  logger.info(f"Starting image generation: {dict(prompt=prompt, image=init_image)}")
@@ -193,6 +195,9 @@ def infer(
193
  if v
194
  }
195
 
 
 
 
196
  logger.debug(f"Generating image: {dict(prompt=prompt, **additional_args)}")
197
 
198
  images = pipe(
@@ -249,6 +254,15 @@ with gr.Blocks(css=css) as demo:
249
  value=0.0,
250
  )
251
 
 
 
 
 
 
 
 
 
 
252
  num_inference_steps = gr.Slider(
253
  label="Number of inference steps",
254
  minimum=0,
@@ -274,6 +288,7 @@ with gr.Blocks(css=css) as demo:
274
  strength,
275
  num_inference_steps,
276
  guidance_scale,
 
277
  ],
278
  outputs=[result],
279
  )
 
160
  else:
161
  raise ValueError(f"Unsupported base model: {base_model}")
162
 
163
+ adapter_name = slugify(model_name)
164
+ pipe.load_lora_weights(get_file_name("Model"), adapter_name=adapter_name)
165
  else:
166
  raise ValueError(f"Unsupported model type: {model_type}")
167
 
 
177
  strength: float,
178
  num_inference_steps: int,
179
  guidance_scale: float,
180
+ lora_weight: float,
181
  progress=gr.Progress(track_tqdm=True),
182
  ):
183
  logger.info(f"Starting image generation: {dict(prompt=prompt, image=init_image)}")
 
195
  if v
196
  }
197
 
198
+ if lora_weight:
199
+ pipe.set_adapters(adapter_name, lora_weight)
200
+
201
  logger.debug(f"Generating image: {dict(prompt=prompt, **additional_args)}")
202
 
203
  images = pipe(
 
254
  value=0.0,
255
  )
256
 
257
+ lora_weight = gr.Slider(
258
+ label="LORA weight",
259
+ minimum=0.0,
260
+ maximum=1.0,
261
+ step=0.01,
262
+ value=0.0,
263
+ visible=model_type == "LORA",
264
+ )
265
+
266
  num_inference_steps = gr.Slider(
267
  label="Number of inference steps",
268
  minimum=0,
 
288
  strength,
289
  num_inference_steps,
290
  guidance_scale,
291
+ lora_weight,
292
  ],
293
  outputs=[result],
294
  )