SemaSci commited on
Commit
1fd63c5
·
verified ·
1 Parent(s): 823eada

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -32
app.py CHANGED
@@ -11,6 +11,9 @@ import torch
11
  from peft import PeftModel, LoraConfig
12
  import os
13
 
 
 
 
14
  def get_lora_sd_pipeline(
15
  ckpt_dir='./lora_logos',
16
  base_model_name_or_path=None,
@@ -33,14 +36,15 @@ def get_lora_sd_pipeline(
33
  pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
34
  pipe.unet.set_adapter(adapter_name)
35
  after_params = pipe.unet.parameters()
36
- print("Parameters changed:", any(torch.any(b != a) for b, a in zip(before_params, after_params)))
37
 
38
  if os.path.exists(text_encoder_sub_dir):
39
  pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name)
40
 
41
  if dtype in (torch.float16, torch.bfloat16):
42
  pipe.unet.half()
43
- pipe.text_encoder.half()
 
44
 
45
  return pipe
46
 
@@ -94,43 +98,50 @@ def infer(
94
  progress=gr.Progress(track_tqdm=True),
95
  ):
96
 
 
 
97
  if randomize_seed:
98
  seed = random.randint(0, MAX_SEED)
99
 
100
  generator = torch.Generator().manual_seed(seed)
101
 
102
- # убираем обновление pipe всегда
103
- #pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
104
- #pipe = pipe.to(device)
105
-
106
- # добавляем обновление pipe по условию
107
- if model_repo_id != model_id_default:
108
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype).to(device)
109
- prompt_embeds = process_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
110
- negative_prompt_embeds = process_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
111
- prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
 
 
 
 
 
 
112
  else:
113
- # добавляем lora
114
- #pipe = get_lora_sd_pipeline(ckpt_dir='./lora_lady_and_cats_logos', base_model_name_or_path=model_id_default, dtype=torch_dtype).to(device)
115
- pipe = get_lora_sd_pipeline(ckpt_dir='./'+model_lora_id, base_model_name_or_path=model_id_default, dtype=torch_dtype).to(device)
116
- prompt_embeds = process_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
117
- negative_prompt_embeds = process_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
118
- prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
119
- print(f"LoRA adapter loaded: {pipe.unet.active_adapters}")
 
 
 
 
 
 
 
 
 
 
 
 
120
  print(f"LoRA scale applied: {lora_scale}")
121
- pipe.fuse_lora(lora_scale=lora_scale)
122
-
123
-
124
- # заменяем просто вызов pipe с промптом
125
- #image = pipe(
126
- # prompt=prompt,
127
- # negative_prompt=negative_prompt,
128
- # guidance_scale=guidance_scale,
129
- # num_inference_steps=num_inference_steps,
130
- # width=width,
131
- # height=height,
132
- # generator=generator,
133
- #).images[0]
134
 
135
 
136
  # на вызов pipe с эмбеддингами
 
11
  from peft import PeftModel, LoraConfig
12
  import os
13
 
14
+ # Добавляем глобальный кэш для пайплайнов
15
+ pipe_cache = {}
16
+
17
  def get_lora_sd_pipeline(
18
  ckpt_dir='./lora_logos',
19
  base_model_name_or_path=None,
 
36
  pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
37
  pipe.unet.set_adapter(adapter_name)
38
  after_params = pipe.unet.parameters()
39
+ print("UNet Parameters changed:", any(torch.any(b != a) for b, a in zip(before_params, after_params)))
40
 
41
  if os.path.exists(text_encoder_sub_dir):
42
  pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name)
43
 
44
  if dtype in (torch.float16, torch.bfloat16):
45
  pipe.unet.half()
46
+ if pipe.text_encoder is not None:
47
+ pipe.text_encoder.half()
48
 
49
  return pipe
50
 
 
98
  progress=gr.Progress(track_tqdm=True),
99
  ):
100
 
101
+ global pipe_cache
102
+
103
  if randomize_seed:
104
  seed = random.randint(0, MAX_SEED)
105
 
106
  generator = torch.Generator().manual_seed(seed)
107
 
108
+ # Кэширование пайплайнов
109
+ cache_key = f"{model_repo_id}_{model_lora_id}"
110
+ if cache_key not in pipe_cache:
111
+ if model_repo_id != model_id_default:
112
+ pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype).to(device)
113
+ prompt_embeds = process_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
114
+ negative_prompt_embeds = process_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
115
+ prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
116
+ else:
117
+ pipe = get_lora_sd_pipeline(
118
+ ckpt_dir='./'+model_lora_id,
119
+ base_model_name_or_path=model_id_default,
120
+ dtype=torch_dtype
121
+ ).to(device)
122
+
123
+ pipe_cache[cache_key] = pipe
124
  else:
125
+ pipe = pipe_cache[cache_key]
126
+
127
+ # Динамическое применение масштаба LoRA
128
+ if model_repo_id == model_id_default:
129
+ # Убираем fuse_lora()
130
+ # pipe.fuse_lora(lora_scale=lora_scale) # Закомментировали проблемную строку
131
+
132
+ # Вместо этого устанавливаем адаптеры динамически
133
+ pipe.unet.set_adapters(
134
+ [model_lora_id],
135
+ adapter_weights=[lora_scale]
136
+ )
137
+ if hasattr(pipe, 'text_encoder') and pipe.text_encoder is not None:
138
+ pipe.text_encoder.set_adapters(
139
+ [model_lora_id],
140
+ adapter_weights=[lora_scale]
141
+ )
142
+
143
+ print(f"Active adapters - UNet: {pipe.unet.active_adapters}, Text Encoder: {pipe.text_encoder.active_adapters if hasattr(pipe, 'text_encoder') else None}")
144
  print(f"LoRA scale applied: {lora_scale}")
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
 
147
  # на вызов pipe с эмбеддингами