nseq commited on
Commit
f9463e1
·
verified ·
1 Parent(s): b2f519f

Update telebot/telebotTest/sd_models.py

Browse files
Files changed (1) hide show
  1. telebot/telebotTest/sd_models.py +55 -56
telebot/telebotTest/sd_models.py CHANGED
@@ -22,8 +22,6 @@ from backend.loader import forge_loader
22
  from backend import memory_management
23
  from backend.args import dynamic_args
24
  from backend.utils import load_torch_file
25
- active_model_reloads = 0
26
- model_reload_lock = threading.Lock()
27
 
28
 
29
  model_dir = "Stable-diffusion"
@@ -451,8 +449,26 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False):
451
 
452
 
453
  def unload_model_weights(sd_model=None, info=None):
454
- memory_management.unload_all_models()
455
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
 
457
 
458
  def apply_token_merging(sd_model, token_merging_ratio):
@@ -473,73 +489,56 @@ def apply_token_merging(sd_model, token_merging_ratio):
473
 
474
  @torch.inference_mode()
475
  def forge_model_reload():
476
- global active_model_reloads
477
-
478
- # Increment counter of active reloads
479
- with model_reload_lock:
480
- active_model_reloads += 1
481
-
482
- try:
483
- current_hash = str(model_data.forge_loading_parameters)
484
-
485
- if model_data.forge_hash == current_hash:
486
- return model_data.sd_model, False
487
 
488
- # Only proceed with unloading/reloading if this is the last active reload
489
- with model_reload_lock:
490
- if active_model_reloads > 1:
491
- return model_data.sd_model, False
492
 
493
- print('Loading Model: ' + str(model_data.forge_loading_parameters))
494
 
495
- timer = Timer()
496
 
497
- if model_data.sd_model:
498
- model_data.sd_model = None
499
- memory_management.unload_all_models()
500
- memory_management.soft_empty_cache()
501
- gc.collect()
502
 
503
- timer.record("unload existing model")
504
 
505
- checkpoint_info = model_data.forge_loading_parameters['checkpoint_info']
506
 
507
- if checkpoint_info is None:
508
- raise ValueError('You do not have any model! Please download at least one model in [models/Stable-diffusion].')
509
 
510
- state_dict = checkpoint_info.filename
511
- additional_state_dicts = model_data.forge_loading_parameters.get('additional_modules', [])
512
 
513
- timer.record("cache state dict")
514
 
515
- dynamic_args['forge_unet_storage_dtype'] = model_data.forge_loading_parameters.get('unet_storage_dtype', None)
516
- dynamic_args['embedding_dir'] = cmd_opts.embeddings_dir
517
- dynamic_args['emphasis_name'] = opts.emphasis
518
- sd_model = forge_loader(state_dict, additional_state_dicts=additional_state_dicts)
519
- timer.record("forge model load")
520
 
521
- sd_model.extra_generation_params = {}
522
- sd_model.comments = []
523
- sd_model.sd_checkpoint_info = checkpoint_info
524
- sd_model.filename = checkpoint_info.filename
525
- sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
526
- timer.record("calculate hash")
527
-
528
- shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
529
 
530
- model_data.set_sd_model(sd_model)
531
 
532
- script_callbacks.model_loaded_callback(sd_model)
533
 
534
- timer.record("scripts callbacks")
535
 
536
- print(f"Model loaded in {timer.summary()}.")
537
 
538
- model_data.forge_hash = current_hash
539
 
540
- return sd_model, True
541
 
542
- finally:
543
- # Decrement counter when done
544
- with model_reload_lock:
545
- active_model_reloads -= 1
 
22
  from backend import memory_management
23
  from backend.args import dynamic_args
24
  from backend.utils import load_torch_file
 
 
25
 
26
 
27
  model_dir = "Stable-diffusion"
 
449
 
450
 
451
  def unload_model_weights(sd_model=None, info=None):
452
+ # Add synchronization point before unloading
453
+ import time
454
+ import os
455
+ lock_file = "/kaggle/working/unload.lock"
456
+
457
+ # Wait for previous unloading to complete
458
+ while os.path.exists(lock_file):
459
+ time.sleep(0.1)
460
+
461
+ # Create lock file
462
+ with open(lock_file, 'w') as f:
463
+ f.write("locked")
464
+
465
+ try:
466
+ memory_management.unload_all_models()
467
+ return
468
+ finally:
469
+ # Release lock
470
+ if os.path.exists(lock_file):
471
+ os.remove(lock_file)
472
 
473
 
474
  def apply_token_merging(sd_model, token_merging_ratio):
 
489
 
490
  @torch.inference_mode()
491
  def forge_model_reload():
492
+ current_hash = str(model_data.forge_loading_parameters)
 
 
 
 
 
 
 
 
 
 
493
 
494
+ if model_data.forge_hash == current_hash:
495
+ return model_data.sd_model, False
 
 
496
 
497
+ print('Loading Model: ' + str(model_data.forge_loading_parameters))
498
 
499
+ timer = Timer()
500
 
501
+ if model_data.sd_model:
502
+ model_data.sd_model = None
503
+ memory_management.unload_all_models()
504
+ memory_management.soft_empty_cache()
505
+ gc.collect()
506
 
507
+ timer.record("unload existing model")
508
 
509
+ checkpoint_info = model_data.forge_loading_parameters['checkpoint_info']
510
 
511
+ if checkpoint_info is None:
512
+ raise ValueError('You do not have any model! Please download at least one model in [models/Stable-diffusion].')
513
 
514
+ state_dict = checkpoint_info.filename
515
+ additional_state_dicts = model_data.forge_loading_parameters.get('additional_modules', [])
516
 
517
+ timer.record("cache state dict")
518
 
519
+ dynamic_args['forge_unet_storage_dtype'] = model_data.forge_loading_parameters.get('unet_storage_dtype', None)
520
+ dynamic_args['embedding_dir'] = cmd_opts.embeddings_dir
521
+ dynamic_args['emphasis_name'] = opts.emphasis
522
+ sd_model = forge_loader(state_dict, additional_state_dicts=additional_state_dicts)
523
+ timer.record("forge model load")
524
 
525
+ sd_model.extra_generation_params = {}
526
+ sd_model.comments = []
527
+ sd_model.sd_checkpoint_info = checkpoint_info
528
+ sd_model.filename = checkpoint_info.filename
529
+ sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
530
+ timer.record("calculate hash")
 
 
531
 
532
+ shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
533
 
534
+ model_data.set_sd_model(sd_model)
535
 
536
+ script_callbacks.model_loaded_callback(sd_model)
537
 
538
+ timer.record("scripts callbacks")
539
 
540
+ print(f"Model loaded in {timer.summary()}.")
541
 
542
+ model_data.forge_hash = current_hash
543
 
544
+ return sd_model, True