ResearcherXman commited on
Commit
2192aaf
1 Parent(s): ec7fc1c
Files changed (2) hide show
  1. app.py +1 -2
  2. model_util.py +0 -472
app.py CHANGED
@@ -18,7 +18,6 @@ from insightface.app import FaceAnalysis
18
 
19
  from style_template import styles
20
  from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline, draw_kps
21
- from model_util import load_models_xl, get_torch_device
22
  from controlnet_util import openpose, get_depth_map, get_canny_image
23
 
24
  import gradio as gr
@@ -27,7 +26,7 @@ import spaces
27
 
28
  # global variable
29
  MAX_SEED = np.iinfo(np.int32).max
30
- device = get_torch_device()
31
  dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
32
  STYLE_NAMES = list(styles.keys())
33
  DEFAULT_STYLE_NAME = "Watercolor"
 
18
 
19
  from style_template import styles
20
  from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline, draw_kps
 
21
  from controlnet_util import openpose, get_depth_map, get_canny_image
22
 
23
  import gradio as gr
 
26
 
27
  # global variable
28
  MAX_SEED = np.iinfo(np.int32).max
29
+ device = "cuda" if torch.cuda.is_available() else "cpu"
30
  dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
31
  STYLE_NAMES = list(styles.keys())
32
  DEFAULT_STYLE_NAME = "Watercolor"
