Update telebot/telebotTest/sd_models.py
Browse files- telebot/telebotTest/sd_models.py +55 -56
telebot/telebotTest/sd_models.py
CHANGED
|
@@ -22,8 +22,6 @@ from backend.loader import forge_loader
|
|
| 22 |
from backend import memory_management
|
| 23 |
from backend.args import dynamic_args
|
| 24 |
from backend.utils import load_torch_file
|
| 25 |
-
active_model_reloads = 0
|
| 26 |
-
model_reload_lock = threading.Lock()
|
| 27 |
|
| 28 |
|
| 29 |
model_dir = "Stable-diffusion"
|
|
@@ -451,8 +449,26 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False):
|
|
| 451 |
|
| 452 |
|
| 453 |
def unload_model_weights(sd_model=None, info=None):
|
| 454 |
-
|
| 455 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
|
| 457 |
|
| 458 |
def apply_token_merging(sd_model, token_merging_ratio):
|
|
@@ -473,73 +489,56 @@ def apply_token_merging(sd_model, token_merging_ratio):
|
|
| 473 |
|
| 474 |
@torch.inference_mode()
|
| 475 |
def forge_model_reload():
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
# Increment counter of active reloads
|
| 479 |
-
with model_reload_lock:
|
| 480 |
-
active_model_reloads += 1
|
| 481 |
-
|
| 482 |
-
try:
|
| 483 |
-
current_hash = str(model_data.forge_loading_parameters)
|
| 484 |
-
|
| 485 |
-
if model_data.forge_hash == current_hash:
|
| 486 |
-
return model_data.sd_model, False
|
| 487 |
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
if active_model_reloads > 1:
|
| 491 |
-
return model_data.sd_model, False
|
| 492 |
|
| 493 |
-
|
| 494 |
|
| 495 |
-
|
| 496 |
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
|
| 503 |
-
|
| 504 |
|
| 505 |
-
|
| 506 |
|
| 507 |
-
|
| 508 |
-
|
| 509 |
|
| 510 |
-
|
| 511 |
-
|
| 512 |
|
| 513 |
-
|
| 514 |
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
|
| 529 |
|
| 530 |
-
|
| 531 |
|
| 532 |
-
|
| 533 |
|
| 534 |
-
|
| 535 |
|
| 536 |
-
|
| 537 |
|
| 538 |
-
|
| 539 |
|
| 540 |
-
|
| 541 |
|
| 542 |
-
|
| 543 |
-
# Decrement counter when done
|
| 544 |
-
with model_reload_lock:
|
| 545 |
-
active_model_reloads -= 1
|
|
|
|
| 22 |
from backend import memory_management
|
| 23 |
from backend.args import dynamic_args
|
| 24 |
from backend.utils import load_torch_file
|
|
|
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
model_dir = "Stable-diffusion"
|
|
|
|
| 449 |
|
| 450 |
|
| 451 |
def unload_model_weights(sd_model=None, info=None):
|
| 452 |
+
# Add synchronization point before unloading
|
| 453 |
+
import time
|
| 454 |
+
import os
|
| 455 |
+
lock_file = "/kaggle/working/unload.lock"
|
| 456 |
+
|
| 457 |
+
# Wait for previous unloading to complete
|
| 458 |
+
while os.path.exists(lock_file):
|
| 459 |
+
time.sleep(0.1)
|
| 460 |
+
|
| 461 |
+
# Create lock file
|
| 462 |
+
with open(lock_file, 'w') as f:
|
| 463 |
+
f.write("locked")
|
| 464 |
+
|
| 465 |
+
try:
|
| 466 |
+
memory_management.unload_all_models()
|
| 467 |
+
return
|
| 468 |
+
finally:
|
| 469 |
+
# Release lock
|
| 470 |
+
if os.path.exists(lock_file):
|
| 471 |
+
os.remove(lock_file)
|
| 472 |
|
| 473 |
|
| 474 |
def apply_token_merging(sd_model, token_merging_ratio):
|
|
|
|
| 489 |
|
| 490 |
@torch.inference_mode()
|
| 491 |
def forge_model_reload():
|
| 492 |
+
current_hash = str(model_data.forge_loading_parameters)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
|
| 494 |
+
if model_data.forge_hash == current_hash:
|
| 495 |
+
return model_data.sd_model, False
|
|
|
|
|
|
|
| 496 |
|
| 497 |
+
print('Loading Model: ' + str(model_data.forge_loading_parameters))
|
| 498 |
|
| 499 |
+
timer = Timer()
|
| 500 |
|
| 501 |
+
if model_data.sd_model:
|
| 502 |
+
model_data.sd_model = None
|
| 503 |
+
memory_management.unload_all_models()
|
| 504 |
+
memory_management.soft_empty_cache()
|
| 505 |
+
gc.collect()
|
| 506 |
|
| 507 |
+
timer.record("unload existing model")
|
| 508 |
|
| 509 |
+
checkpoint_info = model_data.forge_loading_parameters['checkpoint_info']
|
| 510 |
|
| 511 |
+
if checkpoint_info is None:
|
| 512 |
+
raise ValueError('You do not have any model! Please download at least one model in [models/Stable-diffusion].')
|
| 513 |
|
| 514 |
+
state_dict = checkpoint_info.filename
|
| 515 |
+
additional_state_dicts = model_data.forge_loading_parameters.get('additional_modules', [])
|
| 516 |
|
| 517 |
+
timer.record("cache state dict")
|
| 518 |
|
| 519 |
+
dynamic_args['forge_unet_storage_dtype'] = model_data.forge_loading_parameters.get('unet_storage_dtype', None)
|
| 520 |
+
dynamic_args['embedding_dir'] = cmd_opts.embeddings_dir
|
| 521 |
+
dynamic_args['emphasis_name'] = opts.emphasis
|
| 522 |
+
sd_model = forge_loader(state_dict, additional_state_dicts=additional_state_dicts)
|
| 523 |
+
timer.record("forge model load")
|
| 524 |
|
| 525 |
+
sd_model.extra_generation_params = {}
|
| 526 |
+
sd_model.comments = []
|
| 527 |
+
sd_model.sd_checkpoint_info = checkpoint_info
|
| 528 |
+
sd_model.filename = checkpoint_info.filename
|
| 529 |
+
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
|
| 530 |
+
timer.record("calculate hash")
|
|
|
|
|
|
|
| 531 |
|
| 532 |
+
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
|
| 533 |
|
| 534 |
+
model_data.set_sd_model(sd_model)
|
| 535 |
|
| 536 |
+
script_callbacks.model_loaded_callback(sd_model)
|
| 537 |
|
| 538 |
+
timer.record("scripts callbacks")
|
| 539 |
|
| 540 |
+
print(f"Model loaded in {timer.summary()}.")
|
| 541 |
|
| 542 |
+
model_data.forge_hash = current_hash
|
| 543 |
|
| 544 |
+
return sd_model, True
|
|
|
|
|
|
|
|
|