Upload with huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .ipynb_checkpoints/env-checkpoint.py +13 -0
- README.md +6 -6
- app.py +1677 -0
- env.py +13 -0
- ppdiffusers/__init__.py +162 -0
- ppdiffusers/__pycache__/__init__.cpython-37.pyc +0 -0
- ppdiffusers/__pycache__/configuration_utils.cpython-37.pyc +0 -0
- ppdiffusers/__pycache__/download_utils.cpython-37.pyc +0 -0
- ppdiffusers/__pycache__/fastdeploy_utils.cpython-37.pyc +0 -0
- ppdiffusers/__pycache__/initializer.cpython-37.pyc +0 -0
- ppdiffusers/__pycache__/loaders.cpython-37.pyc +0 -0
- ppdiffusers/__pycache__/modeling_utils.cpython-37.pyc +0 -0
- ppdiffusers/__pycache__/optimization.cpython-37.pyc +0 -0
- ppdiffusers/__pycache__/pipeline_utils.cpython-37.pyc +0 -0
- ppdiffusers/__pycache__/ppnlp_patch_utils.cpython-37.pyc +0 -0
- ppdiffusers/__pycache__/training_utils.cpython-37.pyc +0 -0
- ppdiffusers/__pycache__/version.cpython-37.pyc +0 -0
- ppdiffusers/commands/__init__.py +28 -0
- ppdiffusers/commands/env.py +67 -0
- ppdiffusers/commands/ppdiffusers_cli.py +41 -0
- ppdiffusers/configuration_utils.py +591 -0
- ppdiffusers/download_utils.py +44 -0
- ppdiffusers/experimental/README.md +6 -0
- ppdiffusers/experimental/__init__.py +17 -0
- ppdiffusers/experimental/rl/__init__.py +17 -0
- ppdiffusers/experimental/rl/value_guided_sampling.py +146 -0
- ppdiffusers/fastdeploy_utils.py +260 -0
- ppdiffusers/initializer.py +303 -0
- ppdiffusers/loaders.py +190 -0
- ppdiffusers/modeling_paddle_pytorch_utils.py +106 -0
- ppdiffusers/modeling_utils.py +619 -0
- ppdiffusers/models/__init__.py +25 -0
- ppdiffusers/models/__pycache__/__init__.cpython-37.pyc +0 -0
- ppdiffusers/models/__pycache__/attention.cpython-37.pyc +0 -0
- ppdiffusers/models/__pycache__/cross_attention.cpython-37.pyc +0 -0
- ppdiffusers/models/__pycache__/embeddings.cpython-37.pyc +0 -0
- ppdiffusers/models/__pycache__/prior_transformer.cpython-37.pyc +0 -0
- ppdiffusers/models/__pycache__/resnet.cpython-37.pyc +0 -0
- ppdiffusers/models/__pycache__/unet_1d.cpython-37.pyc +0 -0
- ppdiffusers/models/__pycache__/unet_1d_blocks.cpython-37.pyc +0 -0
- ppdiffusers/models/__pycache__/unet_2d.cpython-37.pyc +0 -0
- ppdiffusers/models/__pycache__/unet_2d_blocks.cpython-37.pyc +0 -0
- ppdiffusers/models/__pycache__/unet_2d_condition.cpython-37.pyc +0 -0
- ppdiffusers/models/__pycache__/vae.cpython-37.pyc +0 -0
- ppdiffusers/models/attention.py +683 -0
- ppdiffusers/models/cross_attention.py +435 -0
- ppdiffusers/models/ema.py +103 -0
- ppdiffusers/models/embeddings.py +199 -0
- ppdiffusers/models/prior_transformer.py +220 -0
- ppdiffusers/models/resnet.py +716 -0
.ipynb_checkpoints/env-checkpoint.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
############################################################################################################################
|
2 |
+
# 修改下面的参数
|
3 |
+
# (1)BASE_MODEL_NAME 代表你训练的基础模型
|
4 |
+
BASE_MODEL_NAME = "runwayml/stable-diffusion-v1-5"
|
5 |
+
|
6 |
+
# 是否开启lora
|
7 |
+
# (2)LORA_WEIGHTS_PATH 代码你上传到huggingface后的lora权重。
|
8 |
+
# LORA_WEIGHTS_PATH = None 表示不适应lora
|
9 |
+
LORA_WEIGHTS_PATH = "xianbao/demo_test"
|
10 |
+
|
11 |
+
# (3)PROMPTS 需要展示的prompt文本
|
12 |
+
PROMPTS = "A photo of sks dog in a bucket"
|
13 |
+
############################################################################################################################
|
README.md
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: LoRa ppdiffusers dreambooth
|
3 |
+
emoji: 🎨🎞️
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: purple
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.18.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,1677 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import gradio as gr
|
16 |
+
from env import BASE_MODEL_NAME, LORA_WEIGHTS_PATH, PROMPTS
|
17 |
+
|
18 |
+
examples = [
|
19 |
+
[
|
20 |
+
PROMPTS,
|
21 |
+
'low quality',
|
22 |
+
7.5,
|
23 |
+
512,
|
24 |
+
512,
|
25 |
+
25,
|
26 |
+
"DPMSolver"
|
27 |
+
],
|
28 |
+
]
|
29 |
+
import inspect
|
30 |
+
import os
|
31 |
+
import random
|
32 |
+
import re
|
33 |
+
import time
|
34 |
+
from typing import Callable, List, Optional, Union
|
35 |
+
|
36 |
+
import numpy as np
|
37 |
+
import paddle
|
38 |
+
import PIL
|
39 |
+
import PIL.Image
|
40 |
+
from packaging import version
|
41 |
+
|
42 |
+
from paddlenlp.transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
43 |
+
|
44 |
+
from ppdiffusers.configuration_utils import FrozenDict
|
45 |
+
from ppdiffusers.models import AutoencoderKL, UNet2DConditionModel
|
46 |
+
from ppdiffusers.pipeline_utils import DiffusionPipeline
|
47 |
+
from ppdiffusers.schedulers import (
|
48 |
+
DDIMScheduler,
|
49 |
+
DPMSolverMultistepScheduler,
|
50 |
+
EulerAncestralDiscreteScheduler,
|
51 |
+
EulerDiscreteScheduler,
|
52 |
+
LMSDiscreteScheduler,
|
53 |
+
PNDMScheduler,
|
54 |
+
HeunDiscreteScheduler,
|
55 |
+
KDPM2AncestralDiscreteScheduler,
|
56 |
+
KDPM2DiscreteScheduler,
|
57 |
+
|
58 |
+
)
|
59 |
+
from ppdiffusers.utils import PIL_INTERPOLATION, deprecate, logging
|
60 |
+
from ppdiffusers.utils.testing_utils import load_image
|
61 |
+
from ppdiffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
62 |
+
from ppdiffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
63 |
+
|
64 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
65 |
+
|
66 |
+
|
67 |
+
def save_all(images, FORMAT="jpg", OUTDIR="./outputs/"):
|
68 |
+
if not isinstance(images, (list, tuple)):
|
69 |
+
images = [images]
|
70 |
+
for image in images:
|
71 |
+
PRECISION = "fp32"
|
72 |
+
argument = image.argument
|
73 |
+
os.makedirs(OUTDIR, exist_ok=True)
|
74 |
+
epoch_time = argument["epoch_time"]
|
75 |
+
PROMPT = argument["prompt"]
|
76 |
+
NEGPROMPT = argument["negative_prompt"]
|
77 |
+
HEIGHT = argument["height"]
|
78 |
+
WIDTH = argument["width"]
|
79 |
+
SEED = argument["seed"]
|
80 |
+
STRENGTH = argument.get("strength", 1)
|
81 |
+
INFERENCE_STEPS = argument["num_inference_steps"]
|
82 |
+
GUIDANCE_SCALE = argument["guidance_scale"]
|
83 |
+
|
84 |
+
filename = f"{str(epoch_time)}_scale_{GUIDANCE_SCALE}_steps_{INFERENCE_STEPS}_seed_{SEED}.{FORMAT}"
|
85 |
+
filedir = f"{OUTDIR}/{filename}"
|
86 |
+
image.save(filedir)
|
87 |
+
with open(f"{OUTDIR}/{epoch_time}_prompt.txt", "w") as file:
|
88 |
+
file.write(
|
89 |
+
f"PROMPT: {PROMPT}\nNEG_PROMPT: {NEGPROMPT}\n\nINFERENCE_STEPS: {INFERENCE_STEPS}\nHeight: {HEIGHT}\nWidth: {WIDTH}\nSeed: {SEED}\n\nPrecision: {PRECISION}\nSTRENGTH: {STRENGTH}\nGUIDANCE_SCALE: {GUIDANCE_SCALE}"
|
90 |
+
)
|
91 |
+
|
92 |
+
|
93 |
+
re_attention = re.compile(
|
94 |
+
r"""
|
95 |
+
\\\(|
|
96 |
+
\\\)|
|
97 |
+
\\\[|
|
98 |
+
\\]|
|
99 |
+
\\\\|
|
100 |
+
\\|
|
101 |
+
\(|
|
102 |
+
\[|
|
103 |
+
:([+-]?[.\d]+)\)|
|
104 |
+
\)|
|
105 |
+
]|
|
106 |
+
[^\\()\[\]:]+|
|
107 |
+
:
|
108 |
+
""",
|
109 |
+
re.X,
|
110 |
+
)
|
111 |
+
|
112 |
+
|
113 |
+
def parse_prompt_attention(text):
|
114 |
+
"""
|
115 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
116 |
+
Accepted tokens are:
|
117 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
118 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
119 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
120 |
+
\( - literal character '('
|
121 |
+
\[ - literal character '['
|
122 |
+
\) - literal character ')'
|
123 |
+
\] - literal character ']'
|
124 |
+
\\ - literal character '\'
|
125 |
+
anything else - just text
|
126 |
+
>>> parse_prompt_attention('normal text')
|
127 |
+
[['normal text', 1.0]]
|
128 |
+
>>> parse_prompt_attention('an (important) word')
|
129 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
130 |
+
>>> parse_prompt_attention('(unbalanced')
|
131 |
+
[['unbalanced', 1.1]]
|
132 |
+
>>> parse_prompt_attention('\(literal\]')
|
133 |
+
[['(literal]', 1.0]]
|
134 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
135 |
+
[['unnecessaryparens', 1.1]]
|
136 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
137 |
+
[['a ', 1.0],
|
138 |
+
['house', 1.5730000000000004],
|
139 |
+
[' ', 1.1],
|
140 |
+
['on', 1.0],
|
141 |
+
[' a ', 1.1],
|
142 |
+
['hill', 0.55],
|
143 |
+
[', sun, ', 1.1],
|
144 |
+
['sky', 1.4641000000000006],
|
145 |
+
['.', 1.1]]
|
146 |
+
"""
|
147 |
+
|
148 |
+
res = []
|
149 |
+
round_brackets = []
|
150 |
+
square_brackets = []
|
151 |
+
|
152 |
+
round_bracket_multiplier = 1.1
|
153 |
+
square_bracket_multiplier = 1 / 1.1
|
154 |
+
|
155 |
+
def multiply_range(start_position, multiplier):
|
156 |
+
for p in range(start_position, len(res)):
|
157 |
+
res[p][1] *= multiplier
|
158 |
+
|
159 |
+
for m in re_attention.finditer(text):
|
160 |
+
text = m.group(0)
|
161 |
+
weight = m.group(1)
|
162 |
+
|
163 |
+
if text.startswith("\\"):
|
164 |
+
res.append([text[1:], 1.0])
|
165 |
+
elif text == "(":
|
166 |
+
round_brackets.append(len(res))
|
167 |
+
elif text == "[":
|
168 |
+
square_brackets.append(len(res))
|
169 |
+
elif weight is not None and len(round_brackets) > 0:
|
170 |
+
multiply_range(round_brackets.pop(), float(weight))
|
171 |
+
elif text == ")" and len(round_brackets) > 0:
|
172 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
173 |
+
elif text == "]" and len(square_brackets) > 0:
|
174 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
175 |
+
else:
|
176 |
+
res.append([text, 1.0])
|
177 |
+
|
178 |
+
for pos in round_brackets:
|
179 |
+
multiply_range(pos, round_bracket_multiplier)
|
180 |
+
|
181 |
+
for pos in square_brackets:
|
182 |
+
multiply_range(pos, square_bracket_multiplier)
|
183 |
+
|
184 |
+
if len(res) == 0:
|
185 |
+
res = [["", 1.0]]
|
186 |
+
|
187 |
+
# merge runs of identical weights
|
188 |
+
i = 0
|
189 |
+
while i + 1 < len(res):
|
190 |
+
if res[i][1] == res[i + 1][1]:
|
191 |
+
res[i][0] += res[i + 1][0]
|
192 |
+
res.pop(i + 1)
|
193 |
+
else:
|
194 |
+
i += 1
|
195 |
+
|
196 |
+
return res
|
197 |
+
|
198 |
+
|
199 |
+
def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int):
|
200 |
+
r"""
|
201 |
+
Tokenize a list of prompts and return its tokens with weights of each token.
|
202 |
+
|
203 |
+
No padding, starting or ending token is included.
|
204 |
+
"""
|
205 |
+
tokens = []
|
206 |
+
weights = []
|
207 |
+
for text in prompt:
|
208 |
+
texts_and_weights = parse_prompt_attention(text)
|
209 |
+
text_token = []
|
210 |
+
text_weight = []
|
211 |
+
for word, weight in texts_and_weights:
|
212 |
+
# tokenize and discard the starting and the ending token
|
213 |
+
token = pipe.tokenizer(word).input_ids[1:-1]
|
214 |
+
text_token += token
|
215 |
+
|
216 |
+
# copy the weight by length of token
|
217 |
+
text_weight += [weight] * len(token)
|
218 |
+
|
219 |
+
# stop if the text is too long (longer than truncation limit)
|
220 |
+
if len(text_token) > max_length:
|
221 |
+
break
|
222 |
+
|
223 |
+
# truncate
|
224 |
+
if len(text_token) > max_length:
|
225 |
+
text_token = text_token[:max_length]
|
226 |
+
text_weight = text_weight[:max_length]
|
227 |
+
|
228 |
+
tokens.append(text_token)
|
229 |
+
weights.append(text_weight)
|
230 |
+
return tokens, weights
|
231 |
+
|
232 |
+
|
233 |
+
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
|
234 |
+
r"""
|
235 |
+
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
236 |
+
"""
|
237 |
+
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
238 |
+
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
239 |
+
for i in range(len(tokens)):
|
240 |
+
tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i]))
|
241 |
+
if no_boseos_middle:
|
242 |
+
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
243 |
+
else:
|
244 |
+
w = []
|
245 |
+
if len(weights[i]) == 0:
|
246 |
+
w = [1.0] * weights_length
|
247 |
+
else:
|
248 |
+
for j in range((len(weights[i]) - 1) // chunk_length + 1):
|
249 |
+
w.append(1.0) # weight for starting token in this chunk
|
250 |
+
w += weights[i][j * chunk_length : min(len(weights[i]), (j + 1) * chunk_length)]
|
251 |
+
w.append(1.0) # weight for ending token in this chunk
|
252 |
+
w += [1.0] * (weights_length - len(w))
|
253 |
+
weights[i] = w[:]
|
254 |
+
|
255 |
+
return tokens, weights
|
256 |
+
|
257 |
+
|
258 |
+
def get_unweighted_text_embeddings(
|
259 |
+
pipe: DiffusionPipeline, text_input: paddle.Tensor, chunk_length: int, no_boseos_middle: Optional[bool] = True
|
260 |
+
):
|
261 |
+
"""
|
262 |
+
When the length of tokens is a multiple of the capacity of the text encoder,
|
263 |
+
it should be split into chunks and sent to the text encoder individually.
|
264 |
+
"""
|
265 |
+
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
266 |
+
if max_embeddings_multiples > 1:
|
267 |
+
text_embeddings = []
|
268 |
+
for i in range(max_embeddings_multiples):
|
269 |
+
# extract the i-th chunk
|
270 |
+
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
271 |
+
|
272 |
+
# cover the head and the tail by the starting and the ending tokens
|
273 |
+
text_input_chunk[:, 0] = text_input[0, 0]
|
274 |
+
text_input_chunk[:, -1] = text_input[0, -1]
|
275 |
+
|
276 |
+
text_embedding = pipe.text_encoder(text_input_chunk)[0]
|
277 |
+
|
278 |
+
if no_boseos_middle:
|
279 |
+
if i == 0:
|
280 |
+
# discard the ending token
|
281 |
+
text_embedding = text_embedding[:, :-1]
|
282 |
+
elif i == max_embeddings_multiples - 1:
|
283 |
+
# discard the starting token
|
284 |
+
text_embedding = text_embedding[:, 1:]
|
285 |
+
else:
|
286 |
+
# discard both starting and ending tokens
|
287 |
+
text_embedding = text_embedding[:, 1:-1]
|
288 |
+
|
289 |
+
text_embeddings.append(text_embedding)
|
290 |
+
text_embeddings = paddle.concat(text_embeddings, axis=1)
|
291 |
+
else:
|
292 |
+
text_embeddings = pipe.text_encoder(text_input)[0]
|
293 |
+
return text_embeddings
|
294 |
+
|
295 |
+
|
296 |
+
def get_weighted_text_embeddings(
|
297 |
+
pipe: DiffusionPipeline,
|
298 |
+
prompt: Union[str, List[str]],
|
299 |
+
uncond_prompt: Optional[Union[str, List[str]]] = None,
|
300 |
+
max_embeddings_multiples: Optional[int] = 1,
|
301 |
+
no_boseos_middle: Optional[bool] = False,
|
302 |
+
skip_parsing: Optional[bool] = False,
|
303 |
+
skip_weighting: Optional[bool] = False,
|
304 |
+
**kwargs
|
305 |
+
):
|
306 |
+
r"""
|
307 |
+
Prompts can be assigned with local weights using brackets. For example,
|
308 |
+
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
309 |
+
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
310 |
+
|
311 |
+
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
312 |
+
|
313 |
+
Args:
|
314 |
+
pipe (`DiffusionPipeline`):
|
315 |
+
Pipe to provide access to the tokenizer and the text encoder.
|
316 |
+
prompt (`str` or `List[str]`):
|
317 |
+
The prompt or prompts to guide the image generation.
|
318 |
+
uncond_prompt (`str` or `List[str]`):
|
319 |
+
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
|
320 |
+
is provided, the embeddings of prompt and uncond_prompt are concatenated.
|
321 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `1`):
|
322 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
323 |
+
no_boseos_middle (`bool`, *optional*, defaults to `False`):
|
324 |
+
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
|
325 |
+
ending token in each of the chunk in the middle.
|
326 |
+
skip_parsing (`bool`, *optional*, defaults to `False`):
|
327 |
+
Skip the parsing of brackets.
|
328 |
+
skip_weighting (`bool`, *optional*, defaults to `False`):
|
329 |
+
Skip the weighting. When the parsing is skipped, it is forced True.
|
330 |
+
"""
|
331 |
+
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
332 |
+
if isinstance(prompt, str):
|
333 |
+
prompt = [prompt]
|
334 |
+
|
335 |
+
if not skip_parsing:
|
336 |
+
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
|
337 |
+
if uncond_prompt is not None:
|
338 |
+
if isinstance(uncond_prompt, str):
|
339 |
+
uncond_prompt = [uncond_prompt]
|
340 |
+
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
|
341 |
+
else:
|
342 |
+
prompt_tokens = [
|
343 |
+
token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids
|
344 |
+
]
|
345 |
+
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
|
346 |
+
if uncond_prompt is not None:
|
347 |
+
if isinstance(uncond_prompt, str):
|
348 |
+
uncond_prompt = [uncond_prompt]
|
349 |
+
uncond_tokens = [
|
350 |
+
token[1:-1]
|
351 |
+
for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
|
352 |
+
]
|
353 |
+
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
|
354 |
+
|
355 |
+
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
356 |
+
max_length = max([len(token) for token in prompt_tokens])
|
357 |
+
if uncond_prompt is not None:
|
358 |
+
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
|
359 |
+
|
360 |
+
max_embeddings_multiples = min(
|
361 |
+
max_embeddings_multiples, (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1
|
362 |
+
)
|
363 |
+
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
364 |
+
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
365 |
+
|
366 |
+
# pad the length of tokens and weights
|
367 |
+
# support bert tokenizer
|
368 |
+
bos = pipe.tokenizer.bos_token_id if pipe.tokenizer.bos_token_id is not None else pipe.tokenizer.cls_token_id
|
369 |
+
eos = pipe.tokenizer.eos_token_id if pipe.tokenizer.eos_token_id is not None else pipe.tokenizer.sep_token_id
|
370 |
+
pad = pipe.tokenizer.pad_token_id
|
371 |
+
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
372 |
+
prompt_tokens,
|
373 |
+
prompt_weights,
|
374 |
+
max_length,
|
375 |
+
bos,
|
376 |
+
eos,
|
377 |
+
pad,
|
378 |
+
no_boseos_middle=no_boseos_middle,
|
379 |
+
chunk_length=pipe.tokenizer.model_max_length,
|
380 |
+
)
|
381 |
+
prompt_tokens = paddle.to_tensor(prompt_tokens)
|
382 |
+
if uncond_prompt is not None:
|
383 |
+
uncond_tokens, uncond_weights = pad_tokens_and_weights(
|
384 |
+
uncond_tokens,
|
385 |
+
uncond_weights,
|
386 |
+
max_length,
|
387 |
+
bos,
|
388 |
+
eos,
|
389 |
+
pad,
|
390 |
+
no_boseos_middle=no_boseos_middle,
|
391 |
+
chunk_length=pipe.tokenizer.model_max_length,
|
392 |
+
)
|
393 |
+
uncond_tokens = paddle.to_tensor(uncond_tokens)
|
394 |
+
|
395 |
+
# get the embeddings
|
396 |
+
text_embeddings = get_unweighted_text_embeddings(
|
397 |
+
pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
|
398 |
+
)
|
399 |
+
prompt_weights = paddle.to_tensor(prompt_weights, dtype=text_embeddings.dtype)
|
400 |
+
if uncond_prompt is not None:
|
401 |
+
uncond_embeddings = get_unweighted_text_embeddings(
|
402 |
+
pipe, uncond_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
|
403 |
+
)
|
404 |
+
uncond_weights = paddle.to_tensor(uncond_weights, dtype=uncond_embeddings.dtype)
|
405 |
+
|
406 |
+
# assign weights to the prompts and normalize in the sense of mean
|
407 |
+
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
408 |
+
if (not skip_parsing) and (not skip_weighting):
|
409 |
+
previous_mean = text_embeddings.mean(axis=[-2, -1])
|
410 |
+
text_embeddings *= prompt_weights.unsqueeze(-1)
|
411 |
+
text_embeddings *= previous_mean / text_embeddings.mean(axis=[-2, -1])
|
412 |
+
if uncond_prompt is not None:
|
413 |
+
previous_mean = uncond_embeddings.mean(axis=[-2, -1])
|
414 |
+
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
415 |
+
uncond_embeddings *= previous_mean / uncond_embeddings.mean(axis=[-2, -1])
|
416 |
+
|
417 |
+
# For classifier free guidance, we need to do two forward passes.
|
418 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
419 |
+
# to avoid doing two forward passes
|
420 |
+
if uncond_prompt is not None:
|
421 |
+
text_embeddings = paddle.concat([uncond_embeddings, text_embeddings])
|
422 |
+
|
423 |
+
return text_embeddings
|
424 |
+
|
425 |
+
|
426 |
+
def preprocess_image(image):
|
427 |
+
w, h = image.size
|
428 |
+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
429 |
+
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
430 |
+
image = np.array(image).astype(np.float32) / 255.0
|
431 |
+
image = image[None].transpose(0, 3, 1, 2)
|
432 |
+
image = paddle.to_tensor(image)
|
433 |
+
return 2.0 * image - 1.0
|
434 |
+
|
435 |
+
|
436 |
+
def preprocess_mask(mask):
|
437 |
+
mask = mask.convert("L")
|
438 |
+
w, h = mask.size
|
439 |
+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
440 |
+
mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"])
|
441 |
+
mask = np.array(mask).astype(np.float32) / 255.0
|
442 |
+
mask = np.tile(mask, (4, 1, 1))
|
443 |
+
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
444 |
+
mask = 1 - mask # repaint white, keep black
|
445 |
+
mask = paddle.to_tensor(mask)
|
446 |
+
return mask
|
447 |
+
|
448 |
+
|
449 |
+
class StableDiffusionPipelineAllinOne(DiffusionPipeline):
|
450 |
+
r"""
|
451 |
+
Pipeline for text-to-image image-to-image inpainting generation using Stable Diffusion.
|
452 |
+
|
453 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
454 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular xxxx, etc.)
|
455 |
+
|
456 |
+
Args:
|
457 |
+
vae ([`AutoencoderKL`]):
|
458 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
459 |
+
text_encoder ([`CLIPTextModel`]):
|
460 |
+
Frozen text-encoder. Stable Diffusion uses the text portion of
|
461 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
462 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
463 |
+
tokenizer (`CLIPTokenizer`):
|
464 |
+
Tokenizer of class
|
465 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
466 |
+
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
467 |
+
scheduler ([`SchedulerMixin`]):
|
468 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
469 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], [`PNDMScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`]
|
470 |
+
or [`DPMSolverMultistepScheduler`].
|
471 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
472 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
473 |
+
Please, refer to the [model card](https://huggingface.co/junnyu/stable-diffusion-v1-4-paddle) for details.
|
474 |
+
feature_extractor ([`CLIPFeatureExtractor`]):
|
475 |
+
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
476 |
+
"""
|
477 |
+
_optional_components = ["safety_checker", "feature_extractor"]
|
478 |
+
|
479 |
+
def __init__(
|
480 |
+
self,
|
481 |
+
vae: AutoencoderKL,
|
482 |
+
text_encoder: CLIPTextModel,
|
483 |
+
tokenizer: CLIPTokenizer,
|
484 |
+
unet: UNet2DConditionModel,
|
485 |
+
scheduler: Union[
|
486 |
+
DDIMScheduler,
|
487 |
+
PNDMScheduler,
|
488 |
+
LMSDiscreteScheduler,
|
489 |
+
EulerDiscreteScheduler,
|
490 |
+
EulerAncestralDiscreteScheduler,
|
491 |
+
DPMSolverMultistepScheduler,
|
492 |
+
],
|
493 |
+
safety_checker: StableDiffusionSafetyChecker,
|
494 |
+
feature_extractor: CLIPFeatureExtractor,
|
495 |
+
requires_safety_checker: bool = False,
|
496 |
+
):
|
497 |
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
498 |
+
deprecation_message = (
|
499 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
500 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
501 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
502 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
503 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
504 |
+
" file"
|
505 |
+
)
|
506 |
+
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
507 |
+
new_config = dict(scheduler.config)
|
508 |
+
new_config["steps_offset"] = 1
|
509 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
510 |
+
|
511 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
512 |
+
deprecation_message = (
|
513 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
514 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
515 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
516 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
517 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
518 |
+
)
|
519 |
+
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
520 |
+
new_config = dict(scheduler.config)
|
521 |
+
new_config["clip_sample"] = False
|
522 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
523 |
+
|
524 |
+
if safety_checker is None and requires_safety_checker:
|
525 |
+
logger.warning(
|
526 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
527 |
+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
528 |
+
" results in services or applications open to the public. PaddleNLP team, diffusers team and Hugging Face"
|
529 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
530 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
531 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
532 |
+
)
|
533 |
+
if safety_checker is not None and feature_extractor is None:
|
534 |
+
raise ValueError(
|
535 |
+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
536 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
537 |
+
)
|
538 |
+
is_unet_version_less_0_9_0 = hasattr(unet.config, "_ppdiffusers_version") and version.parse(
|
539 |
+
version.parse(unet.config._ppdiffusers_version).base_version
|
540 |
+
) < version.parse("0.9.0.dev0")
|
541 |
+
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
542 |
+
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
543 |
+
deprecation_message = (
|
544 |
+
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
545 |
+
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
|
546 |
+
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
547 |
+
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
548 |
+
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
549 |
+
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
550 |
+
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
551 |
+
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
552 |
+
" the `unet/config.json` file"
|
553 |
+
)
|
554 |
+
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
555 |
+
new_config = dict(unet.config)
|
556 |
+
new_config["sample_size"] = 64
|
557 |
+
unet._internal_dict = FrozenDict(new_config)
|
558 |
+
|
559 |
+
self.register_modules(
|
560 |
+
vae=vae,
|
561 |
+
text_encoder=text_encoder,
|
562 |
+
tokenizer=tokenizer,
|
563 |
+
unet=unet,
|
564 |
+
scheduler=scheduler,
|
565 |
+
safety_checker=safety_checker,
|
566 |
+
feature_extractor=feature_extractor,
|
567 |
+
)
|
568 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
569 |
+
|
570 |
+
def create_scheduler(self, name="DPMSolver"):
|
571 |
+
config = self.scheduler.config
|
572 |
+
if name == "DPMSolver":
|
573 |
+
return DPMSolverMultistepScheduler.from_config(
|
574 |
+
config,
|
575 |
+
thresholding=False,
|
576 |
+
algorithm_type="dpmsolver++",
|
577 |
+
solver_type="midpoint",
|
578 |
+
lower_order_final=True,
|
579 |
+
)
|
580 |
+
if name == "EulerDiscrete":
|
581 |
+
return EulerDiscreteScheduler.from_config(config)
|
582 |
+
elif name == "EulerAncestralDiscrete":
|
583 |
+
return EulerAncestralDiscreteScheduler.from_config(config)
|
584 |
+
elif name == "PNDM":
|
585 |
+
return PNDMScheduler.from_config(config)
|
586 |
+
elif name == "DDIM":
|
587 |
+
return DDIMScheduler.from_config(config)
|
588 |
+
elif name == "LMSDiscrete":
|
589 |
+
return LMSDiscreteScheduler.from_config(config)
|
590 |
+
elif name == "HeunDiscrete":
|
591 |
+
return HeunDiscreteScheduler.from_config(config)
|
592 |
+
elif name == "KDPM2AncestralDiscrete":
|
593 |
+
return KDPM2AncestralDiscreteScheduler.from_config(config)
|
594 |
+
elif name == "KDPM2Discrete":
|
595 |
+
return KDPM2DiscreteScheduler.from_config(config)
|
596 |
+
else:
|
597 |
+
raise NotImplementedError
|
598 |
+
|
599 |
+
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
600 |
+
r"""
|
601 |
+
Enable sliced attention computation.
|
602 |
+
|
603 |
+
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
604 |
+
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
605 |
+
|
606 |
+
Args:
|
607 |
+
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
608 |
+
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
609 |
+
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
610 |
+
`attention_head_dim` must be a multiple of `slice_size`.
|
611 |
+
"""
|
612 |
+
if slice_size == "auto":
|
613 |
+
if isinstance(self.unet.config.attention_head_dim, int):
|
614 |
+
# half the attention head size is usually a good trade-off between
|
615 |
+
# speed and memory
|
616 |
+
slice_size = self.unet.config.attention_head_dim // 2
|
617 |
+
else:
|
618 |
+
# if `attention_head_dim` is a list, take the smallest head size
|
619 |
+
slice_size = min(self.unet.config.attention_head_dim)
|
620 |
+
self.unet.set_attention_slice(slice_size)
|
621 |
+
|
622 |
+
def disable_attention_slicing(self):
|
623 |
+
r"""
|
624 |
+
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
625 |
+
back to computing attention in one step.
|
626 |
+
"""
|
627 |
+
# set slice_size = `None` to disable `attention slicing`
|
628 |
+
self.enable_attention_slicing(None)
|
629 |
+
|
630 |
+
def __call__(self, *args, **kwargs):
|
631 |
+
return self.text2image(*args, **kwargs)
|
632 |
+
|
633 |
+
def text2img(self, *args, **kwargs):
|
634 |
+
return self.text2image(*args, **kwargs)
|
635 |
+
|
636 |
+
def _encode_prompt(
|
637 |
+
self,
|
638 |
+
prompt,
|
639 |
+
negative_prompt,
|
640 |
+
max_embeddings_multiples,
|
641 |
+
no_boseos_middle,
|
642 |
+
skip_parsing,
|
643 |
+
skip_weighting,
|
644 |
+
do_classifier_free_guidance,
|
645 |
+
num_images_per_prompt,
|
646 |
+
):
|
647 |
+
if do_classifier_free_guidance and negative_prompt is None:
|
648 |
+
negative_prompt = ""
|
649 |
+
text_embeddings = get_weighted_text_embeddings(
|
650 |
+
self, prompt, negative_prompt, max_embeddings_multiples, no_boseos_middle, skip_parsing, skip_weighting
|
651 |
+
)
|
652 |
+
|
653 |
+
bs_embed, seq_len, _ = text_embeddings.shape
|
654 |
+
text_embeddings = text_embeddings.tile([1, num_images_per_prompt, 1])
|
655 |
+
text_embeddings = text_embeddings.reshape([bs_embed * num_images_per_prompt, seq_len, -1])
|
656 |
+
return text_embeddings
|
657 |
+
|
658 |
+
def run_safety_checker(self, image, dtype):
|
659 |
+
if self.safety_checker is not None:
|
660 |
+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pd")
|
661 |
+
image, has_nsfw_concept = self.safety_checker(
|
662 |
+
images=image, clip_input=safety_checker_input.pixel_values.cast(dtype)
|
663 |
+
)
|
664 |
+
else:
|
665 |
+
has_nsfw_concept = None
|
666 |
+
return image, has_nsfw_concept
|
667 |
+
|
668 |
+
def decode_latents(self, latents):
|
669 |
+
latents = 1 / 0.18215 * latents
|
670 |
+
image = self.vae.decode(latents).sample
|
671 |
+
image = (image / 2 + 0.5).clip(0, 1)
|
672 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
673 |
+
image = image.transpose([0, 2, 3, 1]).cast("float32").numpy()
|
674 |
+
return image
|
675 |
+
|
676 |
+
def prepare_extra_step_kwargs(self, eta, scheduler):
|
677 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
678 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
679 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
680 |
+
# and should be between [0, 1]
|
681 |
+
|
682 |
+
accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys())
|
683 |
+
extra_step_kwargs = {}
|
684 |
+
if accepts_eta:
|
685 |
+
extra_step_kwargs["eta"] = eta
|
686 |
+
|
687 |
+
return extra_step_kwargs
|
688 |
+
|
689 |
+
def check_inputs_text2img(self, prompt, height, width, callback_steps):
|
690 |
+
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
691 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
692 |
+
|
693 |
+
if height % 8 != 0 or width % 8 != 0:
|
694 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
695 |
+
|
696 |
+
if (callback_steps is None) or (
|
697 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
698 |
+
):
|
699 |
+
raise ValueError(
|
700 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
701 |
+
f" {type(callback_steps)}."
|
702 |
+
)
|
703 |
+
|
704 |
+
def check_inputs_img2img_inpaint(self, prompt, strength, callback_steps):
|
705 |
+
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
706 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
707 |
+
|
708 |
+
if strength < 0 or strength > 1:
|
709 |
+
raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}")
|
710 |
+
|
711 |
+
if (callback_steps is None) or (
|
712 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
713 |
+
):
|
714 |
+
raise ValueError(
|
715 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
716 |
+
f" {type(callback_steps)}."
|
717 |
+
)
|
718 |
+
|
719 |
+
def prepare_latents_text2img(self, batch_size, num_channels_latents, height, width, dtype, latents=None, scheduler=None):
|
720 |
+
shape = [batch_size, num_channels_latents, height // 8, width // 8]
|
721 |
+
if latents is None:
|
722 |
+
latents = paddle.randn(shape, dtype=dtype)
|
723 |
+
else:
|
724 |
+
if latents.shape != shape:
|
725 |
+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
726 |
+
|
727 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
728 |
+
latents = latents * scheduler.init_noise_sigma
|
729 |
+
return latents
|
730 |
+
|
731 |
+
def prepare_latents_img2img(self, image, timestep, num_images_per_prompt, dtype, scheduler):
|
732 |
+
image = image.cast(dtype=dtype)
|
733 |
+
init_latent_dist = self.vae.encode(image).latent_dist
|
734 |
+
init_latents = init_latent_dist.sample()
|
735 |
+
init_latents = 0.18215 * init_latents
|
736 |
+
|
737 |
+
b, c, h, w = init_latents.shape
|
738 |
+
init_latents = init_latents.tile([1, num_images_per_prompt, 1, 1])
|
739 |
+
init_latents = init_latents.reshape([b * num_images_per_prompt, c, h, w])
|
740 |
+
|
741 |
+
# add noise to latents using the timesteps
|
742 |
+
noise = paddle.randn(init_latents.shape, dtype=dtype)
|
743 |
+
|
744 |
+
# get latents
|
745 |
+
init_latents = scheduler.add_noise(init_latents, noise, timestep)
|
746 |
+
latents = init_latents
|
747 |
+
|
748 |
+
return latents
|
749 |
+
|
750 |
+
def get_timesteps(self, num_inference_steps, strength, scheduler):
|
751 |
+
# get the original timestep using init_timestep
|
752 |
+
offset = scheduler.config.get("steps_offset", 0)
|
753 |
+
init_timestep = int(num_inference_steps * strength) + offset
|
754 |
+
init_timestep = min(init_timestep, num_inference_steps)
|
755 |
+
|
756 |
+
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
757 |
+
timesteps = scheduler.timesteps[t_start:]
|
758 |
+
|
759 |
+
return timesteps, num_inference_steps - t_start
|
760 |
+
|
761 |
+
def prepare_latents_inpaint(self, image, timestep, num_images_per_prompt, dtype, scheduler):
|
762 |
+
image = image.cast(dtype)
|
763 |
+
init_latent_dist = self.vae.encode(image).latent_dist
|
764 |
+
init_latents = init_latent_dist.sample()
|
765 |
+
init_latents = 0.18215 * init_latents
|
766 |
+
|
767 |
+
b, c, h, w = init_latents.shape
|
768 |
+
init_latents = init_latents.tile([1, num_images_per_prompt, 1, 1])
|
769 |
+
init_latents = init_latents.reshape([b * num_images_per_prompt, c, h, w])
|
770 |
+
|
771 |
+
init_latents_orig = init_latents
|
772 |
+
|
773 |
+
# add noise to latents using the timesteps
|
774 |
+
noise = paddle.randn(init_latents.shape, dtype=dtype)
|
775 |
+
init_latents = scheduler.add_noise(init_latents, noise, timestep)
|
776 |
+
latents = init_latents
|
777 |
+
return latents, init_latents_orig, noise
|
778 |
+
|
779 |
+
@paddle.no_grad()
|
780 |
+
def text2image(
|
781 |
+
self,
|
782 |
+
prompt: Union[str, List[str]],
|
783 |
+
height: int = 512,
|
784 |
+
width: int = 512,
|
785 |
+
num_inference_steps: int = 50,
|
786 |
+
guidance_scale: float = 7.5,
|
787 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
788 |
+
num_images_per_prompt: Optional[int] = 1,
|
789 |
+
eta: float = 0.0,
|
790 |
+
seed: Optional[int] = None,
|
791 |
+
latents: Optional[paddle.Tensor] = None,
|
792 |
+
output_type: Optional[str] = "pil",
|
793 |
+
return_dict: bool = True,
|
794 |
+
callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None,
|
795 |
+
callback_steps: Optional[int] = 1,
|
796 |
+
# new add
|
797 |
+
max_embeddings_multiples: Optional[int] = 1,
|
798 |
+
no_boseos_middle: Optional[bool] = False,
|
799 |
+
skip_parsing: Optional[bool] = False,
|
800 |
+
skip_weighting: Optional[bool] = False,
|
801 |
+
scheduler=None,
|
802 |
+
**kwargs,
|
803 |
+
):
|
804 |
+
r"""
|
805 |
+
Function invoked when calling the pipeline for generation.
|
806 |
+
|
807 |
+
Args:
|
808 |
+
prompt (`str` or `List[str]`):
|
809 |
+
The prompt or prompts to guide the image generation.
|
810 |
+
height (`int`, *optional*, defaults to 512):
|
811 |
+
The height in pixels of the generated image.
|
812 |
+
width (`int`, *optional*, defaults to 512):
|
813 |
+
The width in pixels of the generated image.
|
814 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
815 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
816 |
+
expense of slower inference.
|
817 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
818 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
819 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
820 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
821 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
822 |
+
usually at the expense of lower image quality.
|
823 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
824 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
825 |
+
if `guidance_scale` is less than `1`).
|
826 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
827 |
+
The number of images to generate per prompt.
|
828 |
+
eta (`float`, *optional*, defaults to 0.0):
|
829 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
830 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
831 |
+
seed (`int`, *optional*):
|
832 |
+
Random number seed.
|
833 |
+
latents (`paddle.Tensor`, *optional*):
|
834 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
835 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
836 |
+
tensor will ge generated by sampling using the supplied random `seed`.
|
837 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
838 |
+
The output format of the generate image. Choose between
|
839 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
840 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
841 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
842 |
+
plain tuple.
|
843 |
+
callback (`Callable`, *optional*):
|
844 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
845 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: paddle.Tensor)`.
|
846 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
847 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
848 |
+
called at every step.
|
849 |
+
|
850 |
+
Returns:
|
851 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
852 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
853 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
854 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
855 |
+
(nsfw) content, according to the `safety_checker`.
|
856 |
+
"""
|
857 |
+
if scheduler is None:
|
858 |
+
scheduler = self.scheduler
|
859 |
+
seed = random.randint(0, 2**32) if seed is None else seed
|
860 |
+
argument = dict(
|
861 |
+
prompt=prompt,
|
862 |
+
negative_prompt=negative_prompt,
|
863 |
+
height=height,
|
864 |
+
width=width,
|
865 |
+
num_inference_steps=num_inference_steps,
|
866 |
+
guidance_scale=guidance_scale,
|
867 |
+
num_images_per_prompt=num_images_per_prompt,
|
868 |
+
eta=eta,
|
869 |
+
seed=seed,
|
870 |
+
latents=latents,
|
871 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
872 |
+
no_boseos_middle=no_boseos_middle,
|
873 |
+
skip_parsing=skip_parsing,
|
874 |
+
skip_weighting=skip_weighting,
|
875 |
+
epoch_time=time.time(),
|
876 |
+
)
|
877 |
+
paddle.seed(seed)
|
878 |
+
# 1. Check inputs. Raise error if not correct
|
879 |
+
self.check_inputs_text2img(prompt, height, width, callback_steps)
|
880 |
+
|
881 |
+
# 2. Define call parameters
|
882 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
883 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
884 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
885 |
+
# corresponds to doing no classifier free guidance.
|
886 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
887 |
+
|
888 |
+
# 3. Encode input prompt
|
889 |
+
text_embeddings = self._encode_prompt(
|
890 |
+
prompt,
|
891 |
+
negative_prompt,
|
892 |
+
max_embeddings_multiples,
|
893 |
+
no_boseos_middle,
|
894 |
+
skip_parsing,
|
895 |
+
skip_weighting,
|
896 |
+
do_classifier_free_guidance,
|
897 |
+
num_images_per_prompt,
|
898 |
+
)
|
899 |
+
|
900 |
+
# 4. Prepare timesteps
|
901 |
+
scheduler.set_timesteps(num_inference_steps)
|
902 |
+
timesteps = scheduler.timesteps
|
903 |
+
|
904 |
+
# 5. Prepare latent variables
|
905 |
+
num_channels_latents = self.unet.in_channels
|
906 |
+
latents = self.prepare_latents_text2img(
|
907 |
+
batch_size * num_images_per_prompt,
|
908 |
+
num_channels_latents,
|
909 |
+
height,
|
910 |
+
width,
|
911 |
+
text_embeddings.dtype,
|
912 |
+
latents,
|
913 |
+
scheduler=scheduler,
|
914 |
+
)
|
915 |
+
|
916 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
917 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(eta, scheduler)
|
918 |
+
|
919 |
+
# 7. Denoising loop
|
920 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order
|
921 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
922 |
+
for i, t in enumerate(timesteps):
|
923 |
+
# expand the latents if we are doing classifier free guidance
|
924 |
+
latent_model_input = paddle.concat([latents] * 2) if do_classifier_free_guidance else latents
|
925 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
926 |
+
|
927 |
+
# predict the noise residual
|
928 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
929 |
+
|
930 |
+
# perform guidance
|
931 |
+
if do_classifier_free_guidance:
|
932 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
933 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
934 |
+
|
935 |
+
# compute the previous noisy sample x_t -> x_t-1
|
936 |
+
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
937 |
+
|
938 |
+
# call the callback, if provided
|
939 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
940 |
+
progress_bar.update()
|
941 |
+
if callback is not None and i % callback_steps == 0:
|
942 |
+
callback(progress_bar.n, progress_bar.total, progress_bar)
|
943 |
+
|
944 |
+
# 8. Post-processing
|
945 |
+
image = self.decode_latents(latents)
|
946 |
+
|
947 |
+
# 9. Run safety checker
|
948 |
+
image, has_nsfw_concept = self.run_safety_checker(image, text_embeddings.dtype)
|
949 |
+
|
950 |
+
# 10. Convert to PIL
|
951 |
+
if output_type == "pil":
|
952 |
+
image = self.numpy_to_pil(image, argument=argument)
|
953 |
+
|
954 |
+
if not return_dict:
|
955 |
+
return (image, has_nsfw_concept)
|
956 |
+
|
957 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
958 |
+
|
959 |
+
@paddle.no_grad()
|
960 |
+
def img2img(
|
961 |
+
self,
|
962 |
+
prompt: Union[str, List[str]],
|
963 |
+
image: Union[paddle.Tensor, PIL.Image.Image],
|
964 |
+
strength: float = 0.8,
|
965 |
+
height=None,
|
966 |
+
width=None,
|
967 |
+
num_inference_steps: Optional[int] = 50,
|
968 |
+
guidance_scale: Optional[float] = 7.5,
|
969 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
970 |
+
num_images_per_prompt: Optional[int] = 1,
|
971 |
+
eta: Optional[float] = 0.0,
|
972 |
+
seed: Optional[int] = None,
|
973 |
+
output_type: Optional[str] = "pil",
|
974 |
+
return_dict: bool = True,
|
975 |
+
callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None,
|
976 |
+
callback_steps: Optional[int] = 1,
|
977 |
+
# new add
|
978 |
+
max_embeddings_multiples: Optional[int] = 1,
|
979 |
+
no_boseos_middle: Optional[bool] = False,
|
980 |
+
skip_parsing: Optional[bool] = False,
|
981 |
+
skip_weighting: Optional[bool] = False,
|
982 |
+
scheduler=None,
|
983 |
+
**kwargs,
|
984 |
+
):
|
985 |
+
r"""
|
986 |
+
Function invoked when calling the pipeline for generation.
|
987 |
+
|
988 |
+
Args:
|
989 |
+
prompt (`str` or `List[str]`):
|
990 |
+
The prompt or prompts to guide the image generation.
|
991 |
+
image (`paddle.Tensor` or `PIL.Image.Image`):
|
992 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
993 |
+
process.
|
994 |
+
strength (`float`, *optional*, defaults to 0.8):
|
995 |
+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
|
996 |
+
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
|
997 |
+
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
|
998 |
+
noise will be maximum and the denoising process will run for the full number of iterations specified in
|
999 |
+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
1000 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
1001 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
1002 |
+
expense of slower inference. This parameter will be modulated by `strength`.
|
1003 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
1004 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
1005 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
1006 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
1007 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1008 |
+
usually at the expense of lower image quality.
|
1009 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
1010 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
1011 |
+
if `guidance_scale` is less than `1`).
|
1012 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
1013 |
+
The number of images to generate per prompt.
|
1014 |
+
eta (`float`, *optional*, defaults to 0.0):
|
1015 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
1016 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
1017 |
+
seed (`int`, *optional*):
|
1018 |
+
A random seed.
|
1019 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1020 |
+
The output format of the generate image. Choose between
|
1021 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
1022 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1023 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
1024 |
+
plain tuple.
|
1025 |
+
callback (`Callable`, *optional*):
|
1026 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
1027 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: paddle.Tensor)`.
|
1028 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
1029 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
1030 |
+
called at every step.
|
1031 |
+
|
1032 |
+
Returns:
|
1033 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
1034 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
1035 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
1036 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
1037 |
+
(nsfw) content, according to the `safety_checker`.
|
1038 |
+
"""
|
1039 |
+
if scheduler is None:
|
1040 |
+
scheduler = self.scheduler
|
1041 |
+
seed = random.randint(0, 2**32) if seed is None else seed
|
1042 |
+
image_str = image
|
1043 |
+
if isinstance(image_str, str):
|
1044 |
+
image = load_image(image_str)
|
1045 |
+
|
1046 |
+
if height is None and width is None:
|
1047 |
+
width = (image.size[0] // 8) * 8
|
1048 |
+
height = (image.size[1] // 8) * 8
|
1049 |
+
elif height is None and width is not None:
|
1050 |
+
height = (image.size[1] // 8) * 8
|
1051 |
+
elif width is None and height is not None:
|
1052 |
+
width = (image.size[0] // 8) * 8
|
1053 |
+
else:
|
1054 |
+
height = height
|
1055 |
+
width = width
|
1056 |
+
|
1057 |
+
argument = dict(
|
1058 |
+
prompt=prompt,
|
1059 |
+
image=image_str,
|
1060 |
+
negative_prompt=negative_prompt,
|
1061 |
+
height=height,
|
1062 |
+
width=width,
|
1063 |
+
strength=strength,
|
1064 |
+
num_inference_steps=num_inference_steps,
|
1065 |
+
guidance_scale=guidance_scale,
|
1066 |
+
num_images_per_prompt=num_images_per_prompt,
|
1067 |
+
eta=eta,
|
1068 |
+
seed=seed,
|
1069 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
1070 |
+
no_boseos_middle=no_boseos_middle,
|
1071 |
+
skip_parsing=skip_parsing,
|
1072 |
+
skip_weighting=skip_weighting,
|
1073 |
+
epoch_time=time.time(),
|
1074 |
+
)
|
1075 |
+
paddle.seed(seed)
|
1076 |
+
|
1077 |
+
# 1. Check inputs
|
1078 |
+
self.check_inputs_img2img_inpaint(prompt, strength, callback_steps)
|
1079 |
+
|
1080 |
+
# 2. Define call parameters
|
1081 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
1082 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
1083 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
1084 |
+
# corresponds to doing no classifier free guidance.
|
1085 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
1086 |
+
|
1087 |
+
# 3. Encode input prompt
|
1088 |
+
text_embeddings = self._encode_prompt(
|
1089 |
+
prompt,
|
1090 |
+
negative_prompt,
|
1091 |
+
max_embeddings_multiples,
|
1092 |
+
no_boseos_middle,
|
1093 |
+
skip_parsing,
|
1094 |
+
skip_weighting,
|
1095 |
+
do_classifier_free_guidance,
|
1096 |
+
num_images_per_prompt,
|
1097 |
+
)
|
1098 |
+
|
1099 |
+
# 4. Preprocess image
|
1100 |
+
if isinstance(image, PIL.Image.Image):
|
1101 |
+
image = image.resize((width, height))
|
1102 |
+
image = preprocess_image(image)
|
1103 |
+
|
1104 |
+
# 5. set timesteps
|
1105 |
+
scheduler.set_timesteps(num_inference_steps)
|
1106 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, scheduler)
|
1107 |
+
latent_timestep = timesteps[:1].tile([batch_size * num_images_per_prompt])
|
1108 |
+
|
1109 |
+
# 6. Prepare latent variables
|
1110 |
+
latents = self.prepare_latents_img2img(image, latent_timestep, num_images_per_prompt, text_embeddings.dtype, scheduler)
|
1111 |
+
|
1112 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
1113 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(eta, scheduler)
|
1114 |
+
|
1115 |
+
# 8. Denoising loop
|
1116 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order
|
1117 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1118 |
+
for i, t in enumerate(timesteps):
|
1119 |
+
# expand the latents if we are doing classifier free guidance
|
1120 |
+
latent_model_input = paddle.concat([latents] * 2) if do_classifier_free_guidance else latents
|
1121 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
1122 |
+
|
1123 |
+
# predict the noise residual
|
1124 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
1125 |
+
|
1126 |
+
# perform guidance
|
1127 |
+
if do_classifier_free_guidance:
|
1128 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1129 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1130 |
+
|
1131 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1132 |
+
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
1133 |
+
|
1134 |
+
# call the callback, if provided
|
1135 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
1136 |
+
progress_bar.update()
|
1137 |
+
if callback is not None and i % callback_steps == 0:
|
1138 |
+
callback(progress_bar.n, progress_bar.total, progress_bar)
|
1139 |
+
|
1140 |
+
# 9. Post-processing
|
1141 |
+
image = self.decode_latents(latents)
|
1142 |
+
|
1143 |
+
# 10. Run safety checker
|
1144 |
+
image, has_nsfw_concept = self.run_safety_checker(image, text_embeddings.dtype)
|
1145 |
+
|
1146 |
+
# 11. Convert to PIL
|
1147 |
+
if output_type == "pil":
|
1148 |
+
image = self.numpy_to_pil(image, argument=argument)
|
1149 |
+
|
1150 |
+
if not return_dict:
|
1151 |
+
return (image, has_nsfw_concept)
|
1152 |
+
|
1153 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
1154 |
+
|
1155 |
+
@paddle.no_grad()
|
1156 |
+
def inpaint(
|
1157 |
+
self,
|
1158 |
+
prompt: Union[str, List[str]],
|
1159 |
+
image: Union[paddle.Tensor, PIL.Image.Image],
|
1160 |
+
mask_image: Union[paddle.Tensor, PIL.Image.Image],
|
1161 |
+
height=None,
|
1162 |
+
width=None,
|
1163 |
+
strength: float = 0.8,
|
1164 |
+
num_inference_steps: Optional[int] = 50,
|
1165 |
+
guidance_scale: Optional[float] = 7.5,
|
1166 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
1167 |
+
num_images_per_prompt: Optional[int] = 1,
|
1168 |
+
eta: Optional[float] = 0.0,
|
1169 |
+
seed: Optional[int] = None,
|
1170 |
+
output_type: Optional[str] = "pil",
|
1171 |
+
return_dict: bool = True,
|
1172 |
+
callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None,
|
1173 |
+
callback_steps: Optional[int] = 1,
|
1174 |
+
# new add
|
1175 |
+
max_embeddings_multiples: Optional[int] = 1,
|
1176 |
+
no_boseos_middle: Optional[bool] = False,
|
1177 |
+
skip_parsing: Optional[bool] = False,
|
1178 |
+
skip_weighting: Optional[bool] = False,
|
1179 |
+
scheduler=None,
|
1180 |
+
**kwargs,
|
1181 |
+
):
|
1182 |
+
r"""
|
1183 |
+
Function invoked when calling the pipeline for generation.
|
1184 |
+
|
1185 |
+
Args:
|
1186 |
+
prompt (`str` or `List[str]`):
|
1187 |
+
The prompt or prompts to guide the image generation.
|
1188 |
+
image (`paddle.Tensor` or `PIL.Image.Image`):
|
1189 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
1190 |
+
process. This is the image whose masked region will be inpainted.
|
1191 |
+
mask_image (`paddle.Tensor` or `PIL.Image.Image`):
|
1192 |
+
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
1193 |
+
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
|
1194 |
+
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
|
1195 |
+
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
1196 |
+
strength (`float`, *optional*, defaults to 0.8):
|
1197 |
+
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
|
1198 |
+
is 1, the denoising process will be run on the masked area for the full number of iterations specified
|
1199 |
+
in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
|
1200 |
+
noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
|
1201 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
1202 |
+
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
|
1203 |
+
the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
|
1204 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
1205 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
1206 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
1207 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
1208 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1209 |
+
usually at the expense of lower image quality.
|
1210 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
1211 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
1212 |
+
if `guidance_scale` is less than `1`).
|
1213 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
1214 |
+
The number of images to generate per prompt.
|
1215 |
+
eta (`float`, *optional*, defaults to 0.0):
|
1216 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
1217 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
1218 |
+
seed (`int`, *optional*):
|
1219 |
+
A random seed.
|
1220 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1221 |
+
The output format of the generate image. Choose between
|
1222 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
1223 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1224 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
1225 |
+
plain tuple.
|
1226 |
+
callback (`Callable`, *optional*):
|
1227 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
1228 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: paddle.Tensor)`.
|
1229 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
1230 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
1231 |
+
called at every step.
|
1232 |
+
|
1233 |
+
Returns:
|
1234 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
1235 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
1236 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
1237 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
1238 |
+
(nsfw) content, according to the `safety_checker`.
|
1239 |
+
"""
|
1240 |
+
if scheduler is None:
|
1241 |
+
scheduler = self.scheduler
|
1242 |
+
seed = random.randint(0, 2**32) if seed is None else seed
|
1243 |
+
image_str = image
|
1244 |
+
mask_image_str = mask_image
|
1245 |
+
|
1246 |
+
if isinstance(image_str, str):
|
1247 |
+
image = load_image(image_str)
|
1248 |
+
if isinstance(mask_image_str, str):
|
1249 |
+
mask_image = load_image(mask_image_str)
|
1250 |
+
|
1251 |
+
if height is None and width is None:
|
1252 |
+
width = (image.size[0] // 8) * 8
|
1253 |
+
height = (image.size[1] // 8) * 8
|
1254 |
+
elif height is None and width is not None:
|
1255 |
+
height = (image.size[1] // 8) * 8
|
1256 |
+
elif width is None and height is not None:
|
1257 |
+
width = (image.size[0] // 8) * 8
|
1258 |
+
else:
|
1259 |
+
height = height
|
1260 |
+
width = width
|
1261 |
+
|
1262 |
+
argument = dict(
|
1263 |
+
prompt=prompt,
|
1264 |
+
image=image_str,
|
1265 |
+
mask_image=mask_image_str,
|
1266 |
+
negative_prompt=negative_prompt,
|
1267 |
+
height=height,
|
1268 |
+
width=width,
|
1269 |
+
strength=strength,
|
1270 |
+
num_inference_steps=num_inference_steps,
|
1271 |
+
guidance_scale=guidance_scale,
|
1272 |
+
num_images_per_prompt=num_images_per_prompt,
|
1273 |
+
eta=eta,
|
1274 |
+
seed=seed,
|
1275 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
1276 |
+
no_boseos_middle=no_boseos_middle,
|
1277 |
+
skip_parsing=skip_parsing,
|
1278 |
+
skip_weighting=skip_weighting,
|
1279 |
+
epoch_time=time.time(),
|
1280 |
+
)
|
1281 |
+
paddle.seed(seed)
|
1282 |
+
|
1283 |
+
# 1. Check inputs
|
1284 |
+
self.check_inputs_img2img_inpaint(prompt, strength, callback_steps)
|
1285 |
+
|
1286 |
+
# 2. Define call parameters
|
1287 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
1288 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
1289 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
1290 |
+
# corresponds to doing no classifier free guidance.
|
1291 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
1292 |
+
|
1293 |
+
# 3. Encode input prompt
|
1294 |
+
text_embeddings = self._encode_prompt(
|
1295 |
+
prompt,
|
1296 |
+
negative_prompt,
|
1297 |
+
max_embeddings_multiples,
|
1298 |
+
no_boseos_middle,
|
1299 |
+
skip_parsing,
|
1300 |
+
skip_weighting,
|
1301 |
+
do_classifier_free_guidance,
|
1302 |
+
num_images_per_prompt,
|
1303 |
+
)
|
1304 |
+
|
1305 |
+
if not isinstance(image, paddle.Tensor):
|
1306 |
+
image = image.resize((width, height))
|
1307 |
+
image = preprocess_image(image)
|
1308 |
+
|
1309 |
+
if not isinstance(mask_image, paddle.Tensor):
|
1310 |
+
mask_image = mask_image.resize((width, height))
|
1311 |
+
mask_image = preprocess_mask(mask_image)
|
1312 |
+
|
1313 |
+
# 5. set timesteps
|
1314 |
+
scheduler.set_timesteps(num_inference_steps)
|
1315 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, scheduler)
|
1316 |
+
latent_timestep = timesteps[:1].tile([batch_size * num_images_per_prompt])
|
1317 |
+
|
1318 |
+
# 6. Prepare latent variables
|
1319 |
+
# encode the init image into latents and scale the latents
|
1320 |
+
latents, init_latents_orig, noise = self.prepare_latents_inpaint(
|
1321 |
+
image, latent_timestep, num_images_per_prompt, text_embeddings.dtype, scheduler
|
1322 |
+
)
|
1323 |
+
|
1324 |
+
# 7. Prepare mask latent
|
1325 |
+
mask = mask_image.cast(latents.dtype)
|
1326 |
+
mask = paddle.concat([mask] * batch_size * num_images_per_prompt)
|
1327 |
+
|
1328 |
+
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
1329 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(eta, scheduler)
|
1330 |
+
|
1331 |
+
# 9. Denoising loop
|
1332 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order
|
1333 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1334 |
+
for i, t in enumerate(timesteps):
|
1335 |
+
# expand the latents if we are doing classifier free guidance
|
1336 |
+
latent_model_input = paddle.concat([latents] * 2) if do_classifier_free_guidance else latents
|
1337 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
1338 |
+
|
1339 |
+
# predict the noise residual
|
1340 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
1341 |
+
|
1342 |
+
# perform guidance
|
1343 |
+
if do_classifier_free_guidance:
|
1344 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1345 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1346 |
+
|
1347 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1348 |
+
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
1349 |
+
# masking
|
1350 |
+
init_latents_proper = scheduler.add_noise(init_latents_orig, noise, t)
|
1351 |
+
|
1352 |
+
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
1353 |
+
|
1354 |
+
# call the callback, if provided
|
1355 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
1356 |
+
progress_bar.update()
|
1357 |
+
if callback is not None and i % callback_steps == 0:
|
1358 |
+
callback(progress_bar.n, progress_bar.total, progress_bar)
|
1359 |
+
|
1360 |
+
# 10. Post-processing
|
1361 |
+
image = self.decode_latents(latents)
|
1362 |
+
|
1363 |
+
# 11. Run safety checker
|
1364 |
+
image, has_nsfw_concept = self.run_safety_checker(image, text_embeddings.dtype)
|
1365 |
+
|
1366 |
+
# 12. Convert to PIL
|
1367 |
+
if output_type == "pil":
|
1368 |
+
image = self.numpy_to_pil(image, argument=argument)
|
1369 |
+
|
1370 |
+
if not return_dict:
|
1371 |
+
return (image, has_nsfw_concept)
|
1372 |
+
|
1373 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
1374 |
+
|
1375 |
+
@staticmethod
|
1376 |
+
def numpy_to_pil(images, **kwargs):
|
1377 |
+
"""
|
1378 |
+
Convert a numpy image or a batch of images to a PIL image.
|
1379 |
+
"""
|
1380 |
+
if images.ndim == 3:
|
1381 |
+
images = images[None, ...]
|
1382 |
+
images = (images * 255).round().astype("uint8")
|
1383 |
+
pil_images = []
|
1384 |
+
argument = kwargs.pop("argument", None)
|
1385 |
+
for image in images:
|
1386 |
+
image = PIL.Image.fromarray(image)
|
1387 |
+
if argument is not None:
|
1388 |
+
image.argument = argument
|
1389 |
+
pil_images.append(image)
|
1390 |
+
|
1391 |
+
return pil_images
|
1392 |
+
pipeline = StableDiffusionPipelineAllinOne.from_pretrained(BASE_MODEL_NAME, safety_checker=None)
|
1393 |
+
|
1394 |
+
if LORA_WEIGHTS_PATH is not None:
|
1395 |
+
pipeline.unet.load_attn_procs(LORA_WEIGHTS_PATH, from_hf_hub=True)
|
1396 |
+
|
1397 |
+
support_scheduler = [
|
1398 |
+
"DPMSolver",
|
1399 |
+
"EulerDiscrete",
|
1400 |
+
"EulerAncestralDiscrete",
|
1401 |
+
"PNDM",
|
1402 |
+
"DDIM",
|
1403 |
+
"LMSDiscrete",
|
1404 |
+
"HeunDiscrete",
|
1405 |
+
"KDPM2AncestralDiscrete",
|
1406 |
+
"KDPM2Discrete"
|
1407 |
+
]
|
1408 |
+
|
1409 |
+
# generate images
|
1410 |
+
def infer(prompt, negative, scale, height, width, num_inference_steps, scheduler_name):
|
1411 |
+
scheduler = pipeline.create_scheduler(scheduler_name)
|
1412 |
+
|
1413 |
+
images = pipeline(
|
1414 |
+
prompt=prompt, negative_prompt=negative, guidance_scale=scale, height=height, width=width, num_inference_steps=num_inference_steps, scheduler=scheduler,
|
1415 |
+
).images
|
1416 |
+
return images
|
1417 |
+
|
1418 |
+
|
1419 |
+
css = """
|
1420 |
+
.gradio-container {
|
1421 |
+
font-family: 'IBM Plex Sans', sans-serif;
|
1422 |
+
}
|
1423 |
+
.gr-button {
|
1424 |
+
color: white;
|
1425 |
+
border-color: black;
|
1426 |
+
background: black;
|
1427 |
+
}
|
1428 |
+
input[type='range'] {
|
1429 |
+
accent-color: black;
|
1430 |
+
}
|
1431 |
+
.dark input[type='range'] {
|
1432 |
+
accent-color: #dfdfdf;
|
1433 |
+
}
|
1434 |
+
.container {
|
1435 |
+
max-width: 730px;
|
1436 |
+
margin: auto;
|
1437 |
+
padding-top: 1.5rem;
|
1438 |
+
}
|
1439 |
+
#gallery {
|
1440 |
+
min-height: 22rem;
|
1441 |
+
margin-bottom: 15px;
|
1442 |
+
margin-left: auto;
|
1443 |
+
margin-right: auto;
|
1444 |
+
border-bottom-right-radius: .5rem !important;
|
1445 |
+
border-bottom-left-radius: .5rem !important;
|
1446 |
+
}
|
1447 |
+
#gallery>div>.h-full {
|
1448 |
+
min-height: 20rem;
|
1449 |
+
}
|
1450 |
+
.details:hover {
|
1451 |
+
text-decoration: underline;
|
1452 |
+
}
|
1453 |
+
.gr-button {
|
1454 |
+
white-space: nowrap;
|
1455 |
+
}
|
1456 |
+
.gr-button:focus {
|
1457 |
+
border-color: rgb(147 197 253 / var(--tw-border-opacity));
|
1458 |
+
outline: none;
|
1459 |
+
box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
|
1460 |
+
--tw-border-opacity: 1;
|
1461 |
+
--tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
|
1462 |
+
--tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
|
1463 |
+
--tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
|
1464 |
+
--tw-ring-opacity: .5;
|
1465 |
+
}
|
1466 |
+
#advanced-btn {
|
1467 |
+
font-size: .7rem !important;
|
1468 |
+
line-height: 19px;
|
1469 |
+
margin-top: 12px;
|
1470 |
+
margin-bottom: 12px;
|
1471 |
+
padding: 2px 8px;
|
1472 |
+
border-radius: 14px !important;
|
1473 |
+
}
|
1474 |
+
#advanced-options {
|
1475 |
+
display: none;
|
1476 |
+
margin-bottom: 20px;
|
1477 |
+
}
|
1478 |
+
.footer {
|
1479 |
+
margin-bottom: 45px;
|
1480 |
+
margin-top: 35px;
|
1481 |
+
text-align: center;
|
1482 |
+
border-bottom: 1px solid #e5e5e5;
|
1483 |
+
}
|
1484 |
+
.footer>p {
|
1485 |
+
font-size: .8rem;
|
1486 |
+
display: inline-block;
|
1487 |
+
padding: 0 10px;
|
1488 |
+
transform: translateY(10px);
|
1489 |
+
background: white;
|
1490 |
+
}
|
1491 |
+
.dark .footer {
|
1492 |
+
border-color: #303030;
|
1493 |
+
}
|
1494 |
+
.dark .footer>p {
|
1495 |
+
background: #0b0f19;
|
1496 |
+
}
|
1497 |
+
.acknowledgments h4{
|
1498 |
+
margin: 1.25em 0 .25em 0;
|
1499 |
+
font-weight: bold;
|
1500 |
+
font-size: 115%;
|
1501 |
+
}
|
1502 |
+
.animate-spin {
|
1503 |
+
animation: spin 1s linear infinite;
|
1504 |
+
}
|
1505 |
+
@keyframes spin {
|
1506 |
+
from {
|
1507 |
+
transform: rotate(0deg);
|
1508 |
+
}
|
1509 |
+
to {
|
1510 |
+
transform: rotate(360deg);
|
1511 |
+
}
|
1512 |
+
}
|
1513 |
+
#share-btn-container {
|
1514 |
+
display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
|
1515 |
+
margin-top: 10px;
|
1516 |
+
margin-left: auto;
|
1517 |
+
}
|
1518 |
+
#share-btn {
|
1519 |
+
all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0;
|
1520 |
+
}
|
1521 |
+
#share-btn * {
|
1522 |
+
all: unset;
|
1523 |
+
}
|
1524 |
+
#share-btn-container div:nth-child(-n+2){
|
1525 |
+
width: auto !important;
|
1526 |
+
min-height: 0px !important;
|
1527 |
+
}
|
1528 |
+
#share-btn-container .wrap {
|
1529 |
+
display: none !important;
|
1530 |
+
}
|
1531 |
+
|
1532 |
+
.gr-form{
|
1533 |
+
flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0;
|
1534 |
+
}
|
1535 |
+
#prompt-container{
|
1536 |
+
gap: 0;
|
1537 |
+
}
|
1538 |
+
#prompt-text-input, #negative-prompt-text-input{padding: .45rem 0.625rem}
|
1539 |
+
#component-16{border-top-width: 1px!important;margin-top: 1em}
|
1540 |
+
.image_duplication{position: absolute; width: 100px; left: 50px}
|
1541 |
+
"""
|
1542 |
+
|
1543 |
+
block = gr.Blocks(css=css)
|
1544 |
+
|
1545 |
+
with block:
|
1546 |
+
gr.HTML(
|
1547 |
+
"""
|
1548 |
+
<div style="text-align: center; margin: 0 auto;">
|
1549 |
+
<div
|
1550 |
+
style="
|
1551 |
+
display: inline-flex;
|
1552 |
+
align-items: center;
|
1553 |
+
gap: 0.8rem;
|
1554 |
+
font-size: 1.75rem;
|
1555 |
+
"
|
1556 |
+
>
|
1557 |
+
<svg
|
1558 |
+
width="0.65em"
|
1559 |
+
height="0.65em"
|
1560 |
+
viewBox="0 0 115 115"
|
1561 |
+
fill="none"
|
1562 |
+
xmlns="http://www.w3.org/2000/svg"
|
1563 |
+
>
|
1564 |
+
<rect width="23" height="23" fill="white"></rect>
|
1565 |
+
<rect y="69" width="23" height="23" fill="white"></rect>
|
1566 |
+
<rect x="23" width="23" height="23" fill="#AEAEAE"></rect>
|
1567 |
+
<rect x="23" y="69" width="23" height="23" fill="#AEAEAE"></rect>
|
1568 |
+
<rect x="46" width="23" height="23" fill="white"></rect>
|
1569 |
+
<rect x="46" y="69" width="23" height="23" fill="white"></rect>
|
1570 |
+
<rect x="69" width="23" height="23" fill="black"></rect>
|
1571 |
+
<rect x="69" y="69" width="23" height="23" fill="black"></rect>
|
1572 |
+
<rect x="92" width="23" height="23" fill="#D9D9D9"></rect>
|
1573 |
+
<rect x="92" y="69" width="23" height="23" fill="#AEAEAE"></rect>
|
1574 |
+
<rect x="115" y="46" width="23" height="23" fill="white"></rect>
|
1575 |
+
<rect x="115" y="115" width="23" height="23" fill="white"></rect>
|
1576 |
+
<rect x="115" y="69" width="23" height="23" fill="#D9D9D9"></rect>
|
1577 |
+
<rect x="92" y="46" width="23" height="23" fill="#AEAEAE"></rect>
|
1578 |
+
<rect x="92" y="115" width="23" height="23" fill="#AEAEAE"></rect>
|
1579 |
+
<rect x="92" y="69" width="23" height="23" fill="white"></rect>
|
1580 |
+
<rect x="69" y="46" width="23" height="23" fill="white"></rect>
|
1581 |
+
<rect x="69" y="115" width="23" height="23" fill="white"></rect>
|
1582 |
+
<rect x="69" y="69" width="23" height="23" fill="#D9D9D9"></rect>
|
1583 |
+
<rect x="46" y="46" width="23" height="23" fill="black"></rect>
|
1584 |
+
<rect x="46" y="115" width="23" height="23" fill="black"></rect>
|
1585 |
+
<rect x="46" y="69" width="23" height="23" fill="black"></rect>
|
1586 |
+
<rect x="23" y="46" width="23" height="23" fill="#D9D9D9"></rect>
|
1587 |
+
<rect x="23" y="115" width="23" height="23" fill="#AEAEAE"></rect>
|
1588 |
+
<rect x="23" y="69" width="23" height="23" fill="black"></rect>
|
1589 |
+
</svg>
|
1590 |
+
<h1 style="font-weight: 900; margin-bottom: 7px;margin-top:5px">
|
1591 |
+
Dreambooth LoRa Demo
|
1592 |
+
</h1>
|
1593 |
+
</div>
|
1594 |
+
</div>
|
1595 |
+
"""
|
1596 |
+
)
|
1597 |
+
with gr.Group():
|
1598 |
+
with gr.Box():
|
1599 |
+
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
|
1600 |
+
with gr.Column():
|
1601 |
+
text = gr.Textbox(
|
1602 |
+
label="Enter your prompt",
|
1603 |
+
value=PROMPTS,
|
1604 |
+
show_label=False,
|
1605 |
+
max_lines=1,
|
1606 |
+
placeholder="Enter your prompt",
|
1607 |
+
elem_id="prompt-text-input",
|
1608 |
+
).style(
|
1609 |
+
border=(True, False, True, True),
|
1610 |
+
rounded=(True, False, False, True),
|
1611 |
+
container=False,
|
1612 |
+
)
|
1613 |
+
negative = gr.Textbox(
|
1614 |
+
label="Enter your negative prompt",
|
1615 |
+
show_label=False,
|
1616 |
+
max_lines=1,
|
1617 |
+
placeholder="Enter a negative prompt",
|
1618 |
+
elem_id="negative-prompt-text-input",
|
1619 |
+
).style(
|
1620 |
+
border=(True, False, True, True),
|
1621 |
+
rounded=(True, False, False, True),
|
1622 |
+
container=False,
|
1623 |
+
)
|
1624 |
+
btn = gr.Button("Generate image").style(
|
1625 |
+
margin=False,
|
1626 |
+
rounded=(False, True, True, False),
|
1627 |
+
full_width=False,
|
1628 |
+
)
|
1629 |
+
|
1630 |
+
gallery = gr.Gallery(
|
1631 |
+
label="Generated images", show_label=False, elem_id="gallery"
|
1632 |
+
).style(grid=[1], height="auto")
|
1633 |
+
|
1634 |
+
|
1635 |
+
with gr.Accordion("Advanced settings", open=False):
|
1636 |
+
scheduler_name = gr.Dropdown(
|
1637 |
+
label="scheduler_name", choices=support_scheduler, value="DPMSolver"
|
1638 |
+
)
|
1639 |
+
guidance_scale = gr.Slider(
|
1640 |
+
label="Guidance Scale", minimum=1, maximum=30, value=7.5, step=0.1
|
1641 |
+
)
|
1642 |
+
height = gr.Slider(
|
1643 |
+
label="Height", minimum=256, maximum=1024, value=512, step=8
|
1644 |
+
)
|
1645 |
+
width = gr.Slider(
|
1646 |
+
label="Width", minimum=256, maximum=1024, value=512, step=0.1
|
1647 |
+
)
|
1648 |
+
num_inference_steps = gr.Slider(
|
1649 |
+
label="num_inference_steps", minimum=10, maximum=100, value=25, step=1
|
1650 |
+
)
|
1651 |
+
|
1652 |
+
|
1653 |
+
inputs = [text, negative, guidance_scale, height, width, num_inference_steps, scheduler_name]
|
1654 |
+
# ex = gr.Examples(examples=examples, fn=infer, inputs=inputs, outputs=gallery, cache_examples=False)
|
1655 |
+
# ex.dataset.headers = [""]
|
1656 |
+
negative.submit(infer, inputs=inputs, outputs=gallery)
|
1657 |
+
text.submit(infer, inputs=inputs, outputs=gallery)
|
1658 |
+
btn.click(infer, inputs=inputs, outputs=gallery)
|
1659 |
+
|
1660 |
+
|
1661 |
+
gr.HTML(
|
1662 |
+
"""
|
1663 |
+
<div class="footer">
|
1664 |
+
<p>Model by <a href="https://www.paddlepaddle.org.cn/" style="text-decoration: underline;" target="_blank">PaddlePaddle</a> - Gradio Demo by 🤗 Hugging Face
|
1665 |
+
</p>
|
1666 |
+
</div>
|
1667 |
+
<div class="acknowledgments">
|
1668 |
+
<p><h4>LICENSE</h4>
|
1669 |
+
The model is licensed with a <a href="https://huggingface.co/stabilityai/stable-diffusion-2/blob/main/LICENSE-MODEL" style="text-decoration: underline;" target="_blank">CreativeML OpenRAIL++</a> license. The authors claim no rights on the outputs you generate, you are free to use them and are accountable for their use which must not go against the provisions set in this license. The license forbids you from sharing any content that violates any laws, produce any harm to a person, disseminate any personal information that would be meant for harm, spread misinformation and target vulnerable groups. For the full list of restrictions please <a href="https://huggingface.co/spaces/CompVis/stable-diffusion-license" target="_blank" style="text-decoration: underline;" target="_blank">read the license</a></p>
|
1670 |
+
<p><h4>Biases and content acknowledgment</h4>
|
1671 |
+
Despite how impressive being able to turn text into image is, beware to the fact that this model may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography and violence. The model was trained on the <a href="https://laion.ai/blog/laion-5b/" style="text-decoration: underline;" target="_blank">LAION-5B dataset</a>, which scraped non-curated image-text-pairs from the internet (the exception being the removal of illegal content) and is meant for research purposes. You can read more in the <a href="https://huggingface.co/CompVis/stable-diffusion-v1-4" style="text-decoration: underline;" target="_blank">model card</a></p>
|
1672 |
+
</div>
|
1673 |
+
"""
|
1674 |
+
)
|
1675 |
+
|
1676 |
+
block.launch(server_name="0.0.0.0", server_port=8221)
|
1677 |
+
|
env.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
############################################################################################################################
|
2 |
+
# 修改下面的参数
|
3 |
+
# (1)BASE_MODEL_NAME 代表你训练的基础模型
|
4 |
+
BASE_MODEL_NAME = "runwayml/stable-diffusion-v1-5"
|
5 |
+
|
6 |
+
# 是否开启lora
|
7 |
+
# (2)LORA_WEIGHTS_PATH 代码你上传到huggingface后的lora权重。
|
8 |
+
# LORA_WEIGHTS_PATH = None 表示不适应lora
|
9 |
+
LORA_WEIGHTS_PATH = "xianbao/demo_test"
|
10 |
+
|
11 |
+
# (3)PROMPTS 需要展示的prompt文本
|
12 |
+
PROMPTS = "A photo of sks dog in a bucket"
|
13 |
+
############################################################################################################################
|
ppdiffusers/__init__.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# flake8: noqa
|
16 |
+
|
17 |
+
from .configuration_utils import ConfigMixin
|
18 |
+
from .fastdeploy_utils import FastDeployRuntimeModel
|
19 |
+
from .ppnlp_patch_utils import *
|
20 |
+
from .utils import (
|
21 |
+
OptionalDependencyNotAvailable,
|
22 |
+
is_fastdeploy_available,
|
23 |
+
is_inflect_available,
|
24 |
+
is_k_diffusion_available,
|
25 |
+
is_librosa_available,
|
26 |
+
is_onnx_available,
|
27 |
+
is_paddle_available,
|
28 |
+
is_paddlenlp_available,
|
29 |
+
is_scipy_available,
|
30 |
+
is_unidecode_available,
|
31 |
+
logging,
|
32 |
+
)
|
33 |
+
from .version import VERSION as __version__
|
34 |
+
|
35 |
+
try:
|
36 |
+
if not is_paddle_available():
|
37 |
+
raise OptionalDependencyNotAvailable()
|
38 |
+
except OptionalDependencyNotAvailable:
|
39 |
+
from .utils.dummy_paddle_objects import * # noqa F403
|
40 |
+
else:
|
41 |
+
from .initializer import *
|
42 |
+
from .modeling_utils import ModelMixin
|
43 |
+
from .models import (
|
44 |
+
AutoencoderKL,
|
45 |
+
PriorTransformer,
|
46 |
+
Transformer2DModel,
|
47 |
+
UNet1DModel,
|
48 |
+
UNet2DConditionModel,
|
49 |
+
UNet2DModel,
|
50 |
+
VQModel,
|
51 |
+
)
|
52 |
+
from .optimization import (
|
53 |
+
get_constant_schedule,
|
54 |
+
get_constant_schedule_with_warmup,
|
55 |
+
get_cosine_schedule_with_warmup,
|
56 |
+
get_cosine_with_hard_restarts_schedule_with_warmup,
|
57 |
+
get_linear_schedule_with_warmup,
|
58 |
+
get_polynomial_decay_schedule_with_warmup,
|
59 |
+
get_scheduler,
|
60 |
+
)
|
61 |
+
from .pipeline_utils import DiffusionPipeline
|
62 |
+
from .pipelines import (
|
63 |
+
DanceDiffusionPipeline,
|
64 |
+
DDIMPipeline,
|
65 |
+
DDPMPipeline,
|
66 |
+
KarrasVePipeline,
|
67 |
+
LDMPipeline,
|
68 |
+
LDMSuperResolutionPipeline,
|
69 |
+
PNDMPipeline,
|
70 |
+
RePaintPipeline,
|
71 |
+
ScoreSdeVePipeline,
|
72 |
+
)
|
73 |
+
from .schedulers import (
|
74 |
+
DDIMScheduler,
|
75 |
+
DDPMScheduler,
|
76 |
+
DPMSolverMultistepScheduler,
|
77 |
+
DPMSolverSinglestepScheduler,
|
78 |
+
EulerAncestralDiscreteScheduler,
|
79 |
+
EulerDiscreteScheduler,
|
80 |
+
HeunDiscreteScheduler,
|
81 |
+
IPNDMScheduler,
|
82 |
+
KarrasVeScheduler,
|
83 |
+
KDPM2AncestralDiscreteScheduler,
|
84 |
+
KDPM2DiscreteScheduler,
|
85 |
+
PNDMScheduler,
|
86 |
+
RePaintScheduler,
|
87 |
+
SchedulerMixin,
|
88 |
+
ScoreSdeVeScheduler,
|
89 |
+
UnCLIPScheduler,
|
90 |
+
VQDiffusionScheduler,
|
91 |
+
)
|
92 |
+
from .schedulers.preconfig import PreconfigEulerAncestralDiscreteScheduler
|
93 |
+
from .training_utils import EMAModel
|
94 |
+
|
95 |
+
try:
|
96 |
+
if not (is_paddle_available() and is_scipy_available()):
|
97 |
+
raise OptionalDependencyNotAvailable()
|
98 |
+
except OptionalDependencyNotAvailable:
|
99 |
+
from .utils.dummy_paddle_and_scipy_objects import * # noqa F403
|
100 |
+
else:
|
101 |
+
from .schedulers import LMSDiscreteScheduler
|
102 |
+
from .schedulers.preconfig import PreconfigLMSDiscreteScheduler
|
103 |
+
|
104 |
+
try:
|
105 |
+
if not (is_paddle_available() and is_paddlenlp_available()):
|
106 |
+
raise OptionalDependencyNotAvailable()
|
107 |
+
except OptionalDependencyNotAvailable:
|
108 |
+
from .utils.dummy_paddle_and_paddlenlp_objects import * # noqa F403
|
109 |
+
else:
|
110 |
+
from .pipelines import (
|
111 |
+
AltDiffusionImg2ImgPipeline,
|
112 |
+
AltDiffusionPipeline,
|
113 |
+
CycleDiffusionPipeline,
|
114 |
+
LDMBertModel,
|
115 |
+
LDMTextToImagePipeline,
|
116 |
+
PaintByExamplePipeline,
|
117 |
+
StableDiffusionDepth2ImgPipeline,
|
118 |
+
StableDiffusionImageVariationPipeline,
|
119 |
+
StableDiffusionImg2ImgPipeline,
|
120 |
+
StableDiffusionInpaintPipeline,
|
121 |
+
StableDiffusionInpaintPipelineLegacy,
|
122 |
+
StableDiffusionMegaPipeline,
|
123 |
+
StableDiffusionPipeline,
|
124 |
+
StableDiffusionPipelineAllinOne,
|
125 |
+
StableDiffusionPipelineSafe,
|
126 |
+
StableDiffusionUpscalePipeline,
|
127 |
+
UnCLIPPipeline,
|
128 |
+
VersatileDiffusionDualGuidedPipeline,
|
129 |
+
VersatileDiffusionImageVariationPipeline,
|
130 |
+
VersatileDiffusionPipeline,
|
131 |
+
VersatileDiffusionTextToImagePipeline,
|
132 |
+
VQDiffusionPipeline,
|
133 |
+
)
|
134 |
+
|
135 |
+
try:
|
136 |
+
if not (is_paddle_available() and is_paddlenlp_available() and is_k_diffusion_available()):
|
137 |
+
raise OptionalDependencyNotAvailable()
|
138 |
+
except OptionalDependencyNotAvailable:
|
139 |
+
from .utils.dummy_paddle_and_paddlenlp_and_k_diffusion_objects import * # noqa F403
|
140 |
+
else:
|
141 |
+
from .pipelines import StableDiffusionKDiffusionPipeline
|
142 |
+
|
143 |
+
try:
|
144 |
+
if not (is_paddle_available() and is_paddlenlp_available() and is_fastdeploy_available()):
|
145 |
+
raise OptionalDependencyNotAvailable()
|
146 |
+
except OptionalDependencyNotAvailable:
|
147 |
+
from .utils.dummy_paddle_and_paddlenlp_and_fastdeploy_objects import * # noqa F403
|
148 |
+
else:
|
149 |
+
from .pipelines import (
|
150 |
+
FastDeployStableDiffusionImg2ImgPipeline,
|
151 |
+
FastDeployStableDiffusionInpaintPipeline,
|
152 |
+
FastDeployStableDiffusionInpaintPipelineLegacy,
|
153 |
+
FastDeployStableDiffusionMegaPipeline,
|
154 |
+
FastDeployStableDiffusionPipeline,
|
155 |
+
)
|
156 |
+
try:
|
157 |
+
if not (is_paddle_available() and is_librosa_available()):
|
158 |
+
raise OptionalDependencyNotAvailable()
|
159 |
+
except OptionalDependencyNotAvailable:
|
160 |
+
from .utils.dummy_paddle_and_librosa_objects import * # noqa F403
|
161 |
+
else:
|
162 |
+
from .pipelines import AudioDiffusionPipeline, Mel
|
ppdiffusers/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (4.28 kB). View file
|
|
ppdiffusers/__pycache__/configuration_utils.cpython-37.pyc
ADDED
Binary file (20.7 kB). View file
|
|
ppdiffusers/__pycache__/download_utils.cpython-37.pyc
ADDED
Binary file (818 Bytes). View file
|
|
ppdiffusers/__pycache__/fastdeploy_utils.cpython-37.pyc
ADDED
Binary file (8.18 kB). View file
|
|
ppdiffusers/__pycache__/initializer.cpython-37.pyc
ADDED
Binary file (8.69 kB). View file
|
|
ppdiffusers/__pycache__/loaders.cpython-37.pyc
ADDED
Binary file (7.47 kB). View file
|
|
ppdiffusers/__pycache__/modeling_utils.cpython-37.pyc
ADDED
Binary file (19.8 kB). View file
|
|
ppdiffusers/__pycache__/optimization.cpython-37.pyc
ADDED
Binary file (10.7 kB). View file
|
|
ppdiffusers/__pycache__/pipeline_utils.cpython-37.pyc
ADDED
Binary file (22.3 kB). View file
|
|
ppdiffusers/__pycache__/ppnlp_patch_utils.cpython-37.pyc
ADDED
Binary file (15.6 kB). View file
|
|
ppdiffusers/__pycache__/training_utils.cpython-37.pyc
ADDED
Binary file (4.01 kB). View file
|
|
ppdiffusers/__pycache__/version.cpython-37.pyc
ADDED
Binary file (141 Bytes). View file
|
|
ppdiffusers/commands/__init__.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from abc import ABC, abstractmethod
|
17 |
+
from argparse import ArgumentParser
|
18 |
+
|
19 |
+
|
20 |
+
class BasePPDiffusersCLICommand(ABC):
|
21 |
+
@staticmethod
|
22 |
+
@abstractmethod
|
23 |
+
def register_subcommand(parser: ArgumentParser):
|
24 |
+
raise NotImplementedError()
|
25 |
+
|
26 |
+
@abstractmethod
|
27 |
+
def run(self):
|
28 |
+
raise NotImplementedError()
|
ppdiffusers/commands/env.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import platform
|
17 |
+
from argparse import ArgumentParser
|
18 |
+
|
19 |
+
from .. import __version__ as version
|
20 |
+
from ..utils import is_paddle_available, is_paddlenlp_available
|
21 |
+
from . import BasePPDiffusersCLICommand
|
22 |
+
|
23 |
+
|
24 |
+
def info_command_factory(_):
|
25 |
+
return EnvironmentCommand()
|
26 |
+
|
27 |
+
|
28 |
+
class EnvironmentCommand(BasePPDiffusersCLICommand):
|
29 |
+
@staticmethod
|
30 |
+
def register_subcommand(parser: ArgumentParser):
|
31 |
+
download_parser = parser.add_parser("env")
|
32 |
+
download_parser.set_defaults(func=info_command_factory)
|
33 |
+
|
34 |
+
def run(self):
|
35 |
+
|
36 |
+
pd_version = "not installed"
|
37 |
+
pd_cuda_available = "NA"
|
38 |
+
if is_paddle_available():
|
39 |
+
import paddle
|
40 |
+
|
41 |
+
pd_version = paddle.__version__
|
42 |
+
pd_cuda_available = paddle.device.is_compiled_with_cuda()
|
43 |
+
|
44 |
+
paddlenlp_version = "not installed"
|
45 |
+
if is_paddlenlp_available:
|
46 |
+
import paddlenlp
|
47 |
+
|
48 |
+
paddlenlp_version = paddlenlp.__version__
|
49 |
+
|
50 |
+
info = {
|
51 |
+
"`ppdiffusers` version": version,
|
52 |
+
"Platform": platform.platform(),
|
53 |
+
"Python version": platform.python_version(),
|
54 |
+
"Paddle version (GPU?)": f"{pd_version} ({pd_cuda_available})",
|
55 |
+
"PaddleNLP version": paddlenlp_version,
|
56 |
+
"Using GPU in script?": "<fill in>",
|
57 |
+
"Using distributed or parallel set-up in script?": "<fill in>",
|
58 |
+
}
|
59 |
+
|
60 |
+
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
|
61 |
+
print(self.format_dict(info))
|
62 |
+
|
63 |
+
return info
|
64 |
+
|
65 |
+
@staticmethod
|
66 |
+
def format_dict(d):
|
67 |
+
return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
|
ppdiffusers/commands/ppdiffusers_cli.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from argparse import ArgumentParser
|
17 |
+
|
18 |
+
from .env import EnvironmentCommand
|
19 |
+
|
20 |
+
|
21 |
+
def main():
|
22 |
+
parser = ArgumentParser("PPDiffusers CLI tool", usage="ppdiffusers-cli <command> [<args>]")
|
23 |
+
commands_parser = parser.add_subparsers(help="ppdiffusers-cli command helpers")
|
24 |
+
|
25 |
+
# Register commands
|
26 |
+
EnvironmentCommand.register_subcommand(commands_parser)
|
27 |
+
|
28 |
+
# Let's go
|
29 |
+
args = parser.parse_args()
|
30 |
+
|
31 |
+
if not hasattr(args, "func"):
|
32 |
+
parser.print_help()
|
33 |
+
exit(1)
|
34 |
+
|
35 |
+
# Run
|
36 |
+
service = args.func(args)
|
37 |
+
service.run()
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == "__main__":
|
41 |
+
main()
|
ppdiffusers/configuration_utils.py
ADDED
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
3 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
""" ConfigMixin base class and utilities."""
|
17 |
+
import functools
|
18 |
+
import importlib
|
19 |
+
import inspect
|
20 |
+
import json
|
21 |
+
import os
|
22 |
+
import re
|
23 |
+
import tempfile
|
24 |
+
from collections import OrderedDict
|
25 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
from huggingface_hub import (
|
29 |
+
create_repo,
|
30 |
+
get_hf_file_metadata,
|
31 |
+
hf_hub_download,
|
32 |
+
hf_hub_url,
|
33 |
+
repo_type_and_id_from_hf_id,
|
34 |
+
upload_folder,
|
35 |
+
)
|
36 |
+
from huggingface_hub.utils import EntryNotFoundError
|
37 |
+
from requests import HTTPError
|
38 |
+
|
39 |
+
from .download_utils import ppdiffusers_bos_download
|
40 |
+
from .utils import (
|
41 |
+
DOWNLOAD_SERVER,
|
42 |
+
HF_CACHE,
|
43 |
+
PPDIFFUSERS_CACHE,
|
44 |
+
DummyObject,
|
45 |
+
deprecate,
|
46 |
+
logging,
|
47 |
+
)
|
48 |
+
from .version import VERSION as __version__
|
49 |
+
|
50 |
+
logger = logging.get_logger(__name__)
|
51 |
+
|
52 |
+
_re_configuration_file = re.compile(r"config\.(.*)\.json")
|
53 |
+
|
54 |
+
|
55 |
+
class FrozenDict(OrderedDict):
|
56 |
+
def __init__(self, *args, **kwargs):
|
57 |
+
super().__init__(*args, **kwargs)
|
58 |
+
|
59 |
+
for key, value in self.items():
|
60 |
+
setattr(self, key, value)
|
61 |
+
|
62 |
+
self.__frozen = True
|
63 |
+
|
64 |
+
def __delitem__(self, *args, **kwargs):
|
65 |
+
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
|
66 |
+
|
67 |
+
def setdefault(self, *args, **kwargs):
|
68 |
+
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
|
69 |
+
|
70 |
+
def pop(self, *args, **kwargs):
|
71 |
+
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
|
72 |
+
|
73 |
+
def update(self, *args, **kwargs):
|
74 |
+
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
|
75 |
+
|
76 |
+
def __setattr__(self, name, value):
|
77 |
+
if hasattr(self, "__frozen") and self.__frozen:
|
78 |
+
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
79 |
+
super().__setattr__(name, value)
|
80 |
+
|
81 |
+
def __setitem__(self, name, value):
|
82 |
+
if hasattr(self, "__frozen") and self.__frozen:
|
83 |
+
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
84 |
+
super().__setitem__(name, value)
|
85 |
+
|
86 |
+
|
87 |
+
class ConfigMixin:
|
88 |
+
r"""
|
89 |
+
Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
|
90 |
+
methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
|
91 |
+
- [`~ConfigMixin.from_config`]
|
92 |
+
- [`~ConfigMixin.save_config`]
|
93 |
+
|
94 |
+
Class attributes:
|
95 |
+
- **config_name** (`str`) -- A filename under which the config should stored when calling
|
96 |
+
[`~ConfigMixin.save_config`] (should be overridden by parent class).
|
97 |
+
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
|
98 |
+
overridden by subclass).
|
99 |
+
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
|
100 |
+
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
|
101 |
+
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
|
102 |
+
subclass).
|
103 |
+
"""
|
104 |
+
config_name = None
|
105 |
+
ignore_for_config = []
|
106 |
+
has_compatibles = False
|
107 |
+
_deprecated_kwargs = []
|
108 |
+
|
109 |
+
def register_to_config(self, **kwargs):
|
110 |
+
if self.config_name is None:
|
111 |
+
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
|
112 |
+
|
113 |
+
# Special case for `kwargs` used in deprecation warning added to schedulers
|
114 |
+
# TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
|
115 |
+
# or solve in a more general way.
|
116 |
+
kwargs.pop("kwargs", None)
|
117 |
+
for key, value in kwargs.items():
|
118 |
+
try:
|
119 |
+
setattr(self, key, value)
|
120 |
+
except AttributeError as err:
|
121 |
+
logger.error(f"Can't set {key} with value {value} for {self}")
|
122 |
+
raise err
|
123 |
+
|
124 |
+
if not hasattr(self, "_internal_dict"):
|
125 |
+
internal_dict = kwargs
|
126 |
+
else:
|
127 |
+
previous_dict = dict(self._internal_dict)
|
128 |
+
internal_dict = {**self._internal_dict, **kwargs}
|
129 |
+
logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
|
130 |
+
|
131 |
+
self._internal_dict = FrozenDict(internal_dict)
|
132 |
+
|
133 |
+
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
134 |
+
"""
|
135 |
+
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
|
136 |
+
[`~ConfigMixin.from_config`] class method.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
save_directory (`str` or `os.PathLike`):
|
140 |
+
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
141 |
+
"""
|
142 |
+
if os.path.isfile(save_directory):
|
143 |
+
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
144 |
+
|
145 |
+
os.makedirs(save_directory, exist_ok=True)
|
146 |
+
|
147 |
+
# If we save using the predefined names, we can load using `from_config`
|
148 |
+
output_config_file = os.path.join(save_directory, self.config_name)
|
149 |
+
|
150 |
+
self.to_json_file(output_config_file)
|
151 |
+
logger.info(f"Configuration saved in {output_config_file}")
|
152 |
+
|
153 |
+
def save_to_hf_hub(
|
154 |
+
self,
|
155 |
+
repo_id: str,
|
156 |
+
private: Optional[bool] = None,
|
157 |
+
subfolder: Optional[str] = None,
|
158 |
+
commit_message: Optional[str] = None,
|
159 |
+
revision: Optional[str] = None,
|
160 |
+
create_pr: bool = False,
|
161 |
+
):
|
162 |
+
"""
|
163 |
+
Uploads all elements of this config to a new HuggingFace Hub repository.
|
164 |
+
Args:
|
165 |
+
repo_id (str): Repository name for your model/tokenizer in the Hub.
|
166 |
+
private (bool, optional): Whether the model/tokenizer is set to private
|
167 |
+
subfolder (str, optional): Push to a subfolder of the repo instead of the root
|
168 |
+
commit_message (str, optional): The summary / title / first line of the generated commit. Defaults to: f"Upload {path_in_repo} with huggingface_hub"
|
169 |
+
revision (str, optional): The git revision to commit from. Defaults to the head of the "main" branch.
|
170 |
+
create_pr (boolean, optional): Whether or not to create a Pull Request with that commit. Defaults to False.
|
171 |
+
If revision is not set, PR is opened against the "main" branch. If revision is set and is a branch, PR is opened against this branch.
|
172 |
+
If revision is set and is not a branch name (example: a commit oid), an RevisionNotFoundError is returned by the server.
|
173 |
+
|
174 |
+
Returns: The url of the commit of your model in the given repository.
|
175 |
+
"""
|
176 |
+
repo_url = create_repo(repo_id, private=private, exist_ok=True)
|
177 |
+
|
178 |
+
# Infer complete repo_id from repo_url
|
179 |
+
# Can be different from the input `repo_id` if repo_owner was implicit
|
180 |
+
_, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
|
181 |
+
|
182 |
+
repo_id = f"{repo_owner}/{repo_name}"
|
183 |
+
|
184 |
+
# Check if README file already exist in repo
|
185 |
+
try:
|
186 |
+
get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
|
187 |
+
has_readme = True
|
188 |
+
except EntryNotFoundError:
|
189 |
+
has_readme = False
|
190 |
+
|
191 |
+
with tempfile.TemporaryDirectory() as root_dir:
|
192 |
+
if subfolder is not None:
|
193 |
+
save_dir = os.path.join(root_dir, subfolder)
|
194 |
+
else:
|
195 |
+
save_dir = root_dir
|
196 |
+
# save config
|
197 |
+
self.save_config(save_dir)
|
198 |
+
# Add readme if does not exist
|
199 |
+
logger.info("README.md not found, adding the default README.md")
|
200 |
+
if not has_readme:
|
201 |
+
with open(os.path.join(root_dir, "README.md"), "w") as f:
|
202 |
+
f.write(f"---\nlibrary_name: ppdiffusers\n---\n# {repo_id}")
|
203 |
+
|
204 |
+
# Upload model and return
|
205 |
+
logger.info(f"Pushing to the {repo_id}. This might take a while")
|
206 |
+
return upload_folder(
|
207 |
+
repo_id=repo_id,
|
208 |
+
repo_type="model",
|
209 |
+
folder_path=root_dir,
|
210 |
+
commit_message=commit_message,
|
211 |
+
revision=revision,
|
212 |
+
create_pr=create_pr,
|
213 |
+
)
|
214 |
+
|
215 |
+
@classmethod
|
216 |
+
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
|
217 |
+
r"""
|
218 |
+
Instantiate a Python class from a config dictionary
|
219 |
+
|
220 |
+
Parameters:
|
221 |
+
config (`Dict[str, Any]`):
|
222 |
+
A config dictionary from which the Python class will be instantiated. Make sure to only load
|
223 |
+
configuration files of compatible classes.
|
224 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
225 |
+
Whether kwargs that are not consumed by the Python class should be returned or not.
|
226 |
+
|
227 |
+
kwargs (remaining dictionary of keyword arguments, *optional*):
|
228 |
+
Can be used to update the configuration object (after it being loaded) and initiate the Python class.
|
229 |
+
`**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually
|
230 |
+
overwrite same named arguments of `config`.
|
231 |
+
|
232 |
+
Examples:
|
233 |
+
|
234 |
+
```python
|
235 |
+
>>> from ppdiffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
|
236 |
+
|
237 |
+
>>> # Download scheduler from BOS and cache.
|
238 |
+
>>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
|
239 |
+
|
240 |
+
>>> # Instantiate DDIM scheduler class with same config as DDPM
|
241 |
+
>>> scheduler = DDIMScheduler.from_config(scheduler.config)
|
242 |
+
|
243 |
+
>>> # Instantiate PNDM scheduler class with same config as DDPM
|
244 |
+
>>> scheduler = PNDMScheduler.from_config(scheduler.config)
|
245 |
+
```
|
246 |
+
"""
|
247 |
+
# <===== TO BE REMOVED WITH DEPRECATION
|
248 |
+
# TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
|
249 |
+
if "pretrained_model_name_or_path" in kwargs:
|
250 |
+
config = kwargs.pop("pretrained_model_name_or_path")
|
251 |
+
|
252 |
+
if config is None:
|
253 |
+
raise ValueError("Please make sure to provide a config as the first positional argument.")
|
254 |
+
# ======>
|
255 |
+
|
256 |
+
if not isinstance(config, dict):
|
257 |
+
deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
|
258 |
+
if "Scheduler" in cls.__name__:
|
259 |
+
deprecation_message += (
|
260 |
+
f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
|
261 |
+
" Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
|
262 |
+
" be removed in v1.0.0."
|
263 |
+
)
|
264 |
+
elif "Model" in cls.__name__:
|
265 |
+
deprecation_message += (
|
266 |
+
f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
|
267 |
+
f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
|
268 |
+
" instead. This functionality will be removed in v1.0.0."
|
269 |
+
)
|
270 |
+
deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
|
271 |
+
config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
|
272 |
+
|
273 |
+
init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
|
274 |
+
|
275 |
+
# Allow dtype to be specified on initialization
|
276 |
+
if "dtype" in unused_kwargs:
|
277 |
+
# (TODO junnyu, donot use dtype)
|
278 |
+
unused_kwargs.pop("dtype")
|
279 |
+
# init_dict["dtype"] = unused_kwargs.pop("dtype")
|
280 |
+
|
281 |
+
# add possible deprecated kwargs
|
282 |
+
for deprecated_kwarg in cls._deprecated_kwargs:
|
283 |
+
if deprecated_kwarg in unused_kwargs:
|
284 |
+
init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
|
285 |
+
|
286 |
+
# Return model and optionally state and/or unused_kwargs
|
287 |
+
model = cls(**init_dict)
|
288 |
+
|
289 |
+
# make sure to also save config parameters that might be used for compatible classes
|
290 |
+
model.register_to_config(**hidden_dict)
|
291 |
+
|
292 |
+
# add hidden kwargs of compatible classes to unused_kwargs
|
293 |
+
unused_kwargs = {**unused_kwargs, **hidden_dict}
|
294 |
+
|
295 |
+
if return_unused_kwargs:
|
296 |
+
return (model, unused_kwargs)
|
297 |
+
else:
|
298 |
+
return model
|
299 |
+
|
300 |
+
@classmethod
|
301 |
+
def get_config_dict(cls, *args, **kwargs):
|
302 |
+
deprecation_message = (
|
303 |
+
f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
|
304 |
+
" removed in version v1.0.0"
|
305 |
+
)
|
306 |
+
deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
|
307 |
+
return cls.load_config(*args, **kwargs)
|
308 |
+
|
309 |
+
@classmethod
|
310 |
+
def load_config(
|
311 |
+
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
|
312 |
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
313 |
+
r"""
|
314 |
+
Instantiate a Python class from a config dictionary
|
315 |
+
|
316 |
+
Parameters:
|
317 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
318 |
+
Can be either:
|
319 |
+
|
320 |
+
- A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
|
321 |
+
organization name, like `google/ddpm-celebahq-256`.
|
322 |
+
- A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g.,
|
323 |
+
`./my_model_directory/`.
|
324 |
+
|
325 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
326 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
327 |
+
standard cache should not be used.
|
328 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
329 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
330 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
331 |
+
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
332 |
+
huggingface.co or downloaded locally), you can specify the folder name here.
|
333 |
+
from_hf_hub (bool, *optional*):
|
334 |
+
Whether to load from Hugging Face Hub. Defaults to False
|
335 |
+
"""
|
336 |
+
from_hf_hub = kwargs.pop("from_hf_hub", False)
|
337 |
+
if from_hf_hub:
|
338 |
+
cache_dir = kwargs.pop("cache_dir", HF_CACHE)
|
339 |
+
else:
|
340 |
+
cache_dir = kwargs.pop("cache_dir", PPDIFFUSERS_CACHE)
|
341 |
+
subfolder = kwargs.pop("subfolder", None)
|
342 |
+
|
343 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
344 |
+
|
345 |
+
if cls.config_name is None:
|
346 |
+
raise ValueError(
|
347 |
+
"`self.config_name` is not defined. Note that one should not load a config from "
|
348 |
+
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
|
349 |
+
)
|
350 |
+
|
351 |
+
if os.path.isfile(pretrained_model_name_or_path):
|
352 |
+
config_file = pretrained_model_name_or_path
|
353 |
+
elif os.path.isdir(pretrained_model_name_or_path):
|
354 |
+
if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
|
355 |
+
# Load from a Paddle checkpoint
|
356 |
+
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
|
357 |
+
elif subfolder is not None and os.path.isfile(
|
358 |
+
os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
|
359 |
+
):
|
360 |
+
config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
|
361 |
+
else:
|
362 |
+
raise EnvironmentError(
|
363 |
+
f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
|
364 |
+
)
|
365 |
+
elif from_hf_hub:
|
366 |
+
config_file = hf_hub_download(
|
367 |
+
repo_id=pretrained_model_name_or_path,
|
368 |
+
filename=cls.config_name,
|
369 |
+
cache_dir=cache_dir,
|
370 |
+
subfolder=subfolder,
|
371 |
+
library_name="PPDiffusers",
|
372 |
+
library_version=__version__,
|
373 |
+
)
|
374 |
+
else:
|
375 |
+
try:
|
376 |
+
config_file = ppdiffusers_bos_download(
|
377 |
+
pretrained_model_name_or_path,
|
378 |
+
filename=cls.config_name,
|
379 |
+
subfolder=subfolder,
|
380 |
+
cache_dir=cache_dir,
|
381 |
+
)
|
382 |
+
except HTTPError as err:
|
383 |
+
raise EnvironmentError(
|
384 |
+
"There was a specific connection error when trying to load"
|
385 |
+
f" {pretrained_model_name_or_path}:\n{err}"
|
386 |
+
)
|
387 |
+
except ValueError:
|
388 |
+
raise EnvironmentError(
|
389 |
+
f"We couldn't connect to '{DOWNLOAD_SERVER}' to load this model, couldn't find it"
|
390 |
+
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
391 |
+
f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
|
392 |
+
" run the library in offline mode at"
|
393 |
+
" 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
394 |
+
)
|
395 |
+
except EnvironmentError:
|
396 |
+
raise EnvironmentError(
|
397 |
+
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
398 |
+
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
399 |
+
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
400 |
+
f"containing a {cls.config_name} file"
|
401 |
+
)
|
402 |
+
|
403 |
+
try:
|
404 |
+
# Load config dict
|
405 |
+
config_dict = cls._dict_from_json_file(config_file)
|
406 |
+
except (json.JSONDecodeError, UnicodeDecodeError):
|
407 |
+
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
|
408 |
+
|
409 |
+
if return_unused_kwargs:
|
410 |
+
return config_dict, kwargs
|
411 |
+
|
412 |
+
return config_dict
|
413 |
+
|
414 |
+
@staticmethod
|
415 |
+
def _get_init_keys(cls):
|
416 |
+
return set(dict(inspect.signature(cls.__init__).parameters).keys())
|
417 |
+
|
418 |
+
@classmethod
|
419 |
+
def extract_init_dict(cls, config_dict, **kwargs):
|
420 |
+
# 0. Copy origin config dict
|
421 |
+
original_dict = {k: v for k, v in config_dict.items()}
|
422 |
+
|
423 |
+
# 1. Retrieve expected config attributes from __init__ signature
|
424 |
+
expected_keys = cls._get_init_keys(cls)
|
425 |
+
expected_keys.remove("self")
|
426 |
+
# remove general kwargs if present in dict
|
427 |
+
if "kwargs" in expected_keys:
|
428 |
+
expected_keys.remove("kwargs")
|
429 |
+
|
430 |
+
# 2. Remove attributes that cannot be expected from expected config attributes
|
431 |
+
# remove keys to be ignored
|
432 |
+
if len(cls.ignore_for_config) > 0:
|
433 |
+
expected_keys = expected_keys - set(cls.ignore_for_config)
|
434 |
+
|
435 |
+
# load ppdiffusers library to import compatible and original scheduler
|
436 |
+
ppdiffusers_library = importlib.import_module(__name__.split(".")[0])
|
437 |
+
|
438 |
+
if cls.has_compatibles:
|
439 |
+
compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
|
440 |
+
else:
|
441 |
+
compatible_classes = []
|
442 |
+
|
443 |
+
expected_keys_comp_cls = set()
|
444 |
+
for c in compatible_classes:
|
445 |
+
expected_keys_c = cls._get_init_keys(c)
|
446 |
+
expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
|
447 |
+
expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
|
448 |
+
config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
|
449 |
+
|
450 |
+
# remove attributes from orig class that cannot be expected
|
451 |
+
orig_cls_name = config_dict.pop("_class_name", cls.__name__)
|
452 |
+
if orig_cls_name != cls.__name__ and hasattr(ppdiffusers_library, orig_cls_name):
|
453 |
+
orig_cls = getattr(ppdiffusers_library, orig_cls_name)
|
454 |
+
unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
|
455 |
+
config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
|
456 |
+
|
457 |
+
# remove private attributes
|
458 |
+
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
|
459 |
+
|
460 |
+
# 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
|
461 |
+
init_dict = {}
|
462 |
+
for key in expected_keys:
|
463 |
+
# if config param is passed to kwarg and is present in config dict
|
464 |
+
# it should overwrite existing config dict key
|
465 |
+
if key in kwargs and key in config_dict:
|
466 |
+
config_dict[key] = kwargs.pop(key)
|
467 |
+
|
468 |
+
if key in kwargs:
|
469 |
+
# overwrite key
|
470 |
+
init_dict[key] = kwargs.pop(key)
|
471 |
+
elif key in config_dict:
|
472 |
+
# use value from config dict
|
473 |
+
init_dict[key] = config_dict.pop(key)
|
474 |
+
|
475 |
+
# 4. Give nice warning if unexpected values have been passed
|
476 |
+
if len(config_dict) > 0:
|
477 |
+
logger.warning(
|
478 |
+
f"The config attributes {config_dict} were passed to {cls.__name__}, "
|
479 |
+
"but are not expected and will be ignored. Please verify your "
|
480 |
+
f"{cls.config_name} configuration file."
|
481 |
+
)
|
482 |
+
|
483 |
+
# 5. Give nice info if config attributes are initiliazed to default because they have not been passed
|
484 |
+
passed_keys = set(init_dict.keys())
|
485 |
+
if len(expected_keys - passed_keys) > 0:
|
486 |
+
logger.info(
|
487 |
+
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
|
488 |
+
)
|
489 |
+
|
490 |
+
# 6. Define unused keyword arguments
|
491 |
+
unused_kwargs = {**config_dict, **kwargs}
|
492 |
+
|
493 |
+
# 7. Define "hidden" config parameters that were saved for compatible classes
|
494 |
+
hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
|
495 |
+
|
496 |
+
return init_dict, unused_kwargs, hidden_config_dict
|
497 |
+
|
498 |
+
@classmethod
|
499 |
+
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
500 |
+
with open(json_file, "r", encoding="utf-8") as reader:
|
501 |
+
text = reader.read()
|
502 |
+
return json.loads(text)
|
503 |
+
|
504 |
+
def __repr__(self):
|
505 |
+
return f"{self.__class__.__name__} {self.to_json_string()}"
|
506 |
+
|
507 |
+
@property
|
508 |
+
def config(self) -> Dict[str, Any]:
|
509 |
+
"""
|
510 |
+
Returns the config of the class as a frozen dictionary
|
511 |
+
|
512 |
+
Returns:
|
513 |
+
`Dict[str, Any]`: Config of the class.
|
514 |
+
"""
|
515 |
+
return self._internal_dict
|
516 |
+
|
517 |
+
def to_json_string(self) -> str:
|
518 |
+
"""
|
519 |
+
Serializes this instance to a JSON string.
|
520 |
+
|
521 |
+
Returns:
|
522 |
+
`str`: String containing all the attributes that make up this configuration instance in JSON format.
|
523 |
+
"""
|
524 |
+
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
|
525 |
+
config_dict["_class_name"] = self.__class__.__name__
|
526 |
+
config_dict["_ppdiffusers_version"] = __version__
|
527 |
+
|
528 |
+
def to_json_saveable(value):
|
529 |
+
if isinstance(value, np.ndarray):
|
530 |
+
value = value.tolist()
|
531 |
+
return value
|
532 |
+
|
533 |
+
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
|
534 |
+
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
535 |
+
|
536 |
+
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
537 |
+
"""
|
538 |
+
Save this instance to a JSON file.
|
539 |
+
|
540 |
+
Args:
|
541 |
+
json_file_path (`str` or `os.PathLike`):
|
542 |
+
Path to the JSON file in which this configuration instance's parameters will be saved.
|
543 |
+
"""
|
544 |
+
with open(json_file_path, "w", encoding="utf-8") as writer:
|
545 |
+
writer.write(self.to_json_string())
|
546 |
+
|
547 |
+
|
548 |
+
def register_to_config(init):
|
549 |
+
r"""
|
550 |
+
Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
|
551 |
+
automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
|
552 |
+
shouldn't be registered in the config, use the `ignore_for_config` class variable
|
553 |
+
|
554 |
+
Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
|
555 |
+
"""
|
556 |
+
|
557 |
+
@functools.wraps(init)
|
558 |
+
def inner_init(self, *args, **kwargs):
|
559 |
+
# Ignore private kwargs in the init.
|
560 |
+
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
|
561 |
+
config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
|
562 |
+
|
563 |
+
if not isinstance(self, ConfigMixin):
|
564 |
+
raise RuntimeError(
|
565 |
+
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
|
566 |
+
"not inherit from `ConfigMixin`."
|
567 |
+
)
|
568 |
+
|
569 |
+
ignore = getattr(self, "ignore_for_config", [])
|
570 |
+
# Get positional arguments aligned with kwargs
|
571 |
+
new_kwargs = {}
|
572 |
+
signature = inspect.signature(init)
|
573 |
+
parameters = {
|
574 |
+
name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
|
575 |
+
}
|
576 |
+
for arg, name in zip(args, parameters.keys()):
|
577 |
+
new_kwargs[name] = arg
|
578 |
+
|
579 |
+
# Then add all kwargs
|
580 |
+
new_kwargs.update(
|
581 |
+
{
|
582 |
+
k: init_kwargs.get(k, default)
|
583 |
+
for k, default in parameters.items()
|
584 |
+
if k not in ignore and k not in new_kwargs
|
585 |
+
}
|
586 |
+
)
|
587 |
+
new_kwargs = {**config_init_kwargs, **new_kwargs}
|
588 |
+
getattr(self, "register_to_config")(**new_kwargs)
|
589 |
+
init(self, *args, **init_kwargs)
|
590 |
+
|
591 |
+
return inner_init
|
ppdiffusers/download_utils.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import os
|
17 |
+
|
18 |
+
from paddlenlp.utils.downloader import get_path_from_url_with_filelock
|
19 |
+
from paddlenlp.utils.log import logger
|
20 |
+
|
21 |
+
from .utils import DOWNLOAD_SERVER, PPDIFFUSERS_CACHE
|
22 |
+
|
23 |
+
|
24 |
+
def ppdiffusers_bos_download(pretrained_model_name_or_path, filename=None, subfolder=None, cache_dir=None):
|
25 |
+
if cache_dir is None:
|
26 |
+
cache_dir = PPDIFFUSERS_CACHE
|
27 |
+
cache_dir = (
|
28 |
+
pretrained_model_name_or_path
|
29 |
+
if os.path.isdir(pretrained_model_name_or_path)
|
30 |
+
else os.path.join(cache_dir, pretrained_model_name_or_path)
|
31 |
+
)
|
32 |
+
url = DOWNLOAD_SERVER + "/" + pretrained_model_name_or_path
|
33 |
+
if subfolder is not None:
|
34 |
+
url = url + "/" + subfolder
|
35 |
+
cache_dir = os.path.join(cache_dir, subfolder)
|
36 |
+
if filename is not None:
|
37 |
+
url = url + "/" + filename
|
38 |
+
|
39 |
+
file_path = os.path.join(cache_dir, filename)
|
40 |
+
if os.path.exists(file_path):
|
41 |
+
logger.info("Already cached %s" % file_path)
|
42 |
+
else:
|
43 |
+
file_path = get_path_from_url_with_filelock(url, cache_dir)
|
44 |
+
return file_path
|
ppdiffusers/experimental/README.md
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 🧨 PPDiffusers Experimental
|
2 |
+
|
3 |
+
为了使得**PPDiffusers库**能够有更多的应用场景,我们在这里添加了一些**实验性的代码**。
|
4 |
+
|
5 |
+
目前我们支持了以下场景:
|
6 |
+
* Reinforcement learning via an implementation of the [PPDiffuser](https://arxiv.org/abs/2205.09991) model.
|
ppdiffusers/experimental/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# flake8: noqa
|
16 |
+
|
17 |
+
from .rl import ValueGuidedRLPipeline
|
ppdiffusers/experimental/rl/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# flake8: noqa
|
16 |
+
|
17 |
+
from .value_guided_sampling import ValueGuidedRLPipeline
|
ppdiffusers/experimental/rl/value_guided_sampling.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import paddle
|
18 |
+
|
19 |
+
from ...models.unet_1d import UNet1DModel
|
20 |
+
from ...pipeline_utils import DiffusionPipeline
|
21 |
+
from ...utils.dummy_paddle_objects import DDPMScheduler
|
22 |
+
|
23 |
+
|
24 |
+
class ValueGuidedRLPipeline(DiffusionPipeline):
|
25 |
+
r"""
|
26 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
27 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
28 |
+
Pipeline for sampling actions from a diffusion model trained to predict sequences of states.
|
29 |
+
Original implementation inspired by this repository: https://github.com/jannerm/diffuser.
|
30 |
+
|
31 |
+
Parameters:
|
32 |
+
value_function ([`UNet1DModel`]): A specialized UNet for fine-tuning trajectories base on reward.
|
33 |
+
unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded trajectories.
|
34 |
+
scheduler ([`SchedulerMixin`]):
|
35 |
+
A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this
|
36 |
+
application is [`DDPMScheduler`].
|
37 |
+
env: An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
value_function: UNet1DModel,
|
43 |
+
unet: UNet1DModel,
|
44 |
+
scheduler: DDPMScheduler,
|
45 |
+
env,
|
46 |
+
):
|
47 |
+
super().__init__()
|
48 |
+
self.value_function = value_function
|
49 |
+
self.unet = unet
|
50 |
+
self.scheduler = scheduler
|
51 |
+
self.env = env
|
52 |
+
self.data = env.get_dataset()
|
53 |
+
self.means = dict()
|
54 |
+
for key in self.data.keys():
|
55 |
+
try:
|
56 |
+
self.means[key] = self.data[key].mean()
|
57 |
+
except Exception:
|
58 |
+
pass
|
59 |
+
self.stds = dict()
|
60 |
+
for key in self.data.keys():
|
61 |
+
try:
|
62 |
+
self.stds[key] = self.data[key].std()
|
63 |
+
except Exception:
|
64 |
+
pass
|
65 |
+
self.state_dim = env.observation_space.shape[0]
|
66 |
+
self.action_dim = env.action_space.shape[0]
|
67 |
+
|
68 |
+
def normalize(self, x_in, key):
|
69 |
+
return (x_in - self.means[key]) / self.stds[key]
|
70 |
+
|
71 |
+
def de_normalize(self, x_in, key):
|
72 |
+
return x_in * self.stds[key] + self.means[key]
|
73 |
+
|
74 |
+
def to_paddle(self, x_in):
|
75 |
+
if type(x_in) is dict:
|
76 |
+
return {k: self.to_paddle(v) for k, v in x_in.items()}
|
77 |
+
elif paddle.is_tensor(x_in):
|
78 |
+
return x_in
|
79 |
+
return paddle.to_tensor(x_in)
|
80 |
+
|
81 |
+
def reset_x0(self, x_in, cond, act_dim):
|
82 |
+
for key, val in cond.items():
|
83 |
+
x_in[:, key, act_dim:] = val.clone()
|
84 |
+
return x_in
|
85 |
+
|
86 |
+
def run_diffusion(self, x, conditions, n_guide_steps, scale):
|
87 |
+
batch_size = x.shape[0]
|
88 |
+
y = None
|
89 |
+
for i in self.progress_bar(self.scheduler.timesteps):
|
90 |
+
# create batch of timesteps to pass into model
|
91 |
+
timesteps = paddle.full((batch_size,), i, dtype="int64")
|
92 |
+
for _ in range(n_guide_steps):
|
93 |
+
with paddle.set_grad_enabled(True):
|
94 |
+
x.stop_gradient = False
|
95 |
+
# permute to match dimension for pre-trained models
|
96 |
+
y = self.value_function(x.transpose([0, 2, 1]), timesteps).sample
|
97 |
+
grad = paddle.autograd.grad([y.sum()], [x])[0]
|
98 |
+
|
99 |
+
posterior_variance = self.scheduler._get_variance(i)
|
100 |
+
model_std = paddle.exp(0.5 * posterior_variance)
|
101 |
+
grad = model_std * grad
|
102 |
+
|
103 |
+
grad[timesteps < 2] = 0
|
104 |
+
x = x.detach()
|
105 |
+
x = x + scale * grad
|
106 |
+
x = self.reset_x0(x, conditions, self.action_dim)
|
107 |
+
prev_x = self.unet(x.transpose([0, 2, 1]), timesteps).sample.transpose([0, 2, 1])
|
108 |
+
# TODO: verify deprecation of this kwarg
|
109 |
+
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
|
110 |
+
|
111 |
+
# apply conditions to the trajectory (set the initial state)
|
112 |
+
x = self.reset_x0(x, conditions, self.action_dim)
|
113 |
+
x = self.to_paddle(x)
|
114 |
+
return x, y
|
115 |
+
|
116 |
+
def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
|
117 |
+
# normalize the observations and create batch dimension
|
118 |
+
obs = self.normalize(obs, "observations")
|
119 |
+
obs = obs[None].repeat(batch_size, axis=0)
|
120 |
+
|
121 |
+
conditions = {0: self.to_paddle(obs)}
|
122 |
+
shape = [batch_size, planning_horizon, self.state_dim + self.action_dim]
|
123 |
+
|
124 |
+
# generate initial noise and apply our conditions (to make the trajectories start at current state)
|
125 |
+
x1 = paddle.randn(shape)
|
126 |
+
x = self.reset_x0(x1, conditions, self.action_dim)
|
127 |
+
x = self.to_paddle(x)
|
128 |
+
|
129 |
+
# run the diffusion process
|
130 |
+
x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
|
131 |
+
|
132 |
+
# sort output trajectories by value
|
133 |
+
sorted_idx = paddle.argsort(y, 0, descending=True).squeeze()
|
134 |
+
sorted_values = x[sorted_idx]
|
135 |
+
actions = sorted_values[:, :, : self.action_dim]
|
136 |
+
actions = actions.detach().numpy()
|
137 |
+
denorm_actions = self.de_normalize(actions, key="actions")
|
138 |
+
|
139 |
+
# select the action with the highest value
|
140 |
+
if y is not None:
|
141 |
+
selected_index = 0
|
142 |
+
else:
|
143 |
+
# if we didn't run value guiding, select a random action
|
144 |
+
selected_index = np.random.randint(0, batch_size)
|
145 |
+
denorm_actions = denorm_actions[selected_index, 0]
|
146 |
+
return denorm_actions
|
ppdiffusers/fastdeploy_utils.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import os
|
18 |
+
import shutil
|
19 |
+
from pathlib import Path
|
20 |
+
from typing import Optional, Union
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
from .download_utils import ppdiffusers_bos_download
|
25 |
+
from .utils import (
|
26 |
+
FASTDEPLOY_MODEL_NAME,
|
27 |
+
FASTDEPLOY_WEIGHTS_NAME,
|
28 |
+
is_fastdeploy_available,
|
29 |
+
is_paddle_available,
|
30 |
+
logging,
|
31 |
+
)
|
32 |
+
|
33 |
+
if is_paddle_available():
|
34 |
+
import paddle
|
35 |
+
|
36 |
+
|
37 |
+
if is_fastdeploy_available():
|
38 |
+
import fastdeploy as fd
|
39 |
+
|
40 |
+
def fdtensor2pdtensor(fdtensor: fd.C.FDTensor):
|
41 |
+
dltensor = fdtensor.to_dlpack()
|
42 |
+
pdtensor = paddle.utils.dlpack.from_dlpack(dltensor)
|
43 |
+
return pdtensor
|
44 |
+
|
45 |
+
def pdtensor2fdtensor(pdtensor: paddle.Tensor, name: str = "", share_with_raw_ptr=False):
|
46 |
+
if not share_with_raw_ptr:
|
47 |
+
dltensor = paddle.utils.dlpack.to_dlpack(pdtensor)
|
48 |
+
return fd.C.FDTensor.from_dlpack(name, dltensor)
|
49 |
+
else:
|
50 |
+
return fd.C.FDTensor.from_external_data(
|
51 |
+
name,
|
52 |
+
pdtensor.data_ptr(),
|
53 |
+
pdtensor.shape,
|
54 |
+
pdtensor.dtype.name,
|
55 |
+
str(pdtensor.place),
|
56 |
+
int(pdtensor.place.gpu_device_id()),
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
logger = logging.get_logger(__name__)
|
61 |
+
|
62 |
+
|
63 |
+
class FastDeployRuntimeModel:
|
64 |
+
def __init__(self, model=None, **kwargs):
|
65 |
+
logger.info("`ppdiffusers.FastDeployRuntimeModel` is experimental and might change in the future.")
|
66 |
+
self.model = model
|
67 |
+
self.model_save_dir = kwargs.get("model_save_dir", None)
|
68 |
+
self.latest_model_name = kwargs.get("latest_model_name", "inference.pdmodel")
|
69 |
+
self.latest_params_name = kwargs.get("latest_params_name", "inference.pdiparams")
|
70 |
+
|
71 |
+
def zero_copy_infer(self, prebinded_inputs: dict, prebinded_outputs: dict, share_with_raw_ptr=True, **kwargs):
|
72 |
+
"""
|
73 |
+
Execute inference without copying data from cpu to gpu.
|
74 |
+
|
75 |
+
Arguments:
|
76 |
+
kwargs (`dict(name, paddle.Tensor)`):
|
77 |
+
An input map from name to tensor.
|
78 |
+
Return:
|
79 |
+
List of output tensor.
|
80 |
+
"""
|
81 |
+
for inputs_name, inputs_tensor in prebinded_inputs.items():
|
82 |
+
input_fdtensor = pdtensor2fdtensor(inputs_tensor, inputs_name, share_with_raw_ptr=share_with_raw_ptr)
|
83 |
+
self.model.bind_input_tensor(inputs_name, input_fdtensor)
|
84 |
+
|
85 |
+
for outputs_name, outputs_tensor in prebinded_outputs.items():
|
86 |
+
output_fdtensor = pdtensor2fdtensor(outputs_tensor, outputs_name, share_with_raw_ptr=share_with_raw_ptr)
|
87 |
+
self.model.bind_output_tensor(outputs_name, output_fdtensor)
|
88 |
+
|
89 |
+
self.model.zero_copy_infer()
|
90 |
+
|
91 |
+
def __call__(self, **kwargs):
|
92 |
+
inputs = {k: np.array(v) for k, v in kwargs.items()}
|
93 |
+
return self.model.infer(inputs)
|
94 |
+
|
95 |
+
@staticmethod
|
96 |
+
def load_model(
|
97 |
+
model_path: Union[str, Path],
|
98 |
+
params_path: Union[str, Path],
|
99 |
+
runtime_options: Optional["fd.RuntimeOption"] = None,
|
100 |
+
):
|
101 |
+
"""
|
102 |
+
Loads an FastDeploy Inference Model with fastdeploy.RuntimeOption
|
103 |
+
|
104 |
+
Arguments:
|
105 |
+
model_path (`str` or `Path`):
|
106 |
+
Model path from which to load
|
107 |
+
params_path (`str` or `Path`):
|
108 |
+
Params path from which to load
|
109 |
+
runtime_options (fd.RuntimeOption, *optional*):
|
110 |
+
The RuntimeOption of fastdeploy to initialize the fastdeploy runtime. Default setting
|
111 |
+
the device to cpu and the backend to paddle inference
|
112 |
+
"""
|
113 |
+
option = runtime_options
|
114 |
+
if option is None or not isinstance(runtime_options, fd.RuntimeOption):
|
115 |
+
logger.info("No fastdeploy.RuntimeOption specified, using CPU device and paddle inference backend.")
|
116 |
+
option = fd.RuntimeOption()
|
117 |
+
option.use_paddle_backend()
|
118 |
+
option.use_cpu()
|
119 |
+
option.set_model_path(model_path, params_path)
|
120 |
+
return fd.Runtime(option)
|
121 |
+
|
122 |
+
def _save_pretrained(
|
123 |
+
self,
|
124 |
+
save_directory: Union[str, Path],
|
125 |
+
model_file_name: Optional[str] = None,
|
126 |
+
params_file_name: Optional[str] = None,
|
127 |
+
**kwargs
|
128 |
+
):
|
129 |
+
"""
|
130 |
+
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
131 |
+
[`~FastDeployRuntimeModel.from_pretrained`] class method. It will always save the
|
132 |
+
latest_model_name.
|
133 |
+
|
134 |
+
Arguments:
|
135 |
+
save_directory (`str` or `Path`):
|
136 |
+
Directory where to save the model file.
|
137 |
+
model_file_name(`str`, *optional*):
|
138 |
+
Overwrites the default model file name from `"inference.pdmodel"` to `model_file_name`. This allows you to save the
|
139 |
+
model with a different name.
|
140 |
+
params_file_name(`str`, *optional*):
|
141 |
+
Overwrites the default model file name from `"inference.pdiparams"` to `params_file_name`. This allows you to save the
|
142 |
+
model with a different name.
|
143 |
+
"""
|
144 |
+
|
145 |
+
model_file_name = model_file_name if model_file_name is not None else FASTDEPLOY_MODEL_NAME
|
146 |
+
params_file_name = params_file_name if params_file_name is not None else FASTDEPLOY_WEIGHTS_NAME
|
147 |
+
|
148 |
+
src_model_path = self.model_save_dir.joinpath(self.latest_model_name)
|
149 |
+
dst_model_path = Path(save_directory).joinpath(model_file_name)
|
150 |
+
|
151 |
+
src_params_path = self.model_save_dir.joinpath(self.latest_params_name)
|
152 |
+
dst_params_path = Path(save_directory).joinpath(params_file_name)
|
153 |
+
try:
|
154 |
+
shutil.copyfile(src_model_path, dst_model_path)
|
155 |
+
shutil.copyfile(src_params_path, dst_params_path)
|
156 |
+
except shutil.SameFileError:
|
157 |
+
pass
|
158 |
+
|
159 |
+
def save_pretrained(
|
160 |
+
self,
|
161 |
+
save_directory: Union[str, os.PathLike],
|
162 |
+
**kwargs,
|
163 |
+
):
|
164 |
+
"""
|
165 |
+
Save a model to a directory, so that it can be re-loaded using the [`~FastDeployRuntimeModel.from_pretrained`] class
|
166 |
+
method.:
|
167 |
+
|
168 |
+
Arguments:
|
169 |
+
save_directory (`str` or `os.PathLike`):
|
170 |
+
Directory to which to save. Will be created if it doesn't exist.
|
171 |
+
"""
|
172 |
+
if os.path.isfile(save_directory):
|
173 |
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
174 |
+
return
|
175 |
+
|
176 |
+
os.makedirs(save_directory, exist_ok=True)
|
177 |
+
|
178 |
+
# saving model weights/files
|
179 |
+
self._save_pretrained(save_directory, **kwargs)
|
180 |
+
|
181 |
+
@classmethod
|
182 |
+
def _from_pretrained(
|
183 |
+
cls,
|
184 |
+
pretrained_model_name_or_path: Union[str, Path],
|
185 |
+
cache_dir: Optional[str] = None,
|
186 |
+
model_file_name: Optional[str] = None,
|
187 |
+
params_file_name: Optional[str] = None,
|
188 |
+
runtime_options: Optional["fd.RuntimeOption"] = None,
|
189 |
+
**kwargs,
|
190 |
+
):
|
191 |
+
"""
|
192 |
+
Load a model from a directory or the BOS.
|
193 |
+
|
194 |
+
Arguments:
|
195 |
+
pretrained_model_name_or_path (`str` or `Path`):
|
196 |
+
Directory from which to load
|
197 |
+
cache_dir (`Union[str, Path]`, *optional*):
|
198 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
199 |
+
standard cache should not be used.
|
200 |
+
model_file_name (`str`):
|
201 |
+
Overwrites the default model file name from `"inference.pdmodel"` to `file_name`. This allows you to load
|
202 |
+
different model files from the same repository or directory.
|
203 |
+
params_file_name (`str`):
|
204 |
+
Overwrites the default params file name from `"inference.pdiparams"` to `file_name`. This allows you to load
|
205 |
+
different model files from the same repository or directory.
|
206 |
+
runtime_options (`fastdeploy.RuntimeOption`, *optional*):
|
207 |
+
The RuntimeOption of fastdeploy.
|
208 |
+
kwargs (`Dict`, *optional*):
|
209 |
+
kwargs will be passed to the model during initialization
|
210 |
+
"""
|
211 |
+
model_file_name = model_file_name if model_file_name is not None else FASTDEPLOY_MODEL_NAME
|
212 |
+
params_file_name = params_file_name if params_file_name is not None else FASTDEPLOY_WEIGHTS_NAME
|
213 |
+
# load model from local directory
|
214 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
215 |
+
model = FastDeployRuntimeModel.load_model(
|
216 |
+
os.path.join(pretrained_model_name_or_path, model_file_name),
|
217 |
+
os.path.join(pretrained_model_name_or_path, params_file_name),
|
218 |
+
runtime_options=runtime_options,
|
219 |
+
)
|
220 |
+
kwargs["model_save_dir"] = Path(pretrained_model_name_or_path)
|
221 |
+
# load model from hub
|
222 |
+
else:
|
223 |
+
# download model
|
224 |
+
model_cache_path = ppdiffusers_bos_download(
|
225 |
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
226 |
+
filename=model_file_name,
|
227 |
+
cache_dir=cache_dir,
|
228 |
+
)
|
229 |
+
# download params
|
230 |
+
params_cache_path = ppdiffusers_bos_download(
|
231 |
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
232 |
+
filename=params_file_name,
|
233 |
+
cache_dir=cache_dir,
|
234 |
+
)
|
235 |
+
kwargs["model_save_dir"] = Path(model_cache_path).parent
|
236 |
+
kwargs["latest_model_name"] = Path(model_cache_path).name
|
237 |
+
kwargs["latest_params_name"] = Path(params_cache_path).name
|
238 |
+
model = FastDeployRuntimeModel.load_model(
|
239 |
+
model_cache_path, params_cache_path, runtime_options=runtime_options
|
240 |
+
)
|
241 |
+
return cls(model=model, **kwargs)
|
242 |
+
|
243 |
+
@classmethod
|
244 |
+
def from_pretrained(
|
245 |
+
cls,
|
246 |
+
pretrained_model_name_or_path: Union[str, Path],
|
247 |
+
cache_dir: Optional[str] = None,
|
248 |
+
model_file_name: Optional[str] = None,
|
249 |
+
params_file_name: Optional[str] = None,
|
250 |
+
runtime_options: Optional["fd.RuntimeOption"] = None,
|
251 |
+
**model_kwargs,
|
252 |
+
):
|
253 |
+
return cls._from_pretrained(
|
254 |
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
255 |
+
cache_dir=cache_dir,
|
256 |
+
model_file_name=model_file_name,
|
257 |
+
params_file_name=params_file_name,
|
258 |
+
runtime_options=runtime_options,
|
259 |
+
**model_kwargs,
|
260 |
+
)
|
ppdiffusers/initializer.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""
|
16 |
+
This code is based on https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
|
17 |
+
Ths copyright of pytorch/pytorch is a BSD-style license, as found in the LICENSE file.
|
18 |
+
"""
|
19 |
+
|
20 |
+
import math
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import paddle
|
24 |
+
import paddle.nn as nn
|
25 |
+
|
26 |
+
__all__ = [
|
27 |
+
"uniform_",
|
28 |
+
"normal_",
|
29 |
+
"constant_",
|
30 |
+
"ones_",
|
31 |
+
"zeros_",
|
32 |
+
"xavier_uniform_",
|
33 |
+
"xavier_normal_",
|
34 |
+
"kaiming_uniform_",
|
35 |
+
"kaiming_normal_",
|
36 |
+
"linear_init_",
|
37 |
+
"conv_init_",
|
38 |
+
"reset_initialized_parameter",
|
39 |
+
]
|
40 |
+
|
41 |
+
|
42 |
+
def _no_grad_uniform_(tensor, a, b):
|
43 |
+
with paddle.no_grad():
|
44 |
+
tensor.set_value(paddle.uniform(shape=tensor.shape, dtype=tensor.dtype, min=a, max=b))
|
45 |
+
return tensor
|
46 |
+
|
47 |
+
|
48 |
+
def _no_grad_normal_(tensor, mean=0.0, std=1.0):
|
49 |
+
with paddle.no_grad():
|
50 |
+
tensor.set_value(paddle.normal(mean=mean, std=std, shape=tensor.shape))
|
51 |
+
return tensor
|
52 |
+
|
53 |
+
|
54 |
+
def _no_grad_fill_(tensor, value=0.0):
|
55 |
+
with paddle.no_grad():
|
56 |
+
tensor.set_value(paddle.full_like(tensor, value, dtype=tensor.dtype))
|
57 |
+
return tensor
|
58 |
+
|
59 |
+
|
60 |
+
def uniform_(tensor, a, b):
|
61 |
+
"""
|
62 |
+
Modified tensor inspace using uniform_
|
63 |
+
Args:
|
64 |
+
tensor (paddle.Tensor): paddle Tensor
|
65 |
+
a (float|int): min value.
|
66 |
+
b (float|int): max value.
|
67 |
+
Return:
|
68 |
+
tensor
|
69 |
+
"""
|
70 |
+
return _no_grad_uniform_(tensor, a, b)
|
71 |
+
|
72 |
+
|
73 |
+
def normal_(tensor, mean=0.0, std=1.0):
|
74 |
+
"""
|
75 |
+
Modified tensor inspace using normal_
|
76 |
+
Args:
|
77 |
+
tensor (paddle.Tensor): paddle Tensor
|
78 |
+
mean (float|int): mean value.
|
79 |
+
std (float|int): std value.
|
80 |
+
Return:
|
81 |
+
tensor
|
82 |
+
"""
|
83 |
+
return _no_grad_normal_(tensor, mean, std)
|
84 |
+
|
85 |
+
|
86 |
+
def constant_(tensor, value=0.0):
|
87 |
+
"""
|
88 |
+
Modified tensor inspace using constant_
|
89 |
+
Args:
|
90 |
+
tensor (paddle.Tensor): paddle Tensor
|
91 |
+
value (float|int): value to fill tensor.
|
92 |
+
Return:
|
93 |
+
tensor
|
94 |
+
"""
|
95 |
+
return _no_grad_fill_(tensor, value)
|
96 |
+
|
97 |
+
|
98 |
+
def ones_(tensor):
|
99 |
+
"""
|
100 |
+
Modified tensor inspace using ones_
|
101 |
+
Args:
|
102 |
+
tensor (paddle.Tensor): paddle Tensor
|
103 |
+
Return:
|
104 |
+
tensor
|
105 |
+
"""
|
106 |
+
return _no_grad_fill_(tensor, 1)
|
107 |
+
|
108 |
+
|
109 |
+
def zeros_(tensor):
|
110 |
+
"""
|
111 |
+
Modified tensor inspace using zeros_
|
112 |
+
Args:
|
113 |
+
tensor (paddle.Tensor): paddle Tensor
|
114 |
+
Return:
|
115 |
+
tensor
|
116 |
+
"""
|
117 |
+
return _no_grad_fill_(tensor, 0)
|
118 |
+
|
119 |
+
|
120 |
+
def vector_(tensor, vector):
|
121 |
+
with paddle.no_grad():
|
122 |
+
tensor.set_value(paddle.to_tensor(vector, dtype=tensor.dtype))
|
123 |
+
return tensor
|
124 |
+
|
125 |
+
|
126 |
+
def _calculate_fan_in_and_fan_out(tensor, reverse=False):
|
127 |
+
"""
|
128 |
+
Calculate (fan_in, _fan_out) for tensor
|
129 |
+
Args:
|
130 |
+
tensor (Tensor): paddle.Tensor
|
131 |
+
reverse (bool: False): tensor data format order, False by default as [fout, fin, ...]. e.g. : conv.weight [cout, cin, kh, kw] is False; linear.weight [cin, cout] is True
|
132 |
+
Return:
|
133 |
+
Tuple[fan_in, fan_out]
|
134 |
+
"""
|
135 |
+
if tensor.ndim < 2:
|
136 |
+
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
|
137 |
+
|
138 |
+
if reverse:
|
139 |
+
num_input_fmaps, num_output_fmaps = tensor.shape[0], tensor.shape[1]
|
140 |
+
else:
|
141 |
+
num_input_fmaps, num_output_fmaps = tensor.shape[1], tensor.shape[0]
|
142 |
+
|
143 |
+
receptive_field_size = 1
|
144 |
+
if tensor.ndim > 2:
|
145 |
+
receptive_field_size = np.prod(tensor.shape[2:])
|
146 |
+
|
147 |
+
fan_in = num_input_fmaps * receptive_field_size
|
148 |
+
fan_out = num_output_fmaps * receptive_field_size
|
149 |
+
|
150 |
+
return fan_in, fan_out
|
151 |
+
|
152 |
+
|
153 |
+
def xavier_uniform_(tensor, gain=1.0, reverse=False):
|
154 |
+
"""
|
155 |
+
Modified tensor inspace using xavier_uniform_
|
156 |
+
Args:
|
157 |
+
tensor (paddle.Tensor): paddle Tensor
|
158 |
+
gain (float): super parameter, 1. default.
|
159 |
+
reverse (bool): reverse (bool: False): tensor data format order, False by default as [fout, fin, ...].
|
160 |
+
Return:
|
161 |
+
tensor
|
162 |
+
"""
|
163 |
+
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor, reverse=reverse)
|
164 |
+
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
|
165 |
+
k = math.sqrt(3.0) * std
|
166 |
+
return _no_grad_uniform_(tensor, -k, k)
|
167 |
+
|
168 |
+
|
169 |
+
def xavier_normal_(tensor, gain=1.0, reverse=False):
|
170 |
+
"""
|
171 |
+
Modified tensor inspace using xavier_normal_
|
172 |
+
Args:
|
173 |
+
tensor (paddle.Tensor): paddle Tensor
|
174 |
+
gain (float): super parameter, 1. default.
|
175 |
+
reverse (bool): reverse (bool: False): tensor data format order, False by default as [fout, fin, ...].
|
176 |
+
Return:
|
177 |
+
tensor
|
178 |
+
"""
|
179 |
+
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor, reverse=reverse)
|
180 |
+
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
|
181 |
+
return _no_grad_normal_(tensor, 0, std)
|
182 |
+
|
183 |
+
|
184 |
+
# reference: https://pytorch.org/docs/stable/_modules/torch/nn/init.html
|
185 |
+
def _calculate_correct_fan(tensor, mode, reverse=False):
|
186 |
+
mode = mode.lower()
|
187 |
+
valid_modes = ["fan_in", "fan_out"]
|
188 |
+
if mode not in valid_modes:
|
189 |
+
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
|
190 |
+
|
191 |
+
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor, reverse)
|
192 |
+
|
193 |
+
return fan_in if mode == "fan_in" else fan_out
|
194 |
+
|
195 |
+
|
196 |
+
def _calculate_gain(nonlinearity, param=None):
|
197 |
+
linear_fns = ["linear", "conv1d", "conv2d", "conv3d", "conv_transpose1d", "conv_transpose2d", "conv_transpose3d"]
|
198 |
+
if nonlinearity in linear_fns or nonlinearity == "sigmoid":
|
199 |
+
return 1
|
200 |
+
elif nonlinearity == "tanh":
|
201 |
+
return 5.0 / 3
|
202 |
+
elif nonlinearity == "relu":
|
203 |
+
return math.sqrt(2.0)
|
204 |
+
elif nonlinearity == "leaky_relu":
|
205 |
+
if param is None:
|
206 |
+
negative_slope = 0.01
|
207 |
+
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
|
208 |
+
# True/False are instances of int, hence check above
|
209 |
+
negative_slope = param
|
210 |
+
else:
|
211 |
+
raise ValueError("negative_slope {} not a valid number".format(param))
|
212 |
+
return math.sqrt(2.0 / (1 + negative_slope**2))
|
213 |
+
elif nonlinearity == "selu":
|
214 |
+
return 3.0 / 4
|
215 |
+
else:
|
216 |
+
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
217 |
+
|
218 |
+
|
219 |
+
def kaiming_uniform_(tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", reverse=False):
|
220 |
+
"""
|
221 |
+
Modified tensor inspace using kaiming_uniform method
|
222 |
+
Args:
|
223 |
+
tensor (paddle.Tensor): paddle Tensor
|
224 |
+
mode (str): ['fan_in', 'fan_out'], 'fin_in' defalut
|
225 |
+
nonlinearity (str): nonlinearity method name
|
226 |
+
reverse (bool): reverse (bool: False): tensor data format order, False by default as [fout, fin, ...].
|
227 |
+
Return:
|
228 |
+
tensor
|
229 |
+
"""
|
230 |
+
fan = _calculate_correct_fan(tensor, mode, reverse)
|
231 |
+
gain = _calculate_gain(nonlinearity, a)
|
232 |
+
std = gain / math.sqrt(fan)
|
233 |
+
k = math.sqrt(3.0) * std
|
234 |
+
return _no_grad_uniform_(tensor, -k, k)
|
235 |
+
|
236 |
+
|
237 |
+
def kaiming_normal_(tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", reverse=False):
|
238 |
+
"""
|
239 |
+
Modified tensor inspace using kaiming_normal_
|
240 |
+
Args:
|
241 |
+
tensor (paddle.Tensor): paddle Tensor
|
242 |
+
mode (str): ['fan_in', 'fan_out'], 'fin_in' defalut
|
243 |
+
nonlinearity (str): nonlinearity method name
|
244 |
+
reverse (bool): reverse (bool: False): tensor data format order, False by default as [fout, fin, ...].
|
245 |
+
Return:
|
246 |
+
tensor
|
247 |
+
"""
|
248 |
+
fan = _calculate_correct_fan(tensor, mode, reverse)
|
249 |
+
gain = _calculate_gain(nonlinearity, a)
|
250 |
+
std = gain / math.sqrt(fan)
|
251 |
+
return _no_grad_normal_(tensor, 0, std)
|
252 |
+
|
253 |
+
|
254 |
+
def linear_init_(module):
|
255 |
+
bound = 1 / math.sqrt(module.weight.shape[0])
|
256 |
+
uniform_(module.weight, -bound, bound)
|
257 |
+
uniform_(module.bias, -bound, bound)
|
258 |
+
|
259 |
+
|
260 |
+
def conv_init_(module):
|
261 |
+
bound = 1 / np.sqrt(np.prod(module.weight.shape[1:]))
|
262 |
+
uniform_(module.weight, -bound, bound)
|
263 |
+
if module.bias is not None:
|
264 |
+
uniform_(module.bias, -bound, bound)
|
265 |
+
|
266 |
+
|
267 |
+
def bias_init_with_prob(prior_prob=0.01):
|
268 |
+
"""initialize conv/fc bias value according to a given probability value."""
|
269 |
+
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
|
270 |
+
return bias_init
|
271 |
+
|
272 |
+
|
273 |
+
@paddle.no_grad()
|
274 |
+
def reset_initialized_parameter(model, include_self=True):
|
275 |
+
"""
|
276 |
+
Reset initialized parameter using following method for [conv, linear, embedding, bn]
|
277 |
+
Args:
|
278 |
+
model (paddle.Layer): paddle Layer
|
279 |
+
include_self (bool: False): include_self for Layer.named_sublayers method. Indicate whether including itself
|
280 |
+
Return:
|
281 |
+
None
|
282 |
+
"""
|
283 |
+
for _, m in model.named_sublayers(include_self=include_self):
|
284 |
+
if isinstance(m, nn.Conv2D):
|
285 |
+
k = float(m._groups) / (m._in_channels * m._kernel_size[0] * m._kernel_size[1])
|
286 |
+
k = math.sqrt(k)
|
287 |
+
_no_grad_uniform_(m.weight, -k, k)
|
288 |
+
if hasattr(m, "bias") and getattr(m, "bias") is not None:
|
289 |
+
_no_grad_uniform_(m.bias, -k, k)
|
290 |
+
|
291 |
+
elif isinstance(m, nn.Linear):
|
292 |
+
k = math.sqrt(1.0 / m.weight.shape[0])
|
293 |
+
_no_grad_uniform_(m.weight, -k, k)
|
294 |
+
if hasattr(m, "bias") and getattr(m, "bias") is not None:
|
295 |
+
_no_grad_uniform_(m.bias, -k, k)
|
296 |
+
|
297 |
+
elif isinstance(m, nn.Embedding):
|
298 |
+
_no_grad_normal_(m.weight, mean=0.0, std=1.0)
|
299 |
+
|
300 |
+
elif isinstance(m, (nn.BatchNorm2D, nn.LayerNorm)):
|
301 |
+
_no_grad_fill_(m.weight, 1.0)
|
302 |
+
if hasattr(m, "bias") and getattr(m, "bias") is not None:
|
303 |
+
_no_grad_fill_(m.bias, 0)
|
ppdiffusers/loaders.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import os
|
16 |
+
from collections import defaultdict
|
17 |
+
from typing import Callable, Dict, Union
|
18 |
+
|
19 |
+
import paddle
|
20 |
+
import paddle.nn as nn
|
21 |
+
|
22 |
+
from .modeling_utils import _get_model_file, load_dict
|
23 |
+
from .models.cross_attention import LoRACrossAttnProcessor
|
24 |
+
from .utils import HF_CACHE, PPDIFFUSERS_CACHE, logging
|
25 |
+
|
26 |
+
logger = logging.get_logger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
LORA_WEIGHT_NAME = "paddle_lora_weights.pdparams"
|
30 |
+
|
31 |
+
|
32 |
+
class AttnProcsLayers(nn.Layer):
|
33 |
+
def __init__(self, state_dict: Dict[str, paddle.Tensor]):
|
34 |
+
super().__init__()
|
35 |
+
self.layers = nn.LayerList(state_dict.values())
|
36 |
+
self.mapping = {k: v for k, v in enumerate(state_dict.keys())}
|
37 |
+
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
|
38 |
+
|
39 |
+
# we add a hook to state_dict() and load_state_dict() so that the
|
40 |
+
# naming fits with `unet.attn_processors`
|
41 |
+
def map_to(state_dict, *args, **kwargs):
|
42 |
+
new_state_dict = {}
|
43 |
+
for key, value in state_dict.items():
|
44 |
+
num = int(key.split(".")[1]) # 0 is always "layers"
|
45 |
+
new_key = key.replace(f"layers.{num}", self.mapping[num])
|
46 |
+
new_state_dict[new_key] = value
|
47 |
+
|
48 |
+
return new_state_dict
|
49 |
+
|
50 |
+
def map_from(module, state_dict, *args, **kwargs):
|
51 |
+
all_keys = list(state_dict.keys())
|
52 |
+
for key in all_keys:
|
53 |
+
replace_key = key.split(".processor")[0] + ".processor"
|
54 |
+
new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
|
55 |
+
state_dict[new_key] = state_dict[key]
|
56 |
+
del state_dict[key]
|
57 |
+
|
58 |
+
self.register_state_dict_hook(map_to)
|
59 |
+
self.register_load_state_dict_pre_hook(map_from, with_module=True)
|
60 |
+
|
61 |
+
|
62 |
+
class UNet2DConditionLoadersMixin:
|
63 |
+
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, paddle.Tensor]], **kwargs):
|
64 |
+
r"""
|
65 |
+
Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be
|
66 |
+
defined in
|
67 |
+
[cross_attention.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py)
|
68 |
+
and be a `paddle.nn.Layer` class.
|
69 |
+
<Tip warning={true}>
|
70 |
+
This function is experimental and might change in the future
|
71 |
+
</Tip>
|
72 |
+
Parameters:
|
73 |
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
74 |
+
Can be either:
|
75 |
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
76 |
+
Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
|
77 |
+
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
|
78 |
+
`./my_model_directory/`.
|
79 |
+
- A [paddle state
|
80 |
+
dict].
|
81 |
+
from_hf_hub (bool, optional): whether to load from Huggingface Hub.
|
82 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
83 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
84 |
+
standard cache should not be used.
|
85 |
+
subfolder (`str`, *optional*, defaults to `None`):
|
86 |
+
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
87 |
+
huggingface.co or downloaded locally), you can specify the folder name here.
|
88 |
+
"""
|
89 |
+
|
90 |
+
from_hf_hub = kwargs.pop("from_hf_hub", False)
|
91 |
+
if from_hf_hub:
|
92 |
+
cache_dir = kwargs.pop("cache_dir", HF_CACHE)
|
93 |
+
else:
|
94 |
+
cache_dir = kwargs.pop("cache_dir", PPDIFFUSERS_CACHE)
|
95 |
+
subfolder = kwargs.pop("subfolder", None)
|
96 |
+
weight_name = kwargs.pop("weight_name", LORA_WEIGHT_NAME)
|
97 |
+
|
98 |
+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
99 |
+
model_file = _get_model_file(
|
100 |
+
pretrained_model_name_or_path_or_dict,
|
101 |
+
weights_name=weight_name,
|
102 |
+
cache_dir=cache_dir,
|
103 |
+
subfolder=subfolder,
|
104 |
+
from_hf_hub=from_hf_hub,
|
105 |
+
)
|
106 |
+
state_dict = load_dict(model_file, map_location="cpu")
|
107 |
+
else:
|
108 |
+
state_dict = pretrained_model_name_or_path_or_dict
|
109 |
+
|
110 |
+
# fill attn processors
|
111 |
+
attn_processors = {}
|
112 |
+
|
113 |
+
is_lora = all("lora" in k for k in state_dict.keys())
|
114 |
+
|
115 |
+
if is_lora:
|
116 |
+
lora_grouped_dict = defaultdict(dict)
|
117 |
+
for key, value in state_dict.items():
|
118 |
+
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
119 |
+
lora_grouped_dict[attn_processor_key][sub_key] = value
|
120 |
+
|
121 |
+
for key, value_dict in lora_grouped_dict.items():
|
122 |
+
rank = value_dict["to_k_lora.down.weight"].shape[1] # 0 -> 1, torch vs paddle nn.Linear
|
123 |
+
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[0] # 1 -> 0, torch vs paddle nn.Linear
|
124 |
+
hidden_size = value_dict["to_k_lora.up.weight"].shape[1] # 0 -> 1, torch vs paddle nn.Linear
|
125 |
+
|
126 |
+
attn_processors[key] = LoRACrossAttnProcessor(
|
127 |
+
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
|
128 |
+
)
|
129 |
+
attn_processors[key].load_dict(value_dict)
|
130 |
+
|
131 |
+
else:
|
132 |
+
raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
|
133 |
+
|
134 |
+
# set correct dtype & device
|
135 |
+
attn_processors = {k: v.to(dtype=self.dtype) for k, v in attn_processors.items()}
|
136 |
+
|
137 |
+
# set layers
|
138 |
+
self.set_attn_processor(attn_processors)
|
139 |
+
|
140 |
+
def save_attn_procs(
|
141 |
+
self,
|
142 |
+
save_directory: Union[str, os.PathLike],
|
143 |
+
is_main_process: bool = True,
|
144 |
+
weights_name: str = LORA_WEIGHT_NAME,
|
145 |
+
save_function: Callable = None,
|
146 |
+
):
|
147 |
+
r"""
|
148 |
+
Save an attention procesor to a directory, so that it can be re-loaded using the
|
149 |
+
`[`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`]` method.
|
150 |
+
Arguments:
|
151 |
+
save_directory (`str` or `os.PathLike`):
|
152 |
+
Directory to which to save. Will be created if it doesn't exist.
|
153 |
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
154 |
+
Whether the process calling this is the main process or not. Useful when in distributed training like
|
155 |
+
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
|
156 |
+
the main process to avoid race conditions.
|
157 |
+
weights_name (`str`, *optional*, defaults to `LORA_WEIGHT_NAME`):
|
158 |
+
The name of weights.
|
159 |
+
save_function (`Callable`):
|
160 |
+
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
161 |
+
need to replace `torch.save` by another method. Can be configured with the environment variable
|
162 |
+
`DIFFUSERS_SAVE_MODE`.
|
163 |
+
"""
|
164 |
+
if os.path.isfile(save_directory):
|
165 |
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
166 |
+
return
|
167 |
+
|
168 |
+
if save_function is None:
|
169 |
+
save_function = paddle.save
|
170 |
+
|
171 |
+
os.makedirs(save_directory, exist_ok=True)
|
172 |
+
|
173 |
+
model_to_save = AttnProcsLayers(self.attn_processors)
|
174 |
+
|
175 |
+
# Save the model
|
176 |
+
state_dict = model_to_save.state_dict()
|
177 |
+
|
178 |
+
# Clean the folder from a previous save
|
179 |
+
for filename in os.listdir(save_directory):
|
180 |
+
full_filename = os.path.join(save_directory, filename)
|
181 |
+
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
|
182 |
+
# in distributed settings to avoid race conditions.
|
183 |
+
weights_no_suffix = weights_name.replace(".pdparams", "")
|
184 |
+
if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
|
185 |
+
os.remove(full_filename)
|
186 |
+
|
187 |
+
# Save the model
|
188 |
+
save_function(state_dict, os.path.join(save_directory, weights_name))
|
189 |
+
|
190 |
+
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
|
ppdiffusers/modeling_paddle_pytorch_utils.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch - Paddle general utilities."""
|
16 |
+
import re
|
17 |
+
|
18 |
+
from .utils import logging
|
19 |
+
|
20 |
+
logger = logging.get_logger(__name__)
|
21 |
+
|
22 |
+
|
23 |
+
def rename_key(key):
|
24 |
+
regex = r"\w+[.]\d+"
|
25 |
+
pats = re.findall(regex, key)
|
26 |
+
for pat in pats:
|
27 |
+
key = key.replace(pat, "_".join(pat.split(".")))
|
28 |
+
return key
|
29 |
+
|
30 |
+
|
31 |
+
#####################
|
32 |
+
# PyTorch => Paddle #
|
33 |
+
#####################
|
34 |
+
|
35 |
+
|
36 |
+
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_paddle_state_dict):
|
37 |
+
"""Rename PT weight names to corresponding Paddle weight names and reshape tensor if necessary"""
|
38 |
+
|
39 |
+
# conv norm or layer norm
|
40 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
|
41 |
+
if (
|
42 |
+
any("norm" in str_ for str_ in pt_tuple_key)
|
43 |
+
and (pt_tuple_key[-1] in ["bias", "beta"])
|
44 |
+
and (pt_tuple_key[:-1] + ("bias",) in random_paddle_state_dict)
|
45 |
+
):
|
46 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
|
47 |
+
return renamed_pt_tuple_key, pt_tensor
|
48 |
+
elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("bias",) in random_paddle_state_dict:
|
49 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
|
50 |
+
return renamed_pt_tuple_key, pt_tensor
|
51 |
+
|
52 |
+
# embedding
|
53 |
+
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("weight",) in random_paddle_state_dict:
|
54 |
+
pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
|
55 |
+
return renamed_pt_tuple_key, pt_tensor
|
56 |
+
|
57 |
+
# conv layer
|
58 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
|
59 |
+
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
|
60 |
+
return renamed_pt_tuple_key, pt_tensor
|
61 |
+
|
62 |
+
# linear layer
|
63 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
|
64 |
+
if pt_tuple_key[-1] == "weight":
|
65 |
+
pt_tensor = pt_tensor.t()
|
66 |
+
return renamed_pt_tuple_key, pt_tensor
|
67 |
+
|
68 |
+
# old PyTorch layer norm weight
|
69 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
|
70 |
+
if pt_tuple_key[-1] == "gamma":
|
71 |
+
return renamed_pt_tuple_key, pt_tensor
|
72 |
+
|
73 |
+
# old PyTorch layer norm bias
|
74 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
|
75 |
+
if pt_tuple_key[-1] == "beta":
|
76 |
+
return renamed_pt_tuple_key, pt_tensor
|
77 |
+
|
78 |
+
return pt_tuple_key, pt_tensor
|
79 |
+
|
80 |
+
|
81 |
+
def convert_pytorch_state_dict_to_paddle(pt_state_dict, paddle_model):
|
82 |
+
# Step 1: Convert pytorch tensor to numpy
|
83 |
+
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
84 |
+
|
85 |
+
random_paddle_state_dict = paddle_model.state_dict
|
86 |
+
paddle_state_dict = {}
|
87 |
+
|
88 |
+
# Need to change some parameters name to match Paddle names
|
89 |
+
for pt_key, pt_tensor in pt_state_dict.items():
|
90 |
+
renamed_pt_key = rename_key(pt_key)
|
91 |
+
pt_tuple_key = tuple(renamed_pt_key.split("."))
|
92 |
+
|
93 |
+
# Correctly rename weight parameters
|
94 |
+
paddle_key, paddle_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_paddle_state_dict)
|
95 |
+
|
96 |
+
if paddle_key in random_paddle_state_dict:
|
97 |
+
if list(paddle_tensor.shape) != list(random_paddle_state_dict[paddle_key].shape):
|
98 |
+
raise ValueError(
|
99 |
+
f"Paddle checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
|
100 |
+
f"{random_paddle_state_dict[paddle_key].shape}, but is {paddle_tensor.shape}."
|
101 |
+
)
|
102 |
+
|
103 |
+
# also add unexpected weight so that warning is thrown
|
104 |
+
paddle_state_dict[paddle_key] = paddle_tensor.numpy()
|
105 |
+
|
106 |
+
return paddle_state_dict
|
ppdiffusers/modeling_utils.py
ADDED
@@ -0,0 +1,619 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
3 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import os
|
18 |
+
import tempfile
|
19 |
+
from functools import partial
|
20 |
+
from typing import Callable, Optional, Union
|
21 |
+
|
22 |
+
import paddle
|
23 |
+
import paddle.nn as nn
|
24 |
+
from huggingface_hub import (
|
25 |
+
create_repo,
|
26 |
+
get_hf_file_metadata,
|
27 |
+
hf_hub_download,
|
28 |
+
hf_hub_url,
|
29 |
+
repo_type_and_id_from_hf_id,
|
30 |
+
upload_folder,
|
31 |
+
)
|
32 |
+
from huggingface_hub.utils import EntryNotFoundError
|
33 |
+
from requests import HTTPError
|
34 |
+
|
35 |
+
from .download_utils import ppdiffusers_bos_download
|
36 |
+
from .utils import (
|
37 |
+
CONFIG_NAME,
|
38 |
+
DOWNLOAD_SERVER,
|
39 |
+
HF_CACHE,
|
40 |
+
PPDIFFUSERS_CACHE,
|
41 |
+
WEIGHTS_NAME,
|
42 |
+
logging,
|
43 |
+
)
|
44 |
+
from .version import VERSION as __version__
|
45 |
+
|
46 |
+
logger = logging.get_logger(__name__)
|
47 |
+
|
48 |
+
|
49 |
+
def unfreeze_params(params):
|
50 |
+
for param in params:
|
51 |
+
param.stop_gradient = False
|
52 |
+
|
53 |
+
|
54 |
+
def freeze_params(params):
|
55 |
+
for param in params:
|
56 |
+
param.stop_gradient = True
|
57 |
+
|
58 |
+
|
59 |
+
# device
|
60 |
+
def get_parameter_device(parameter: nn.Layer):
|
61 |
+
try:
|
62 |
+
return next(parameter.named_parameters())[1].place
|
63 |
+
except StopIteration:
|
64 |
+
return paddle.get_device()
|
65 |
+
|
66 |
+
|
67 |
+
def get_parameter_dtype(parameter: nn.Layer):
|
68 |
+
try:
|
69 |
+
return next(parameter.named_parameters())[1].dtype
|
70 |
+
except StopIteration:
|
71 |
+
return paddle.get_default_dtype()
|
72 |
+
|
73 |
+
|
74 |
+
def load_dict(checkpoint_file: Union[str, os.PathLike], map_location: str = "cpu"):
|
75 |
+
"""
|
76 |
+
Reads a Paddle checkpoint file, returning properly formatted errors if they arise.
|
77 |
+
"""
|
78 |
+
try:
|
79 |
+
if map_location == "cpu":
|
80 |
+
with paddle.device_scope("cpu"):
|
81 |
+
state_dict = paddle.load(checkpoint_file)
|
82 |
+
else:
|
83 |
+
state_dict = paddle.load(checkpoint_file)
|
84 |
+
return state_dict
|
85 |
+
except Exception as e:
|
86 |
+
try:
|
87 |
+
with open(checkpoint_file) as f:
|
88 |
+
if f.read().startswith("version"):
|
89 |
+
raise OSError(
|
90 |
+
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
91 |
+
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
92 |
+
"you cloned."
|
93 |
+
)
|
94 |
+
else:
|
95 |
+
raise ValueError(
|
96 |
+
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
|
97 |
+
"model. Make sure you have saved the model properly."
|
98 |
+
) from e
|
99 |
+
except (UnicodeDecodeError, ValueError):
|
100 |
+
raise OSError(
|
101 |
+
f"Unable to load weights from Paddle checkpoint file for '{checkpoint_file}' "
|
102 |
+
f"at '{checkpoint_file}'. "
|
103 |
+
"If you tried to load a Paddle model from a TF 2.0 checkpoint, please set from_tf=True."
|
104 |
+
)
|
105 |
+
|
106 |
+
|
107 |
+
class ModelMixin(nn.Layer):
|
108 |
+
r"""
|
109 |
+
Base class for all models.
|
110 |
+
|
111 |
+
[`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
|
112 |
+
and saving models.
|
113 |
+
|
114 |
+
- **config_name** ([`str`]) -- A filename under which the model should be stored when calling
|
115 |
+
[`~modeling_utils.ModelMixin.save_pretrained`].
|
116 |
+
"""
|
117 |
+
config_name = CONFIG_NAME
|
118 |
+
_automatically_saved_args = ["_ppdiffusers_version", "_class_name", "_name_or_path"]
|
119 |
+
_supports_gradient_checkpointing = False
|
120 |
+
|
121 |
+
def __init__(self):
|
122 |
+
super().__init__()
|
123 |
+
|
124 |
+
@property
|
125 |
+
def is_gradient_checkpointing(self) -> bool:
|
126 |
+
"""
|
127 |
+
Whether gradient checkpointing is activated for this model or not.
|
128 |
+
|
129 |
+
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
130 |
+
activations".
|
131 |
+
"""
|
132 |
+
return any(
|
133 |
+
hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing
|
134 |
+
for m in self.sublayers(include_self=True)
|
135 |
+
)
|
136 |
+
|
137 |
+
def enable_gradient_checkpointing(self):
|
138 |
+
"""
|
139 |
+
Activates gradient checkpointing for the current model.
|
140 |
+
|
141 |
+
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
142 |
+
activations".
|
143 |
+
"""
|
144 |
+
if not self._supports_gradient_checkpointing:
|
145 |
+
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
146 |
+
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
147 |
+
|
148 |
+
def disable_gradient_checkpointing(self):
|
149 |
+
"""
|
150 |
+
Deactivates gradient checkpointing for the current model.
|
151 |
+
|
152 |
+
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
153 |
+
activations".
|
154 |
+
"""
|
155 |
+
if self._supports_gradient_checkpointing:
|
156 |
+
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
157 |
+
|
158 |
+
def save_pretrained(
|
159 |
+
self,
|
160 |
+
save_directory: Union[str, os.PathLike],
|
161 |
+
is_main_process: bool = True,
|
162 |
+
save_function: Callable = paddle.save,
|
163 |
+
):
|
164 |
+
"""
|
165 |
+
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
166 |
+
`[`~modeling_utils.ModelMixin.from_pretrained`]` class method.
|
167 |
+
|
168 |
+
Arguments:
|
169 |
+
save_directory (`str` or `os.PathLike`):
|
170 |
+
Directory to which to save. Will be created if it doesn't exist.
|
171 |
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
172 |
+
Whether the process calling this is the main process or not. Useful when in distributed training like
|
173 |
+
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
|
174 |
+
the main process to avoid race conditions.
|
175 |
+
save_function (`Callable`):
|
176 |
+
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
177 |
+
need to replace `paddle.save` by another method.
|
178 |
+
"""
|
179 |
+
if os.path.isfile(save_directory):
|
180 |
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
181 |
+
return
|
182 |
+
|
183 |
+
os.makedirs(save_directory, exist_ok=True)
|
184 |
+
|
185 |
+
model_to_save = self
|
186 |
+
|
187 |
+
# Attach architecture to the config
|
188 |
+
# Save the config
|
189 |
+
if is_main_process:
|
190 |
+
model_to_save.save_config(save_directory)
|
191 |
+
|
192 |
+
# Save the model
|
193 |
+
state_dict = model_to_save.state_dict()
|
194 |
+
|
195 |
+
# Clean the folder from a previous save
|
196 |
+
for filename in os.listdir(save_directory):
|
197 |
+
full_filename = os.path.join(save_directory, filename)
|
198 |
+
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
|
199 |
+
# in distributed settings to avoid race conditions.
|
200 |
+
if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename) and is_main_process:
|
201 |
+
os.remove(full_filename)
|
202 |
+
|
203 |
+
# Save the model
|
204 |
+
save_function(state_dict, os.path.join(save_directory, WEIGHTS_NAME))
|
205 |
+
|
206 |
+
logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
|
207 |
+
|
208 |
+
def save_to_hf_hub(
|
209 |
+
self,
|
210 |
+
repo_id: str,
|
211 |
+
private: Optional[bool] = None,
|
212 |
+
subfolder: Optional[str] = None,
|
213 |
+
commit_message: Optional[str] = None,
|
214 |
+
revision: Optional[str] = None,
|
215 |
+
create_pr: bool = False,
|
216 |
+
):
|
217 |
+
"""
|
218 |
+
Uploads all elements of this model to a new HuggingFace Hub repository.
|
219 |
+
Args:
|
220 |
+
repo_id (str): Repository name for your model/tokenizer in the Hub.
|
221 |
+
private (bool, optional): Whether the model/tokenizer is set to private
|
222 |
+
subfolder (str, optional): Push to a subfolder of the repo instead of the root
|
223 |
+
commit_message (str, optional) — The summary / title / first line of the generated commit. Defaults to: f"Upload {path_in_repo} with huggingface_hub"
|
224 |
+
revision (str, optional) — The git revision to commit from. Defaults to the head of the "main" branch.
|
225 |
+
create_pr (boolean, optional) — Whether or not to create a Pull Request with that commit. Defaults to False.
|
226 |
+
If revision is not set, PR is opened against the "main" branch. If revision is set and is a branch, PR is opened against this branch.
|
227 |
+
If revision is set and is not a branch name (example: a commit oid), an RevisionNotFoundError is returned by the server.
|
228 |
+
|
229 |
+
Returns: The url of the commit of your model in the given repository.
|
230 |
+
"""
|
231 |
+
repo_url = create_repo(repo_id, private=private, exist_ok=True)
|
232 |
+
|
233 |
+
# Infer complete repo_id from repo_url
|
234 |
+
# Can be different from the input `repo_id` if repo_owner was implicit
|
235 |
+
_, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
|
236 |
+
|
237 |
+
repo_id = f"{repo_owner}/{repo_name}"
|
238 |
+
|
239 |
+
# Check if README file already exist in repo
|
240 |
+
try:
|
241 |
+
get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
|
242 |
+
has_readme = True
|
243 |
+
except EntryNotFoundError:
|
244 |
+
has_readme = False
|
245 |
+
|
246 |
+
with tempfile.TemporaryDirectory() as root_dir:
|
247 |
+
if subfolder is not None:
|
248 |
+
save_dir = os.path.join(root_dir, subfolder)
|
249 |
+
else:
|
250 |
+
save_dir = root_dir
|
251 |
+
# save model
|
252 |
+
self.save_pretrained(save_dir)
|
253 |
+
# Add readme if does not exist
|
254 |
+
logger.info("README.md not found, adding the default README.md")
|
255 |
+
if not has_readme:
|
256 |
+
with open(os.path.join(root_dir, "README.md"), "w") as f:
|
257 |
+
f.write(f"---\nlibrary_name: ppdiffusers\n---\n# {repo_id}")
|
258 |
+
|
259 |
+
# Upload model and return
|
260 |
+
logger.info(f"Pushing to the {repo_id}. This might take a while")
|
261 |
+
return upload_folder(
|
262 |
+
repo_id=repo_id,
|
263 |
+
repo_type="model",
|
264 |
+
folder_path=root_dir,
|
265 |
+
commit_message=commit_message,
|
266 |
+
revision=revision,
|
267 |
+
create_pr=create_pr,
|
268 |
+
)
|
269 |
+
|
270 |
+
@classmethod
|
271 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
272 |
+
r"""
|
273 |
+
Instantiate a pretrained paddle model from a pre-trained model configuration.
|
274 |
+
|
275 |
+
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
|
276 |
+
the model, you should first set it back in training mode with `model.train()`.
|
277 |
+
|
278 |
+
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
279 |
+
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
280 |
+
task.
|
281 |
+
|
282 |
+
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
283 |
+
weights are discarded.
|
284 |
+
|
285 |
+
Parameters:
|
286 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
287 |
+
Can be either:
|
288 |
+
|
289 |
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
290 |
+
Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
|
291 |
+
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
|
292 |
+
`./my_model_directory/`.
|
293 |
+
|
294 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
295 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
296 |
+
standard cache should not be used.
|
297 |
+
paddle_dtype (`str` or `paddle.dtype`, *optional*):
|
298 |
+
Override the default `paddle.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
299 |
+
will be automatically derived from the model's weights.
|
300 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
301 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
302 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
303 |
+
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
304 |
+
huggingface.co or downloaded locally), you can specify the folder name here.
|
305 |
+
from_hf_hub (bool, *optional*):
|
306 |
+
Whether to load from Hugging Face Hub. Defaults to False
|
307 |
+
"""
|
308 |
+
from_hf_hub = kwargs.pop("from_hf_hub", False)
|
309 |
+
if from_hf_hub:
|
310 |
+
cache_dir = kwargs.pop("cache_dir", HF_CACHE)
|
311 |
+
else:
|
312 |
+
cache_dir = kwargs.pop("cache_dir", PPDIFFUSERS_CACHE)
|
313 |
+
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
314 |
+
output_loading_info = kwargs.pop("output_loading_info", False)
|
315 |
+
paddle_dtype = kwargs.pop("paddle_dtype", None)
|
316 |
+
subfolder = kwargs.pop("subfolder", None)
|
317 |
+
ignore_keys = kwargs.pop("ignore_keys", [])
|
318 |
+
|
319 |
+
# Load config if we don't provide a configuration
|
320 |
+
config_path = pretrained_model_name_or_path
|
321 |
+
|
322 |
+
model_file = None
|
323 |
+
if model_file is None:
|
324 |
+
model_file = _get_model_file(
|
325 |
+
pretrained_model_name_or_path,
|
326 |
+
weights_name=WEIGHTS_NAME,
|
327 |
+
cache_dir=cache_dir,
|
328 |
+
subfolder=subfolder,
|
329 |
+
from_hf_hub=from_hf_hub,
|
330 |
+
)
|
331 |
+
|
332 |
+
config, unused_kwargs = cls.load_config(
|
333 |
+
config_path,
|
334 |
+
cache_dir=cache_dir,
|
335 |
+
return_unused_kwargs=True,
|
336 |
+
subfolder=subfolder,
|
337 |
+
from_hf_hub=from_hf_hub,
|
338 |
+
**kwargs,
|
339 |
+
)
|
340 |
+
model = cls.from_config(config, **unused_kwargs)
|
341 |
+
|
342 |
+
state_dict = load_dict(model_file, map_location="cpu")
|
343 |
+
|
344 |
+
keys = list(state_dict.keys())
|
345 |
+
for k in keys:
|
346 |
+
for ik in ignore_keys:
|
347 |
+
if k.startswith(ik):
|
348 |
+
logger.warning("Deleting key {} from state_dict.".format(k))
|
349 |
+
del state_dict[k]
|
350 |
+
|
351 |
+
dtype = set(v.dtype for v in state_dict.values())
|
352 |
+
|
353 |
+
if len(dtype) > 1 and paddle.float32 not in dtype:
|
354 |
+
raise ValueError(
|
355 |
+
f"The weights of the model file {model_file} have a mixture of incompatible dtypes {dtype}. Please"
|
356 |
+
f" make sure that {model_file} weights have only one dtype."
|
357 |
+
)
|
358 |
+
elif len(dtype) > 1 and paddle.float32 in dtype:
|
359 |
+
dtype = paddle.float32
|
360 |
+
else:
|
361 |
+
dtype = dtype.pop()
|
362 |
+
|
363 |
+
# move model to correct dtype
|
364 |
+
model = model.to(dtype=dtype)
|
365 |
+
|
366 |
+
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
367 |
+
model,
|
368 |
+
state_dict,
|
369 |
+
model_file,
|
370 |
+
pretrained_model_name_or_path,
|
371 |
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
372 |
+
)
|
373 |
+
|
374 |
+
loading_info = {
|
375 |
+
"missing_keys": missing_keys,
|
376 |
+
"unexpected_keys": unexpected_keys,
|
377 |
+
"mismatched_keys": mismatched_keys,
|
378 |
+
"error_msgs": error_msgs,
|
379 |
+
}
|
380 |
+
|
381 |
+
if paddle_dtype is not None and not isinstance(paddle_dtype, paddle.dtype):
|
382 |
+
raise ValueError(
|
383 |
+
f"{paddle_dtype} needs to be of type `paddle.dtype`, e.g. `paddle.float16`, but is {type(paddle_dtype)}."
|
384 |
+
)
|
385 |
+
elif paddle_dtype is not None:
|
386 |
+
model = model.to(dtype=paddle_dtype)
|
387 |
+
|
388 |
+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
389 |
+
|
390 |
+
# Set model in evaluation mode to deactivate DropOut modules by default
|
391 |
+
model.eval()
|
392 |
+
if output_loading_info:
|
393 |
+
return model, loading_info
|
394 |
+
|
395 |
+
return model
|
396 |
+
|
397 |
+
@classmethod
|
398 |
+
def _load_pretrained_model(
|
399 |
+
cls,
|
400 |
+
model,
|
401 |
+
state_dict,
|
402 |
+
resolved_archive_file,
|
403 |
+
pretrained_model_name_or_path,
|
404 |
+
ignore_mismatched_sizes=False,
|
405 |
+
):
|
406 |
+
# Retrieve missing & unexpected_keys
|
407 |
+
model_state_dict = model.state_dict()
|
408 |
+
loaded_keys = [k for k in state_dict.keys()]
|
409 |
+
|
410 |
+
expected_keys = list(model_state_dict.keys())
|
411 |
+
|
412 |
+
original_loaded_keys = loaded_keys
|
413 |
+
|
414 |
+
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
415 |
+
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
416 |
+
|
417 |
+
# Make sure we are able to load base models as well as derived models (with heads)
|
418 |
+
model_to_load = model
|
419 |
+
|
420 |
+
def _find_mismatched_keys(
|
421 |
+
state_dict,
|
422 |
+
model_state_dict,
|
423 |
+
loaded_keys,
|
424 |
+
ignore_mismatched_sizes,
|
425 |
+
):
|
426 |
+
mismatched_keys = []
|
427 |
+
if ignore_mismatched_sizes:
|
428 |
+
for checkpoint_key in loaded_keys:
|
429 |
+
model_key = checkpoint_key
|
430 |
+
|
431 |
+
if model_key in model_state_dict and list(state_dict[checkpoint_key].shape) != list(
|
432 |
+
model_state_dict[model_key].shape
|
433 |
+
):
|
434 |
+
mismatched_keys.append(
|
435 |
+
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
436 |
+
)
|
437 |
+
del state_dict[checkpoint_key]
|
438 |
+
return mismatched_keys
|
439 |
+
|
440 |
+
if state_dict is not None:
|
441 |
+
# Whole checkpoint
|
442 |
+
mismatched_keys = _find_mismatched_keys(
|
443 |
+
state_dict,
|
444 |
+
model_state_dict,
|
445 |
+
original_loaded_keys,
|
446 |
+
ignore_mismatched_sizes,
|
447 |
+
)
|
448 |
+
error_msgs = ""
|
449 |
+
model_to_load.load_dict(state_dict)
|
450 |
+
|
451 |
+
if len(error_msgs) > 0:
|
452 |
+
error_msg = "\n\t".join(error_msgs)
|
453 |
+
if "size mismatch" in error_msg:
|
454 |
+
error_msg += (
|
455 |
+
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
|
456 |
+
)
|
457 |
+
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
458 |
+
|
459 |
+
if len(unexpected_keys) > 0:
|
460 |
+
logger.warning(
|
461 |
+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
462 |
+
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
463 |
+
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
|
464 |
+
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
465 |
+
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
466 |
+
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
|
467 |
+
" identical (initializing a BertForSequenceClassification model from a"
|
468 |
+
" BertForSequenceClassification model)."
|
469 |
+
)
|
470 |
+
else:
|
471 |
+
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
472 |
+
if len(missing_keys) > 0:
|
473 |
+
logger.warning(
|
474 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
475 |
+
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
476 |
+
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
477 |
+
)
|
478 |
+
elif len(mismatched_keys) == 0:
|
479 |
+
logger.info(
|
480 |
+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
481 |
+
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
|
482 |
+
f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
|
483 |
+
" without further training."
|
484 |
+
)
|
485 |
+
if len(mismatched_keys) > 0:
|
486 |
+
mismatched_warning = "\n".join(
|
487 |
+
[
|
488 |
+
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
489 |
+
for key, shape1, shape2 in mismatched_keys
|
490 |
+
]
|
491 |
+
)
|
492 |
+
logger.warning(
|
493 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
494 |
+
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
495 |
+
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
|
496 |
+
" able to use it for predictions and inference."
|
497 |
+
)
|
498 |
+
|
499 |
+
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
500 |
+
|
501 |
+
@property
|
502 |
+
def device(self):
|
503 |
+
"""
|
504 |
+
`paddle.place`: The device on which the module is (assuming that all the module parameters are on the same
|
505 |
+
device).
|
506 |
+
"""
|
507 |
+
return get_parameter_device(self)
|
508 |
+
|
509 |
+
@property
|
510 |
+
def dtype(self) -> paddle.dtype:
|
511 |
+
"""
|
512 |
+
`paddle.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
513 |
+
"""
|
514 |
+
return get_parameter_dtype(self)
|
515 |
+
|
516 |
+
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
|
517 |
+
"""
|
518 |
+
Get number of (optionally, trainable or non-embeddings) parameters in the module.
|
519 |
+
|
520 |
+
Args:
|
521 |
+
only_trainable (`bool`, *optional*, defaults to `False`):
|
522 |
+
Whether or not to return only the number of trainable parameters
|
523 |
+
|
524 |
+
exclude_embeddings (`bool`, *optional*, defaults to `False`):
|
525 |
+
Whether or not to return only the number of non-embeddings parameters
|
526 |
+
|
527 |
+
Returns:
|
528 |
+
`int`: The number of parameters.
|
529 |
+
"""
|
530 |
+
|
531 |
+
if exclude_embeddings:
|
532 |
+
embedding_param_names = [
|
533 |
+
f"{name}.weight"
|
534 |
+
for name, module_type in self.named_sublayers(include_self=True)
|
535 |
+
if isinstance(module_type, nn.Embedding)
|
536 |
+
]
|
537 |
+
non_embedding_parameters = [
|
538 |
+
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
|
539 |
+
]
|
540 |
+
return sum(p.numel() for p in non_embedding_parameters if not p.stop_gradient or not only_trainable)
|
541 |
+
else:
|
542 |
+
return sum(p.numel() for p in self.parameters() if not p.stop_gradient or not only_trainable)
|
543 |
+
|
544 |
+
|
545 |
+
def unwrap_model(model: nn.Layer) -> nn.Layer:
|
546 |
+
"""
|
547 |
+
Recursively unwraps a model from potential containers (as used in distributed training).
|
548 |
+
|
549 |
+
Args:
|
550 |
+
model (`nn.Layer`): The model to unwrap.
|
551 |
+
"""
|
552 |
+
# since there could be multiple levels of wrapping, unwrap recursively
|
553 |
+
if hasattr(model, "_layers"):
|
554 |
+
return unwrap_model(model._layers)
|
555 |
+
else:
|
556 |
+
return model
|
557 |
+
|
558 |
+
|
559 |
+
def _get_model_file(
|
560 |
+
pretrained_model_name_or_path,
|
561 |
+
*,
|
562 |
+
weights_name,
|
563 |
+
subfolder,
|
564 |
+
cache_dir,
|
565 |
+
from_hf_hub,
|
566 |
+
):
|
567 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
568 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
569 |
+
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
|
570 |
+
# Load from a PyTorch checkpoint
|
571 |
+
model_file = os.path.join(pretrained_model_name_or_path, weights_name)
|
572 |
+
elif subfolder is not None and os.path.isfile(
|
573 |
+
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
|
574 |
+
):
|
575 |
+
model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
|
576 |
+
else:
|
577 |
+
raise EnvironmentError(
|
578 |
+
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
|
579 |
+
)
|
580 |
+
return model_file
|
581 |
+
elif from_hf_hub:
|
582 |
+
model_file = hf_hub_download(
|
583 |
+
repo_id=pretrained_model_name_or_path,
|
584 |
+
filename=weights_name,
|
585 |
+
cache_dir=cache_dir,
|
586 |
+
subfolder=subfolder,
|
587 |
+
library_name="PPDiffusers",
|
588 |
+
library_version=__version__,
|
589 |
+
)
|
590 |
+
return model_file
|
591 |
+
else:
|
592 |
+
try:
|
593 |
+
# Load from URL or cache if already cached
|
594 |
+
model_file = ppdiffusers_bos_download(
|
595 |
+
pretrained_model_name_or_path,
|
596 |
+
filename=weights_name,
|
597 |
+
subfolder=subfolder,
|
598 |
+
cache_dir=cache_dir,
|
599 |
+
)
|
600 |
+
except HTTPError as err:
|
601 |
+
raise EnvironmentError(
|
602 |
+
"There was a specific connection error when trying to load" f" {pretrained_model_name_or_path}:\n{err}"
|
603 |
+
)
|
604 |
+
except ValueError:
|
605 |
+
raise EnvironmentError(
|
606 |
+
f"We couldn't connect to '{DOWNLOAD_SERVER}' to load this model, couldn't find it"
|
607 |
+
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
608 |
+
f" directory containing a file named {weights_name} or"
|
609 |
+
" \nCheckout your internet connection or see how to run the library in"
|
610 |
+
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
611 |
+
)
|
612 |
+
except EnvironmentError:
|
613 |
+
raise EnvironmentError(
|
614 |
+
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
615 |
+
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
616 |
+
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
617 |
+
f"containing a file named {weights_name}"
|
618 |
+
)
|
619 |
+
return model_file
|
ppdiffusers/models/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# flake8: noqa
|
16 |
+
|
17 |
+
from ..utils import is_paddle_available
|
18 |
+
|
19 |
+
if is_paddle_available():
|
20 |
+
from .attention import Transformer2DModel
|
21 |
+
from .prior_transformer import PriorTransformer
|
22 |
+
from .unet_1d import UNet1DModel
|
23 |
+
from .unet_2d import UNet2DModel
|
24 |
+
from .unet_2d_condition import UNet2DConditionModel
|
25 |
+
from .vae import AutoencoderKL, VQModel
|
ppdiffusers/models/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (511 Bytes). View file
|
|
ppdiffusers/models/__pycache__/attention.cpython-37.pyc
ADDED
Binary file (22.5 kB). View file
|
|
ppdiffusers/models/__pycache__/cross_attention.cpython-37.pyc
ADDED
Binary file (10.5 kB). View file
|
|
ppdiffusers/models/__pycache__/embeddings.cpython-37.pyc
ADDED
Binary file (5.68 kB). View file
|
|
ppdiffusers/models/__pycache__/prior_transformer.cpython-37.pyc
ADDED
Binary file (7.11 kB). View file
|
|
ppdiffusers/models/__pycache__/resnet.cpython-37.pyc
ADDED
Binary file (19.6 kB). View file
|
|
ppdiffusers/models/__pycache__/unet_1d.cpython-37.pyc
ADDED
Binary file (7.22 kB). View file
|
|
ppdiffusers/models/__pycache__/unet_1d_blocks.cpython-37.pyc
ADDED
Binary file (17.4 kB). View file
|
|
ppdiffusers/models/__pycache__/unet_2d.cpython-37.pyc
ADDED
Binary file (8.18 kB). View file
|
|
ppdiffusers/models/__pycache__/unet_2d_blocks.cpython-37.pyc
ADDED
Binary file (36.7 kB). View file
|
|
ppdiffusers/models/__pycache__/unet_2d_condition.cpython-37.pyc
ADDED
Binary file (15.7 kB). View file
|
|
ppdiffusers/models/__pycache__/vae.cpython-37.pyc
ADDED
Binary file (16.9 kB). View file
|
|
ppdiffusers/models/attention.py
ADDED
@@ -0,0 +1,683 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import math
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import Optional
|
18 |
+
|
19 |
+
import paddle
|
20 |
+
import paddle.nn.functional as F
|
21 |
+
from paddle import nn
|
22 |
+
|
23 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
24 |
+
from ..modeling_utils import ModelMixin
|
25 |
+
from ..models.embeddings import ImagePositionalEmbeddings
|
26 |
+
from ..utils import BaseOutput
|
27 |
+
from .cross_attention import CrossAttention
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class Transformer2DModelOutput(BaseOutput):
|
32 |
+
"""
|
33 |
+
Args:
|
34 |
+
sample (`paddle.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
35 |
+
Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
|
36 |
+
for the unnoised latent pixels.
|
37 |
+
"""
|
38 |
+
|
39 |
+
sample: paddle.Tensor
|
40 |
+
|
41 |
+
|
42 |
+
class Transformer2DModel(ModelMixin, ConfigMixin):
|
43 |
+
"""
|
44 |
+
Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
|
45 |
+
embeddings) inputs.
|
46 |
+
|
47 |
+
When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
|
48 |
+
transformer action. Finally, reshape to image.
|
49 |
+
|
50 |
+
When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
|
51 |
+
embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
|
52 |
+
classes of unnoised image.
|
53 |
+
|
54 |
+
Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
|
55 |
+
image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
|
56 |
+
|
57 |
+
Parameters:
|
58 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
59 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
60 |
+
in_channels (`int`, *optional*):
|
61 |
+
Pass if the input is continuous. The number of channels in the input and output.
|
62 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
63 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
64 |
+
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
65 |
+
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
66 |
+
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
67 |
+
`ImagePositionalEmbeddings`.
|
68 |
+
num_vector_embeds (`int`, *optional*):
|
69 |
+
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
70 |
+
Includes the class for the masked latent pixel.
|
71 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
72 |
+
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
73 |
+
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
74 |
+
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
75 |
+
up to but not more than steps than `num_embeds_ada_norm`.
|
76 |
+
attention_bias (`bool`, *optional*):
|
77 |
+
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
78 |
+
"""
|
79 |
+
|
80 |
+
@register_to_config
|
81 |
+
def __init__(
|
82 |
+
self,
|
83 |
+
num_attention_heads: int = 16,
|
84 |
+
attention_head_dim: int = 88,
|
85 |
+
in_channels: Optional[int] = None,
|
86 |
+
num_layers: int = 1,
|
87 |
+
dropout: float = 0.0,
|
88 |
+
norm_num_groups: int = 32,
|
89 |
+
cross_attention_dim: Optional[int] = None,
|
90 |
+
attention_bias: bool = False,
|
91 |
+
sample_size: Optional[int] = None,
|
92 |
+
num_vector_embeds: Optional[int] = None,
|
93 |
+
activation_fn: str = "geglu",
|
94 |
+
num_embeds_ada_norm: Optional[int] = None,
|
95 |
+
use_linear_projection: bool = False,
|
96 |
+
only_cross_attention: bool = False,
|
97 |
+
upcast_attention: bool = False,
|
98 |
+
):
|
99 |
+
super().__init__()
|
100 |
+
self.use_linear_projection = use_linear_projection
|
101 |
+
self.num_attention_heads = num_attention_heads
|
102 |
+
self.attention_head_dim = attention_head_dim
|
103 |
+
self.inner_dim = inner_dim = num_attention_heads * attention_head_dim
|
104 |
+
|
105 |
+
# 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
106 |
+
# Define whether input is continuous or discrete depending on configuration
|
107 |
+
self.is_input_continuous = in_channels is not None
|
108 |
+
self.is_input_vectorized = num_vector_embeds is not None
|
109 |
+
|
110 |
+
if self.is_input_continuous and self.is_input_vectorized:
|
111 |
+
raise ValueError(
|
112 |
+
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
113 |
+
" sure that either `in_channels` or `num_vector_embeds` is None."
|
114 |
+
)
|
115 |
+
elif not self.is_input_continuous and not self.is_input_vectorized:
|
116 |
+
raise ValueError(
|
117 |
+
f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
|
118 |
+
" sure that either `in_channels` or `num_vector_embeds` is not None."
|
119 |
+
)
|
120 |
+
|
121 |
+
# 2. Define input layers
|
122 |
+
if self.is_input_continuous:
|
123 |
+
self.in_channels = in_channels
|
124 |
+
|
125 |
+
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, epsilon=1e-6)
|
126 |
+
if use_linear_projection:
|
127 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
128 |
+
else:
|
129 |
+
self.proj_in = nn.Conv2D(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
130 |
+
elif self.is_input_vectorized:
|
131 |
+
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
132 |
+
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
133 |
+
|
134 |
+
self.height = sample_size
|
135 |
+
self.width = sample_size
|
136 |
+
self.num_vector_embeds = num_vector_embeds
|
137 |
+
self.num_latent_pixels = self.height * self.width
|
138 |
+
|
139 |
+
self.latent_image_embedding = ImagePositionalEmbeddings(
|
140 |
+
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
|
141 |
+
)
|
142 |
+
|
143 |
+
# 3. Define transformers blocks
|
144 |
+
self.transformer_blocks = nn.LayerList(
|
145 |
+
[
|
146 |
+
BasicTransformerBlock(
|
147 |
+
inner_dim,
|
148 |
+
num_attention_heads,
|
149 |
+
attention_head_dim,
|
150 |
+
dropout=dropout,
|
151 |
+
cross_attention_dim=cross_attention_dim,
|
152 |
+
activation_fn=activation_fn,
|
153 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
154 |
+
attention_bias=attention_bias,
|
155 |
+
only_cross_attention=only_cross_attention,
|
156 |
+
upcast_attention=upcast_attention,
|
157 |
+
)
|
158 |
+
for d in range(num_layers)
|
159 |
+
]
|
160 |
+
)
|
161 |
+
|
162 |
+
# 4. Define output layers
|
163 |
+
if self.is_input_continuous:
|
164 |
+
if use_linear_projection:
|
165 |
+
self.proj_out = nn.Linear(in_channels, inner_dim)
|
166 |
+
else:
|
167 |
+
self.proj_out = nn.Conv2D(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
168 |
+
elif self.is_input_vectorized:
|
169 |
+
self.norm_out = nn.LayerNorm(inner_dim)
|
170 |
+
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
171 |
+
|
172 |
+
def forward(
|
173 |
+
self,
|
174 |
+
hidden_states,
|
175 |
+
encoder_hidden_states=None,
|
176 |
+
timestep=None,
|
177 |
+
cross_attention_kwargs=None,
|
178 |
+
return_dict: bool = True,
|
179 |
+
):
|
180 |
+
"""
|
181 |
+
Args:
|
182 |
+
hidden_states ( When discrete, `paddle.Tensor` of shape `(batch size, num latent pixels)`.
|
183 |
+
When continous, `paddle.Tensor` of shape `(batch size, channel, height, width)`): Input
|
184 |
+
hidden_states
|
185 |
+
encoder_hidden_states ( `paddle.Tensor` of shape `(batch size, encoder_hidden_states)`, *optional*):
|
186 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
187 |
+
self-attention.
|
188 |
+
timestep ( `paddle.Tensor`, *optional*):
|
189 |
+
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
190 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
191 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
[`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
|
195 |
+
if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
|
196 |
+
tensor.
|
197 |
+
"""
|
198 |
+
# 1. Input
|
199 |
+
if self.is_input_continuous:
|
200 |
+
_, _, height, width = hidden_states.shape
|
201 |
+
residual = hidden_states
|
202 |
+
hidden_states = self.norm(hidden_states)
|
203 |
+
if not self.use_linear_projection:
|
204 |
+
hidden_states = self.proj_in(hidden_states)
|
205 |
+
hidden_states = hidden_states.transpose([0, 2, 3, 1]).flatten(1, 2)
|
206 |
+
if self.use_linear_projection:
|
207 |
+
hidden_states = self.proj_in(hidden_states)
|
208 |
+
elif self.is_input_vectorized:
|
209 |
+
hidden_states = self.latent_image_embedding(hidden_states.cast("int64"))
|
210 |
+
|
211 |
+
# 2. Blocks
|
212 |
+
for block in self.transformer_blocks:
|
213 |
+
hidden_states = block(
|
214 |
+
hidden_states,
|
215 |
+
encoder_hidden_states=encoder_hidden_states,
|
216 |
+
timestep=timestep,
|
217 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
218 |
+
)
|
219 |
+
|
220 |
+
# 3. Output
|
221 |
+
if self.is_input_continuous:
|
222 |
+
if self.use_linear_projection:
|
223 |
+
hidden_states = self.proj_out(hidden_states)
|
224 |
+
hidden_states = hidden_states.reshape([-1, height, width, self.inner_dim]).transpose([0, 3, 1, 2])
|
225 |
+
if not self.use_linear_projection:
|
226 |
+
hidden_states = self.proj_out(hidden_states)
|
227 |
+
output = hidden_states + residual
|
228 |
+
elif self.is_input_vectorized:
|
229 |
+
hidden_states = self.norm_out(hidden_states)
|
230 |
+
logits = self.out(hidden_states)
|
231 |
+
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
232 |
+
logits = logits.transpose([0, 2, 1])
|
233 |
+
|
234 |
+
# log(p(x_0))
|
235 |
+
output = F.log_softmax(logits.cast("float64"), axis=1).cast("float32")
|
236 |
+
|
237 |
+
if not return_dict:
|
238 |
+
return (output,)
|
239 |
+
|
240 |
+
return Transformer2DModelOutput(sample=output)
|
241 |
+
|
242 |
+
|
243 |
+
class AttentionBlock(nn.Layer):
|
244 |
+
"""
|
245 |
+
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
|
246 |
+
to the N-d case.
|
247 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
248 |
+
Uses three q, k, v linear layers to compute attention.
|
249 |
+
|
250 |
+
Parameters:
|
251 |
+
channels (`int`): The number of channels in the input and output.
|
252 |
+
num_head_channels (`int`, *optional*):
|
253 |
+
The number of channels in each head. If None, then `num_heads` = 1.
|
254 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
|
255 |
+
rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
|
256 |
+
eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
|
257 |
+
"""
|
258 |
+
|
259 |
+
def __init__(
|
260 |
+
self,
|
261 |
+
channels: int,
|
262 |
+
num_head_channels: Optional[int] = None,
|
263 |
+
norm_num_groups: int = 32,
|
264 |
+
rescale_output_factor: float = 1.0,
|
265 |
+
eps: float = 1e-5,
|
266 |
+
):
|
267 |
+
super().__init__()
|
268 |
+
self.channels = channels
|
269 |
+
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
|
270 |
+
self.head_dim = self.channels // self.num_heads
|
271 |
+
self.scale = 1 / math.sqrt(self.channels / self.num_heads)
|
272 |
+
|
273 |
+
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, epsilon=eps)
|
274 |
+
|
275 |
+
# define q,k,v as linear layers
|
276 |
+
self.query = nn.Linear(channels, channels)
|
277 |
+
self.key = nn.Linear(channels, channels)
|
278 |
+
self.value = nn.Linear(channels, channels)
|
279 |
+
|
280 |
+
self.rescale_output_factor = rescale_output_factor
|
281 |
+
self.proj_attn = nn.Linear(channels, channels)
|
282 |
+
|
283 |
+
def reshape_heads_to_batch_dim(self, tensor):
|
284 |
+
tensor = tensor.reshape([0, 0, self.num_heads, self.head_dim])
|
285 |
+
tensor = tensor.transpose([0, 2, 1, 3])
|
286 |
+
return tensor
|
287 |
+
|
288 |
+
def reshape_batch_dim_to_heads(self, tensor):
|
289 |
+
tensor = tensor.transpose([0, 2, 1, 3])
|
290 |
+
tensor = tensor.reshape([0, 0, tensor.shape[2] * tensor.shape[3]])
|
291 |
+
return tensor
|
292 |
+
|
293 |
+
def forward(self, hidden_states):
|
294 |
+
residual = hidden_states
|
295 |
+
batch, channel, height, width = hidden_states.shape
|
296 |
+
|
297 |
+
# norm
|
298 |
+
hidden_states = self.group_norm(hidden_states)
|
299 |
+
|
300 |
+
hidden_states = hidden_states.reshape([batch, channel, height * width]).transpose([0, 2, 1])
|
301 |
+
|
302 |
+
# proj to q, k, v
|
303 |
+
query_proj = self.query(hidden_states)
|
304 |
+
key_proj = self.key(hidden_states)
|
305 |
+
value_proj = self.value(hidden_states)
|
306 |
+
|
307 |
+
query_proj = self.reshape_heads_to_batch_dim(query_proj)
|
308 |
+
key_proj = self.reshape_heads_to_batch_dim(key_proj)
|
309 |
+
value_proj = self.reshape_heads_to_batch_dim(value_proj)
|
310 |
+
|
311 |
+
# get scores
|
312 |
+
attention_scores = paddle.matmul(query_proj, key_proj, transpose_y=True) * self.scale
|
313 |
+
attention_probs = F.softmax(attention_scores.cast("float32"), axis=-1).cast(attention_scores.dtype)
|
314 |
+
|
315 |
+
# compute attention output
|
316 |
+
hidden_states = paddle.matmul(attention_probs, value_proj)
|
317 |
+
|
318 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
319 |
+
|
320 |
+
# compute next hidden_states
|
321 |
+
hidden_states = self.proj_attn(hidden_states)
|
322 |
+
hidden_states = hidden_states.transpose([0, 2, 1]).reshape([batch, channel, height, width])
|
323 |
+
|
324 |
+
# res connect and rescale
|
325 |
+
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
326 |
+
return hidden_states
|
327 |
+
|
328 |
+
|
329 |
+
class BasicTransformerBlock(nn.Layer):
|
330 |
+
r"""
|
331 |
+
A basic Transformer block.
|
332 |
+
|
333 |
+
Parameters:
|
334 |
+
dim (`int`): The number of channels in the input and output.
|
335 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
336 |
+
attention_head_dim (`int`): The number of channels in each head.
|
337 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
338 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
339 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
340 |
+
num_embeds_ada_norm (:
|
341 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
342 |
+
attention_bias (:
|
343 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
344 |
+
"""
|
345 |
+
|
346 |
+
def __init__(
|
347 |
+
self,
|
348 |
+
dim: int,
|
349 |
+
num_attention_heads: int,
|
350 |
+
attention_head_dim: int,
|
351 |
+
dropout=0.0,
|
352 |
+
cross_attention_dim: Optional[int] = None,
|
353 |
+
activation_fn: str = "geglu",
|
354 |
+
num_embeds_ada_norm: Optional[int] = None,
|
355 |
+
attention_bias: bool = False,
|
356 |
+
only_cross_attention: bool = False,
|
357 |
+
upcast_attention: bool = False,
|
358 |
+
):
|
359 |
+
super().__init__()
|
360 |
+
self.only_cross_attention = only_cross_attention
|
361 |
+
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
362 |
+
|
363 |
+
# 1. Self-Attn
|
364 |
+
self.attn1 = CrossAttention(
|
365 |
+
query_dim=dim,
|
366 |
+
heads=num_attention_heads,
|
367 |
+
dim_head=attention_head_dim,
|
368 |
+
dropout=dropout,
|
369 |
+
bias=attention_bias,
|
370 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
371 |
+
upcast_attention=upcast_attention,
|
372 |
+
)
|
373 |
+
|
374 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
375 |
+
|
376 |
+
# 2. Cross-Attn
|
377 |
+
if cross_attention_dim is not None:
|
378 |
+
self.attn2 = CrossAttention(
|
379 |
+
query_dim=dim,
|
380 |
+
cross_attention_dim=cross_attention_dim,
|
381 |
+
heads=num_attention_heads,
|
382 |
+
dim_head=attention_head_dim,
|
383 |
+
dropout=dropout,
|
384 |
+
bias=attention_bias,
|
385 |
+
upcast_attention=upcast_attention,
|
386 |
+
) # is self-attn if encoder_hidden_states is none
|
387 |
+
else:
|
388 |
+
self.attn2 = None
|
389 |
+
|
390 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
391 |
+
|
392 |
+
if cross_attention_dim is not None:
|
393 |
+
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
394 |
+
else:
|
395 |
+
self.norm2 = None
|
396 |
+
|
397 |
+
# 3. Feed-forward
|
398 |
+
self.norm3 = nn.LayerNorm(dim)
|
399 |
+
|
400 |
+
def forward(
|
401 |
+
self,
|
402 |
+
hidden_states,
|
403 |
+
encoder_hidden_states=None,
|
404 |
+
timestep=None,
|
405 |
+
attention_mask=None,
|
406 |
+
cross_attention_kwargs=None,
|
407 |
+
):
|
408 |
+
# 1. Self-Attention
|
409 |
+
norm_hidden_states = (
|
410 |
+
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
411 |
+
)
|
412 |
+
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
413 |
+
attn_output = self.attn1(
|
414 |
+
norm_hidden_states,
|
415 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
416 |
+
attention_mask=attention_mask,
|
417 |
+
**cross_attention_kwargs,
|
418 |
+
)
|
419 |
+
hidden_states = attn_output + hidden_states
|
420 |
+
|
421 |
+
if self.attn2 is not None:
|
422 |
+
# 2. Cross-Attention
|
423 |
+
norm_hidden_states = (
|
424 |
+
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
425 |
+
)
|
426 |
+
attn_output = self.attn2(
|
427 |
+
norm_hidden_states,
|
428 |
+
encoder_hidden_states=encoder_hidden_states,
|
429 |
+
attention_mask=attention_mask,
|
430 |
+
**cross_attention_kwargs,
|
431 |
+
)
|
432 |
+
hidden_states = attn_output + hidden_states
|
433 |
+
|
434 |
+
# 3. Feed-forward
|
435 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
436 |
+
|
437 |
+
return hidden_states
|
438 |
+
|
439 |
+
|
440 |
+
class FeedForward(nn.Layer):
|
441 |
+
r"""
|
442 |
+
A feed-forward layer.
|
443 |
+
|
444 |
+
Parameters:
|
445 |
+
dim (`int`): The number of channels in the input.
|
446 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
447 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
448 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
449 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
450 |
+
"""
|
451 |
+
|
452 |
+
def __init__(
|
453 |
+
self,
|
454 |
+
dim: int,
|
455 |
+
dim_out: Optional[int] = None,
|
456 |
+
mult: int = 4,
|
457 |
+
dropout: float = 0.0,
|
458 |
+
activation_fn: str = "geglu",
|
459 |
+
):
|
460 |
+
super().__init__()
|
461 |
+
inner_dim = int(dim * mult)
|
462 |
+
dim_out = dim_out if dim_out is not None else dim
|
463 |
+
|
464 |
+
if activation_fn == "gelu":
|
465 |
+
act_fn = GELU(dim, inner_dim)
|
466 |
+
elif activation_fn == "geglu":
|
467 |
+
act_fn = GEGLU(dim, inner_dim)
|
468 |
+
elif activation_fn == "geglu-approximate":
|
469 |
+
act_fn = ApproximateGELU(dim, inner_dim)
|
470 |
+
|
471 |
+
self.net = nn.LayerList([])
|
472 |
+
# project in
|
473 |
+
self.net.append(act_fn)
|
474 |
+
# project dropout
|
475 |
+
self.net.append(nn.Dropout(dropout))
|
476 |
+
# project out
|
477 |
+
self.net.append(nn.Linear(inner_dim, dim_out))
|
478 |
+
|
479 |
+
def forward(self, hidden_states):
|
480 |
+
for module in self.net:
|
481 |
+
hidden_states = module(hidden_states)
|
482 |
+
return hidden_states
|
483 |
+
|
484 |
+
|
485 |
+
class GELU(nn.Layer):
|
486 |
+
r"""
|
487 |
+
GELU activation function
|
488 |
+
"""
|
489 |
+
|
490 |
+
def __init__(self, dim_in: int, dim_out: int):
|
491 |
+
super().__init__()
|
492 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
493 |
+
|
494 |
+
def forward(self, hidden_states):
|
495 |
+
hidden_states = self.proj(hidden_states)
|
496 |
+
hidden_states = F.gelu(hidden_states)
|
497 |
+
return hidden_states
|
498 |
+
|
499 |
+
|
500 |
+
# feedforward
|
501 |
+
class GEGLU(nn.Layer):
|
502 |
+
r"""
|
503 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
504 |
+
|
505 |
+
Parameters:
|
506 |
+
dim_in (`int`): The number of channels in the input.
|
507 |
+
dim_out (`int`): The number of channels in the output.
|
508 |
+
"""
|
509 |
+
|
510 |
+
def __init__(self, dim_in: int, dim_out: int):
|
511 |
+
super().__init__()
|
512 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
513 |
+
|
514 |
+
def forward(self, hidden_states):
|
515 |
+
hidden_states, gate = self.proj(hidden_states).chunk(2, axis=-1)
|
516 |
+
return hidden_states * F.gelu(gate)
|
517 |
+
|
518 |
+
|
519 |
+
class ApproximateGELU(nn.Layer):
|
520 |
+
"""
|
521 |
+
The approximate form of Gaussian Error Linear Unit (GELU)
|
522 |
+
|
523 |
+
For more details, see section 2: https://arxiv.org/abs/1606.08415
|
524 |
+
"""
|
525 |
+
|
526 |
+
def __init__(self, dim_in: int, dim_out: int):
|
527 |
+
super().__init__()
|
528 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
529 |
+
|
530 |
+
def forward(self, x):
|
531 |
+
x = self.proj(x)
|
532 |
+
return x * F.sigmoid(1.702 * x)
|
533 |
+
|
534 |
+
|
535 |
+
class AdaLayerNorm(nn.Layer):
|
536 |
+
"""
|
537 |
+
Norm layer modified to incorporate timestep embeddings.
|
538 |
+
"""
|
539 |
+
|
540 |
+
def __init__(self, embedding_dim, num_embeddings):
|
541 |
+
super().__init__()
|
542 |
+
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
543 |
+
self.silu = nn.Silu()
|
544 |
+
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
545 |
+
self.norm = nn.LayerNorm(embedding_dim) # elementwise_affine=False
|
546 |
+
|
547 |
+
def forward(self, x, timestep):
|
548 |
+
emb = self.linear(self.silu(self.emb(timestep)))
|
549 |
+
scale, shift = paddle.chunk(emb, 2, axis=-1)
|
550 |
+
x = self.norm(x) * (1 + scale) + shift
|
551 |
+
return x
|
552 |
+
|
553 |
+
|
554 |
+
class DualTransformer2DModel(nn.Layer):
|
555 |
+
"""
|
556 |
+
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
|
557 |
+
Parameters:
|
558 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
559 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
560 |
+
in_channels (`int`, *optional*):
|
561 |
+
Pass if the input is continuous. The number of channels in the input and output.
|
562 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
563 |
+
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
564 |
+
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
565 |
+
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
566 |
+
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
567 |
+
`ImagePositionalEmbeddings`.
|
568 |
+
num_vector_embeds (`int`, *optional*):
|
569 |
+
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
570 |
+
Includes the class for the masked latent pixel.
|
571 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
572 |
+
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
573 |
+
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
574 |
+
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
575 |
+
up to but not more than steps than `num_embeds_ada_norm`.
|
576 |
+
attention_bias (`bool`, *optional*):
|
577 |
+
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
578 |
+
"""
|
579 |
+
|
580 |
+
def __init__(
|
581 |
+
self,
|
582 |
+
num_attention_heads: int = 16,
|
583 |
+
attention_head_dim: int = 88,
|
584 |
+
in_channels: Optional[int] = None,
|
585 |
+
num_layers: int = 1,
|
586 |
+
dropout: float = 0.0,
|
587 |
+
norm_num_groups: int = 32,
|
588 |
+
cross_attention_dim: Optional[int] = None,
|
589 |
+
attention_bias: bool = False,
|
590 |
+
sample_size: Optional[int] = None,
|
591 |
+
num_vector_embeds: Optional[int] = None,
|
592 |
+
activation_fn: str = "geglu",
|
593 |
+
num_embeds_ada_norm: Optional[int] = None,
|
594 |
+
):
|
595 |
+
super().__init__()
|
596 |
+
self.transformers = nn.LayerList(
|
597 |
+
[
|
598 |
+
Transformer2DModel(
|
599 |
+
num_attention_heads=num_attention_heads,
|
600 |
+
attention_head_dim=attention_head_dim,
|
601 |
+
in_channels=in_channels,
|
602 |
+
num_layers=num_layers,
|
603 |
+
dropout=dropout,
|
604 |
+
norm_num_groups=norm_num_groups,
|
605 |
+
cross_attention_dim=cross_attention_dim,
|
606 |
+
attention_bias=attention_bias,
|
607 |
+
sample_size=sample_size,
|
608 |
+
num_vector_embeds=num_vector_embeds,
|
609 |
+
activation_fn=activation_fn,
|
610 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
611 |
+
)
|
612 |
+
for _ in range(2)
|
613 |
+
]
|
614 |
+
)
|
615 |
+
|
616 |
+
# Variables that can be set by a pipeline:
|
617 |
+
|
618 |
+
# The ratio of transformer1 to transformer2's output states to be combined during inference
|
619 |
+
self.mix_ratio = 0.5
|
620 |
+
|
621 |
+
# The shape of `encoder_hidden_states` is expected to be
|
622 |
+
# `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
|
623 |
+
self.condition_lengths = [77, 257]
|
624 |
+
|
625 |
+
# Which transformer to use to encode which condition.
|
626 |
+
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
|
627 |
+
self.transformer_index_for_condition = [1, 0]
|
628 |
+
|
629 |
+
def forward(
|
630 |
+
self,
|
631 |
+
hidden_states,
|
632 |
+
encoder_hidden_states,
|
633 |
+
timestep=None,
|
634 |
+
attention_mask=None,
|
635 |
+
cross_attention_kwargs=None,
|
636 |
+
return_dict: bool = True,
|
637 |
+
):
|
638 |
+
"""
|
639 |
+
Args:
|
640 |
+
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
641 |
+
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
642 |
+
hidden_states
|
643 |
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
644 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
645 |
+
self-attention.
|
646 |
+
timestep ( `torch.long`, *optional*):
|
647 |
+
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
648 |
+
attention_mask (`torch.FloatTensor`, *optional*):
|
649 |
+
Optional attention mask to be applied in CrossAttention
|
650 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
651 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
652 |
+
|
653 |
+
Returns:
|
654 |
+
[`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
|
655 |
+
if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
|
656 |
+
tensor.
|
657 |
+
"""
|
658 |
+
input_states = hidden_states
|
659 |
+
|
660 |
+
encoded_states = []
|
661 |
+
tokens_start = 0
|
662 |
+
# attention_mask is not used yet
|
663 |
+
for i in range(2):
|
664 |
+
# for each of the two transformers, pass the corresponding condition tokens
|
665 |
+
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
|
666 |
+
transformer_index = self.transformer_index_for_condition[i]
|
667 |
+
encoded_state = self.transformers[transformer_index](
|
668 |
+
input_states,
|
669 |
+
encoder_hidden_states=condition_state,
|
670 |
+
timestep=timestep,
|
671 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
672 |
+
return_dict=False,
|
673 |
+
)[0]
|
674 |
+
encoded_states.append(encoded_state - input_states)
|
675 |
+
tokens_start += self.condition_lengths[i]
|
676 |
+
|
677 |
+
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
|
678 |
+
output_states = output_states + input_states
|
679 |
+
|
680 |
+
if not return_dict:
|
681 |
+
return (output_states,)
|
682 |
+
|
683 |
+
return Transformer2DModelOutput(sample=output_states)
|
ppdiffusers/models/cross_attention.py
ADDED
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Optional, Union
|
15 |
+
|
16 |
+
import paddle
|
17 |
+
import paddle.nn as nn
|
18 |
+
import paddle.nn.functional as F
|
19 |
+
|
20 |
+
from ..initializer import normal_, zeros_
|
21 |
+
|
22 |
+
|
23 |
+
class CrossAttention(nn.Layer):
|
24 |
+
r"""
|
25 |
+
A cross attention layer.
|
26 |
+
|
27 |
+
Parameters:
|
28 |
+
query_dim (`int`): The number of channels in the query.
|
29 |
+
cross_attention_dim (`int`, *optional*):
|
30 |
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
31 |
+
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
32 |
+
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
33 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
34 |
+
bias (`bool`, *optional*, defaults to False):
|
35 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
query_dim: int,
|
41 |
+
cross_attention_dim: Optional[int] = None,
|
42 |
+
heads: int = 8,
|
43 |
+
dim_head: int = 64,
|
44 |
+
dropout: float = 0.0,
|
45 |
+
bias=False,
|
46 |
+
upcast_attention: bool = False,
|
47 |
+
upcast_softmax: bool = False,
|
48 |
+
added_kv_proj_dim: Optional[int] = None,
|
49 |
+
norm_num_groups: Optional[int] = None,
|
50 |
+
processor: Optional["AttnProcessor"] = None,
|
51 |
+
):
|
52 |
+
super().__init__()
|
53 |
+
inner_dim = dim_head * heads
|
54 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
55 |
+
self.upcast_attention = upcast_attention
|
56 |
+
self.upcast_softmax = upcast_softmax
|
57 |
+
|
58 |
+
self.scale = dim_head**-0.5
|
59 |
+
self.num_heads = heads
|
60 |
+
self.head_dim = inner_dim // heads
|
61 |
+
# for slice_size > 0 the attention score computation
|
62 |
+
# is split across the batch axis to save memory
|
63 |
+
# You can set slice_size with `set_attention_slice`
|
64 |
+
self.sliceable_head_dim = heads
|
65 |
+
|
66 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
67 |
+
|
68 |
+
if norm_num_groups is not None:
|
69 |
+
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, epsilon=1e-5)
|
70 |
+
else:
|
71 |
+
self.group_norm = None
|
72 |
+
|
73 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias_attr=bias)
|
74 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias_attr=bias)
|
75 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias_attr=bias)
|
76 |
+
|
77 |
+
if self.added_kv_proj_dim is not None:
|
78 |
+
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
79 |
+
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
80 |
+
|
81 |
+
self.to_out = nn.LayerList([])
|
82 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
83 |
+
self.to_out.append(nn.Dropout(dropout))
|
84 |
+
|
85 |
+
# set attention processor
|
86 |
+
processor = processor if processor is not None else CrossAttnProcessor()
|
87 |
+
self.set_processor(processor)
|
88 |
+
|
89 |
+
def set_attention_slice(self, slice_size):
|
90 |
+
if slice_size is not None and slice_size > self.sliceable_head_dim:
|
91 |
+
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
92 |
+
|
93 |
+
if slice_size is not None and self.added_kv_proj_dim is not None:
|
94 |
+
processor = SlicedAttnAddedKVProcessor(slice_size)
|
95 |
+
elif slice_size is not None:
|
96 |
+
processor = SlicedAttnProcessor(slice_size)
|
97 |
+
elif self.added_kv_proj_dim is not None:
|
98 |
+
processor = CrossAttnAddedKVProcessor()
|
99 |
+
else:
|
100 |
+
processor = CrossAttnProcessor()
|
101 |
+
|
102 |
+
self.set_processor(processor)
|
103 |
+
|
104 |
+
def set_processor(self, processor: "AttnProcessor"):
|
105 |
+
self.processor = processor
|
106 |
+
|
107 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
|
108 |
+
# The `CrossAttention` class can call different attention processors / attention functions
|
109 |
+
# here we simply pass along all tensors to the selected processor class
|
110 |
+
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
111 |
+
return self.processor(
|
112 |
+
self,
|
113 |
+
hidden_states,
|
114 |
+
encoder_hidden_states=encoder_hidden_states,
|
115 |
+
attention_mask=attention_mask,
|
116 |
+
**cross_attention_kwargs,
|
117 |
+
)
|
118 |
+
|
119 |
+
def batch_to_head_dim(self, tensor):
|
120 |
+
tensor = tensor.transpose([0, 2, 1, 3])
|
121 |
+
tensor = tensor.reshape([0, 0, tensor.shape[2] * tensor.shape[3]])
|
122 |
+
return tensor
|
123 |
+
|
124 |
+
def head_to_batch_dim(self, tensor):
|
125 |
+
tensor = tensor.reshape([0, 0, self.num_heads, self.head_dim])
|
126 |
+
tensor = tensor.transpose([0, 2, 1, 3])
|
127 |
+
return tensor
|
128 |
+
|
129 |
+
def get_attention_scores(self, query, key, attention_mask=None):
|
130 |
+
if self.upcast_attention:
|
131 |
+
query = query.cast("float32")
|
132 |
+
key = key.cast("float32")
|
133 |
+
|
134 |
+
attention_scores = paddle.matmul(query, key, transpose_y=True) * self.scale
|
135 |
+
|
136 |
+
if attention_mask is not None:
|
137 |
+
attention_scores = attention_scores + attention_mask
|
138 |
+
|
139 |
+
if self.upcast_softmax:
|
140 |
+
attention_scores = attention_scores.cast("float32")
|
141 |
+
|
142 |
+
attention_probs = F.softmax(attention_scores, axis=-1)
|
143 |
+
if self.upcast_softmax:
|
144 |
+
attention_probs = attention_probs.cast(query.dtype)
|
145 |
+
|
146 |
+
return attention_probs
|
147 |
+
|
148 |
+
def prepare_attention_mask(self, attention_mask, target_length):
|
149 |
+
if attention_mask is None:
|
150 |
+
return attention_mask
|
151 |
+
|
152 |
+
if attention_mask.shape[-1] != target_length:
|
153 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0, data_format="NCL")
|
154 |
+
attention_mask = attention_mask.repeat_interleave(self.num_heads, axis=0)
|
155 |
+
return attention_mask
|
156 |
+
|
157 |
+
|
158 |
+
class CrossAttnProcessor:
|
159 |
+
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
160 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
161 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
162 |
+
attention_mask = (
|
163 |
+
attention_mask.reshape([batch_size, attn.num_heads, -1, attention_mask.shape[-1]])
|
164 |
+
if attention_mask is not None
|
165 |
+
else None
|
166 |
+
)
|
167 |
+
|
168 |
+
query = attn.to_q(hidden_states)
|
169 |
+
query = attn.head_to_batch_dim(query)
|
170 |
+
|
171 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
172 |
+
key = attn.to_k(encoder_hidden_states)
|
173 |
+
value = attn.to_v(encoder_hidden_states)
|
174 |
+
key = attn.head_to_batch_dim(key)
|
175 |
+
value = attn.head_to_batch_dim(value)
|
176 |
+
|
177 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
178 |
+
hidden_states = paddle.matmul(attention_probs, value)
|
179 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
180 |
+
|
181 |
+
# linear proj
|
182 |
+
hidden_states = attn.to_out[0](hidden_states)
|
183 |
+
# dropout
|
184 |
+
hidden_states = attn.to_out[1](hidden_states)
|
185 |
+
|
186 |
+
return hidden_states
|
187 |
+
|
188 |
+
|
189 |
+
class LoRALinearLayer(nn.Layer):
|
190 |
+
def __init__(self, in_features, out_features, rank=4):
|
191 |
+
super().__init__()
|
192 |
+
|
193 |
+
if rank > min(in_features, out_features):
|
194 |
+
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
|
195 |
+
|
196 |
+
self.down = nn.Linear(in_features, rank, bias_attr=False)
|
197 |
+
self.up = nn.Linear(rank, out_features, bias_attr=False)
|
198 |
+
self.scale = 1.0
|
199 |
+
|
200 |
+
normal_(self.down.weight, std=1 / rank)
|
201 |
+
zeros_(self.up.weight)
|
202 |
+
|
203 |
+
def forward(self, hidden_states):
|
204 |
+
orig_dtype = hidden_states.dtype
|
205 |
+
dtype = self.down.weight.dtype
|
206 |
+
|
207 |
+
down_hidden_states = self.down(hidden_states.cast(dtype))
|
208 |
+
up_hidden_states = self.up(down_hidden_states)
|
209 |
+
|
210 |
+
return up_hidden_states.cast(orig_dtype)
|
211 |
+
|
212 |
+
|
213 |
+
class LoRACrossAttnProcessor(nn.Layer):
|
214 |
+
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
|
215 |
+
super().__init__()
|
216 |
+
|
217 |
+
self.hidden_size = hidden_size
|
218 |
+
self.cross_attention_dim = cross_attention_dim
|
219 |
+
self.rank = rank
|
220 |
+
|
221 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
222 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
|
223 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
|
224 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
225 |
+
|
226 |
+
def __call__(
|
227 |
+
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
|
228 |
+
):
|
229 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
230 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
231 |
+
attention_mask = (
|
232 |
+
attention_mask.reshape([batch_size, attn.num_heads, -1, attention_mask.shape[-1]])
|
233 |
+
if attention_mask is not None
|
234 |
+
else None
|
235 |
+
)
|
236 |
+
|
237 |
+
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
238 |
+
query = attn.head_to_batch_dim(query)
|
239 |
+
|
240 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
241 |
+
|
242 |
+
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
243 |
+
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
244 |
+
|
245 |
+
key = attn.head_to_batch_dim(key)
|
246 |
+
value = attn.head_to_batch_dim(value)
|
247 |
+
|
248 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
249 |
+
hidden_states = paddle.matmul(attention_probs, value)
|
250 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
251 |
+
|
252 |
+
# linear proj
|
253 |
+
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
254 |
+
# dropout
|
255 |
+
hidden_states = attn.to_out[1](hidden_states)
|
256 |
+
|
257 |
+
return hidden_states
|
258 |
+
|
259 |
+
|
260 |
+
class CrossAttnAddedKVProcessor:
|
261 |
+
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
262 |
+
residual = hidden_states
|
263 |
+
hidden_states = hidden_states.reshape([hidden_states.shape[0], hidden_states.shape[1], -1]).transpose(
|
264 |
+
[0, 2, 1]
|
265 |
+
)
|
266 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
267 |
+
encoder_hidden_states = encoder_hidden_states.transpose([0, 2, 1])
|
268 |
+
|
269 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
270 |
+
attention_mask = (
|
271 |
+
attention_mask.reshape([batch_size, attn.num_heads, -1, attention_mask.shape[-1]])
|
272 |
+
if attention_mask is not None
|
273 |
+
else None
|
274 |
+
)
|
275 |
+
|
276 |
+
hidden_states = attn.group_norm(hidden_states.transpose([0, 2, 1])).transpose([0, 2, 1])
|
277 |
+
|
278 |
+
query = attn.to_q(hidden_states)
|
279 |
+
query = attn.head_to_batch_dim(query)
|
280 |
+
|
281 |
+
key = attn.to_k(hidden_states)
|
282 |
+
value = attn.to_v(hidden_states)
|
283 |
+
key = attn.head_to_batch_dim(key)
|
284 |
+
value = attn.head_to_batch_dim(value)
|
285 |
+
|
286 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
287 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
288 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
289 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
290 |
+
|
291 |
+
key = paddle.concat([encoder_hidden_states_key_proj, key], axis=2)
|
292 |
+
value = paddle.concat([encoder_hidden_states_value_proj, value], axis=2)
|
293 |
+
|
294 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
295 |
+
hidden_states = paddle.matmul(attention_probs, value)
|
296 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
297 |
+
|
298 |
+
# linear proj
|
299 |
+
hidden_states = attn.to_out[0](hidden_states)
|
300 |
+
# dropout
|
301 |
+
hidden_states = attn.to_out[1](hidden_states)
|
302 |
+
|
303 |
+
hidden_states = hidden_states.transpose([0, 2, 1]).reshape(residual.shape)
|
304 |
+
hidden_states = hidden_states + residual
|
305 |
+
|
306 |
+
return hidden_states
|
307 |
+
|
308 |
+
|
309 |
+
class SlicedAttnProcessor:
|
310 |
+
def __init__(self, slice_size):
|
311 |
+
self.slice_size = slice_size
|
312 |
+
|
313 |
+
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
314 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
315 |
+
|
316 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
317 |
+
|
318 |
+
query = attn.to_q(hidden_states)
|
319 |
+
query = attn.head_to_batch_dim(query)
|
320 |
+
|
321 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
322 |
+
key = attn.to_k(encoder_hidden_states)
|
323 |
+
value = attn.to_v(encoder_hidden_states)
|
324 |
+
key = attn.head_to_batch_dim(key)
|
325 |
+
value = attn.head_to_batch_dim(value)
|
326 |
+
|
327 |
+
query = query.flatten(0, 1)
|
328 |
+
key = key.flatten(0, 1)
|
329 |
+
value = value.flatten(0, 1)
|
330 |
+
|
331 |
+
batch_size_attention = query.shape[0]
|
332 |
+
hidden_states = paddle.zeros((batch_size_attention, sequence_length, attn.head_dim), dtype=query.dtype)
|
333 |
+
|
334 |
+
for i in range(hidden_states.shape[0] // self.slice_size):
|
335 |
+
start_idx = i * self.slice_size
|
336 |
+
end_idx = (i + 1) * self.slice_size
|
337 |
+
|
338 |
+
query_slice = query[start_idx:end_idx]
|
339 |
+
key_slice = key[start_idx:end_idx]
|
340 |
+
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
341 |
+
|
342 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
343 |
+
|
344 |
+
attn_slice = paddle.matmul(attn_slice, value[start_idx:end_idx])
|
345 |
+
|
346 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
347 |
+
|
348 |
+
# reshape back to [bs, num_heads, seqlen, head_dim]
|
349 |
+
hidden_states = hidden_states.reshape([-1, attn.num_heads, sequence_length, attn.head_dim])
|
350 |
+
# reshape hidden_states
|
351 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
352 |
+
|
353 |
+
# linear proj
|
354 |
+
hidden_states = attn.to_out[0](hidden_states)
|
355 |
+
# dropout
|
356 |
+
hidden_states = attn.to_out[1](hidden_states)
|
357 |
+
|
358 |
+
return hidden_states
|
359 |
+
|
360 |
+
|
361 |
+
class SlicedAttnAddedKVProcessor:
|
362 |
+
def __init__(self, slice_size):
|
363 |
+
self.slice_size = slice_size
|
364 |
+
|
365 |
+
def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None):
|
366 |
+
residual = hidden_states
|
367 |
+
hidden_states = hidden_states.reshape([hidden_states.shape[0], hidden_states.shape[1], -1]).transpose(
|
368 |
+
[0, 2, 1]
|
369 |
+
)
|
370 |
+
encoder_hidden_states = encoder_hidden_states.transpose([0, 2, 1])
|
371 |
+
|
372 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
373 |
+
|
374 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
375 |
+
|
376 |
+
hidden_states = attn.group_norm(hidden_states.transpose([0, 2, 1])).transpose([0, 2, 1])
|
377 |
+
|
378 |
+
query = attn.to_q(hidden_states)
|
379 |
+
query = attn.head_to_batch_dim(query)
|
380 |
+
|
381 |
+
key = attn.to_k(hidden_states)
|
382 |
+
value = attn.to_v(hidden_states)
|
383 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
384 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
385 |
+
|
386 |
+
key = attn.head_to_batch_dim(key)
|
387 |
+
value = attn.head_to_batch_dim(value)
|
388 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
389 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
390 |
+
|
391 |
+
key = paddle.concat([encoder_hidden_states_key_proj, key], axis=2)
|
392 |
+
value = paddle.concat([encoder_hidden_states_value_proj, value], axis=2)
|
393 |
+
|
394 |
+
query = query.flatten(0, 1)
|
395 |
+
key = key.flatten(0, 1)
|
396 |
+
value = value.flatten(0, 1)
|
397 |
+
|
398 |
+
batch_size_attention = query.shape[0]
|
399 |
+
hidden_states = paddle.zeros((batch_size_attention, sequence_length, attn.head_dim), dtype=query.dtype)
|
400 |
+
for i in range(hidden_states.shape[0] // self.slice_size):
|
401 |
+
start_idx = i * self.slice_size
|
402 |
+
end_idx = (i + 1) * self.slice_size
|
403 |
+
|
404 |
+
query_slice = query[start_idx:end_idx]
|
405 |
+
key_slice = key[start_idx:end_idx]
|
406 |
+
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
407 |
+
|
408 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
409 |
+
|
410 |
+
attn_slice = paddle.matmul(attn_slice, value[start_idx:end_idx])
|
411 |
+
|
412 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
413 |
+
|
414 |
+
# reshape back to [bs, num_heads, seqlen, head_dim]
|
415 |
+
hidden_states = hidden_states.reshape([-1, attn.num_heads, sequence_length, attn.head_dim])
|
416 |
+
# reshape hidden_states
|
417 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
418 |
+
|
419 |
+
# linear proj
|
420 |
+
hidden_states = attn.to_out[0](hidden_states)
|
421 |
+
# dropout
|
422 |
+
hidden_states = attn.to_out[1](hidden_states)
|
423 |
+
|
424 |
+
hidden_states = hidden_states.transpose([0, 2, 1]).reshape(residual.shape)
|
425 |
+
hidden_states = hidden_states + residual
|
426 |
+
|
427 |
+
return hidden_states
|
428 |
+
|
429 |
+
|
430 |
+
AttnProcessor = Union[
|
431 |
+
CrossAttnProcessor,
|
432 |
+
SlicedAttnProcessor,
|
433 |
+
CrossAttnAddedKVProcessor,
|
434 |
+
SlicedAttnAddedKVProcessor,
|
435 |
+
]
|
ppdiffusers/models/ema.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import paddle
|
17 |
+
from paddle import nn
|
18 |
+
|
19 |
+
|
20 |
+
class LitEma(nn.Layer):
|
21 |
+
"""
|
22 |
+
Exponential Moving Average (EMA) of model updates
|
23 |
+
|
24 |
+
Parameters:
|
25 |
+
model: The model architecture for apply EMA.
|
26 |
+
decay: The exponential decay. Default 0.9999.
|
27 |
+
use_num_updates: Whether to use number of updates when computing
|
28 |
+
averages.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, model, decay=0.9999, use_num_upates=True):
|
32 |
+
super().__init__()
|
33 |
+
if decay < 0.0 or decay > 1.0:
|
34 |
+
raise ValueError("Decay must be between 0 and 1")
|
35 |
+
|
36 |
+
self.m_name2s_name = {}
|
37 |
+
self.register_buffer("decay", paddle.to_tensor(decay, dtype=paddle.float32))
|
38 |
+
self.register_buffer(
|
39 |
+
"num_updates",
|
40 |
+
paddle.to_tensor(0, dtype=paddle.int64) if use_num_upates else paddle.to_tensor(-1, dtype=paddle.int64),
|
41 |
+
)
|
42 |
+
|
43 |
+
for name, p in model.named_parameters():
|
44 |
+
if not p.stop_gradient:
|
45 |
+
# remove as '.'-character is not allowed in buffers
|
46 |
+
s_name = name.replace(".", "")
|
47 |
+
self.m_name2s_name.update({name: s_name})
|
48 |
+
self.register_buffer(s_name, p.clone().detach())
|
49 |
+
|
50 |
+
self.collected_params = []
|
51 |
+
|
52 |
+
def forward(self, model):
|
53 |
+
decay = self.decay
|
54 |
+
|
55 |
+
if self.num_updates >= 0:
|
56 |
+
self.num_updates += 1
|
57 |
+
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
|
58 |
+
|
59 |
+
one_minus_decay = 1.0 - decay
|
60 |
+
|
61 |
+
with paddle.no_grad():
|
62 |
+
m_param = dict(model.named_parameters())
|
63 |
+
shadow_params = dict(self.named_buffers())
|
64 |
+
|
65 |
+
for key in m_param:
|
66 |
+
if not m_param[key].stop_gradient:
|
67 |
+
sname = self.m_name2s_name[key]
|
68 |
+
shadow_params[sname].scale_(decay)
|
69 |
+
shadow_params[sname].add_(m_param[key] * one_minus_decay)
|
70 |
+
else:
|
71 |
+
assert key not in self.m_name2s_name
|
72 |
+
|
73 |
+
def copy_to(self, model):
|
74 |
+
m_param = dict(model.named_parameters())
|
75 |
+
shadow_params = dict(self.named_buffers())
|
76 |
+
for key in m_param:
|
77 |
+
if not m_param[key].stop_gradient:
|
78 |
+
m_param[key].copy_(shadow_params[self.m_name2s_name[key]], True)
|
79 |
+
else:
|
80 |
+
assert key not in self.m_name2s_name
|
81 |
+
|
82 |
+
def store(self, parameters):
|
83 |
+
"""
|
84 |
+
Save the current parameters for restoring later.
|
85 |
+
Args:
|
86 |
+
parameters: Iterable of `EagerParamBase`; the parameters to be
|
87 |
+
temporarily stored.
|
88 |
+
"""
|
89 |
+
self.collected_params = [param.clone() for param in parameters]
|
90 |
+
|
91 |
+
def restore(self, parameters):
|
92 |
+
"""
|
93 |
+
Restore the parameters stored with the `store` method.
|
94 |
+
Useful to validate the model with EMA parameters without affecting the
|
95 |
+
original optimization process. Store the parameters before the
|
96 |
+
`copy_to` method. After validation (or model saving), use this to
|
97 |
+
restore the former parameters.
|
98 |
+
Args:
|
99 |
+
parameters: Iterable of `EagerParamBase`; the parameters to be
|
100 |
+
updated with the stored parameters.
|
101 |
+
"""
|
102 |
+
for c_param, param in zip(self.collected_params, parameters):
|
103 |
+
param.copy_(c_param, True)
|
ppdiffusers/models/embeddings.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import math
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import paddle
|
19 |
+
from paddle import nn
|
20 |
+
|
21 |
+
|
22 |
+
def get_timestep_embedding(
|
23 |
+
timesteps: paddle.Tensor,
|
24 |
+
embedding_dim: int,
|
25 |
+
flip_sin_to_cos: bool = False,
|
26 |
+
downscale_freq_shift: float = 1,
|
27 |
+
scale: float = 1,
|
28 |
+
max_period: int = 10000,
|
29 |
+
):
|
30 |
+
"""
|
31 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
32 |
+
|
33 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
34 |
+
These may be fractional.
|
35 |
+
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
36 |
+
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
37 |
+
"""
|
38 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
39 |
+
|
40 |
+
half_dim = embedding_dim // 2
|
41 |
+
exponent = -math.log(max_period) * paddle.arange(start=0, end=half_dim, dtype="float32")
|
42 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
43 |
+
|
44 |
+
emb = paddle.exp(exponent)
|
45 |
+
emb = timesteps[:, None].cast("float32") * emb[None, :]
|
46 |
+
|
47 |
+
# scale embeddings
|
48 |
+
emb = scale * emb
|
49 |
+
|
50 |
+
# concat sine and cosine embeddings
|
51 |
+
emb = paddle.concat([paddle.sin(emb), paddle.cos(emb)], axis=-1)
|
52 |
+
|
53 |
+
# flip sine and cosine embeddings
|
54 |
+
if flip_sin_to_cos:
|
55 |
+
emb = paddle.concat([emb[:, half_dim:], emb[:, :half_dim]], axis=-1)
|
56 |
+
|
57 |
+
# zero pad
|
58 |
+
if embedding_dim % 2 == 1:
|
59 |
+
emb = paddle.concat(emb, paddle.zeros([emb.shape[0], 1]), axis=-1)
|
60 |
+
return emb
|
61 |
+
|
62 |
+
|
63 |
+
class TimestepEmbedding(nn.Layer):
|
64 |
+
def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
|
65 |
+
super().__init__()
|
66 |
+
|
67 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
68 |
+
self.act = None
|
69 |
+
if act_fn == "silu":
|
70 |
+
self.act = nn.Silu()
|
71 |
+
elif act_fn == "mish":
|
72 |
+
self.act = nn.Mish()
|
73 |
+
|
74 |
+
if out_dim is not None:
|
75 |
+
time_embed_dim_out = out_dim
|
76 |
+
else:
|
77 |
+
time_embed_dim_out = time_embed_dim
|
78 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
79 |
+
|
80 |
+
def forward(self, sample):
|
81 |
+
sample = self.linear_1(sample)
|
82 |
+
|
83 |
+
if self.act is not None:
|
84 |
+
sample = self.act(sample)
|
85 |
+
|
86 |
+
sample = self.linear_2(sample)
|
87 |
+
return sample
|
88 |
+
|
89 |
+
|
90 |
+
class Timesteps(nn.Layer):
|
91 |
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
|
92 |
+
super().__init__()
|
93 |
+
self.num_channels = num_channels
|
94 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
95 |
+
self.downscale_freq_shift = downscale_freq_shift
|
96 |
+
|
97 |
+
def forward(self, timesteps):
|
98 |
+
t_emb = get_timestep_embedding(
|
99 |
+
timesteps,
|
100 |
+
self.num_channels,
|
101 |
+
flip_sin_to_cos=self.flip_sin_to_cos,
|
102 |
+
downscale_freq_shift=self.downscale_freq_shift,
|
103 |
+
)
|
104 |
+
return t_emb
|
105 |
+
|
106 |
+
|
107 |
+
class GaussianFourierProjection(nn.Layer):
|
108 |
+
"""Gaussian Fourier embeddings for noise levels."""
|
109 |
+
|
110 |
+
def __init__(
|
111 |
+
self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
|
112 |
+
):
|
113 |
+
super().__init__()
|
114 |
+
self.register_buffer("weight", paddle.randn((embedding_size,)) * scale)
|
115 |
+
self.log = log
|
116 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
117 |
+
|
118 |
+
if set_W_to_weight:
|
119 |
+
# to delete later
|
120 |
+
self.register_buffer("W", paddle.randn((embedding_size,)) * scale)
|
121 |
+
|
122 |
+
self.weight = self.W
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
if self.log:
|
126 |
+
x = paddle.log(x.cast(self.weight.dtype))
|
127 |
+
|
128 |
+
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
|
129 |
+
|
130 |
+
if self.flip_sin_to_cos:
|
131 |
+
out = paddle.concat([paddle.cos(x_proj), paddle.sin(x_proj)], axis=-1)
|
132 |
+
else:
|
133 |
+
out = paddle.concat([paddle.sin(x_proj), paddle.cos(x_proj)], axis=-1)
|
134 |
+
return out
|
135 |
+
|
136 |
+
|
137 |
+
class ImagePositionalEmbeddings(nn.Layer):
|
138 |
+
"""
|
139 |
+
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
|
140 |
+
height and width of the latent space.
|
141 |
+
|
142 |
+
For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
|
143 |
+
|
144 |
+
For VQ-diffusion:
|
145 |
+
|
146 |
+
Output vector embeddings are used as input for the transformer.
|
147 |
+
|
148 |
+
Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
num_embed (`int`):
|
152 |
+
Number of embeddings for the latent pixels embeddings.
|
153 |
+
height (`int`):
|
154 |
+
Height of the latent image i.e. the number of height embeddings.
|
155 |
+
width (`int`):
|
156 |
+
Width of the latent image i.e. the number of width embeddings.
|
157 |
+
embed_dim (`int`):
|
158 |
+
Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
|
159 |
+
"""
|
160 |
+
|
161 |
+
def __init__(
|
162 |
+
self,
|
163 |
+
num_embed: int,
|
164 |
+
height: int,
|
165 |
+
width: int,
|
166 |
+
embed_dim: int,
|
167 |
+
):
|
168 |
+
super().__init__()
|
169 |
+
|
170 |
+
self.height = height
|
171 |
+
self.width = width
|
172 |
+
self.num_embed = num_embed
|
173 |
+
self.embed_dim = embed_dim
|
174 |
+
|
175 |
+
self.emb = nn.Embedding(self.num_embed, embed_dim)
|
176 |
+
self.height_emb = nn.Embedding(self.height, embed_dim)
|
177 |
+
self.width_emb = nn.Embedding(self.width, embed_dim)
|
178 |
+
|
179 |
+
def forward(self, index):
|
180 |
+
emb = self.emb(index)
|
181 |
+
|
182 |
+
height_emb = self.height_emb(paddle.arange(self.height).reshape([1, self.height]))
|
183 |
+
|
184 |
+
# 1 x H x D -> 1 x H x 1 x D
|
185 |
+
height_emb = height_emb.unsqueeze(2)
|
186 |
+
|
187 |
+
width_emb = self.width_emb(paddle.arange(self.width).reshape([1, self.width]))
|
188 |
+
|
189 |
+
# 1 x W x D -> 1 x 1 x W x D
|
190 |
+
width_emb = width_emb.unsqueeze(1)
|
191 |
+
|
192 |
+
pos_emb = height_emb + width_emb
|
193 |
+
|
194 |
+
# 1 x H x W x D -> 1 x L xD
|
195 |
+
pos_emb = pos_emb.reshape([1, self.height * self.width, -1])
|
196 |
+
|
197 |
+
emb = emb + pos_emb[:, : emb.shape[1], :]
|
198 |
+
|
199 |
+
return emb
|
ppdiffusers/models/prior_transformer.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from typing import Optional, Union
|
17 |
+
|
18 |
+
import paddle
|
19 |
+
import paddle.nn as nn
|
20 |
+
import paddle.nn.functional as F
|
21 |
+
|
22 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
23 |
+
from ..modeling_utils import ModelMixin
|
24 |
+
from ..utils import BaseOutput
|
25 |
+
from .attention import BasicTransformerBlock
|
26 |
+
from .embeddings import TimestepEmbedding, Timesteps
|
27 |
+
|
28 |
+
NEG_INF = -1e4
|
29 |
+
|
30 |
+
|
31 |
+
@dataclass
|
32 |
+
class PriorTransformerOutput(BaseOutput):
|
33 |
+
"""
|
34 |
+
Args:
|
35 |
+
predicted_image_embedding (`paddle.Tensor` of shape `(batch_size, embedding_dim)`):
|
36 |
+
The predicted CLIP image embedding conditioned on the CLIP text embedding input.
|
37 |
+
"""
|
38 |
+
|
39 |
+
predicted_image_embedding: paddle.Tensor
|
40 |
+
|
41 |
+
|
42 |
+
class PriorTransformer(ModelMixin, ConfigMixin):
|
43 |
+
"""
|
44 |
+
The prior transformer from unCLIP is used to predict CLIP image embeddings from CLIP text embeddings. Note that the
|
45 |
+
transformer predicts the image embeddings through a denoising diffusion process.
|
46 |
+
|
47 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
48 |
+
implements for all the models (such as downloading or saving, etc.)
|
49 |
+
|
50 |
+
For more details, see the original paper: https://arxiv.org/abs/2204.06125
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
|
54 |
+
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
55 |
+
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
|
56 |
+
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the CLIP embeddings. Note that CLIP
|
57 |
+
image embeddings and text embeddings are both the same dimension.
|
58 |
+
num_embeddings (`int`, *optional*, defaults to 77): The max number of clip embeddings allowed. I.e. the
|
59 |
+
length of the prompt after it has been tokenized.
|
60 |
+
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
|
61 |
+
projected hidden_states. The actual length of the used hidden_states is `num_embeddings +
|
62 |
+
additional_embeddings`.
|
63 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
64 |
+
|
65 |
+
"""
|
66 |
+
|
67 |
+
@register_to_config
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
num_attention_heads: int = 32,
|
71 |
+
attention_head_dim: int = 64,
|
72 |
+
num_layers: int = 20,
|
73 |
+
embedding_dim: int = 768,
|
74 |
+
num_embeddings=77,
|
75 |
+
additional_embeddings=4,
|
76 |
+
dropout: float = 0.0,
|
77 |
+
):
|
78 |
+
super().__init__()
|
79 |
+
self.num_attention_heads = num_attention_heads
|
80 |
+
self.attention_head_dim = attention_head_dim
|
81 |
+
inner_dim = num_attention_heads * attention_head_dim
|
82 |
+
self.additional_embeddings = additional_embeddings
|
83 |
+
|
84 |
+
self.time_proj = Timesteps(inner_dim, True, 0)
|
85 |
+
self.time_embedding = TimestepEmbedding(inner_dim, inner_dim)
|
86 |
+
|
87 |
+
self.proj_in = nn.Linear(embedding_dim, inner_dim)
|
88 |
+
|
89 |
+
self.embedding_proj = nn.Linear(embedding_dim, inner_dim)
|
90 |
+
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
|
91 |
+
|
92 |
+
self.positional_embedding = self.create_parameter(
|
93 |
+
(1, num_embeddings + additional_embeddings, inner_dim),
|
94 |
+
dtype=paddle.get_default_dtype(),
|
95 |
+
default_initializer=nn.initializer.Constant(0.0),
|
96 |
+
)
|
97 |
+
|
98 |
+
self.prd_embedding = self.create_parameter(
|
99 |
+
(1, 1, inner_dim), dtype=paddle.get_default_dtype(), default_initializer=nn.initializer.Constant(0.0)
|
100 |
+
)
|
101 |
+
|
102 |
+
self.transformer_blocks = nn.LayerList(
|
103 |
+
[
|
104 |
+
BasicTransformerBlock(
|
105 |
+
inner_dim,
|
106 |
+
num_attention_heads,
|
107 |
+
attention_head_dim,
|
108 |
+
dropout=dropout,
|
109 |
+
activation_fn="gelu",
|
110 |
+
attention_bias=True,
|
111 |
+
)
|
112 |
+
for d in range(num_layers)
|
113 |
+
]
|
114 |
+
)
|
115 |
+
|
116 |
+
self.norm_out = nn.LayerNorm(inner_dim)
|
117 |
+
self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim)
|
118 |
+
|
119 |
+
causal_attention_mask = paddle.triu(
|
120 |
+
paddle.full([num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], NEG_INF), 1
|
121 |
+
)
|
122 |
+
causal_attention_mask = causal_attention_mask.unsqueeze(0)
|
123 |
+
self.register_buffer("causal_attention_mask", causal_attention_mask, persistable=False)
|
124 |
+
|
125 |
+
self.clip_mean = self.create_parameter(
|
126 |
+
(1, embedding_dim), dtype=paddle.get_default_dtype(), default_initializer=nn.initializer.Constant(0.0)
|
127 |
+
)
|
128 |
+
self.clip_std = self.create_parameter(
|
129 |
+
(1, embedding_dim), dtype=paddle.get_default_dtype(), default_initializer=nn.initializer.Constant(0.0)
|
130 |
+
)
|
131 |
+
|
132 |
+
def forward(
|
133 |
+
self,
|
134 |
+
hidden_states,
|
135 |
+
timestep: Union[paddle.Tensor, float, int],
|
136 |
+
proj_embedding: paddle.Tensor,
|
137 |
+
encoder_hidden_states: paddle.Tensor,
|
138 |
+
attention_mask: Optional[paddle.Tensor] = None,
|
139 |
+
return_dict: bool = True,
|
140 |
+
):
|
141 |
+
"""
|
142 |
+
Args:
|
143 |
+
hidden_states (`paddle.Tensor` of shape `(batch_size, embedding_dim)`):
|
144 |
+
x_t, the currently predicted image embeddings.
|
145 |
+
timestep (`paddle.Tensor`):
|
146 |
+
Current denoising step.
|
147 |
+
proj_embedding (`paddle.Tensor` of shape `(batch_size, embedding_dim)`):
|
148 |
+
Projected embedding vector the denoising process is conditioned on.
|
149 |
+
encoder_hidden_states (`paddle.Tensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
|
150 |
+
Hidden states of the text embeddings the denoising process is conditioned on.
|
151 |
+
attention_mask (`paddle.Tensor` of shape `(batch_size, num_embeddings)`):
|
152 |
+
Text mask for the text embeddings.
|
153 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
154 |
+
Whether or not to return a [`models.prior_transformer.PriorTransformerOutput`] instead of a plain
|
155 |
+
tuple.
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
|
159 |
+
[`~models.prior_transformer.PriorTransformerOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
160 |
+
returning a tuple, the first element is the sample tensor.
|
161 |
+
"""
|
162 |
+
batch_size = hidden_states.shape[0]
|
163 |
+
|
164 |
+
timesteps = timestep
|
165 |
+
if not paddle.is_tensor(timesteps):
|
166 |
+
timesteps = paddle.to_tensor([timesteps], dtype=paddle.int64)
|
167 |
+
elif paddle.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
168 |
+
timesteps = timesteps[None]
|
169 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
170 |
+
timesteps = timesteps * paddle.ones((batch_size,), dtype=timesteps.dtype)
|
171 |
+
|
172 |
+
timesteps_projected = self.time_proj(timesteps)
|
173 |
+
|
174 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
175 |
+
# but time_embedding might be fp16, so we need to cast here.
|
176 |
+
timesteps_projected = timesteps_projected.cast(dtype=self.dtype)
|
177 |
+
time_embeddings = self.time_embedding(timesteps_projected)
|
178 |
+
|
179 |
+
proj_embeddings = self.embedding_proj(proj_embedding)
|
180 |
+
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
|
181 |
+
hidden_states = self.proj_in(hidden_states)
|
182 |
+
prd_embedding = self.prd_embedding.cast(hidden_states.dtype).expand([batch_size, -1, -1])
|
183 |
+
positional_embeddings = self.positional_embedding.cast(hidden_states.dtype)
|
184 |
+
|
185 |
+
hidden_states = paddle.concat(
|
186 |
+
[
|
187 |
+
encoder_hidden_states,
|
188 |
+
proj_embeddings[:, None, :],
|
189 |
+
time_embeddings[:, None, :],
|
190 |
+
hidden_states[:, None, :],
|
191 |
+
prd_embedding,
|
192 |
+
],
|
193 |
+
axis=1,
|
194 |
+
)
|
195 |
+
|
196 |
+
hidden_states = hidden_states + positional_embeddings
|
197 |
+
|
198 |
+
if attention_mask is not None:
|
199 |
+
attention_mask = (1 - attention_mask.cast(hidden_states.dtype)) * -10000.0
|
200 |
+
attention_mask = F.pad(
|
201 |
+
attention_mask.unsqueeze(0), (0, self.additional_embeddings), value=0.0, data_format="NCL"
|
202 |
+
).squeeze(0)
|
203 |
+
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).cast(hidden_states.dtype)
|
204 |
+
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, axis=0)
|
205 |
+
|
206 |
+
for block in self.transformer_blocks:
|
207 |
+
hidden_states = block(hidden_states, attention_mask=attention_mask)
|
208 |
+
|
209 |
+
hidden_states = self.norm_out(hidden_states)
|
210 |
+
hidden_states = hidden_states[:, -1]
|
211 |
+
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
|
212 |
+
|
213 |
+
if not return_dict:
|
214 |
+
return (predicted_image_embedding,)
|
215 |
+
|
216 |
+
return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
|
217 |
+
|
218 |
+
def post_process_latents(self, prior_latents):
|
219 |
+
prior_latents = (prior_latents * self.clip_std) + self.clip_mean
|
220 |
+
return prior_latents
|
ppdiffusers/models/resnet.py
ADDED
@@ -0,0 +1,716 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from functools import partial
|
17 |
+
|
18 |
+
import paddle
|
19 |
+
import paddle.nn as nn
|
20 |
+
import paddle.nn.functional as F
|
21 |
+
|
22 |
+
|
23 |
+
class Upsample1D(nn.Layer):
|
24 |
+
"""
|
25 |
+
An upsampling layer with an optional convolution.
|
26 |
+
|
27 |
+
Parameters:
|
28 |
+
channels: channels in the inputs and outputs.
|
29 |
+
use_conv: a bool determining if a convolution is applied.
|
30 |
+
use_conv_transpose:
|
31 |
+
out_channels:
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
35 |
+
super().__init__()
|
36 |
+
self.channels = channels
|
37 |
+
self.out_channels = out_channels or channels
|
38 |
+
self.use_conv = use_conv
|
39 |
+
self.use_conv_transpose = use_conv_transpose
|
40 |
+
self.name = name
|
41 |
+
|
42 |
+
self.conv = None
|
43 |
+
if use_conv_transpose:
|
44 |
+
self.conv = nn.Conv1DTranspose(channels, self.out_channels, 4, 2, 1)
|
45 |
+
elif use_conv:
|
46 |
+
self.conv = nn.Conv1D(self.channels, self.out_channels, 3, padding=1)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
assert x.shape[1] == self.channels
|
50 |
+
if self.use_conv_transpose:
|
51 |
+
return self.conv(x)
|
52 |
+
|
53 |
+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
54 |
+
|
55 |
+
if self.use_conv:
|
56 |
+
x = self.conv(x)
|
57 |
+
|
58 |
+
return x
|
59 |
+
|
60 |
+
|
61 |
+
class Downsample1D(nn.Layer):
|
62 |
+
"""
|
63 |
+
A downsampling layer with an optional convolution.
|
64 |
+
|
65 |
+
Parameters:
|
66 |
+
channels: channels in the inputs and outputs.
|
67 |
+
use_conv: a bool determining if a convolution is applied.
|
68 |
+
out_channels:
|
69 |
+
padding:
|
70 |
+
"""
|
71 |
+
|
72 |
+
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
73 |
+
super().__init__()
|
74 |
+
self.channels = channels
|
75 |
+
self.out_channels = out_channels or channels
|
76 |
+
self.use_conv = use_conv
|
77 |
+
self.padding = padding
|
78 |
+
stride = 2
|
79 |
+
self.name = name
|
80 |
+
|
81 |
+
if use_conv:
|
82 |
+
self.conv = nn.Conv1D(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
83 |
+
else:
|
84 |
+
assert self.channels == self.out_channels
|
85 |
+
self.conv = nn.AvgPool1D(kernel_size=stride, stride=stride)
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
assert x.shape[1] == self.channels
|
89 |
+
return self.conv(x)
|
90 |
+
|
91 |
+
|
92 |
+
class Upsample2D(nn.Layer):
|
93 |
+
"""
|
94 |
+
An upsampling layer with an optional convolution.
|
95 |
+
|
96 |
+
Parameters:
|
97 |
+
channels: channels in the inputs and outputs.
|
98 |
+
use_conv: a bool determining if a convolution is applied.
|
99 |
+
use_conv_transpose:
|
100 |
+
out_channels:
|
101 |
+
"""
|
102 |
+
|
103 |
+
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
104 |
+
super().__init__()
|
105 |
+
self.channels = channels
|
106 |
+
self.out_channels = out_channels or channels
|
107 |
+
self.use_conv = use_conv
|
108 |
+
self.use_conv_transpose = use_conv_transpose
|
109 |
+
self.name = name
|
110 |
+
|
111 |
+
conv = None
|
112 |
+
if use_conv_transpose:
|
113 |
+
conv = nn.Conv2DTranspose(channels, self.out_channels, 4, 2, 1)
|
114 |
+
elif use_conv:
|
115 |
+
conv = nn.Conv2D(self.channels, self.out_channels, 3, padding=1)
|
116 |
+
|
117 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
118 |
+
if name == "conv":
|
119 |
+
self.conv = conv
|
120 |
+
else:
|
121 |
+
self.Conv2d_0 = conv
|
122 |
+
|
123 |
+
def forward(self, hidden_states, output_size=None):
|
124 |
+
assert hidden_states.shape[1] == self.channels
|
125 |
+
|
126 |
+
if self.use_conv_transpose:
|
127 |
+
return self.conv(hidden_states)
|
128 |
+
|
129 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
130 |
+
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
131 |
+
# https://github.com/pytorch/pytorch/issues/86679
|
132 |
+
dtype = hidden_states.dtype
|
133 |
+
if dtype == paddle.bfloat16:
|
134 |
+
hidden_states = hidden_states.cast("float32")
|
135 |
+
|
136 |
+
# if `output_size` is passed we force the interpolation output
|
137 |
+
# size and do not make use of `scale_factor=2`
|
138 |
+
if output_size is None:
|
139 |
+
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
140 |
+
else:
|
141 |
+
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
142 |
+
|
143 |
+
# If the input is bfloat16, we cast back to bfloat16
|
144 |
+
if dtype == paddle.bfloat16:
|
145 |
+
hidden_states = hidden_states.cast(dtype)
|
146 |
+
|
147 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
148 |
+
if self.use_conv:
|
149 |
+
if self.name == "conv":
|
150 |
+
hidden_states = self.conv(hidden_states)
|
151 |
+
else:
|
152 |
+
hidden_states = self.Conv2d_0(hidden_states)
|
153 |
+
|
154 |
+
return hidden_states
|
155 |
+
|
156 |
+
|
157 |
+
class Downsample2D(nn.Layer):
|
158 |
+
"""
|
159 |
+
A downsampling layer with an optional convolution.
|
160 |
+
|
161 |
+
Parameters:
|
162 |
+
channels: channels in the inputs and outputs.
|
163 |
+
use_conv: a bool determining if a convolution is applied.
|
164 |
+
out_channels:
|
165 |
+
padding:
|
166 |
+
"""
|
167 |
+
|
168 |
+
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
169 |
+
super().__init__()
|
170 |
+
self.channels = channels
|
171 |
+
self.out_channels = out_channels or channels
|
172 |
+
self.use_conv = use_conv
|
173 |
+
self.padding = padding
|
174 |
+
stride = 2
|
175 |
+
self.name = name
|
176 |
+
|
177 |
+
if use_conv:
|
178 |
+
conv = nn.Conv2D(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
179 |
+
else:
|
180 |
+
assert self.channels == self.out_channels
|
181 |
+
conv = nn.AvgPool2D(kernel_size=stride, stride=stride)
|
182 |
+
|
183 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
184 |
+
if name == "conv":
|
185 |
+
self.Conv2d_0 = conv
|
186 |
+
self.conv = conv
|
187 |
+
elif name == "Conv2d_0":
|
188 |
+
self.conv = conv
|
189 |
+
else:
|
190 |
+
self.conv = conv
|
191 |
+
|
192 |
+
def forward(self, hidden_states):
|
193 |
+
assert hidden_states.shape[1] == self.channels
|
194 |
+
if self.use_conv and self.padding == 0:
|
195 |
+
pad = (0, 1, 0, 1)
|
196 |
+
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
|
197 |
+
|
198 |
+
assert hidden_states.shape[1] == self.channels
|
199 |
+
hidden_states = self.conv(hidden_states)
|
200 |
+
|
201 |
+
return hidden_states
|
202 |
+
|
203 |
+
|
204 |
+
class FirUpsample2D(nn.Layer):
|
205 |
+
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
206 |
+
super().__init__()
|
207 |
+
out_channels = out_channels if out_channels else channels
|
208 |
+
if use_conv:
|
209 |
+
self.Conv2d_0 = nn.Conv2D(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
210 |
+
self.use_conv = use_conv
|
211 |
+
self.fir_kernel = fir_kernel
|
212 |
+
self.out_channels = out_channels
|
213 |
+
|
214 |
+
def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
|
215 |
+
"""Fused `upsample_2d()` followed by `Conv2d()`.
|
216 |
+
|
217 |
+
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
218 |
+
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
219 |
+
arbitrary order.
|
220 |
+
|
221 |
+
Args:
|
222 |
+
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
223 |
+
weight: Weight tensor of the shape `[filterH, filterW, inChannels,
|
224 |
+
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
|
225 |
+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
|
226 |
+
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
227 |
+
factor: Integer upsampling factor (default: 2).
|
228 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
229 |
+
|
230 |
+
Returns:
|
231 |
+
output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
|
232 |
+
datatype as `hidden_states`.
|
233 |
+
"""
|
234 |
+
|
235 |
+
assert isinstance(factor, int) and factor >= 1
|
236 |
+
|
237 |
+
# Setup filter kernel.
|
238 |
+
if kernel is None:
|
239 |
+
kernel = [1] * factor
|
240 |
+
|
241 |
+
# setup kernel
|
242 |
+
kernel = paddle.to_tensor(kernel, dtype="float32")
|
243 |
+
if kernel.ndim == 1:
|
244 |
+
kernel = paddle.outer(kernel, kernel)
|
245 |
+
kernel /= paddle.sum(kernel)
|
246 |
+
|
247 |
+
kernel = kernel * (gain * (factor**2))
|
248 |
+
|
249 |
+
if self.use_conv:
|
250 |
+
convH = weight.shape[2]
|
251 |
+
convW = weight.shape[3]
|
252 |
+
inC = weight.shape[1]
|
253 |
+
|
254 |
+
pad_value = (kernel.shape[0] - factor) - (convW - 1)
|
255 |
+
|
256 |
+
stride = (factor, factor)
|
257 |
+
# Determine data dimensions.
|
258 |
+
output_shape = (
|
259 |
+
(hidden_states.shape[2] - 1) * factor + convH,
|
260 |
+
(hidden_states.shape[3] - 1) * factor + convW,
|
261 |
+
)
|
262 |
+
output_padding = (
|
263 |
+
output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
|
264 |
+
output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
|
265 |
+
)
|
266 |
+
assert output_padding[0] >= 0 and output_padding[1] >= 0
|
267 |
+
num_groups = hidden_states.shape[1] // inC
|
268 |
+
|
269 |
+
# Transpose weights.
|
270 |
+
weight = weight.reshape([num_groups, -1, inC, convH, convW])
|
271 |
+
weight = paddle.flip(weight, axis=[3, 4]).transpose([0, 2, 1, 3, 4])
|
272 |
+
weight = weight.reshape([num_groups * inC, -1, convH, convW])
|
273 |
+
|
274 |
+
inverse_conv = F.conv2d_transpose(
|
275 |
+
hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
|
276 |
+
)
|
277 |
+
|
278 |
+
output = upfirdn2d_native(
|
279 |
+
inverse_conv,
|
280 |
+
paddle.to_tensor(kernel),
|
281 |
+
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
|
282 |
+
)
|
283 |
+
else:
|
284 |
+
pad_value = kernel.shape[0] - factor
|
285 |
+
output = upfirdn2d_native(
|
286 |
+
hidden_states,
|
287 |
+
paddle.to_tensor(kernel),
|
288 |
+
up=factor,
|
289 |
+
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
|
290 |
+
)
|
291 |
+
|
292 |
+
return output
|
293 |
+
|
294 |
+
def forward(self, hidden_states):
|
295 |
+
if self.use_conv:
|
296 |
+
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
|
297 |
+
height = height + self.Conv2d_0.bias.reshape([1, -1, 1, 1])
|
298 |
+
else:
|
299 |
+
height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
|
300 |
+
|
301 |
+
return height
|
302 |
+
|
303 |
+
|
304 |
+
class FirDownsample2D(nn.Layer):
|
305 |
+
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
306 |
+
super().__init__()
|
307 |
+
out_channels = out_channels if out_channels else channels
|
308 |
+
if use_conv:
|
309 |
+
self.Conv2d_0 = nn.Conv2D(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
310 |
+
self.fir_kernel = fir_kernel
|
311 |
+
self.use_conv = use_conv
|
312 |
+
self.out_channels = out_channels
|
313 |
+
|
314 |
+
def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
|
315 |
+
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
316 |
+
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
317 |
+
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
318 |
+
arbitrary order.
|
319 |
+
|
320 |
+
Args:
|
321 |
+
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
322 |
+
weight:
|
323 |
+
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
|
324 |
+
performed by `inChannels = x.shape[0] // numGroups`.
|
325 |
+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
|
326 |
+
factor`, which corresponds to average pooling.
|
327 |
+
factor: Integer downsampling factor (default: 2).
|
328 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
329 |
+
|
330 |
+
Returns:
|
331 |
+
output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
|
332 |
+
same datatype as `x`.
|
333 |
+
"""
|
334 |
+
|
335 |
+
assert isinstance(factor, int) and factor >= 1
|
336 |
+
if kernel is None:
|
337 |
+
kernel = [1] * factor
|
338 |
+
|
339 |
+
# setup kernel
|
340 |
+
kernel = paddle.to_tensor(kernel, dtype="float32")
|
341 |
+
if kernel.ndim == 1:
|
342 |
+
kernel = paddle.outer(kernel, kernel)
|
343 |
+
kernel /= paddle.sum(kernel)
|
344 |
+
|
345 |
+
kernel = kernel * gain
|
346 |
+
|
347 |
+
if self.use_conv:
|
348 |
+
_, _, convH, convW = weight.shape
|
349 |
+
pad_value = (kernel.shape[0] - factor) + (convW - 1)
|
350 |
+
stride_value = [factor, factor]
|
351 |
+
upfirdn_input = upfirdn2d_native(
|
352 |
+
hidden_states,
|
353 |
+
paddle.to_tensor(kernel),
|
354 |
+
pad=((pad_value + 1) // 2, pad_value // 2),
|
355 |
+
)
|
356 |
+
output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
|
357 |
+
else:
|
358 |
+
pad_value = kernel.shape[0] - factor
|
359 |
+
output = upfirdn2d_native(
|
360 |
+
hidden_states,
|
361 |
+
paddle.to_tensor(kernel),
|
362 |
+
down=factor,
|
363 |
+
pad=((pad_value + 1) // 2, pad_value // 2),
|
364 |
+
)
|
365 |
+
|
366 |
+
return output
|
367 |
+
|
368 |
+
def forward(self, hidden_states):
|
369 |
+
if self.use_conv:
|
370 |
+
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
|
371 |
+
hidden_states = downsample_input + self.Conv2d_0.bias.reshape([1, -1, 1, 1])
|
372 |
+
else:
|
373 |
+
hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
|
374 |
+
|
375 |
+
return hidden_states
|
376 |
+
|
377 |
+
|
378 |
+
class ResnetBlock2D(nn.Layer):
|
379 |
+
def __init__(
|
380 |
+
self,
|
381 |
+
*,
|
382 |
+
in_channels,
|
383 |
+
out_channels=None,
|
384 |
+
conv_shortcut=False,
|
385 |
+
dropout=0.0,
|
386 |
+
temb_channels=512,
|
387 |
+
groups=32,
|
388 |
+
groups_out=None,
|
389 |
+
pre_norm=True,
|
390 |
+
eps=1e-6,
|
391 |
+
non_linearity="swish",
|
392 |
+
time_embedding_norm="default",
|
393 |
+
kernel=None,
|
394 |
+
output_scale_factor=1.0,
|
395 |
+
use_in_shortcut=None,
|
396 |
+
up=False,
|
397 |
+
down=False,
|
398 |
+
):
|
399 |
+
super().__init__()
|
400 |
+
self.pre_norm = pre_norm
|
401 |
+
self.pre_norm = True
|
402 |
+
self.in_channels = in_channels
|
403 |
+
out_channels = in_channels if out_channels is None else out_channels
|
404 |
+
self.out_channels = out_channels
|
405 |
+
self.use_conv_shortcut = conv_shortcut
|
406 |
+
self.time_embedding_norm = time_embedding_norm
|
407 |
+
self.up = up
|
408 |
+
self.down = down
|
409 |
+
self.output_scale_factor = output_scale_factor
|
410 |
+
|
411 |
+
if groups_out is None:
|
412 |
+
groups_out = groups
|
413 |
+
|
414 |
+
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, epsilon=eps)
|
415 |
+
|
416 |
+
self.conv1 = nn.Conv2D(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
417 |
+
|
418 |
+
if temb_channels is not None:
|
419 |
+
if self.time_embedding_norm == "default":
|
420 |
+
time_emb_proj_out_channels = out_channels
|
421 |
+
elif self.time_embedding_norm == "scale_shift":
|
422 |
+
time_emb_proj_out_channels = out_channels * 2
|
423 |
+
else:
|
424 |
+
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
|
425 |
+
|
426 |
+
self.time_emb_proj = nn.Linear(temb_channels, time_emb_proj_out_channels)
|
427 |
+
else:
|
428 |
+
self.time_emb_proj = None
|
429 |
+
|
430 |
+
self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, epsilon=eps)
|
431 |
+
self.dropout = nn.Dropout(dropout)
|
432 |
+
self.conv2 = nn.Conv2D(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
433 |
+
|
434 |
+
if non_linearity == "swish":
|
435 |
+
self.nonlinearity = lambda x: F.silu(x)
|
436 |
+
elif non_linearity == "mish":
|
437 |
+
self.nonlinearity = Mish()
|
438 |
+
elif non_linearity == "silu":
|
439 |
+
self.nonlinearity = nn.Silu()
|
440 |
+
|
441 |
+
self.upsample = self.downsample = None
|
442 |
+
if self.up:
|
443 |
+
if kernel == "fir":
|
444 |
+
fir_kernel = (1, 3, 3, 1)
|
445 |
+
self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
|
446 |
+
elif kernel == "sde_vp":
|
447 |
+
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
|
448 |
+
else:
|
449 |
+
self.upsample = Upsample2D(in_channels, use_conv=False)
|
450 |
+
elif self.down:
|
451 |
+
if kernel == "fir":
|
452 |
+
fir_kernel = (1, 3, 3, 1)
|
453 |
+
self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
|
454 |
+
elif kernel == "sde_vp":
|
455 |
+
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
|
456 |
+
else:
|
457 |
+
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
|
458 |
+
|
459 |
+
self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
|
460 |
+
|
461 |
+
self.conv_shortcut = None
|
462 |
+
if self.use_in_shortcut:
|
463 |
+
self.conv_shortcut = nn.Conv2D(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
464 |
+
|
465 |
+
def forward(self, input_tensor, temb):
|
466 |
+
hidden_states = input_tensor
|
467 |
+
|
468 |
+
hidden_states = self.norm1(hidden_states)
|
469 |
+
hidden_states = self.nonlinearity(hidden_states)
|
470 |
+
|
471 |
+
if self.upsample is not None:
|
472 |
+
input_tensor = self.upsample(input_tensor)
|
473 |
+
hidden_states = self.upsample(hidden_states)
|
474 |
+
elif self.downsample is not None:
|
475 |
+
input_tensor = self.downsample(input_tensor)
|
476 |
+
hidden_states = self.downsample(hidden_states)
|
477 |
+
|
478 |
+
hidden_states = self.conv1(hidden_states)
|
479 |
+
|
480 |
+
if temb is not None:
|
481 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
|
482 |
+
|
483 |
+
if temb is not None and self.time_embedding_norm == "default":
|
484 |
+
hidden_states = hidden_states + temb
|
485 |
+
|
486 |
+
hidden_states = self.norm2(hidden_states)
|
487 |
+
|
488 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
489 |
+
scale, shift = paddle.chunk(temb, 2, axis=1)
|
490 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
491 |
+
|
492 |
+
hidden_states = self.nonlinearity(hidden_states)
|
493 |
+
|
494 |
+
hidden_states = self.dropout(hidden_states)
|
495 |
+
hidden_states = self.conv2(hidden_states)
|
496 |
+
|
497 |
+
if self.conv_shortcut is not None:
|
498 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
499 |
+
|
500 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
501 |
+
|
502 |
+
return output_tensor
|
503 |
+
|
504 |
+
|
505 |
+
class Mish(nn.Layer):
|
506 |
+
def forward(self, hidden_states):
|
507 |
+
return hidden_states * paddle.tanh(F.softplus(hidden_states))
|
508 |
+
|
509 |
+
|
510 |
+
# unet_rl.py
|
511 |
+
def rearrange_dims(tensor):
|
512 |
+
if len(tensor.shape) == 2:
|
513 |
+
return tensor[:, :, None]
|
514 |
+
if len(tensor.shape) == 3:
|
515 |
+
return tensor[:, :, None, :]
|
516 |
+
elif len(tensor.shape) == 4:
|
517 |
+
return tensor[:, :, 0, :]
|
518 |
+
else:
|
519 |
+
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
|
520 |
+
|
521 |
+
|
522 |
+
class Conv1dBlock(nn.Layer):
|
523 |
+
"""
|
524 |
+
Conv1d --> GroupNorm --> Mish
|
525 |
+
"""
|
526 |
+
|
527 |
+
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
528 |
+
super().__init__()
|
529 |
+
|
530 |
+
self.conv1d = nn.Conv1D(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
|
531 |
+
self.group_norm = nn.GroupNorm(n_groups, out_channels)
|
532 |
+
self.mish = nn.Mish()
|
533 |
+
|
534 |
+
def forward(self, x):
|
535 |
+
x = self.conv1d(x)
|
536 |
+
x = rearrange_dims(x)
|
537 |
+
x = self.group_norm(x)
|
538 |
+
x = rearrange_dims(x)
|
539 |
+
x = self.mish(x)
|
540 |
+
return x
|
541 |
+
|
542 |
+
|
543 |
+
# unet_rl.py
|
544 |
+
class ResidualTemporalBlock1D(nn.Layer):
|
545 |
+
def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
|
546 |
+
super().__init__()
|
547 |
+
self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
|
548 |
+
self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
|
549 |
+
|
550 |
+
self.time_emb_act = nn.Mish()
|
551 |
+
self.time_emb = nn.Linear(embed_dim, out_channels)
|
552 |
+
|
553 |
+
self.residual_conv = (
|
554 |
+
nn.Conv1D(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
|
555 |
+
)
|
556 |
+
|
557 |
+
def forward(self, x, t):
|
558 |
+
"""
|
559 |
+
Args:
|
560 |
+
x : [ batch_size x inp_channels x horizon ]
|
561 |
+
t : [ batch_size x embed_dim ]
|
562 |
+
|
563 |
+
returns:
|
564 |
+
out : [ batch_size x out_channels x horizon ]
|
565 |
+
"""
|
566 |
+
t = self.time_emb_act(t)
|
567 |
+
t = self.time_emb(t)
|
568 |
+
out = self.conv_in(x) + rearrange_dims(t)
|
569 |
+
out = self.conv_out(out)
|
570 |
+
return out + self.residual_conv(x)
|
571 |
+
|
572 |
+
|
573 |
+
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
574 |
+
r"""Upsample2D a batch of 2D images with the given filter.
|
575 |
+
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
576 |
+
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
|
577 |
+
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
|
578 |
+
a: multiple of the upsampling factor.
|
579 |
+
|
580 |
+
Args:
|
581 |
+
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
582 |
+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
|
583 |
+
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
584 |
+
factor: Integer upsampling factor (default: 2).
|
585 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
586 |
+
|
587 |
+
Returns:
|
588 |
+
output: Tensor of the shape `[N, C, H * factor, W * factor]`
|
589 |
+
"""
|
590 |
+
assert isinstance(factor, int) and factor >= 1
|
591 |
+
if kernel is None:
|
592 |
+
kernel = [1] * factor
|
593 |
+
|
594 |
+
kernel = paddle.to_tensor(kernel, dtype="float32")
|
595 |
+
if kernel.ndim == 1:
|
596 |
+
kernel = paddle.outer(kernel, kernel)
|
597 |
+
kernel /= paddle.sum(kernel)
|
598 |
+
|
599 |
+
if gain != 1:
|
600 |
+
kernel = kernel * (gain * (factor**2))
|
601 |
+
else:
|
602 |
+
kernel = kernel * (factor**2)
|
603 |
+
pad_value = kernel.shape[0] - factor
|
604 |
+
output = upfirdn2d_native(
|
605 |
+
hidden_states,
|
606 |
+
kernel,
|
607 |
+
up=factor,
|
608 |
+
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
|
609 |
+
)
|
610 |
+
return output
|
611 |
+
|
612 |
+
|
613 |
+
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
614 |
+
r"""Downsample2D a batch of 2D images with the given filter.
|
615 |
+
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
616 |
+
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
|
617 |
+
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
|
618 |
+
shape is a multiple of the downsampling factor.
|
619 |
+
|
620 |
+
Args:
|
621 |
+
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
622 |
+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
|
623 |
+
(separable). The default is `[1] * factor`, which corresponds to average pooling.
|
624 |
+
factor: Integer downsampling factor (default: 2).
|
625 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
626 |
+
|
627 |
+
Returns:
|
628 |
+
output: Tensor of the shape `[N, C, H // factor, W // factor]`
|
629 |
+
"""
|
630 |
+
|
631 |
+
assert isinstance(factor, int) and factor >= 1
|
632 |
+
if kernel is None:
|
633 |
+
kernel = [1] * factor
|
634 |
+
|
635 |
+
kernel = paddle.to_tensor(kernel, dtype="float32")
|
636 |
+
if kernel.ndim == 1:
|
637 |
+
kernel = paddle.outer(kernel, kernel)
|
638 |
+
kernel /= paddle.sum(kernel)
|
639 |
+
|
640 |
+
kernel = kernel * gain
|
641 |
+
pad_value = kernel.shape[0] - factor
|
642 |
+
output = upfirdn2d_native(hidden_states, kernel, down=factor, pad=((pad_value + 1) // 2, pad_value // 2))
|
643 |
+
return output
|
644 |
+
|
645 |
+
|
646 |
+
def dummy_pad(tensor, up_x=0, up_y=0):
|
647 |
+
if up_x > 0:
|
648 |
+
tensor = paddle.concat(
|
649 |
+
[
|
650 |
+
tensor,
|
651 |
+
paddle.zeros(
|
652 |
+
[tensor.shape[0], tensor.shape[1], tensor.shape[2], tensor.shape[3], up_x, tensor.shape[5]],
|
653 |
+
dtype=tensor.dtype,
|
654 |
+
),
|
655 |
+
],
|
656 |
+
axis=4,
|
657 |
+
)
|
658 |
+
if up_y > 0:
|
659 |
+
tensor = paddle.concat(
|
660 |
+
[
|
661 |
+
tensor,
|
662 |
+
paddle.zeros(
|
663 |
+
[tensor.shape[0], tensor.shape[1], up_y, tensor.shape[3], tensor.shape[4], tensor.shape[5]],
|
664 |
+
dtype=tensor.dtype,
|
665 |
+
),
|
666 |
+
],
|
667 |
+
axis=2,
|
668 |
+
)
|
669 |
+
return tensor
|
670 |
+
|
671 |
+
|
672 |
+
def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
|
673 |
+
up_x = up_y = up
|
674 |
+
down_x = down_y = down
|
675 |
+
pad_x0 = pad_y0 = pad[0]
|
676 |
+
pad_x1 = pad_y1 = pad[1]
|
677 |
+
|
678 |
+
_, channel, in_h, in_w = tensor.shape
|
679 |
+
tensor = tensor.reshape([-1, in_h, in_w, 1])
|
680 |
+
|
681 |
+
_, in_h, in_w, minor = tensor.shape
|
682 |
+
kernel_h, kernel_w = kernel.shape
|
683 |
+
|
684 |
+
out = tensor.reshape([-1, in_h, 1, in_w, 1, minor])
|
685 |
+
# (TODO, junnyu F.pad bug)
|
686 |
+
# F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
687 |
+
out = dummy_pad(out, up_x - 1, up_y - 1)
|
688 |
+
out = out.reshape([-1, in_h * up_y, in_w * up_x, minor])
|
689 |
+
|
690 |
+
# (TODO, junnyu F.pad bug)
|
691 |
+
# out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
692 |
+
out = out.unsqueeze(0)
|
693 |
+
out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0), 0, 0], data_format="NDHWC")
|
694 |
+
out = out.squeeze(0)
|
695 |
+
|
696 |
+
out = out[
|
697 |
+
:,
|
698 |
+
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
699 |
+
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
700 |
+
:,
|
701 |
+
]
|
702 |
+
|
703 |
+
out = out.transpose([0, 3, 1, 2])
|
704 |
+
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
705 |
+
w = paddle.flip(kernel, [0, 1]).reshape([1, 1, kernel_h, kernel_w])
|
706 |
+
out = F.conv2d(out, w)
|
707 |
+
out = out.reshape(
|
708 |
+
[-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1]
|
709 |
+
)
|
710 |
+
out = out.transpose([0, 2, 3, 1])
|
711 |
+
out = out[:, ::down_y, ::down_x, :]
|
712 |
+
|
713 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
714 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
715 |
+
|
716 |
+
return out.reshape([-1, channel, out_h, out_w])
|