3v324v23 commited on
Commit
b3e9026
1 Parent(s): b816b7d

forced lora

Browse files
Files changed (1) hide show
  1. handler.py +58 -11
handler.py CHANGED
@@ -6,6 +6,7 @@ from io import BytesIO
6
  from pprint import pprint
7
  from typing import Any, Dict, List
8
  import os
 
9
  from pathlib import Path
10
  from typing import Union
11
  from concurrent.futures import ThreadPoolExecutor
@@ -87,6 +88,7 @@ class EndpointHandler:
87
  self.inference_progress = {} # Dictionary to store progress of each request
88
  self.inference_images = {} # Dictionary to store latest image of each request
89
  self.total_steps = {}
 
90
  self.inference_in_progress = False
91
 
92
  self.executor = ThreadPoolExecutor(
@@ -131,6 +133,18 @@ class EndpointHandler:
131
  self.pipe.enable_attention_slicing()
132
  # may need a requirement in the root with xformer
133
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  def load_lora(self, pipeline, lora_path, lora_weight=0.5):
135
  state_dict = load_file(lora_path)
136
  LORA_PREFIX_UNET = "lora_unet"
@@ -218,10 +232,33 @@ class EndpointHandler:
218
  """Load Loras models, can lead to marvelous creations"""
219
  for model_name, weight in selections:
220
  lora_path = EndpointHandler.LORA_PATHS[model_name]
221
- self.pipe = self.load_lora(
222
- pipeline=self.pipe, lora_path=lora_path, lora_weight=weight
223
- )
224
- return self.pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
  def clean_request_data(self, request_id: str):
227
  """Clean up the data related to a specific request ID."""
@@ -235,6 +272,9 @@ class EndpointHandler:
235
  # Remove the request ID from the total_steps dictionary
236
  self.total_steps.pop(request_id, None)
237
 
 
 
 
238
  # Set inference to False
239
  self.inference_in_progress = False
240
 
@@ -349,17 +389,18 @@ class EndpointHandler:
349
  self.total_steps[request_id] = num_inference_steps
350
 
351
  # USe this to add automatically some negative prompts
352
- forced_negative = (
353
- negative_prompt
354
- + """, easynegative, badhandv4, bad-artist-anime, negfeetv2, ng_deepnegative_v1_75t, bad-hands-5, """
355
- )
356
 
357
  # Set the generator seed if provided
358
  generator = torch.Generator(device="cuda").manual_seed(seed) if seed else None
359
 
360
  # Load the provided Lora models
 
361
  # if loras_model:
362
- # self.pipe = self.load_selected_loras(loras_model)
 
 
 
363
 
364
  try:
365
  # 2. Process
@@ -376,8 +417,8 @@ class EndpointHandler:
376
  callback=lambda step, timestep, latents: self.progress_callback(
377
  step, timestep, latents, request_id, "progress"
378
  ),
379
- callback_steps=5, # The frequency at which the callback function is called.
380
- # output_type="pt",
381
  ).images[0]
382
 
383
  # print(image)
@@ -405,6 +446,11 @@ class EndpointHandler:
405
  return {"flag": "error", "message": "Missing request_id."}
406
 
407
  if action == "check_progress":
 
 
 
 
 
408
  return self.check_progress(request_id)
409
 
410
  elif action == "inference":
@@ -420,6 +466,7 @@ class EndpointHandler:
420
  self.inference_in_progress = True
421
  self.inference_progress[request_id] = 0
422
  self.inference_images[request_id] = None
 
423
 
424
  self.executor.submit(self.start_inference, data)
425
 
 
6
  from pprint import pprint
7
  from typing import Any, Dict, List
8
  import os
9
+ import re
10
  from pathlib import Path
11
  from typing import Union
12
  from concurrent.futures import ThreadPoolExecutor
 
88
  self.inference_progress = {} # Dictionary to store progress of each request
89
  self.inference_images = {} # Dictionary to store latest image of each request
90
  self.total_steps = {}
91
+ self.active_request_ids = set()
92
  self.inference_in_progress = False
93
 
