forced lora
Browse files- handler.py +58 -11
handler.py
CHANGED
@@ -6,6 +6,7 @@ from io import BytesIO
|
|
6 |
from pprint import pprint
|
7 |
from typing import Any, Dict, List
|
8 |
import os
|
|
|
9 |
from pathlib import Path
|
10 |
from typing import Union
|
11 |
from concurrent.futures import ThreadPoolExecutor
|
@@ -87,6 +88,7 @@ class EndpointHandler:
|
|
87 |
self.inference_progress = {} # Dictionary to store progress of each request
|
88 |
self.inference_images = {} # Dictionary to store latest image of each request
|
89 |
self.total_steps = {}
|
|
|
90 |
self.inference_in_progress = False
|
91 |
|
92 |
self.executor = ThreadPoolExecutor(
|
@@ -131,6 +133,18 @@ class EndpointHandler:
|
|
131 |
self.pipe.enable_attention_slicing()
|
132 |
# may need a requirement in the root with xformer
|
133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
def load_lora(self, pipeline, lora_path, lora_weight=0.5):
|
135 |
state_dict = load_file(lora_path)
|
136 |
LORA_PREFIX_UNET = "lora_unet"
|
@@ -218,10 +232,33 @@ class EndpointHandler:
|
|
218 |
"""Load Loras models, can lead to marvelous creations"""
|
219 |
for model_name, weight in selections:
|
220 |
lora_path = EndpointHandler.LORA_PATHS[model_name]
|
221 |
-
self.pipe
|
222 |
-
|
223 |
-
|
224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
|
226 |
def clean_request_data(self, request_id: str):
|
227 |
"""Clean up the data related to a specific request ID."""
|
@@ -235,6 +272,9 @@ class EndpointHandler:
|
|
235 |
# Remove the request ID from the total_steps dictionary
|
236 |
self.total_steps.pop(request_id, None)
|
237 |
|
|
|
|
|
|
|
238 |
# Set inference to False
|
239 |
self.inference_in_progress = False
|
240 |
|
@@ -349,17 +389,18 @@ class EndpointHandler:
|
|
349 |
self.total_steps[request_id] = num_inference_steps
|
350 |
|
351 |
# USe this to add automatically some negative prompts
|
352 |
-
forced_negative = (
|
353 |
-
negative_prompt
|
354 |
-
+ """, easynegative, badhandv4, bad-artist-anime, negfeetv2, ng_deepnegative_v1_75t, bad-hands-5, """
|
355 |
-
)
|
356 |
|
357 |
# Set the generator seed if provided
|
358 |
generator = torch.Generator(device="cuda").manual_seed(seed) if seed else None
|
359 |
|
360 |
# Load the provided Lora models
|
|
|
361 |
# if loras_model:
|
362 |
-
# self.
|
|
|
|
|
|
|
363 |
|
364 |
try:
|
365 |
# 2. Process
|
@@ -376,8 +417,8 @@ class EndpointHandler:
|
|
376 |
callback=lambda step, timestep, latents: self.progress_callback(
|
377 |
step, timestep, latents, request_id, "progress"
|
378 |
),
|
379 |
-
callback_steps=5,
|
380 |
-
#
|
381 |
).images[0]
|
382 |
|
383 |
# print(image)
|
@@ -405,6 +446,11 @@ class EndpointHandler:
|
|
405 |
return {"flag": "error", "message": "Missing request_id."}
|
406 |
|
407 |
if action == "check_progress":
|
|
|
|
|
|
|
|
|
|
|
408 |
return self.check_progress(request_id)
|
409 |
|
410 |
elif action == "inference":
|
@@ -420,6 +466,7 @@ class EndpointHandler:
|
|
420 |
self.inference_in_progress = True
|
421 |
self.inference_progress[request_id] = 0
|
422 |
self.inference_images[request_id] = None
|
|
|
423 |
|
424 |
self.executor.submit(self.start_inference, data)
|
425 |
|
|
|
6 |
from pprint import pprint
|
7 |
from typing import Any, Dict, List
|
8 |
import os
|
9 |
+
import re
|
10 |
from pathlib import Path
|
11 |
from typing import Union
|
12 |
from concurrent.futures import ThreadPoolExecutor
|
|
|
88 |
self.inference_progress = {} # Dictionary to store progress of each request
|
89 |
self.inference_images = {} # Dictionary to store latest image of each request
|
90 |
self.total_steps = {}
|
91 |
+
self.active_request_ids = set()
|
92 |
self.inference_in_progress = False
|
93 |
|
94 |
self.executor = ThreadPoolExecutor(
|
|
|
133 |
self.pipe.enable_attention_slicing()
|
134 |
# may need a requirement in the root with xformer
|
135 |
|
136 |
+
# Load loras one time only
|
137 |
+
# Must be replaced once we will know how to hot load/unload
|
138 |
+
# it use the own made load_lora function
|
139 |
+
self.load_selected_loras(
|
140 |
+
[
|
141 |
+
["polyhedron_new_skin_v1.1", 0.2],
|
142 |
+
["detailed_eye-10", 0.2],
|
143 |
+
["add_detail", 0.3],
|
144 |
+
["MuscleGirl_v1", 0.2],
|
145 |
+
]
|
146 |
+
)
|
147 |
+
|
148 |
def load_lora(self, pipeline, lora_path, lora_weight=0.5):
|
149 |
state_dict = load_file(lora_path)
|
150 |
LORA_PREFIX_UNET = "lora_unet"
|
|
|
232 |
"""Load Loras models, can lead to marvelous creations"""
|
233 |
for model_name, weight in selections:
|
234 |
lora_path = EndpointHandler.LORA_PATHS[model_name]
|
235 |
+
# self.pipe.load_lora_weights(lora_path)
|
236 |
+
self.load_lora(self.pipe, lora_path, weight)
|
237 |
+
|
238 |
+
def clean_negative_prompt(self, negative_prompt):
|
239 |
+
"""Clean negative prompt to remove already used negative prompt handlers"""
|
240 |
+
|
241 |
+
# negative_prompt = (
|
242 |
+
# negative_prompt
|
243 |
+
# + """, easynegative, badhandv4, bad-artist-anime, negfeetv2, ng_deepnegative_v1_75t, bad-hands-5, """
|
244 |
+
# )
|
245 |
+
|
246 |
+
tokens = [item["token"] for item in self.TEXTUAL_INVERSION]
|
247 |
+
|
248 |
+
# Retirer tous les tokens de negative_prompt s'ils existent déjà
|
249 |
+
for token in tokens:
|
250 |
+
# Utiliser une expression régulière pour un remplacement insensible à la casse
|
251 |
+
negative_prompt = re.sub(
|
252 |
+
r"\b" + re.escape(token) + r"\b",
|
253 |
+
"",
|
254 |
+
negative_prompt,
|
255 |
+
flags=re.IGNORECASE,
|
256 |
+
).strip()
|
257 |
+
|
258 |
+
# Ajouter tous les tokens à la fin de negative_prompt
|
259 |
+
negative_prompt += " " + " ".join(tokens)
|
260 |
+
|
261 |
+
return negative_prompt
|
262 |
|
263 |
def clean_request_data(self, request_id: str):
|
264 |
"""Clean up the data related to a specific request ID."""
|
|
|
272 |
# Remove the request ID from the total_steps dictionary
|
273 |
self.total_steps.pop(request_id, None)
|
274 |
|
275 |
+
# Delete request id
|
276 |
+
self.active_request_ids.discard(request_id)
|
277 |
+
|
278 |
# Set inference to False
|
279 |
self.inference_in_progress = False
|
280 |
|
|
|
389 |
self.total_steps[request_id] = num_inference_steps
|
390 |
|
391 |
# USe this to add automatically some negative prompts
|
392 |
+
forced_negative = self.clean_negative_prompt(negative_prompt)
|
|
|
|
|
|
|
393 |
|
394 |
# Set the generator seed if provided
|
395 |
generator = torch.Generator(device="cuda").manual_seed(seed) if seed else None
|
396 |
|
397 |
# Load the provided Lora models
|
398 |
+
# self.pipe.unload_lora_weights() # Unload models to avoid lora staking
|
399 |
# if loras_model:
|
400 |
+
# self.load_selected_loras(loras_model)
|
401 |
+
|
402 |
+
# set scale of loras, for now take only first scale of the loaded lora and apply to all until we find the way to apply specified scale
|
403 |
+
# scale = {"scale": loras_model[0][1]} if loras_model else None
|
404 |
|
405 |
try:
|
406 |
# 2. Process
|
|
|
417 |
callback=lambda step, timestep, latents: self.progress_callback(
|
418 |
step, timestep, latents, request_id, "progress"
|
419 |
),
|
420 |
+
callback_steps=5,
|
421 |
+
# cross_attention_kwargs={"scale": 0.2},
|
422 |
).images[0]
|
423 |
|
424 |
# print(image)
|
|
|
446 |
return {"flag": "error", "message": "Missing request_id."}
|
447 |
|
448 |
if action == "check_progress":
|
449 |
+
if request_id not in self.active_request_ids:
|
450 |
+
return {
|
451 |
+
"flag": "error",
|
452 |
+
"message": "Request id doesn't match any active request.",
|
453 |
+
}
|
454 |
return self.check_progress(request_id)
|
455 |
|
456 |
elif action == "inference":
|
|
|
466 |
self.inference_in_progress = True
|
467 |
self.inference_progress[request_id] = 0
|
468 |
self.inference_images[request_id] = None
|
469 |
+
self.active_request_ids.add(request_id)
|
470 |
|
471 |
self.executor.submit(self.start_inference, data)
|
472 |
|