model_util.py DELETED
@@ -1,472 +0,0 @@
1
- from typing import Literal, Union, Optional, Tuple, List
2
-
3
- import torch
4
- from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
5
- from diffusers import (
6
- UNet2DConditionModel,
7
- SchedulerMixin,
8
- StableDiffusionPipeline,
9
- StableDiffusionXLPipeline,
10
- AutoencoderKL,
11
- )
12
- from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
13
- convert_ldm_unet_checkpoint,
14
- )
15
- from safetensors.torch import load_file
16
- from diffusers.schedulers import (
17
- DDIMScheduler,
18
- DDPMScheduler,
19
- LMSDiscreteScheduler,
20
- EulerDiscreteScheduler,
21
- EulerAncestralDiscreteScheduler,
22
- UniPCMultistepScheduler,
23
- )
24
-
25
- from omegaconf import OmegaConf
26
-
27
- # DiffUsers版StableDiffusionのモデルパラメータ
28
- NUM_TRAIN_TIMESTEPS = 1000
29
- BETA_START = 0.00085
30
- BETA_END = 0.0120
31
-
32
- UNET_PARAMS_MODEL_CHANNELS = 320
33
- UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
34
- UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
35
- UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
36
- UNET_PARAMS_IN_CHANNELS = 4
37
- UNET_PARAMS_OUT_CHANNELS = 4
38
- UNET_PARAMS_NUM_RES_BLOCKS = 2
39
- UNET_PARAMS_CONTEXT_DIM = 768
40
- UNET_PARAMS_NUM_HEADS = 8
41
- # UNET_PARAMS_USE_LINEAR_PROJECTION = False
42
-
43
- VAE_PARAMS_Z_CHANNELS = 4
44
- VAE_PARAMS_RESOLUTION = 256
45
- VAE_PARAMS_IN_CHANNELS = 3
46
- VAE_PARAMS_OUT_CH = 3
47
- VAE_PARAMS_CH = 128
48
- VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
49
- VAE_PARAMS_NUM_RES_BLOCKS = 2
50
-
51
- # V2
52
- V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
53
- V2_UNET_PARAMS_CONTEXT_DIM = 1024
54
- # V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True
55
-
56
- TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4"
57
- TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1"
58
-
59
- AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a", "euler", "uniPC"]
60
-
61
- SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection]
62
-
63
- DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this
64
-
65
-
66
- def load_checkpoint_with_text_encoder_conversion(ckpt_path: str, device="cpu"):
67
- # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
68
- TEXT_ENCODER_KEY_REPLACEMENTS = [
69
- (
70
- "cond_stage_model.transformer.embeddings.",
71
- "cond_stage_model.transformer.text_model.embeddings.",
72
- ),
73
- (
74
- "cond_stage_model.transformer.encoder.",
75
- "cond_stage_model.transformer.text_model.encoder.",
76
- ),
77
- (
78
- "cond_stage_model.transformer.final_layer_norm.",
79
- "cond_stage_model.transformer.text_model.final_layer_norm.",
80
- ),
81
- ]
82
-
83
- if ckpt_path.endswith(".safetensors"):
84
- checkpoint = None
85
- state_dict = load_file(ckpt_path) # , device) # may causes error
86
- else:
87
- checkpoint = torch.load(ckpt_path, map_location=device)
88
- if "state_dict" in checkpoint:
89
- state_dict = checkpoint["state_dict"]
90
- else:
91
- state_dict = checkpoint
92
- checkpoint = None
93
-
94
- key_reps = []
95
- for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
96
- for key in state_dict.keys():
97
- if key.startswith(rep_from):
98
- new_key = rep_to + key[len(rep_from) :]
99
- key_reps.append((key, new_key))
100
-
101
- for key, new_key in key_reps:
102
- state_dict[new_key] = state_dict[key]
103
- del state_dict[key]
104
-
105
- return checkpoint, state_dict
106
-
107
-
108
- def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False):
109
- """
110
- Creates a config for the diffusers based on the config of the LDM model.
111
- """
112
- # unet_params = original_config.model.params.unet_config.params
113
-
114
- block_out_channels = [
115
- UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT
116
- ]
117
-
118
- down_block_types = []
119
- resolution = 1
120
- for i in range(len(block_out_channels)):
121
- block_type = (
122
- "CrossAttnDownBlock2D"
123
- if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS
124
- else "DownBlock2D"
125
- )
126
- down_block_types.append(block_type)
127
- if i != len(block_out_channels) - 1:
128
- resolution *= 2
129
-
130
- up_block_types = []
131
- for i in range(len(block_out_channels)):
132
- block_type = (
133
- "CrossAttnUpBlock2D"
134
- if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS
135
- else "UpBlock2D"
136
- )
137
- up_block_types.append(block_type)
138
- resolution //= 2
139
-
140
- config = dict(
141
- sample_size=UNET_PARAMS_IMAGE_SIZE,
142
- in_channels=UNET_PARAMS_IN_CHANNELS,
143
- out_channels=UNET_PARAMS_OUT_CHANNELS,
144
- down_block_types=tuple(down_block_types),
145
- up_block_types=tuple(up_block_types),
146
- block_out_channels=tuple(block_out_channels),
147
- layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
148
- cross_attention_dim=UNET_PARAMS_CONTEXT_DIM
149
- if not v2
150
- else V2_UNET_PARAMS_CONTEXT_DIM,
151
- attention_head_dim=UNET_PARAMS_NUM_HEADS
152
- if not v2
153
- else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
154
- # use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION,
155
- )
156
- if v2 and use_linear_projection_in_v2:
157
- config["use_linear_projection"] = True
158
-
159
- return config
160
-
161
-
162
- def load_diffusers_model(
163
- pretrained_model_name_or_path: str,
164
- v2: bool = False,
165
- clip_skip: Optional[int] = None,
166
- weight_dtype: torch.dtype = torch.float32,
167
- ) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
168
- if v2:
169
- tokenizer = CLIPTokenizer.from_pretrained(
170
- TOKENIZER_V2_MODEL_NAME,
171
- subfolder="tokenizer",
172
- torch_dtype=weight_dtype,
173
- cache_dir=DIFFUSERS_CACHE_DIR,
174
- )
175
- text_encoder = CLIPTextModel.from_pretrained(
176
- pretrained_model_name_or_path,
177
- subfolder="text_encoder",
178
- # default is clip skip 2
179
- num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23,
180
- torch_dtype=weight_dtype,
181
- cache_dir=DIFFUSERS_CACHE_DIR,
182
- )
183
- else:
184
- tokenizer = CLIPTokenizer.from_pretrained(
185
- TOKENIZER_V1_MODEL_NAME,
186
- subfolder="tokenizer",
187
- torch_dtype=weight_dtype,
188
- cache_dir=DIFFUSERS_CACHE_DIR,
189
- )
190
- text_encoder = CLIPTextModel.from_pretrained(
191
- pretrained_model_name_or_path,
192
- subfolder="text_encoder",
193
- num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12,
194
- torch_dtype=weight_dtype,
195
- cache_dir=DIFFUSERS_CACHE_DIR,
196
- )
197
-
198
- unet = UNet2DConditionModel.from_pretrained(
199
- pretrained_model_name_or_path,
200
- subfolder="unet",
201
- torch_dtype=weight_dtype,
202
- cache_dir=DIFFUSERS_CACHE_DIR,
203
- )
204
-
205
- vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
206
-
207
- return tokenizer, text_encoder, unet, vae
208
-
209
-
210
- def load_checkpoint_model(
211
- checkpoint_path: str,
212
- v2: bool = False,
213
- clip_skip: Optional[int] = None,
214
- weight_dtype: torch.dtype = torch.float32,
215
- ) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
216
- pipe = StableDiffusionPipeline.from_single_file(
217
- checkpoint_path,
218
- upcast_attention=True if v2 else False,
219
- torch_dtype=weight_dtype,
220
- cache_dir=DIFFUSERS_CACHE_DIR,
221
- )
222
-
223
- _, state_dict = load_checkpoint_with_text_encoder_conversion(checkpoint_path)
224
- unet_config = create_unet_diffusers_config(v2, use_linear_projection_in_v2=v2)
225
- unet_config["class_embed_type"] = None
226
- unet_config["addition_embed_type"] = None
227
- converted_unet_checkpoint = convert_ldm_unet_checkpoint(state_dict, unet_config)
228
- unet = UNet2DConditionModel(**unet_config)
229
- unet.load_state_dict(converted_unet_checkpoint)
230
-
231
- tokenizer = pipe.tokenizer
232
- text_encoder = pipe.text_encoder
233
- vae = pipe.vae
234
- if clip_skip is not None:
235
- if v2:
236
- text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1)
237
- else:
238
- text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1)
239
-
240
- del pipe
241
-
242
- return tokenizer, text_encoder, unet, vae
243
-
244
-
245
- def load_models(
246
- pretrained_model_name_or_path: str,
247
- scheduler_name: str,
248
- v2: bool = False,
249
- v_pred: bool = False,
250
- weight_dtype: torch.dtype = torch.float32,
251
- ) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]:
252
- if pretrained_model_name_or_path.endswith(
253
- ".ckpt"
254
- ) or pretrained_model_name_or_path.endswith(".safetensors"):
255
- tokenizer, text_encoder, unet, vae = load_checkpoint_model(
256
- pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
257
- )
258
- else: # diffusers
259
- tokenizer, text_encoder, unet, vae = load_diffusers_model(
260
- pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
261
- )
262
-
263
- if scheduler_name:
264
- scheduler = create_noise_scheduler(
265
- scheduler_name,
266
- prediction_type="v_prediction" if v_pred else "epsilon",
267
- )
268
- else:
269
- scheduler = None
270
-
271
- return tokenizer, text_encoder, unet, scheduler, vae
272
-
273
-
274
- def load_diffusers_model_xl(
275
- pretrained_model_name_or_path: str,
276
- weight_dtype: torch.dtype = torch.float32,
277
- ) -> Tuple[List[CLIPTokenizer], List[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
278
- # returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet
279
-
280
- tokenizers = [
281
- CLIPTokenizer.from_pretrained(
282
- pretrained_model_name_or_path,
283
- subfolder="tokenizer",
284
- torch_dtype=weight_dtype,
285
- cache_dir=DIFFUSERS_CACHE_DIR,
286
- ),
287
- CLIPTokenizer.from_pretrained(
288
- pretrained_model_name_or_path,
289
- subfolder="tokenizer_2",
290
- torch_dtype=weight_dtype,
291
- cache_dir=DIFFUSERS_CACHE_DIR,
292
- pad_token_id=0, # same as open clip
293
- ),
294
- ]
295
-
296
- text_encoders = [
297
- CLIPTextModel.from_pretrained(
298
- pretrained_model_name_or_path,
299
- subfolder="text_encoder",
300
- torch_dtype=weight_dtype,
301
- cache_dir=DIFFUSERS_CACHE_DIR,
302
- ),
303
- CLIPTextModelWithProjection.from_pretrained(
304
- pretrained_model_name_or_path,
305
- subfolder="text_encoder_2",
306
- torch_dtype=weight_dtype,
307
- cache_dir=DIFFUSERS_CACHE_DIR,
308
- ),
309
- ]
310
-
311
- unet = UNet2DConditionModel.from_pretrained(
312
- pretrained_model_name_or_path,
313
- subfolder="unet",
314
- torch_dtype=weight_dtype,
315
- cache_dir=DIFFUSERS_CACHE_DIR,
316
- )
317
- vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
318
- return tokenizers, text_encoders, unet, vae
319
-
320
-
321
- def load_checkpoint_model_xl(
322
- checkpoint_path: str,
323
- weight_dtype: torch.dtype = torch.float32,
324
- ) -> Tuple[List[CLIPTokenizer], List[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
325
- pipe = StableDiffusionXLPipeline.from_single_file(
326
- checkpoint_path,
327
- torch_dtype=weight_dtype,
328
- cache_dir=DIFFUSERS_CACHE_DIR,
329
- )
330
-
331
- unet = pipe.unet
332
- vae = pipe.vae
333
- tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
334
- text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
335
- if len(text_encoders) == 2:
336
- text_encoders[1].pad_token_id = 0
337
-
338
- del pipe
339
-
340
- return tokenizers, text_encoders, unet, vae
341
-
342
-
343
- def load_models_xl(
344
- pretrained_model_name_or_path: str,
345
- scheduler_name: str,
346
- weight_dtype: torch.dtype = torch.float32,
347
- noise_scheduler_kwargs=None,
348
- ) -> Tuple[
349
- List[CLIPTokenizer],
350
- List[SDXL_TEXT_ENCODER_TYPE],
351
- UNet2DConditionModel,
352
- SchedulerMixin,
353
- ]:
354
- if pretrained_model_name_or_path.endswith(
355
- ".ckpt"
356
- ) or pretrained_model_name_or_path.endswith(".safetensors"):
357
- (tokenizers, text_encoders, unet, vae) = load_checkpoint_model_xl(
358
- pretrained_model_name_or_path, weight_dtype
359
- )
360
- else: # diffusers
361
- (tokenizers, text_encoders, unet, vae) = load_diffusers_model_xl(
362
- pretrained_model_name_or_path, weight_dtype
363
- )
364
- if scheduler_name:
365
- scheduler = create_noise_scheduler(scheduler_name, noise_scheduler_kwargs)
366
- else:
367
- scheduler = None
368
-
369
- return tokenizers, text_encoders, unet, scheduler, vae
370
-
371
- def create_noise_scheduler(
372
- scheduler_name: AVAILABLE_SCHEDULERS = "ddpm",
373
- noise_scheduler_kwargs=None,
374
- prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
375
- ) -> SchedulerMixin:
376
- name = scheduler_name.lower().replace(" ", "_")
377
- if name.lower() == "ddim":
378
- # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim
379
- scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs))
380
- elif name.lower() == "ddpm":
381
- # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm
382
- scheduler = DDPMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs))
383
- elif name.lower() == "lms":
384
- # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete
385
- scheduler = LMSDiscreteScheduler(
386
- **OmegaConf.to_container(noise_scheduler_kwargs)
387
- )
388
- elif name.lower() == "euler_a":
389
- # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral
390
- scheduler = EulerAncestralDiscreteScheduler(
391
- **OmegaConf.to_container(noise_scheduler_kwargs)
392
- )
393
- elif name.lower() == "euler":
394
- # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral
395
- scheduler = EulerDiscreteScheduler(
396
- **OmegaConf.to_container(noise_scheduler_kwargs)
397
- )
398
- elif name.lower() == "unipc":
399
- # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/unipc
400
- scheduler = UniPCMultistepScheduler(
401
- **OmegaConf.to_container(noise_scheduler_kwargs)
402
- )
403
- else:
404
- raise ValueError(f"Unknown scheduler name: {name}")
405
-
406
- return scheduler
407
-
408
-
409
- def torch_gc():
410
- import gc
411
-
412
- gc.collect()
413
- if torch.cuda.is_available():
414
- with torch.cuda.device("cuda"):
415
- torch.cuda.empty_cache()
416
- torch.cuda.ipc_collect()
417
-
418
-
419
- from enum import Enum
420
-
421
-
422
- class CPUState(Enum):
423
- GPU = 0
424
- CPU = 1
425
- MPS = 2
426
-
427
-
428
- cpu_state = CPUState.GPU
429
- xpu_available = False
430
- directml_enabled = False
431
-
432
-
433
- def is_intel_xpu():
434
- global cpu_state
435
- global xpu_available
436
- if cpu_state == CPUState.GPU:
437
- if xpu_available:
438
- return True
439
- return False
440
-
441
-
442
- try:
443
- import intel_extension_for_pytorch as ipex
444
-
445
- if torch.xpu.is_available():
446
- xpu_available = True
447
- except:
448
- pass
449
-
450
- try:
451
- if torch.backends.mps.is_available():
452
- cpu_state = CPUState.MPS
453
- import torch.mps
454
- except:
455
- pass
456
-
457
-
458
- def get_torch_device():
459
- global directml_enabled
460
- global cpu_state
461
- if directml_enabled:
462
- global directml_device
463
- return directml_device
464
- if cpu_state == CPUState.MPS:
465
- return torch.device("mps")
466
- if cpu_state == CPUState.CPU:
467
- return torch.device("cpu")
468
- else:
469
- if is_intel_xpu():
470
- return torch.device("xpu")
471
- else:
472
- return torch.device(torch.cuda.current_device())