SemaSci commited on
Commit
1a6ef1d
·
verified ·
1 Parent(s): b8f4d7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -20,7 +20,7 @@ def get_lora_sd_pipeline(
20
  dtype=torch.float16,
21
  adapter_name="default"
22
  ):
23
-
24
  unet_sub_dir = os.path.join(ckpt_dir, "unet")
25
  text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
26
 
@@ -33,7 +33,17 @@ def get_lora_sd_pipeline(
33
 
34
  pipe = DiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype)
35
  before_params = pipe.unet.parameters()
36
- pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
 
 
 
 
 
 
 
 
 
 
37
  pipe.unet.set_adapter(adapter_name)
38
  after_params = pipe.unet.parameters()
39
  print("UNet Parameters changed:", any(torch.any(b != a) for b, a in zip(before_params, after_params)))
@@ -141,6 +151,7 @@ def infer(
141
  )
142
 
143
  print(f"Active adapters - UNet: {pipe.unet.active_adapters}, Text Encoder: {pipe.text_encoder.active_adapters if hasattr(pipe, 'text_encoder') else None}")
 
144
  print(f"LoRA scale applied: {lora_scale}")
145
 
146
 
 
20
  dtype=torch.float16,
21
  adapter_name="default"
22
  ):
23
+
24
  unet_sub_dir = os.path.join(ckpt_dir, "unet")
25
  text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
26
 
 
33
 
34
  pipe = DiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype)
35
  before_params = pipe.unet.parameters()
36
+ # pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
37
+ # Исправляем загрузку конфигурации
38
+ config = LoraConfig.from_pretrained(unet_sub_dir)
39
+
40
+ pipe.unet = PeftModel.from_pretrained(
41
+ pipe.unet,
42
+ unet_sub_dir,
43
+ adapter_name=adapter_name,
44
+ config=config # Явно передаем конфигурацию
45
+ )
46
+
47
  pipe.unet.set_adapter(adapter_name)
48
  after_params = pipe.unet.parameters()
49
  print("UNet Parameters changed:", any(torch.any(b != a) for b, a in zip(before_params, after_params)))
 
151
  )
152
 
153
  print(f"Active adapters - UNet: {pipe.unet.active_adapters}, Text Encoder: {pipe.text_encoder.active_adapters if hasattr(pipe, 'text_encoder') else None}")
154
+ print("UNet first layer weights:", pipe.unet.base_model.model[0].weight.data[0,0,:5])
155
  print(f"LoRA scale applied: {lora_scale}")
156
 
157