multimodalart HF staff commited on
Commit
fb4901e
1 Parent(s): 3b6af48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -98,15 +98,19 @@ def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lor
98
  print("Loading state dicts...")
99
  start_time = time()
100
  state_dict_1 = copy.deepcopy(state_dicts[repo_id_1]["state_dict"])
 
101
  state_dict_2 = copy.deepcopy(state_dicts[repo_id_2]["state_dict"])
 
102
  state_dict_time = time() - start_time
103
  print(f"State Dict time: {state_dict_time}")
104
  #pipe = copy.deepcopy(original_pipe)
105
  start_time = time()
106
  unet = copy.deepcopy(original_pipe.unet)
 
 
107
  pipe = StableDiffusionXLPipeline(vae=original_pipe.vae,
108
- text_encoder=original_pipe.text_encoder,
109
- text_encoder_2=original_pipe.text_encoder_2,
110
  scheduler=original_pipe.scheduler,
111
  tokenizer=original_pipe.tokenizer,
112
  tokenizer_2=original_pipe.tokenizer_2,
 
98
  print("Loading state dicts...")
99
  start_time = time()
100
  state_dict_1 = copy.deepcopy(state_dicts[repo_id_1]["state_dict"])
101
+ state_dict_1 = {k: v.to(device="cuda", dtype=torch.float16) for k,v in state_dict_1.items() if torch.is_tensor(v)}
102
  state_dict_2 = copy.deepcopy(state_dicts[repo_id_2]["state_dict"])
103
+ state_dict_2 = {k: v.to(device="cuda", dtype=torch.float16) for k,v in state_dict_2.items() if torch.is_tensor(v)}
104
  state_dict_time = time() - start_time
105
  print(f"State Dict time: {state_dict_time}")
106
  #pipe = copy.deepcopy(original_pipe)
107
  start_time = time()
108
  unet = copy.deepcopy(original_pipe.unet)
109
+ text_encoder=copy.deepcopy(original_pipe.text_encoder)
110
+ text_encoder_2=copy.deepcopy(original_pipe.text_encoder_2)
111
  pipe = StableDiffusionXLPipeline(vae=original_pipe.vae,
112
+ text_encoder=text_encoder,
113
+ text_encoder_2=text_encoder_2,
114
  scheduler=original_pipe.scheduler,
115
  tokenizer=original_pipe.tokenizer,
116
  tokenizer_2=original_pipe.tokenizer_2,