94
  self.executor = ThreadPoolExecutor(
 
133
  self.pipe.enable_attention_slicing()
134
  # may need a requirement in the root with xformer
135
 
136
+ # Load loras one time only
137
+ # Must be replaced once we will know how to hot load/unload
138
+ # it use the own made load_lora function
139
+ self.load_selected_loras(
140
+ [
141
+ ["polyhedron_new_skin_v1.1", 0.2],
142
+ ["detailed_eye-10", 0.2],
143
+ ["add_detail", 0.3],
144
+ ["MuscleGirl_v1", 0.2],
145
+ ]
146
+ )
147
+
148
  def load_lora(self, pipeline, lora_path, lora_weight=0.5):
149
  state_dict = load_file(lora_path)
150
  LORA_PREFIX_UNET = "lora_unet"
 
232
  """Load Loras models, can lead to marvelous creations"""
233
  for model_name, weight in selections:
234
  lora_path = EndpointHandler.LORA_PATHS[model_name]
235
+ # self.pipe.load_lora_weights(lora_path)
236
+ self.load_lora(self.pipe, lora_path, weight)
237
+
238
+ def clean_negative_prompt(self, negative_prompt):
239
+ """Clean negative prompt to remove already used negative prompt handlers"""
240
+
241
+ # negative_prompt = (
242
+ # negative_prompt
243
+ # + """, easynegative, badhandv4, bad-artist-anime, negfeetv2, ng_deepnegative_v1_75t, bad-hands-5, """
244
+ # )
245
+
246
+ tokens = [item["token"] for item in self.TEXTUAL_INVERSION]
247
+
248
+ # Retirer tous les tokens de negative_prompt s'ils existent déjà
249
+ for token in tokens:
250
+ # Utiliser une expression régulière pour un remplacement insensible à la casse
251
+ negative_prompt = re.sub(
252
+ r"\b" + re.escape(token) + r"\b",
253
+ "",
254
+ negative_prompt,
255
+ flags=re.IGNORECASE,
256
+ ).strip()
257
+
258
+ # Ajouter tous les tokens à la fin de negative_prompt
259
+ negative_prompt += " " + " ".join(tokens)
260
+
261
+ return negative_prompt
262
 
263
  def clean_request_data(self, request_id: str):
264
  """Clean up the data related to a specific request ID."""
 
272
  # Remove the request ID from the total_steps dictionary
273
  self.total_steps.pop(request_id, None)
274
 
275
+ # Delete request id
276
+ self.active_request_ids.discard(request_id)
277
+
278
  # Set inference to False
279
  self.inference_in_progress = False
280
 
 
389
  self.total_steps[request_id] = num_inference_steps
390
 
391
  # USe this to add automatically some negative prompts
392
+ forced_negative = self.clean_negative_prompt(negative_prompt)
 
 
 
393
 
394
  # Set the generator seed if provided
395
  generator = torch.Generator(device="cuda").manual_seed(seed) if seed else None
396
 
397
  # Load the provided Lora models
398
+ # self.pipe.unload_lora_weights() # Unload models to avoid lora staking
399
  # if loras_model:
400
+ # self.load_selected_loras(loras_model)
401
+
402
+ # set scale of loras, for now take only first scale of the loaded lora and apply to all until we find the way to apply specified scale
403
+ # scale = {"scale": loras_model[0][1]} if loras_model else None
404
 
405
  try:
406
  # 2. Process
 
417
  callback=lambda step, timestep, latents: self.progress_callback(
418
  step, timestep, latents, request_id, "progress"
419
  ),
420
+ callback_steps=5,
421
+ # cross_attention_kwargs={"scale": 0.2},
422
  ).images[0]
423
 
424
  # print(image)
 
446
  return {"flag": "error", "message": "Missing request_id."}
447
 
448
  if action == "check_progress":
449
+ if request_id not in self.active_request_ids:
450
+ return {
451
+ "flag": "error",
452
+ "message": "Request id doesn't match any active request.",
453
+ }
454
  return self.check_progress(request_id)
455
 
456
  elif action == "inference":
 
466
  self.inference_in_progress = True
467
  self.inference_progress[request_id] = 0
468
  self.inference_images[request_id] = None
469
+ self.active_request_ids.add(request_id)
470
 
471
  self.executor.submit(self.start_inference, data)
472