StoryVisualizationaTask

#193
by Anyou - opened
README.md DELETED
@@ -1,207 +0,0 @@
1
- ---
2
- license: creativeml-openrail-m
3
- tags:
4
- - stable-diffusion
5
- - stable-diffusion-diffusers
6
- - text-to-image
7
- inference: true
8
- extra_gated_prompt: |-
9
- This model is open access and available to all, with a CreativeML OpenRAIL-M license further specifying rights and usage.
10
- The CreativeML OpenRAIL License specifies:
11
-
12
- 1. You can't use the model to deliberately produce nor share illegal or harmful outputs or content
13
- 2. CompVis claims no rights on the outputs you generate, you are free to use them and are accountable for their use which must not go against the provisions set in the license
14
- 3. You may re-distribute the weights and use the model commercially and/or as a service. If you do, please be aware you have to include the same use restrictions as the ones in the license and share a copy of the CreativeML OpenRAIL-M to all your users (please read the license entirely and carefully)
15
- Please read the full license carefully here: https://huggingface.co/spaces/CompVis/stable-diffusion-license
16
-
17
- extra_gated_heading: Please read the LICENSE to access this model
18
- ---
19
-
20
- # Stable Diffusion v1-5 Model Card
21
-
22
- Stable Diffusion is a latent text-to-image diffusion model capable of generating photo-realistic images given any text input.
23
- For more information about how Stable Diffusion functions, please have a look at [🤗's Stable Diffusion blog](https://huggingface.co/blog/stable_diffusion).
24
-
25
- The **Stable-Diffusion-v1-5** checkpoint was initialized with the weights of the [Stable-Diffusion-v1-2](https:/steps/huggingface.co/CompVis/stable-diffusion-v1-2)
26
- checkpoint and subsequently fine-tuned on 595k steps at resolution 512x512 on "laion-aesthetics v2 5+" and 10% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
27
-
28
- You can use this both with the [🧨Diffusers library](https://github.com/huggingface/diffusers) and the [RunwayML GitHub repository](https://github.com/runwayml/stable-diffusion).
29
-
30
- ### Diffusers
31
- ```py
32
- from diffusers import StableDiffusionPipeline
33
- import torch
34
-
35
- model_id = "runwayml/stable-diffusion-v1-5"
36
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
37
- pipe = pipe.to("cuda")
38
-
39
- prompt = "a photo of an astronaut riding a horse on mars"
40
- image = pipe(prompt).images[0]
41
-
42
- image.save("astronaut_rides_horse.png")
43
- ```
44
- For more detailed instructions, use-cases and examples in JAX follow the instructions [here](https://github.com/huggingface/diffusers#text-to-image-generation-with-stable-diffusion)
45
-
46
- ### Original GitHub Repository
47
-
48
- 1. Download the weights
49
- - [v1-5-pruned-emaonly.ckpt](https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt) - 4.27GB, ema-only weight. uses less VRAM - suitable for inference
50
- - [v1-5-pruned.ckpt](https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned.ckpt) - 7.7GB, ema+non-ema weights. uses more VRAM - suitable for fine-tuning
51
-
52
- 2. Follow instructions [here](https://github.com/runwayml/stable-diffusion).
53
-
54
- ## Model Details
55
- - **Developed by:** Robin Rombach, Patrick Esser
56
- - **Model type:** Diffusion-based text-to-image generation model
57
- - **Language(s):** English
58
- - **License:** [The CreativeML OpenRAIL M license](https://huggingface.co/spaces/CompVis/stable-diffusion-license) is an [Open RAIL M license](https://www.licenses.ai/blog/2022/8/18/naming-convention-of-responsible-ai-licenses), adapted from the work that [BigScience](https://bigscience.huggingface.co/) and [the RAIL Initiative](https://www.licenses.ai/) are jointly carrying in the area of responsible AI licensing. See also [the article about the BLOOM Open RAIL license](https://bigscience.huggingface.co/blog/the-bigscience-rail-license) on which our license is based.
59
- - **Model Description:** This is a model that can be used to generate and modify images based on text prompts. It is a [Latent Diffusion Model](https://arxiv.org/abs/2112.10752) that uses a fixed, pretrained text encoder ([CLIP ViT-L/14](https://arxiv.org/abs/2103.00020)) as suggested in the [Imagen paper](https://arxiv.org/abs/2205.11487).
60
- - **Resources for more information:** [GitHub Repository](https://github.com/CompVis/stable-diffusion), [Paper](https://arxiv.org/abs/2112.10752).
61
- - **Cite as:**
62
-
63
- @InProceedings{Rombach_2022_CVPR,
64
- author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
65
- title = {High-Resolution Image Synthesis With Latent Diffusion Models},
66
- booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
67
- month = {June},
68
- year = {2022},
69
- pages = {10684-10695}
70
- }
71
-
72
- # Uses
73
-
74
- ## Direct Use
75
- The model is intended for research purposes only. Possible research areas and
76
- tasks include
77
-
78
- - Safe deployment of models which have the potential to generate harmful content.
79
- - Probing and understanding the limitations and biases of generative models.
80
- - Generation of artworks and use in design and other artistic processes.
81
- - Applications in educational or creative tools.
82
- - Research on generative models.
83
-
84
- Excluded uses are described below.
85
-
86
- ### Misuse, Malicious Use, and Out-of-Scope Use
87
- _Note: This section is taken from the [DALLE-MINI model card](https://huggingface.co/dalle-mini/dalle-mini), but applies in the same way to Stable Diffusion v1_.
88
-
89
-
90
- The model should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
91
-
92
- #### Out-of-Scope Use
93
- The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model.
94
-
95
- #### Misuse and Malicious Use
96
- Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to:
97
-
98
- - Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc.
99
- - Intentionally promoting or propagating discriminatory content or harmful stereotypes.
100
- - Impersonating individuals without their consent.
101
- - Sexual content without consent of the people who might see it.
102
- - Mis- and disinformation
103
- - Representations of egregious violence and gore
104
- - Sharing of copyrighted or licensed material in violation of its terms of use.
105
- - Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use.
106
-
107
- ## Limitations and Bias
108
-
109
- ### Limitations
110
-
111
- - The model does not achieve perfect photorealism
112
- - The model cannot render legible text
113
- - The model does not perform well on more difficult tasks which involve compositionality, such as rendering an image corresponding to “A red cube on top of a blue sphere”
114
- - Faces and people in general may not be generated properly.
115
- - The model was trained mainly with English captions and will not work as well in other languages.
116
- - The autoencoding part of the model is lossy
117
- - The model was trained on a large-scale dataset
118
- [LAION-5B](https://laion.ai/blog/laion-5b/) which contains adult material
119
- and is not fit for product use without additional safety mechanisms and
120
- considerations.
121
- - No additional measures were used to deduplicate the dataset. As a result, we observe some degree of memorization for images that are duplicated in the training data.
122
- The training data can be searched at [https://rom1504.github.io/clip-retrieval/](https://rom1504.github.io/clip-retrieval/) to possibly assist in the detection of memorized images.
123
-
124
- ### Bias
125
-
126
- While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases.
127
- Stable Diffusion v1 was trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/),
128
- which consists of images that are primarily limited to English descriptions.
129
- Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for.
130
- This affects the overall output of the model, as white and western cultures are often set as the default. Further, the
131
- ability of the model to generate content with non-English prompts is significantly worse than with English-language prompts.
132
-
133
- ### Safety Module
134
-
135
- The intended use of this model is with the [Safety Checker](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) in Diffusers.
136
- This checker works by checking model outputs against known hard-coded NSFW concepts.
137
- The concepts are intentionally hidden to reduce the likelihood of reverse-engineering this filter.
138
- Specifically, the checker compares the class probability of harmful concepts in the embedding space of the `CLIPTextModel` *after generation* of the images.
139
- The concepts are passed into the model with the generated image and compared to a hand-engineered weight for each NSFW concept.
140
-
141
-
142
- ## Training
143
-
144
- **Training Data**
145
- The model developers used the following dataset for training the model:
146
-
147
- - LAION-2B (en) and subsets thereof (see next section)
148
-
149
- **Training Procedure**
150
- Stable Diffusion v1-5 is a latent diffusion model which combines an autoencoder with a diffusion model that is trained in the latent space of the autoencoder. During training,
151
-
152
- - Images are encoded through an encoder, which turns images into latent representations. The autoencoder uses a relative downsampling factor of 8 and maps images of shape H x W x 3 to latents of shape H/f x W/f x 4
153
- - Text prompts are encoded through a ViT-L/14 text-encoder.
154
- - The non-pooled output of the text encoder is fed into the UNet backbone of the latent diffusion model via cross-attention.
155
- - The loss is a reconstruction objective between the noise that was added to the latent and the prediction made by the UNet.
156
-
157
- Currently six Stable Diffusion checkpoints are provided, which were trained as follows.
158
- - [`stable-diffusion-v1-1`](https://huggingface.co/CompVis/stable-diffusion-v1-1): 237,000 steps at resolution `256x256` on [laion2B-en](https://huggingface.co/datasets/laion/laion2B-en).
159
- 194,000 steps at resolution `512x512` on [laion-high-resolution](https://huggingface.co/datasets/laion/laion-high-resolution) (170M examples from LAION-5B with resolution `>= 1024x1024`).
160
- - [`stable-diffusion-v1-2`](https://huggingface.co/CompVis/stable-diffusion-v1-2): Resumed from `stable-diffusion-v1-1`.
161
- 515,000 steps at resolution `512x512` on "laion-improved-aesthetics" (a subset of laion2B-en,
162
- filtered to images with an original size `>= 512x512`, estimated aesthetics score `> 5.0`, and an estimated watermark probability `< 0.5`. The watermark estimate is from the LAION-5B metadata, the aesthetics score is estimated using an [improved aesthetics estimator](https://github.com/christophschuhmann/improved-aesthetic-predictor)).
163
- - [`stable-diffusion-v1-3`](https://huggingface.co/CompVis/stable-diffusion-v1-3): Resumed from `stable-diffusion-v1-2` - 195,000 steps at resolution `512x512` on "laion-improved-aesthetics" and 10 % dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
164
- - [`stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) Resumed from `stable-diffusion-v1-2` - 225,000 steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10 % dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
165
- - [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) Resumed from `stable-diffusion-v1-2` - 595,000 steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10 % dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
166
- - [`stable-diffusion-inpainting`](https://huggingface.co/runwayml/stable-diffusion-inpainting) Resumed from `stable-diffusion-v1-5` - then 440,000 steps of inpainting training at resolution 512x512 on “laion-aesthetics v2 5+” and 10% dropping of the text-conditioning. For inpainting, the UNet has 5 additional input channels (4 for the encoded masked-image and 1 for the mask itself) whose weights were zero-initialized after restoring the non-inpainting checkpoint. During training, we generate synthetic masks and in 25% mask everything.
167
-
168
- - **Hardware:** 32 x 8 x A100 GPUs
169
- - **Optimizer:** AdamW
170
- - **Gradient Accumulations**: 2
171
- - **Batch:** 32 x 8 x 2 x 4 = 2048
172
- - **Learning rate:** warmup to 0.0001 for 10,000 steps and then kept constant
173
-
174
- ## Evaluation Results
175
- Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0,
176
- 5.0, 6.0, 7.0, 8.0) and 50 PNDM/PLMS sampling
177
- steps show the relative improvements of the checkpoints:
178
-
179
- ![pareto](https://huggingface.co/CompVis/stable-diffusion/resolve/main/v1-1-to-v1-5.png)
180
-
181
- Evaluated using 50 PLMS steps and 10000 random prompts from the COCO2017 validation set, evaluated at 512x512 resolution. Not optimized for FID scores.
182
- ## Environmental Impact
183
-
184
- **Stable Diffusion v1** **Estimated Emissions**
185
- Based on that information, we estimate the following CO2 emissions using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). The hardware, runtime, cloud provider, and compute region were utilized to estimate the carbon impact.
186
-
187
- - **Hardware Type:** A100 PCIe 40GB
188
- - **Hours used:** 150000
189
- - **Cloud Provider:** AWS
190
- - **Compute Region:** US-east
191
- - **Carbon Emitted (Power consumption x Time x Carbon produced based on location of power grid):** 11250 kg CO2 eq.
192
-
193
-
194
- ## Citation
195
-
196
- ```bibtex
197
- @InProceedings{Rombach_2022_CVPR,
198
- author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
199
- title = {High-Resolution Image Synthesis With Latent Diffusion Models},
200
- booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
201
- month = {June},
202
- year = {2022},
203
- pages = {10684-10695}
204
- }
205
- ```
206
-
207
- *This model card was written by: Robin Rombach and Patrick Esser and is based on the [DALL-E Mini model card](https://huggingface.co/dalle-mini/dalle-mini).*
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__init__.py ADDED
Binary file (2 Bytes). View file
 
config.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # device
2
+ mode: sample # train sample
3
+ gpu_ids: [3] # gpu ids
4
+ batch_size: 1 # batch size each item denotes one story
5
+ num_workers: 4 # number of workers
6
+ num_cpu_cores: -1 # number of cpu cores
7
+ seed: 0 # random seed
8
+ ckpt_dir: /root/lihui/StoryVisualization/save_ckpt_epoch5_new # checkpoint directory
9
+ run_name: ARLDM # name for this run
10
+
11
+ # task
12
+ dataset: pororo # pororo flintstones vistsis vistdii
13
+ task: visualization # continuation visualization
14
+
15
+ # train
16
+ init_lr: 1e-5 # initial learning rate
17
+ warmup_epochs: 1 # warmup epochs
18
+ max_epochs: 5 #50 # max epochs
19
+ train_model_file: /root/lihui/StoryVisualization/save_ckpt_3last50/ARLDM/last.ckpt # model file for resume, none for train from scratch
20
+ freeze_clip: True #False # whether to freeze clip
21
+ freeze_blip: True #False # whether to freeze blip
22
+ freeze_resnet: True #False # whether to freeze resnet
23
+
24
+ # sample
25
+ test_model_file: /root/lihui/StoryVisualization/save_ckpt_3last50/ARLDM/last.ckpt # model file for test
26
+ calculate_fid: True # whether to calculate FID scores
27
+ scheduler: ddim # ddim pndm
28
+ guidance_scale: 6 # guidance scale
29
+ num_inference_steps: 250 # number of inference steps
30
+ sample_output_dir: /root/lihui/StoryVisualization/save_samples_128_epoch50 # output directory
31
+
32
+ pororo:
33
+ hdf5_file: /root/lihui/StoryVisualization/pororo.h5
34
+ max_length: 85
35
+ new_tokens: [ "pororo", "loopy", "eddy", "harry", "poby", "tongtong", "crong", "rody", "petty" ]
36
+ clip_embedding_tokens: 49416
37
+ blip_embedding_tokens: 30530
38
+
39
+ flintstones:
40
+ hdf5_file: /path/to/flintstones.h5
41
+ max_length: 91
42
+ new_tokens: [ "fred", "barney", "wilma", "betty", "pebbles", "dino", "slate" ]
43
+ clip_embedding_tokens: 49412
44
+ blip_embedding_tokens: 30525
45
+
46
+ vistsis:
47
+ hdf5_file: /path/to/vist.h5
48
+ max_length: 100
49
+ clip_embedding_tokens: 49408
50
+ blip_embedding_tokens: 30524
51
+
52
+ vistdii:
53
+ hdf5_file: /path/to/vist.h5
54
+ max_length: 65
55
+ clip_embedding_tokens: 49408
56
+ blip_embedding_tokens: 30524
57
+
58
+ hydra:
59
+ run:
60
+ dir: .
61
+ output_subdir: null
62
+ hydra/job_logging: disabled
63
+ hydra/hydra_logging: disabled
data_script/flintstones_hdf5.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import pickle
5
+
6
+ import cv2
7
+ import h5py
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+
11
+
12
+ def main(args):
13
+ splits = json.load(open(os.path.join(args.data_dir, 'train-val-test_split.json'), 'r'))
14
+ train_ids, val_ids, test_ids = splits["train"], splits["val"], splits["test"]
15
+ followings = pickle.load(open(os.path.join(args.data_dir, 'following_cache4.pkl'), 'rb'))
16
+ annotations = json.load(open(os.path.join(args.data_dir, 'flintstones_annotations_v1-0.json')))
17
+ descriptions = dict()
18
+ for sample in annotations:
19
+ descriptions[sample["globalID"]] = sample["description"]
20
+
21
+ f = h5py.File(args.save_path, "w")
22
+ for subset, ids in {'train': train_ids, 'val': val_ids, 'test': test_ids}.items():
23
+ ids = [i for i in ids if i in followings and len(followings[i]) == 4]
24
+ length = len(ids)
25
+
26
+ group = f.create_group(subset)
27
+ images = list()
28
+ for i in range(5):
29
+ images.append(
30
+ group.create_dataset('image{}'.format(i), (length,), dtype=h5py.vlen_dtype(np.dtype('uint8'))))
31
+ text = group.create_dataset('text', (length,), dtype=h5py.string_dtype(encoding='utf-8'))
32
+ for i, item in enumerate(tqdm(ids, leave=True, desc="saveh5")):
33
+ globalIDs = [item] + followings[item]
34
+ txt = list()
35
+ for j, globalID in enumerate(globalIDs):
36
+ img = np.load(os.path.join(args.data_dir, 'video_frames_sampled', '{}.npy'.format(globalID)))
37
+ img = np.concatenate(img, axis=0).astype(np.uint8)
38
+ img = cv2.imencode('.png', img)[1].tobytes()
39
+ img = np.frombuffer(img, np.uint8)
40
+ images[j][i] = img
41
+ txt.append(descriptions[globalID])
42
+ text[i] = '|'.join([t.replace('\n', '').replace('\t', '').strip() for t in txt])
43
+ f.close()
44
+
45
+
46
+ if __name__ == '__main__':
47
+ parser = argparse.ArgumentParser(description='arguments for flintstones hdf5 file saving')
48
+ parser.add_argument('--data_dir', type=str, required=True, help='flintstones data directory')
49
+ parser.add_argument('--save_path', type=str, required=True, help='path to save hdf5')
50
+ args = parser.parse_args()
51
+ main(args)
data_script/pororo_hdf5.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import cv2
5
+ import h5py
6
+ import numpy as np
7
+ from PIL import Image
8
+ from tqdm import tqdm
9
+
10
+
11
+ def main(args):
12
+ # 使用numpy库的load函数来加载名为descriptions.npy的文件。该文件是一个Python字典对象,因此我们使用item()方法将其转换为字典对象。
13
+ # ——os.path.join函数用于连接文件路径
14
+ # ——args.data_dir作为基础目录,将'descriptions.npy'添加到该目录中
15
+ # ——指定allow_pickle=True,表示允许加载包含Python对象的文件
16
+ # ——指定encoding='latin1',表示使用拉丁字符编码加载该文件
17
+ descriptions = np.load(os.path.join(args.data_dir, 'descriptions.npy'), allow_pickle=True, encoding='latin1').item()
18
+ # imgs_list包含一组图像文件的路径,
19
+ # followings_list包含每个图像的一些附加信息
20
+ imgs_list = np.load(os.path.join(args.data_dir, 'img_cache4.npy'), encoding='latin1')
21
+ followings_list = np.load(os.path.join(args.data_dir, 'following_cache4.npy'))
22
+ # 使用numpy库的load函数来加载名为train_seen_unseen_ids.npy的文件
23
+ # 该文件包含三个numpy数组:train_ids、val_ids和test_ids,分别代表训练集、验证集和测试集的ID列表。
24
+ # 使用元组来一次性加载这三个数组,并将它们赋值给相应的变量。
25
+ train_ids, val_ids, test_ids = np.load(os.path.join(args.data_dir, 'train_seen_unseen_ids.npy'), allow_pickle=True)
26
+ # 按照ID的顺序逐一排序
27
+ train_ids = np.sort(train_ids)
28
+ val_ids = np.sort(val_ids)
29
+ test_ids = np.sort(test_ids)
30
+
31
+ # 创建一个新的HDF5文件,并指定文件名为args.save_path。
32
+ # 使用h5py库的File函数来创建文件对象,指定打开方式为写模式("w")。
33
+ # 在这个文件中存储处理后的图像和文本数据。
34
+ f = h5py.File(args.save_path, "w")
35
+ for subset, ids in {'train': train_ids, 'val': val_ids, 'test': test_ids}.items():
36
+ length = len(ids)
37
+
38
+ # 为每个数据集(train、val和test)创建一个组
39
+ # 针对每个数据集都创建了5个数据集,名为'image0'、'image1'、'image2'、'image3'、'image4',分别对应于当前图像及其相关联的4个图像。
40
+ # 目的:将每个图像及其相关联的图像数据保存到同一个HDF5文件中,并按照一定的组织方式存储,方便后续的数据读取和处理。
41
+ group = f.create_group(subset)
42
+ # 创建一个长度为ids列表长度的空列表images,按照image0-4顺序添加了5个HDF5数据集对象
43
+ images = list()
44
+ # 为当前数据集中的每个图像创建了五个数据集。
45
+ # 每个数据集都使用vlen_dtype(np.dtype('uint8'))作为数据类型,并将其添加到当前组group中。
46
+ # ——vlen_dtype(np.dtype('uint8'))表示可变长度的无符号8位整数数组。
47
+ for i in range(5):
48
+ images.append(
49
+ group.create_dataset('image{}'.format(i), (length,), dtype=h5py.vlen_dtype(np.dtype('uint8'))))
50
+ # 创建一个数据集text,用于存储与当前数据集中图像相关的文本描述。该数据集的数据类型为字符串,编码方式为utf-8,并将其添加到当前组group中。
51
+ text = group.create_dataset('text', (length,), dtype=h5py.string_dtype(encoding='utf-8'))
52
+ # 遍历当前数据集中的每个图像,并将相关数据保存到HDF5文件中
53
+ for i, item in enumerate(tqdm(ids, leave=True, desc="saveh5")):
54
+ # 获取与当前图像相关的所有图像的路径,存储到列表img_paths中。
55
+ # ——imgs_list是一个字典,存储了所有图像的路径
56
+ # ——followings_list是一个字典,存储了与每个图像相关的四张图像的路径
57
+ img_paths = [str(imgs_list[item])[2:-1]] + [str(followings_list[item][i])[2:-1] for i in range(4)]
58
+ # 打开img_paths列表中的每个图像,并将其转换为RGB格式的PIL图像对象。
59
+ imgs = [Image.open(os.path.join(args.data_dir, img_path)).convert('RGB') for img_path in img_paths]
60
+ # 将每个PIL图像对象转换为numpy数组
61
+ for j, img in enumerate(imgs):
62
+ img = np.array(img).astype(np.uint8)
63
+ # 使用OpenCV将其编码为png格式的二进制数据
64
+ img = cv2.imencode('.png', img)[1].tobytes()
65
+ # 将该二进制数据转换为numpy数组
66
+ img = np.frombuffer(img, np.uint8)
67
+ # 将其存储到images列表中与当前图像相关的数据集中
68
+ images[j][i] = img
69
+ # 获取与当前图像相关的所有图像的文件名,并将其存储到列表tgt_img_ids中
70
+ tgt_img_ids = [str(img_path).replace('.png', '') for img_path in img_paths]
71
+ # 根据目标图像的文件名,获取其对应的文本描述,并将其存储到列表txt中。
72
+ txt = [descriptions[tgt_img_id][0] for tgt_img_id in tgt_img_ids]
73
+ # 将txt列表中的所有文本描述合并为一个字符串,并将其中的"\n"、"\t"等无关字符替换为空格。然后,将该字符串存储到数据集text中
74
+ text[i] = '|'.join([t.replace('\n', '').replace('\t', '').strip() for t in txt])
75
+ f.close()
76
+
77
+
78
+ if __name__ == '__main__':
79
+ parser = argparse.ArgumentParser(description='arguments for flintstones pororo file saving')
80
+ parser.add_argument('--data_dir', type=str, required=True, help='pororo data directory')
81
+ parser.add_argument('--save_path', type=str, required=True, help='path to save hdf5')
82
+ args = parser.parse_args()
83
+ main(args)
data_script/vist_hdf5.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ import cv2
6
+ import h5py
7
+ import numpy as np
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+
11
+
12
+ def main(args):
13
+ train_data = json.load(open(os.path.join(args.sis_json_dir, 'train.story-in-sequence.json')))
14
+ val_data = json.load(open(os.path.join(args.sis_json_dir, 'val.story-in-sequence.json')))
15
+ test_data = json.load(open(os.path.join(args.sis_json_dir, 'test.story-in-sequence.json')))
16
+
17
+ prefix = ["train", "val", "test"]
18
+ whole_album = {}
19
+ for i, data in enumerate([train_data, val_data, test_data]):
20
+ album_mapping = {}
21
+ for annot_new in data["annotations"]:
22
+ annot = annot_new[0]
23
+ assert len(annot_new) == 1
24
+ if annot['story_id'] not in album_mapping:
25
+ album_mapping[annot['story_id']] = {"flickr_id": [annot['photo_flickr_id']],
26
+ "sis": [annot['original_text']],
27
+ "length": 1}
28
+ else:
29
+ album_mapping[annot['story_id']]["flickr_id"].append(annot['photo_flickr_id'])
30
+ album_mapping[annot['story_id']]["sis"].append(
31
+ annot['original_text'])
32
+ album_mapping[annot['story_id']]["length"] += 1
33
+ whole_album[prefix[i]] = album_mapping
34
+
35
+ for p in prefix:
36
+ deletables = []
37
+ for story_id, story in whole_album[p].items():
38
+ if story['length'] != 5:
39
+ print("deleting {}".format(story_id))
40
+ deletables.append(story_id)
41
+ continue
42
+ d = [os.path.exists(os.path.join(args.img_dir, "{}.jpg".format(_))) for _ in story["flickr_id"]]
43
+ if sum(d) < 5:
44
+ print("deleting {}".format(story_id))
45
+ deletables.append(story_id)
46
+ else:
47
+ pass
48
+ for i in deletables:
49
+ del whole_album[p][i]
50
+
51
+ train_data = json.load(open(os.path.join(args.sis_json_dir, 'train.description-in-isolation.json')))
52
+ val_data = json.load(open(os.path.join(args.sis_json_dir, 'val.description-in-isolation.json')))
53
+ test_data = json.load(open(os.path.join(args.sis_json_dir, 'test.description-in-isolation.json')))
54
+
55
+ flickr_id2text = {}
56
+ for i, data in enumerate([train_data, val_data, test_data]):
57
+ for l in data['annotations']:
58
+ assert len(l) == 1
59
+ if l[0]['photo_flickr_id'] in flickr_id2text:
60
+ flickr_id2text[l[0]['photo_flickr_id']] = \
61
+ max([flickr_id2text[l[0]['photo_flickr_id']], l[0]['original_text']], key=len)
62
+ else:
63
+ flickr_id2text[l[0]['photo_flickr_id']] = l[0]['original_text']
64
+
65
+ for p in prefix:
66
+ deletables = []
67
+ for story_id, story in whole_album[p].items():
68
+ story['dii'] = []
69
+ for i, flickr_id in enumerate(story['flickr_id']):
70
+ if flickr_id not in flickr_id2text:
71
+ print("{} not found in story {}".format(flickr_id, story_id))
72
+ deletables.append(story_id)
73
+ break
74
+ story['dii'].append(flickr_id2text[flickr_id])
75
+ for i in deletables:
76
+ del whole_album[p][i]
77
+
78
+ f = h5py.File(args.save_path, "w")
79
+ for p in prefix:
80
+ group = f.create_group(p)
81
+ story_dict = whole_album[p]
82
+ length = len(story_dict)
83
+ images = list()
84
+ for i in range(5):
85
+ images.append(
86
+ group.create_dataset('image{}'.format(i), (length,), dtype=h5py.vlen_dtype(np.dtype('uint8'))))
87
+ sis = group.create_dataset('sis', (length,), dtype=h5py.string_dtype(encoding='utf-8'))
88
+ dii = group.create_dataset('dii', (length,), dtype=h5py.string_dtype(encoding='utf-8'))
89
+ for i, (story_id, story) in enumerate(tqdm(story_dict.items(), leave=True, desc="saveh5")):
90
+ imgs = [Image.open('{}/{}.jpg'.format(args.img_dir, flickr_id)).convert('RGB') for flickr_id in
91
+ story['flickr_id']]
92
+ for j, img in enumerate(imgs):
93
+ img = np.array(img).astype(np.uint8)
94
+ img = cv2.imencode('.png', img)[1].tobytes()
95
+ img = np.frombuffer(img, np.uint8)
96
+ images[j][i] = img
97
+ sis[i] = '|'.join([t.replace('\n', '').replace('\t', '').strip() for t in story['sis']])
98
+ txt_dii = [t.replace('\n', '').replace('\t', '').strip() for t in story['dii']]
99
+ txt_dii = sorted(set(txt_dii), key=txt_dii.index)
100
+ dii[i] = '|'.join(txt_dii)
101
+ f.close()
102
+
103
+
104
+ if __name__ == '__main__':
105
+ parser = argparse.ArgumentParser(description='arguments for vist hdf5 file saving')
106
+ parser.add_argument('--sis_json_dir', type=str, required=True, help='sis json file directory')
107
+ parser.add_argument('--dii_json_dir', type=str, required=True, help='dii json file directory')
108
+ parser.add_argument('--img_dir', type=str, required=True, help='json file directory')
109
+ parser.add_argument('--save_path', type=str, required=True, help='path to save hdf5')
110
+ args = parser.parse_args()
111
+ main(args)
data_script/vist_img_download.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import requests
3
+ from io import BytesIO
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+ from multiprocessing import Process
7
+ import os
8
+ import argparse
9
+
10
+
11
+ def download_subprocess(dii, save_dir):
12
+ for image in tqdm(dii):
13
+ key, value = image.popitem()
14
+ try:
15
+ img_data = requests.get(value).content
16
+ img = Image.open(BytesIO(img_data)).convert('RGB')
17
+ h = img.size[0]
18
+ w = img.size[1]
19
+ if min(h, w) > 512:
20
+ img = img.resize((int(h / (w / 512)), 512) if h > w else (512, int(w / (h / 512))))
21
+ img.save('{}/{}.jpg'.format(save_dir, key))
22
+ except:
23
+ print(key, value)
24
+
25
+
26
+ def main(args):
27
+ train_data = json.load(open(os.path.join(args.json_dir, 'train.description-in-isolation.json')))
28
+ val_data = json.load(open(os.path.join(args.json_dir, 'val.description-in-isolation.json')))
29
+ test_data = json.load(open(os.path.join(args.json_dir, 'test.description-in-isolation.json')))
30
+ dii = []
31
+ for subset in [train_data, val_data, test_data]:
32
+ for image in subset["images"]:
33
+ try:
34
+ dii.append({image['id']: image['url_o']})
35
+ except:
36
+ dii.append({image['id']: image['url_m']})
37
+
38
+ dii = [image for image in dii if not os.path.exists('{}/{}.jpg'.format(args.save_dir, list(image)[0]))]
39
+ print('total images: {}'.format(len(dii)))
40
+
41
+ def splitlist(inlist, chunksize):
42
+ return [inlist[x:x + chunksize] for x in range(0, len(inlist), chunksize)]
43
+
44
+ dii_splitted = splitlist(dii, int((len(dii) / args.num_process)))
45
+ process_list = []
46
+ for dii_sub_list in dii_splitted:
47
+ p = Process(target=download_subprocess, args=(dii_sub_list,))
48
+ process_list.append(p)
49
+ p.Daemon = True
50
+ p.start()
51
+ for p in process_list:
52
+ p.join()
53
+
54
+
55
+ if __name__ == "__main__":
56
+ parser = argparse.ArgumentParser(description='arguments for vist images downloading')
57
+ parser.add_argument('--json_dir', type=str, required=True, help='dii json file directory')
58
+ parser.add_argument('--img_dir', type=str, required=True, help='images saving directory')
59
+ parser.add_argument('--num_process', type=int, default=32)
60
+ args = parser.parse_args()
61
+ main(args)
datasets/flintstones.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import cv2
4
+ import h5py
5
+ import numpy as np
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ from torchvision import transforms
9
+ from transformers import CLIPTokenizer
10
+
11
+ from models.blip_override.blip import init_tokenizer
12
+
13
+
14
+ class StoryDataset(Dataset):
15
+ """
16
+ A custom subset class for the LRW (includes train, val, test) subset
17
+ """
18
+
19
+ def __init__(self, subset, args):
20
+ super(StoryDataset, self).__init__()
21
+ self.args = args
22
+
23
+ self.h5_file = args.get(args.dataset).hdf5_file
24
+ self.subset = subset
25
+
26
+ self.augment = transforms.Compose([
27
+ transforms.ToPILImage(),
28
+ transforms.Resize([512, 512]),
29
+ transforms.ToTensor(),
30
+ transforms.Normalize([0.5], [0.5])
31
+ ])
32
+ self.dataset = args.dataset
33
+ self.max_length = args.get(args.dataset).max_length
34
+ self.clip_tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer")
35
+ self.blip_tokenizer = init_tokenizer()
36
+ msg = self.clip_tokenizer.add_tokens(list(args.get(args.dataset).new_tokens))
37
+ print("clip {} new tokens added".format(msg))
38
+ msg = self.blip_tokenizer.add_tokens(list(args.get(args.dataset).new_tokens))
39
+ print("blip {} new tokens added".format(msg))
40
+
41
+ self.blip_image_processor = transforms.Compose([
42
+ transforms.ToPILImage(),
43
+ transforms.Resize([224, 224]),
44
+ transforms.ToTensor(),
45
+ transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
46
+ ])
47
+
48
+ def open_h5(self):
49
+ h5 = h5py.File(self.h5_file, "r")
50
+ self.h5 = h5[self.subset]
51
+
52
+ def __getitem__(self, index):
53
+ if not hasattr(self, 'h5'):
54
+ self.open_h5()
55
+
56
+ images = list()
57
+ for i in range(5):
58
+ im = self.h5['image{}'.format(i)][index]
59
+ im = cv2.imdecode(im, cv2.IMREAD_COLOR)
60
+ idx = random.randint(0, 4)
61
+ images.append(im[idx * 128: (idx + 1) * 128])
62
+
63
+ source_images = torch.stack([self.blip_image_processor(im) for im in images])
64
+ images = images[1:] if self.args.task == 'continuation' else images
65
+ images = torch.stack([self.augment(im) for im in images]) \
66
+ if self.subset in ['train', 'val'] else torch.from_numpy(np.array(images)).permute(0, 3, 1, 2)
67
+
68
+ texts = self.h5['text'][index].decode('utf-8').split('|')
69
+
70
+ # tokenize caption using default tokenizer
71
+ tokenized = self.clip_tokenizer(
72
+ texts[1:] if self.args.task == 'continuation' else texts,
73
+ padding="max_length",
74
+ max_length=self.max_length,
75
+ truncation=False,
76
+ return_tensors="pt",
77
+ )
78
+ captions, attention_mask = tokenized['input_ids'], tokenized['attention_mask']
79
+
80
+ tokenized = self.blip_tokenizer(
81
+ texts,
82
+ padding="max_length",
83
+ max_length=self.max_length,
84
+ truncation=False,
85
+ return_tensors="pt",
86
+ )
87
+ source_caption, source_attention_mask = tokenized['input_ids'], tokenized['attention_mask']
88
+ return images, captions, attention_mask, source_images, source_caption, source_attention_mask
89
+
90
+ def __len__(self):
91
+ if not hasattr(self, 'h5'):
92
+ self.open_h5()
93
+ return len(self.h5['text'])
datasets/pororo.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import random
4
+ from PIL import Image
5
+ import cv2
6
+ import h5py
7
+ import numpy as np
8
+ import torch
9
+ from torch.utils.data import Dataset
10
+ from torchvision import transforms
11
+ from transformers import CLIPTokenizer
12
+
13
+ from models.blip_override.blip import init_tokenizer
14
+
15
+
16
+ class StoryDataset(Dataset):
17
+ """
18
+ A custom subset class for the LRW (includes train, val, test) subset
19
+ """
20
+ # StoryDataset 类的构造函数
21
+ def __init__(self, subset, args):
22
+ # 用来调用父类 Dataset 的初始化函数,确保该类能够继承 Dataset 类的所有方法和属性。
23
+ super(StoryDataset, self).__init__()
24
+ # args 则是该类的其他参数,是一个命名空间(namespace)对象
25
+ self.args = args
26
+ # 一个 HDF5 文件的路径,存储了训练、验证和测试集的图像和文本数据。
27
+ # ——args.get(args.dataset)表示从命名空间对象args中获取指定数据集(训练集、验证集或测试集)的参数。
28
+ self.h5_file = args.get(args.dataset).hdf5_file
29
+ # 初始化函数中 subset 表示要读取的子集的类型(如训练集、验证集、测试集)
30
+ self.subset = subset
31
+
32
+ # 一个图像变换函数序列(transform),用来对图像进行预处理,包括将图像转化为 PIL 格式,调整图像大小,将图像转换为 Tensor,并进行归一化。
33
+ self.augment = transforms.Compose([
34
+ transforms.ToPILImage(),
35
+ # transforms.Resize([256, 256]),
36
+ transforms.Resize([512, 512]),
37
+ transforms.ToTensor(),
38
+ transforms.Normalize([0.5], [0.5])
39
+ ])
40
+ # 表示当前数据集的类型(训练集、验证集或测试集)
41
+ self.dataset = args.dataset
42
+ # 最大的 caption 长度,在进行tokenize操作时,caption中的单词数量将被填充到该长度。
43
+ self.max_length = args.get(args.dataset).max_length
44
+ # 一个使用CLIP模型进行tokenize的tokenizer
45
+ self.clip_tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer")
46
+ # 一个自定义的tokenizer,用于处理文本输入
47
+ self.blip_tokenizer = init_tokenizer()
48
+ msg = self.clip_tokenizer.add_tokens(list(args.get(args.dataset).new_tokens))
49
+ print("clip {} new tokens added".format(msg))
50
+ msg = self.blip_tokenizer.add_tokens(list(args.get(args.dataset).new_tokens))
51
+ print("blip {} new tokens added".format(msg))
52
+
53
+ # 一个用于对输入的图像进行处理的函数序列,包括转换为PIL图像、重置图像大小、转换为tensor、归一化等。
54
+ self.blip_image_processor = transforms.Compose([
55
+ transforms.ToPILImage(),
56
+ transforms.Resize([224, 224]),
57
+ transforms.ToTensor(),
58
+ transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
59
+ ])
60
+
61
+ # 打开与数据集对应的h5文件
62
+ def open_h5(self):
63
+ h5 = h5py.File(self.h5_file, "r")
64
+ self.h5 = h5[self.subset]
65
+
66
+ # 用于按索引获取数据。
67
+
68
+ # 对于每个图像,都进行数据增强操作,以进行数据增强。
69
+ # 然后,将文本输入的caption进行tokenize操作,
70
+ # 使用CLIP tokenizer和自定义tokenizer分别进行tokenize。
71
+ # 最后,将处理好的图像、caption和attention mask返回
72
+ def __getitem__(self, index):
73
+ # 首先调用open_h5()打开数据集的h5文件
74
+ if not hasattr(self, 'h5'):
75
+ self.open_h5()
76
+ #index = 1
77
+ images = list()
78
+ for i in range(5):
79
+ # 从h5文件中读取一组图像和对应的文本。
80
+ im = self.h5['image{}'.format(i)][index]
81
+ # print(im)
82
+ # pil_img = Image.fromarray(im)
83
+ # # 保存图像
84
+ # pil_img.save(os.path.join('/root/lihui/StoryVisualization/ori_test_images', '{:04d}.png'.format(i)))
85
+ # 对每个图像解码
86
+ im = cv2.imdecode(im, cv2.IMREAD_COLOR)
87
+ # 随机选择一个128像素的图像切片
88
+ idx = random.randint(0, im.shape[0] / 128 - 1)
89
+ # 将切片后的图像加到images列表中
90
+ images.append(im[idx * 128: (idx + 1) * 128])
91
+ # 深拷贝,后续不随images变化
92
+ ori_images = copy.deepcopy(images)
93
+ # 保存test原始图像
94
+
95
+ # for i, im in enumerate(images):
96
+ # file_path = '/root/lihui/StoryVisualization/ori_test_images/group{:02d}_image{:02d}.png'.format(index + 1,
97
+ # i + 1)
98
+ # cv2.imwrite(file_path, im)
99
+ # 将图像转换为张量
100
+ source_images = torch.stack([self.blip_image_processor(im) for im in images])
101
+ # 如果为continuation任务,将列表中的第一个图像从images中移除
102
+ images = images[1:] if self.args.task == 'continuation' else images
103
+ # 如果subset的值为train/val,则使用augment方法对images列表中的所有图像进行数据增强,并将其转换为张量
104
+ # 否则使用numpy.array方法将images列表转换为张量,并进行转置操作
105
+ images = torch.stack([self.augment(im) for im in images]) \
106
+ if self.subset in ['train', 'val'] else torch.from_numpy(np.array(images)).permute(0, 3, 1, 2)
107
+ ######################
108
+ # 读取当前索引处的文本,并使用decode方法将其解码为UTF-8
109
+ texts = self.h5['text'][index].decode('utf-8').split('|')
110
+ # print(f"index: {index}")
111
+ # for text in texts:
112
+ # print(f"texts: {text}")
113
+
114
+ # tokenize caption using default tokenizer
115
+ tokenized = self.clip_tokenizer(
116
+ texts[1:] if self.args.task == 'continuation' else texts,
117
+ padding="max_length",
118
+ max_length=self.max_length,
119
+ truncation=False,
120
+ return_tensors="pt",
121
+ )
122
+ captions, attention_mask = tokenized['input_ids'], tokenized['attention_mask']
123
+
124
+ tokenized = self.blip_tokenizer(
125
+ texts,
126
+ padding="max_length",
127
+ max_length=self.max_length,
128
+ truncation=False,
129
+ return_tensors="pt",
130
+ )
131
+ source_caption, source_attention_mask = tokenized['input_ids'], tokenized['attention_mask']
132
+ return images, captions, attention_mask, source_images, source_caption, source_attention_mask, texts, ori_images
133
+
134
+ # 返回数据集中样本的数量
135
+ # 如果是测试集,则返回100,否则返回对应的数据集中的样本数量
136
+ def __len__(self):
137
+ if not hasattr(self, 'h5'):
138
+ self.open_h5()
139
+ if self.subset == 'test':
140
+ #print('')
141
+ return 1
142
+ # if self.subset == 'test':
143
+ # return 100
144
+ return len(self.h5['text'])
datasets/vistdii.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import h5py
3
+ import numpy as np
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from torchvision import transforms
7
+ from transformers import CLIPTokenizer
8
+
9
+ from models.blip_override.blip import init_tokenizer
10
+
11
+
12
+ class StoryDataset(Dataset):
13
+ """
14
+ A custom subset class for the LRW (includes train, val, test) subset
15
+ """
16
+
17
+ def __init__(self, subset, args):
18
+ super(StoryDataset, self).__init__()
19
+ self.args = args
20
+
21
+ self.h5_file = args.get(args.dataset).hdf5_file
22
+ self.subset = subset
23
+
24
+ self.augment = transforms.Compose([
25
+ transforms.ToPILImage(),
26
+ transforms.Resize(512),
27
+ transforms.RandomCrop(512) if self.subset == 'train' else transforms.CenterCrop(512),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize([0.5], [0.5])
30
+ ]) if self.subset in ['train', 'val'] else transforms.Compose([
31
+ transforms.ToPILImage(),
32
+ transforms.Resize(64),
33
+ transforms.CenterCrop(64)
34
+ ])
35
+
36
+ self.dataset = args.dataset
37
+ self.max_length = args.get(args.dataset).max_length
38
+ self.clip_tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer")
39
+ self.blip_tokenizer = init_tokenizer()
40
+
41
+ self.blip_image_processor = transforms.Compose([
42
+ transforms.ToPILImage(),
43
+ transforms.Resize(224),
44
+ transforms.RandomCrop(224) if self.subset == 'train' else transforms.CenterCrop(224),
45
+ transforms.ToTensor(),
46
+ transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
47
+ ])
48
+
49
+ def open_h5(self):
50
+ h5 = h5py.File(self.h5_file, "r")
51
+ self.h5 = h5[self.subset]
52
+
53
+ def __getitem__(self, index):
54
+ if not hasattr(self, 'h5'):
55
+ self.open_h5()
56
+
57
+ images = list()
58
+ for i in range(5):
59
+ im = self.h5['image{}'.format(i)][index]
60
+ im = cv2.imdecode(im, cv2.IMREAD_COLOR)
61
+ images.append(im)
62
+
63
+ source_images = torch.stack([self.blip_image_processor(im) for im in images])
64
+ images = images[1:] if self.args.task == 'continuation' else images
65
+ images = [self.augment(im) for im in images]
66
+ images = torch.stack(images) if self.subset in ['train', 'val'] \
67
+ else torch.from_numpy(np.array([np.array(im) for im in images])).permute(0, 3, 1, 2)
68
+
69
+ texts = self.h5['dii'][index].decode('utf-8').split('|')
70
+
71
+ # tokenize caption using default tokenizer
72
+ tokenized = self.clip_tokenizer(
73
+ texts[1:] if self.args.task == 'continuation' else texts,
74
+ padding="max_length",
75
+ max_length=self.max_length,
76
+ truncation=False,
77
+ return_tensors="pt",
78
+ )
79
+ captions, attention_mask = tokenized['input_ids'], tokenized['attention_mask']
80
+
81
+ tokenized = self.blip_tokenizer(
82
+ texts,
83
+ padding="max_length",
84
+ max_length=self.max_length,
85
+ truncation=False,
86
+ return_tensors="pt",
87
+ )
88
+ source_caption, source_attention_mask = tokenized['input_ids'], tokenized['attention_mask']
89
+ return images, captions, attention_mask, source_images, source_caption, source_attention_mask
90
+
91
+ def __len__(self):
92
+ if not hasattr(self, 'h5'):
93
+ self.open_h5()
94
+ return len(self.h5['dii'])
datasets/vistsis.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import h5py
3
+ import numpy as np
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from torchvision import transforms
7
+ from transformers import CLIPTokenizer
8
+
9
+ from models.blip_override.blip import init_tokenizer
10
+
11
+
12
+ class StoryDataset(Dataset):
13
+ """
14
+ A custom subset class for the LRW (includes train, val, test) subset
15
+ """
16
+
17
+ def __init__(self, subset, args):
18
+ super(StoryDataset, self).__init__()
19
+ self.args = args
20
+
21
+ self.h5_file = args.get(args.dataset).hdf5_file
22
+ self.subset = subset
23
+
24
+ self.augment = transforms.Compose([
25
+ transforms.ToPILImage(),
26
+ transforms.Resize(512),
27
+ transforms.RandomCrop(512) if self.subset == 'train' else transforms.CenterCrop(512),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize([0.5], [0.5])
30
+ ]) if self.subset in ['train', 'val'] else transforms.Compose([
31
+ transforms.ToPILImage(),
32
+ transforms.Resize(64),
33
+ transforms.CenterCrop(64)
34
+ ])
35
+
36
+ self.dataset = args.dataset
37
+ self.max_length = args.get(args.dataset).max_length
38
+ self.clip_tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer")
39
+ self.blip_tokenizer = init_tokenizer()
40
+
41
+ self.blip_image_processor = transforms.Compose([
42
+ transforms.ToPILImage(),
43
+ transforms.Resize(224),
44
+ transforms.RandomCrop(224) if self.subset == 'train' else transforms.CenterCrop(224),
45
+ transforms.ToTensor(),
46
+ transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
47
+ ])
48
+
49
+ def open_h5(self):
50
+ h5 = h5py.File(self.h5_file, "r")
51
+ self.h5 = h5[self.subset]
52
+
53
+ def __getitem__(self, index):
54
+ if not hasattr(self, 'h5'):
55
+ self.open_h5()
56
+
57
+ images = list()
58
+ for i in range(5):
59
+ im = self.h5['image{}'.format(i)][index]
60
+ im = cv2.imdecode(im, cv2.IMREAD_COLOR)
61
+ images.append(im)
62
+
63
+ source_images = torch.stack([self.blip_image_processor(im) for im in images])
64
+ images = images[1:] if self.args.task == 'continuation' else images
65
+ images = [self.augment(im) for im in images]
66
+ images = torch.stack(images) if self.subset in ['train', 'val'] \
67
+ else torch.from_numpy(np.array([np.array(im) for im in images])).permute(0, 3, 1, 2)
68
+
69
+ texts = self.h5['sis'][index].decode('utf-8').split('|')
70
+
71
+ # tokenize caption using default tokenizer
72
+ tokenized = self.clip_tokenizer(
73
+ texts[1:] if self.args.task == 'continuation' else texts,
74
+ padding="max_length",
75
+ max_length=self.max_length,
76
+ truncation=False,
77
+ return_tensors="pt",
78
+ )
79
+ captions, attention_mask = tokenized['input_ids'], tokenized['attention_mask']
80
+
81
+ tokenized = self.blip_tokenizer(
82
+ texts,
83
+ padding="max_length",
84
+ max_length=self.max_length,
85
+ truncation=False,
86
+ return_tensors="pt",
87
+ )
88
+ source_caption, source_attention_mask = tokenized['input_ids'], tokenized['attention_mask']
89
+ return images, captions, attention_mask, source_images, source_caption, source_attention_mask
90
+
91
+ def __len__(self):
92
+ if not hasattr(self, 'h5'):
93
+ self.open_h5()
94
+ return len(self.h5['sis'])
environment.yml ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: story
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - defaults
6
+ dependencies:
7
+ - _libgcc_mutex=0.1=main
8
+ - _openmp_mutex=5.1=1_gnu
9
+ - blas=1.0=mkl
10
+ - brotlipy=0.7.0=py38h27cfd23_1003
11
+ - bzip2=1.0.8=h7b6447c_0
12
+ - ca-certificates=2023.01.10=h06a4308_0
13
+ - certifi=2022.12.7=py38h06a4308_0
14
+ - cffi=1.15.1=py38h5eee18b_3
15
+ - cryptography=39.0.1=py38h9ce1e76_0
16
+ - cuda-cudart=11.7.99=0
17
+ - cuda-cupti=11.7.101=0
18
+ - cuda-libraries=11.7.1=0
19
+ - cuda-nvrtc=11.7.99=0
20
+ - cuda-nvtx=11.7.91=0
21
+ - cuda-runtime=11.7.1=0
22
+ - ffmpeg=4.3=hf484d3e_0
23
+ - flit-core=3.8.0=py38h06a4308_0
24
+ - freetype=2.12.1=h4a9f257_0
25
+ - giflib=5.2.1=h5eee18b_3
26
+ - gmp=6.2.1=h295c915_3
27
+ - gnutls=3.6.15=he1e5248_0
28
+ - idna=3.4=py38h06a4308_0
29
+ - intel-openmp=2021.4.0=h06a4308_3561
30
+ - jpeg=9e=h5eee18b_1
31
+ - lame=3.100=h7b6447c_0
32
+ - lcms2=2.12=h3be6417_0
33
+ - ld_impl_linux-64=2.38=h1181459_1
34
+ - lerc=3.0=h295c915_0
35
+ - libcublas=11.10.3.66=0
36
+ - libcufft=10.7.2.124=h4fbf590_0
37
+ - libcufile=1.6.0.25=0
38
+ - libcurand=10.3.2.56=0
39
+ - libcusolver=11.4.0.1=0
40
+ - libcusparse=11.7.4.91=0
41
+ - libdeflate=1.17=h5eee18b_0
42
+ - libffi=3.4.2=h6a678d5_6
43
+ - libgcc-ng=11.2.0=h1234567_1
44
+ - libgomp=11.2.0=h1234567_1
45
+ - libiconv=1.16=h7f8727e_2
46
+ - libidn2=2.3.2=h7f8727e_0
47
+ - libnpp=11.7.4.75=0
48
+ - libnvjpeg=11.8.0.2=0
49
+ - libpng=1.6.39=h5eee18b_0
50
+ - libstdcxx-ng=11.2.0=h1234567_1
51
+ - libtasn1=4.19.0=h5eee18b_0
52
+ - libtiff=4.5.0=h6a678d5_2
53
+ - libunistring=0.9.10=h27cfd23_0
54
+ - libwebp=1.2.4=h11a3e52_1
55
+ - libwebp-base=1.2.4=h5eee18b_1
56
+ - lz4-c=1.9.4=h6a678d5_0
57
+ - mkl=2021.4.0=h06a4308_640
58
+ - mkl-service=2.4.0=py38h7f8727e_0
59
+ - mkl_fft=1.3.1=py38hd3c417c_0
60
+ - mkl_random=1.2.2=py38h51133e4_0
61
+ - ncurses=6.4=h6a678d5_0
62
+ - nettle=3.7.3=hbbd107a_1
63
+ - numpy-base=1.23.5=py38h31eccc5_0
64
+ - openh264=2.1.1=h4ff587b_0
65
+ - openssl=1.1.1t=h7f8727e_0
66
+ - pip=23.0.1=py38h06a4308_0
67
+ - pycparser=2.21=pyhd3eb1b0_0
68
+ - pyopenssl=23.0.0=py38h06a4308_0
69
+ - pysocks=1.7.1=py38h06a4308_0
70
+ - python=3.8.16=h7a1cb2a_3
71
+ - pytorch=1.13.1=py3.8_cuda11.7_cudnn8.5.0_0
72
+ - pytorch-cuda=11.7=h778d358_3
73
+ - pytorch-mutex=1.0=cuda
74
+ - readline=8.2=h5eee18b_0
75
+ - six=1.16.0=pyhd3eb1b0_1
76
+ - sqlite=3.41.1=h5eee18b_0
77
+ - tk=8.6.12=h1ccaba5_0
78
+ - typing_extensions=4.4.0=py38h06a4308_0
79
+ - urllib3=1.26.15=py38h06a4308_0
80
+ - wheel=0.38.4=py38h06a4308_0
81
+ - xz=5.2.10=h5eee18b_1
82
+ - zlib=1.2.13=h5eee18b_0
83
+ - zstd=1.5.4=hc292b87_0
84
+ - pip:
85
+ - absl-py==1.4.0
86
+ - accelerate==0.17.1
87
+ - aiofiles==23.1.0
88
+ - aiohttp==3.8.4
89
+ - aiosignal==1.3.1
90
+ - altair==4.2.2
91
+ - antlr4-python3-runtime==4.9.3
92
+ - anyio==3.6.2
93
+ - appdirs==1.4.4
94
+ - argon2-cffi==21.3.0
95
+ - argon2-cffi-bindings==21.2.0
96
+ - arrow==1.2.3
97
+ - asttokens==2.2.1
98
+ - async-timeout==4.0.2
99
+ - attrs==22.2.0
100
+ - backcall==0.2.0
101
+ - beautifulsoup4==4.11.2
102
+ - bleach==6.0.0
103
+ - cachetools==5.3.0
104
+ - chardet==5.1.0
105
+ - charset-normalizer==3.1.0
106
+ - click==8.1.3
107
+ - comm==0.1.2
108
+ - contourpy==1.0.7
109
+ - cycler==0.11.0
110
+ - debugpy==1.6.6
111
+ - decorator==5.1.1
112
+ - defusedxml==0.7.1
113
+ - diffusers==0.9.0
114
+ - docker-pycreds==0.4.0
115
+ - entrypoints==0.4
116
+ - executing==1.2.0
117
+ - fastapi==0.95.0
118
+ - fastjsonschema==2.16.3
119
+ - ffmpy==0.3.0
120
+ - filelock==3.10.0
121
+ - fire==0.5.0
122
+ - flatbuffers==23.3.3
123
+ - fonttools==4.39.3
124
+ - fqdn==1.5.1
125
+ - frozenlist==1.3.3
126
+ - fsspec==2023.3.0
127
+ - ftfy==6.1.1
128
+ - gitdb==4.0.10
129
+ - gitpython==3.1.31
130
+ - google-auth==2.16.2
131
+ - google-auth-oauthlib==0.4.6
132
+ - gradio==3.24.1
133
+ - gradio-client==0.0.5
134
+ - grpcio==1.51.3
135
+ - h11==0.14.0
136
+ - h5py==3.8.0
137
+ - httpcore==0.16.3
138
+ - httpx==0.23.3
139
+ - huggingface-hub==0.13.2
140
+ - hydra-core==1.3.2
141
+ - importlib-metadata==6.1.0
142
+ - importlib-resources==5.12.0
143
+ - ipykernel==6.21.3
144
+ - ipython==8.11.0
145
+ - ipython-genutils==0.2.0
146
+ - ipywidgets==8.0.4
147
+ - isoduration==20.11.0
148
+ - jedi==0.18.2
149
+ - jinja2==3.1.2
150
+ - jsonpointer==2.3
151
+ - jsonschema==4.17.3
152
+ - jupyter==1.0.0
153
+ - jupyter-client==8.0.3
154
+ - jupyter-console==6.6.3
155
+ - jupyter-core==5.3.0
156
+ - jupyter-events==0.6.3
157
+ - jupyter-server==2.5.0
158
+ - jupyter-server-terminals==0.4.4
159
+ - jupyterlab-pygments==0.2.2
160
+ - jupyterlab-widgets==3.0.5
161
+ - kiwisolver==1.4.4
162
+ - lightning-bolts==0.5.0
163
+ - linkify-it-py==2.0.0
164
+ - lora-diffusion==0.1.7
165
+ - markdown==3.4.1
166
+ - markdown-it-py==2.2.0
167
+ - markupsafe==2.1.2
168
+ - matplotlib==3.7.1
169
+ - matplotlib-inline==0.1.6
170
+ - mdit-py-plugins==0.3.3
171
+ - mdurl==0.1.2
172
+ - mediapipe==0.9.1.0
173
+ - mistune==2.0.5
174
+ - multidict==6.0.4
175
+ - nbclassic==0.5.3
176
+ - nbclient==0.7.2
177
+ - nbconvert==7.2.10
178
+ - nbformat==5.7.3
179
+ - nest-asyncio==1.5.6
180
+ - notebook==6.5.3
181
+ - notebook-shim==0.2.2
182
+ - numpy==1.24.2
183
+ - oauthlib==3.2.2
184
+ - omegaconf==2.3.0
185
+ - opencv-contrib-python==4.7.0.72
186
+ - opencv-python==4.7.0.72
187
+ - orjson==3.8.9
188
+ - packaging==23.0
189
+ - pandas==1.5.3
190
+ - pandocfilters==1.5.0
191
+ - parso==0.8.3
192
+ - pathtools==0.1.2
193
+ - pexpect==4.8.0
194
+ - pickleshare==0.7.5
195
+ - pillow==9.4.0
196
+ - pkgutil-resolve-name==1.3.10
197
+ - platformdirs==3.1.1
198
+ - prometheus-client==0.16.0
199
+ - prompt-toolkit==3.0.38
200
+ - protobuf==3.20.1
201
+ - psutil==5.9.4
202
+ - ptyprocess==0.7.0
203
+ - pure-eval==0.2.2
204
+ - pyasn1==0.4.8
205
+ - pyasn1-modules==0.2.8
206
+ - pydantic==1.10.7
207
+ - pydeprecate==0.3.2
208
+ - pydub==0.25.1
209
+ - pygments==2.14.0
210
+ - pyparsing==3.0.9
211
+ - pyrsistent==0.19.3
212
+ - python-dateutil==2.8.2
213
+ - python-json-logger==2.0.7
214
+ - python-multipart==0.0.6
215
+ - pytorch-lightning==1.6.5
216
+ - pytz==2023.3
217
+ - pyyaml==6.0
218
+ - pyzmq==25.0.1
219
+ - qtconsole==5.4.1
220
+ - qtpy==2.3.0
221
+ - regex==2022.10.31
222
+ - requests==2.28.2
223
+ - requests-oauthlib==1.3.1
224
+ - rfc3339-validator==0.1.4
225
+ - rfc3986==1.5.0
226
+ - rfc3986-validator==0.1.1
227
+ - rsa==4.9
228
+ - safetensors==0.3.0
229
+ - scipy==1.10.1
230
+ - semantic-version==2.10.0
231
+ - send2trash==1.8.0
232
+ - sentry-sdk==1.17.0
233
+ - setproctitle==1.3.2
234
+ - setuptools==59.5.0
235
+ - smmap==5.0.0
236
+ - sniffio==1.3.0
237
+ - soupsieve==2.4
238
+ - stack-data==0.6.2
239
+ - starlette==0.26.1
240
+ - tensorboard==2.12.0
241
+ - tensorboard-data-server==0.7.0
242
+ - tensorboard-plugin-wit==1.8.1
243
+ - termcolor==2.2.0
244
+ - terminado==0.17.1
245
+ - timm==0.6.12
246
+ - tinycss2==1.2.1
247
+ - tokenizers==0.13.2
248
+ - toolz==0.12.0
249
+ - torch==1.9.0
250
+ - torchaudio==0.9.0
251
+ - torchmetrics==0.11.4
252
+ - torchvision==0.10.0+cu111
253
+ - tornado==6.2
254
+ - tqdm==4.65.0
255
+ - traitlets==5.9.0
256
+ - transformers==4.28.1
257
+ - typing-extensions==4.5.0
258
+ - uc-micro-py==1.0.1
259
+ - uri-template==1.2.0
260
+ - uvicorn==0.21.1
261
+ - wandb==0.14.0
262
+ - wcwidth==0.2.6
263
+ - webcolors==1.12
264
+ - webencodings==0.5.1
265
+ - websocket-client==1.5.1
266
+ - websockets==11.0
267
+ - werkzeug==2.2.3
268
+ - widgetsnbextension==4.0.5
269
+ - yarl==1.8.2
270
+ - zipp==3.15.0
271
+ prefix: /root/anaconda3/envs/story
fid_utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy import linalg
3
+
4
+
5
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
6
+ mu1 = np.atleast_1d(mu1)
7
+ mu2 = np.atleast_1d(mu2)
8
+
9
+ sigma1 = np.atleast_2d(sigma1)
10
+ sigma2 = np.atleast_2d(sigma2)
11
+
12
+ assert mu1.shape == mu2.shape, 'Training and test mean vectors have different lengths'
13
+ assert sigma1.shape == sigma2.shape, 'Training and test covariances have different dimensions'
14
+
15
+ diff = mu1 - mu2
16
+
17
+ # Product might be almost singular
18
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
19
+ if not np.isfinite(covmean).all():
20
+ print('fid calculation produces singular product; adding %s to diagonal of cov estimates' % eps)
21
+ offset = np.eye(sigma1.shape[0]) * eps
22
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
23
+
24
+ # Numerical error might give slight imaginary component
25
+ if np.iscomplexobj(covmean):
26
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
27
+ m = np.max(np.abs(covmean.imag))
28
+ raise ValueError('Imaginary component {}'.format(m))
29
+ covmean = covmean.real
30
+
31
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean)
32
+
33
+
34
+ def calculate_fid_given_features(feature1, feature2):
35
+ mu1 = np.mean(feature1, axis=0)
36
+ sigma1 = np.cov(feature1, rowvar=False)
37
+ mu2 = np.mean(feature2, axis=0)
38
+ sigma2 = np.cov(feature2, rowvar=False)
39
+ fid_value = calculate_frechet_distance(mu1, sigma1, mu2, sigma2)
40
+
41
+ return fid_value
main.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import os
3
+
4
+ import cv2
5
+ import hydra
6
+ import numpy as np
7
+ import pytorch_lightning as pl
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint
11
+ from PIL import Image
12
+ from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler
13
+ from omegaconf import DictConfig
14
+ from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
15
+ from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
16
+ from pytorch_lightning.loggers import TensorBoardLogger
17
+ from pytorch_lightning.strategies import DDPStrategy
18
+ from torch import nn
19
+ from torch.utils.data import DataLoader
20
+ from torchvision import transforms
21
+ from transformers import CLIPTokenizer, CLIPTextModel
22
+
23
+ from fid_utils import calculate_fid_given_features
24
+ from lora_diffusion import monkeypatch_or_replace_lora, tune_lora_scale
25
+
26
+ from models.blip_override.blip import blip_feature_extractor, init_tokenizer
27
+ from models.diffusers_override.unet_2d_condition import UNet2DConditionModel
28
+ from models.inception import InceptionV3
29
+ unet_target_replace_module = {"CrossAttention", "Attention", "GEGLU"}
30
+ #!/usr/bin/env python3
31
+ from transformers import CLIPProcessor
32
+ import transformers
33
+ from PIL import Image
34
+ import PIL.Image
35
+ import numpy as np
36
+ import torchvision.transforms as tvtrans
37
+ import requests
38
+ from io import BytesIO
39
+
40
+ class LightningDataset(pl.LightningDataModule):
41
+ def __init__(self, args: DictConfig):
42
+ super(LightningDataset, self).__init__()
43
+ self.kwargs = {"num_workers": args.num_workers, "persistent_workers": True if args.num_workers > 0 else False,
44
+ "pin_memory": True}
45
+ self.args = args
46
+
47
+ def setup(self, stage="fit"):
48
+ if self.args.dataset == "pororo":
49
+ import datasets.pororo as data
50
+ elif self.args.dataset == 'flintstones':
51
+ import datasets.flintstones as data
52
+ elif self.args.dataset == 'vistsis':
53
+ import datasets.vistsis as data
54
+ elif self.args.dataset == 'vistdii':
55
+ import datasets.vistdii as data
56
+ else:
57
+ raise ValueError("Unknown dataset: {}".format(self.args.dataset))
58
+ if stage == "fit":
59
+ self.train_data = data.StoryDataset("train", self.args)
60
+ self.val_data = data.StoryDataset("val", self.args)
61
+ if stage == "test":
62
+ self.test_data = data.StoryDataset("test", self.args)
63
+
64
+ def train_dataloader(self):
65
+ if not hasattr(self, 'trainloader'):
66
+ self.trainloader = DataLoader(self.train_data, batch_size=self.args.batch_size, shuffle=True, **self.kwargs)
67
+ return self.trainloader
68
+
69
+ def val_dataloader(self):
70
+ return DataLoader(self.val_data, batch_size=self.args.batch_size, shuffle=False, **self.kwargs)
71
+
72
+ def test_dataloader(self):
73
+ return DataLoader(self.test_data, batch_size=self.args.batch_size, shuffle=False, **self.kwargs)
74
+
75
+ def predict_dataloader(self):
76
+ return DataLoader(self.test_data, batch_size=self.args.batch_size, shuffle=False, **self.kwargs)
77
+
78
+ def get_length_of_train_dataloader(self):
79
+ if not hasattr(self, 'trainloader'):
80
+ self.trainloader = DataLoader(self.train_data, batch_size=self.args.batch_size, shuffle=True, **self.kwargs)
81
+ return len(self.trainloader)
82
+
83
+
84
+ class ARLDM(pl.LightningModule):
85
+ def __init__(self, args: DictConfig, steps_per_epoch=1):
86
+ super(ARLDM, self).__init__()
87
+ self.args = args
88
+ self.steps_per_epoch = steps_per_epoch
89
+ """
90
+ Configurations
91
+ """
92
+ self.task = args.task
93
+
94
+ if args.mode == 'sample':
95
+ if args.scheduler == "pndm":
96
+ self.scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
97
+ skip_prk_steps=True)
98
+ elif args.scheduler == "ddim":
99
+ self.scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
100
+ clip_sample=False, set_alpha_to_one=True)
101
+ else:
102
+ raise ValueError("Scheduler not supported")
103
+ self.fid_augment = transforms.Compose([
104
+ transforms.Resize([64, 64]),
105
+ transforms.ToTensor(),
106
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
107
+ ])
108
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
109
+ self.inception = InceptionV3([block_idx])
110
+
111
+ self.clip_tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer")
112
+ ##############################
113
+ #self.clip_tokenizer.save_pretrained('/root/lihui/StoryVisualization/save_pretrained/tokenizer')
114
+ self.blip_tokenizer = init_tokenizer()
115
+ self.blip_image_processor = transforms.Compose([
116
+ transforms.Resize([224, 224]),
117
+ transforms.ToTensor(),
118
+ transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
119
+ ])
120
+ self.max_length = args.get(args.dataset).max_length
121
+
122
+ blip_image_null_token = self.blip_image_processor(
123
+ Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))).unsqueeze(0).float()
124
+ clip_text_null_token = self.clip_tokenizer([""], padding="max_length", max_length=self.max_length,
125
+ return_tensors="pt").input_ids
126
+ blip_text_null_token = self.blip_tokenizer([""], padding="max_length", max_length=self.max_length,
127
+ return_tensors="pt").input_ids
128
+
129
+ self.register_buffer('clip_text_null_token', clip_text_null_token)
130
+ self.register_buffer('blip_text_null_token', blip_text_null_token)
131
+ self.register_buffer('blip_image_null_token', blip_image_null_token)
132
+
133
+ self.text_encoder = CLIPTextModel.from_pretrained('runwayml/stable-diffusion-v1-5',
134
+ subfolder="text_encoder")
135
+ ############################################
136
+ #self.text_encoder.save_pretrained('/root/lihui/StoryVisualization/save_pretrained/text_encoder')
137
+ self.text_encoder.resize_token_embeddings(args.get(args.dataset).clip_embedding_tokens)
138
+ # resize_position_embeddings
139
+ old_embeddings = self.text_encoder.text_model.embeddings.position_embedding
140
+ new_embeddings = self.text_encoder._get_resized_embeddings(old_embeddings, self.max_length)
141
+ self.text_encoder.text_model.embeddings.position_embedding = new_embeddings
142
+ self.text_encoder.config.max_position_embeddings = self.max_length
143
+ self.text_encoder.max_position_embeddings = self.max_length
144
+ self.text_encoder.text_model.embeddings.position_ids = torch.arange(self.max_length).expand((1, -1))
145
+
146
+ self.modal_type_embeddings = nn.Embedding(2, 768)
147
+ self.time_embeddings = nn.Embedding(5, 768)
148
+ self.mm_encoder = blip_feature_extractor(
149
+ # pretrained='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth',
150
+ pretrained='/root/lihui/StoryVisualization/save_pretrained/model_large.pth',
151
+ image_size=224, vit='large')#, local_files_only=True)
152
+ self.mm_encoder.text_encoder.resize_token_embeddings(args.get(args.dataset).blip_embedding_tokens)
153
+
154
+ self.vae = AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="vae")
155
+ self.unet = UNet2DConditionModel.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="unet")
156
+
157
+ self.noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
158
+ num_train_timesteps=1000)
159
+ # monkeypatch_or_replace_lora(
160
+ # self.unet,
161
+ # torch.load("lora/example_loras/analog_svd_rank4.safetensors"),
162
+ # r=4,
163
+ # target_replace_module=unet_target_replace_module,
164
+ # )
165
+ #
166
+ # tune_lora_scale(self.unet, 1.00)
167
+ #tune_lora_scale(self.text_encoder, 1.00)
168
+
169
+ # torch.manual_seed(0)
170
+ ###################################
171
+ #self.vae.save_pretrained('/root/lihui/StoryVisualization/save_pretrained/vae')
172
+ #self.unet.save_pretrained('/root/lihui/StoryVisualization/save_pretrained/unet')
173
+
174
+ # Freeze vae and unet
175
+ self.freeze_params(self.vae.parameters())
176
+ if args.freeze_resnet:
177
+ self.freeze_params([p for n, p in self.unet.named_parameters() if "attentions" not in n])
178
+
179
+ if args.freeze_blip and hasattr(self, "mm_encoder"):
180
+ self.freeze_params(self.mm_encoder.parameters())
181
+ self.unfreeze_params(self.mm_encoder.text_encoder.embeddings.word_embeddings.parameters())
182
+
183
+ if args.freeze_clip and hasattr(self, "text_encoder"):
184
+ self.freeze_params(self.text_encoder.parameters())
185
+ self.unfreeze_params(self.text_encoder.text_model.embeddings.token_embedding.parameters())
186
+
187
+ @staticmethod
188
+ def freeze_params(params):
189
+ for param in params:
190
+ param.requires_grad = False
191
+
192
+ @staticmethod
193
+ def unfreeze_params(params):
194
+ for param in params:
195
+ param.requires_grad = True
196
+
197
+ def configure_optimizers(self):
198
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.args.init_lr, weight_decay=1e-4) # optim_bits=8
199
+ scheduler = LinearWarmupCosineAnnealingLR(optimizer,
200
+ warmup_epochs=self.args.warmup_epochs * self.steps_per_epoch,
201
+ max_epochs=self.args.max_epochs * self.steps_per_epoch)
202
+ optim_dict = {
203
+ 'optimizer': optimizer,
204
+ 'lr_scheduler': {
205
+ 'scheduler': scheduler, # The LR scheduler instance (required)
206
+ 'interval': 'step', # The unit of the scheduler's step size
207
+ }
208
+ }
209
+ return optim_dict
210
+
211
+ def forward(self, batch):
212
+ if self.args.freeze_clip and hasattr(self, "text_encoder"):
213
+ self.text_encoder.eval()
214
+ if self.args.freeze_blip and hasattr(self, "mm_encoder"):
215
+ self.mm_encoder.eval()
216
+ images, captions, attention_mask, source_images, source_caption, source_attention_mask, texts, ori_images = batch
217
+ B, V, S = captions.shape
218
+ src_V = V + 1 if self.task == 'continuation' else V
219
+ images = torch.flatten(images, 0, 1)
220
+ captions = torch.flatten(captions, 0, 1)
221
+ attention_mask = torch.flatten(attention_mask, 0, 1)
222
+ source_images = torch.flatten(source_images, 0, 1)
223
+ source_caption = torch.flatten(source_caption, 0, 1)
224
+ source_attention_mask = torch.flatten(source_attention_mask, 0, 1)
225
+ # 1 is not masked, 0 is maske
226
+
227
+ classifier_free_idx = np.random.rand(B * V) < 0.1
228
+
229
+ caption_embeddings = self.text_encoder(captions, attention_mask).last_hidden_state # B * V, S, D
230
+ source_embeddings = self.mm_encoder(source_images, source_caption, source_attention_mask,
231
+ mode='multimodal').reshape(B, src_V * S, -1)
232
+ source_embeddings = source_embeddings.repeat_interleave(V, dim=0)
233
+ caption_embeddings[classifier_free_idx] = \
234
+ self.text_encoder(self.clip_text_null_token).last_hidden_state[0]
235
+ source_embeddings[classifier_free_idx] = \
236
+ self.mm_encoder(self.blip_image_null_token, self.blip_text_null_token, attention_mask=None,
237
+ mode='multimodal')[0].repeat(src_V, 1)
238
+ caption_embeddings += self.modal_type_embeddings(torch.tensor(0, device=self.device))
239
+ source_embeddings += self.modal_type_embeddings(torch.tensor(1, device=self.device))
240
+ source_embeddings += self.time_embeddings(
241
+ torch.arange(src_V, device=self.device).repeat_interleave(S, dim=0))
242
+ encoder_hidden_states = torch.cat([caption_embeddings, source_embeddings], dim=1)
243
+
244
+ attention_mask = torch.cat(
245
+ [attention_mask, source_attention_mask.reshape(B, src_V * S).repeat_interleave(V, dim=0)], dim=1)
246
+ attention_mask = ~(attention_mask.bool()) # B * V, (src_V + 1) * S
247
+ attention_mask[classifier_free_idx] = False
248
+
249
+ # B, V, V, S
250
+ square_mask = torch.triu(torch.ones((V, V), device=self.device)).bool()
251
+ square_mask = square_mask.unsqueeze(0).unsqueeze(-1).expand(B, V, V, S)
252
+ square_mask = square_mask.reshape(B * V, V * S)
253
+ attention_mask[:, -V * S:] = torch.logical_or(square_mask, attention_mask[:, -V * S:])
254
+
255
+ latents = self.vae.encode(images).latent_dist.sample()
256
+ latents = latents * 0.18215
257
+
258
+ noise = torch.randn(latents.shape, device=self.device)
259
+ bsz = latents.shape[0]
260
+ timesteps = torch.randint(0, self.noise_scheduler.num_train_timesteps, (bsz,), device=self.device).long()
261
+ noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
262
+
263
+ noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, attention_mask).sample
264
+ loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
265
+ return loss
266
+
267
+ def sample(self, batch):
268
+ original_images, captions, attention_mask, source_images, source_caption, source_attention_mask, texts, ori_test_images = batch
269
+ B, V, S = captions.shape
270
+ src_V = V + 1 if self.task == 'continuation' else V
271
+ original_images = torch.flatten(original_images, 0, 1)
272
+ captions = torch.flatten(captions, 0, 1)
273
+ attention_mask = torch.flatten(attention_mask, 0, 1)
274
+ source_images = torch.flatten(source_images, 0, 1)
275
+ source_caption = torch.flatten(source_caption, 0, 1)
276
+ source_attention_mask = torch.flatten(source_attention_mask, 0, 1)
277
+
278
+ caption_embeddings = self.text_encoder(captions, attention_mask).last_hidden_state # B * V, S, D
279
+ source_embeddings = self.mm_encoder(source_images, source_caption, source_attention_mask,
280
+ mode='multimodal').reshape(B, src_V * S, -1)
281
+ caption_embeddings += self.modal_type_embeddings(torch.tensor(0, device=self.device))
282
+ source_embeddings += self.modal_type_embeddings(torch.tensor(1, device=self.device))
283
+ source_embeddings += self.time_embeddings(
284
+ torch.arange(src_V, device=self.device).repeat_interleave(S, dim=0))
285
+ source_embeddings = source_embeddings.repeat_interleave(V, dim=0)
286
+ encoder_hidden_states = torch.cat([caption_embeddings, source_embeddings], dim=1)
287
+
288
+ attention_mask = torch.cat(
289
+ [attention_mask, source_attention_mask.reshape(B, src_V * S).repeat_interleave(V, dim=0)], dim=1)
290
+ attention_mask = ~(attention_mask.bool()) # B * V, (src_V + 1) * S
291
+ # B, V, V, S
292
+ square_mask = torch.triu(torch.ones((V, V), device=self.device)).bool()
293
+ square_mask = square_mask.unsqueeze(0).unsqueeze(-1).expand(B, V, V, S)
294
+ square_mask = square_mask.reshape(B * V, V * S)
295
+ attention_mask[:, -V * S:] = torch.logical_or(square_mask, attention_mask[:, -V * S:])
296
+
297
+ uncond_caption_embeddings = self.text_encoder(self.clip_text_null_token).last_hidden_state
298
+ uncond_source_embeddings = self.mm_encoder(self.blip_image_null_token, self.blip_text_null_token,
299
+ attention_mask=None, mode='multimodal').repeat(1, src_V, 1)
300
+ uncond_caption_embeddings += self.modal_type_embeddings(torch.tensor(0, device=self.device))
301
+ uncond_source_embeddings += self.modal_type_embeddings(torch.tensor(1, device=self.device))
302
+ uncond_source_embeddings += self.time_embeddings(
303
+ torch.arange(src_V, device=self.device).repeat_interleave(S, dim=0))
304
+ uncond_embeddings = torch.cat([uncond_caption_embeddings, uncond_source_embeddings], dim=1)
305
+ uncond_embeddings = uncond_embeddings.expand(B * V, -1, -1)
306
+
307
+ encoder_hidden_states = torch.cat([uncond_embeddings, encoder_hidden_states])
308
+ uncond_attention_mask = torch.zeros((B * V, (src_V + 1) * S), device=self.device).bool()
309
+ uncond_attention_mask[:, -V * S:] = square_mask
310
+ attention_mask = torch.cat([uncond_attention_mask, attention_mask], dim=0)
311
+
312
+ attention_mask = attention_mask.reshape(2, B, V, (src_V + 1) * S)
313
+ images = list()
314
+ for i in range(V):
315
+ encoder_hidden_states = encoder_hidden_states.reshape(2, B, V, (src_V + 1) * S, -1)
316
+ new_image = self.diffusion(encoder_hidden_states[:, :, i].reshape(2 * B, (src_V + 1) * S, -1),
317
+ attention_mask[:, :, i].reshape(2 * B, (src_V + 1) * S),
318
+ 512, 512, self.args.num_inference_steps, self.args.guidance_scale, 0.0)
319
+ images += new_image
320
+
321
+ new_image = torch.stack([self.blip_image_processor(im) for im in new_image]).to(self.device)
322
+
323
+ new_embedding = self.mm_encoder(new_image, # B,C,H,W
324
+ source_caption.reshape(B, src_V, S)[:, i + src_V - V],
325
+ source_attention_mask.reshape(B, src_V, S)[:, i + src_V - V],
326
+ mode='multimodal') # B, S, D
327
+ new_embedding = new_embedding.repeat_interleave(V, dim=0)
328
+ new_embedding += self.modal_type_embeddings(torch.tensor(1, device=self.device))
329
+ new_embedding += self.time_embeddings(torch.tensor(i + src_V - V, device=self.device))
330
+
331
+ encoder_hidden_states = encoder_hidden_states[1].reshape(B * V, (src_V + 1) * S, -1)
332
+ encoder_hidden_states[:, (i + 1 + src_V - V) * S:(i + 2 + src_V - V) * S] = new_embedding
333
+ encoder_hidden_states = torch.cat([uncond_embeddings, encoder_hidden_states])
334
+
335
+ return original_images, images, texts, ori_test_images
336
+
337
+
338
+ def training_step(self, batch, batch_idx):
339
+ loss = self(batch)
340
+ self.log('loss/train_loss', loss, on_step=True, on_epoch=False, sync_dist=True, prog_bar=True)
341
+ return loss
342
+
343
+ def validation_step(self, batch, batch_idx):
344
+ loss = self(batch)
345
+ self.log('loss/val_loss', loss, on_step=False, on_epoch=True, sync_dist=True, prog_bar=True)
346
+
347
+ def predict_step(self, batch, batch_idx, dataloader_idx=0):
348
+ original_images, images, texts, ori_test_images = self.sample(batch)
349
+ if self.args.calculate_fid:
350
+ original_images = original_images.cpu().numpy().astype('uint8')
351
+ original_images = [Image.fromarray(im, 'RGB') for im in original_images]
352
+
353
+ # ori_test_images = torch.stack(ori_test_images).cpu().numpy().astype('uint8')
354
+ # ori_test_images = [Image.fromarray(im, 'RGB') for im in ori_test_images]
355
+ ori = self.inception_feature(original_images).cpu().numpy()
356
+ gen = self.inception_feature(images).cpu().numpy()
357
+ else:
358
+ ori = None
359
+ gen = None
360
+
361
+ return images, ori, gen, ori_test_images, texts
362
+
363
+ def diffusion(self, encoder_hidden_states, attention_mask, height, width, num_inference_steps, guidance_scale, eta):
364
+ latents = torch.randn((encoder_hidden_states.shape[0] // 2, self.unet.in_channels, height // 8, width // 8),
365
+ device=self.device)
366
+
367
+ # set timesteps
368
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
369
+ extra_set_kwargs = {}
370
+ if accepts_offset:
371
+ extra_set_kwargs["offset"] = 1
372
+
373
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
374
+
375
+ # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
376
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
377
+ latents = latents * self.scheduler.sigmas[0]
378
+
379
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
380
+ extra_step_kwargs = {}
381
+ if accepts_eta:
382
+ extra_step_kwargs["eta"] = eta
383
+
384
+ for i, t in enumerate(self.scheduler.timesteps):
385
+ # expand the latents if we are doing classifier free guidance
386
+ latent_model_input = torch.cat([latents] * 2)
387
+
388
+ # noise_pred = self.unet(latent_model_input, t, encoder_hidden_states).sample
389
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states, attention_mask).sample
390
+
391
+ # perform guidance
392
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
393
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
394
+
395
+ # compute the previous noisy sample x_t -> x_t-1
396
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
397
+
398
+ # scale and decode the image latents with vae
399
+ latents = 1 / 0.18215 * latents
400
+ image = self.vae.decode(latents).sample
401
+
402
+ image = (image / 2 + 0.5).clamp(0, 1)
403
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
404
+
405
+ return self.numpy_to_pil(image)
406
+
407
+ @staticmethod
408
+ def numpy_to_pil(images):
409
+ """
410
+ Convert a numpy image or a batch of images to a PIL image.
411
+ """
412
+ if images.ndim == 3:
413
+ images = images[None, ...]
414
+ images = (images * 255).round().astype("uint8")
415
+ pil_images = [Image.fromarray(image, 'RGB') for image in images]
416
+
417
+ return pil_images
418
+
419
+ def inception_feature(self, images):
420
+ images = torch.stack([self.fid_augment(image) for image in images])
421
+ images = images.type(torch.FloatTensor).to(self.device)
422
+ images = (images + 1) / 2
423
+ images = F.interpolate(images, size=(299, 299), mode='bilinear', align_corners=False)
424
+ pred = self.inception(images)[0]
425
+
426
+ if pred.shape[2] != 1 or pred.shape[3] != 1:
427
+ pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1))
428
+ return pred.reshape(-1, 2048)
429
+
430
+
431
+ def train(args: DictConfig) -> None:
432
+ dataloader = LightningDataset(args)
433
+ dataloader.setup('fit')
434
+ # dataloader.
435
+ model = ARLDM(args, steps_per_epoch=dataloader.get_length_of_train_dataloader())
436
+
437
+ logger = TensorBoardLogger(save_dir=os.path.join(args.ckpt_dir, args.run_name), name='log', default_hp_metric=False)
438
+
439
+ checkpoint_callback = ModelCheckpoint(
440
+ dirpath=os.path.join(args.ckpt_dir, args.run_name),
441
+ save_top_k=0,
442
+ save_last=True
443
+ )
444
+
445
+ lr_monitor = LearningRateMonitor(logging_interval='step')
446
+
447
+ callback_list = [lr_monitor, checkpoint_callback]
448
+
449
+ trainer = pl.Trainer(
450
+ accelerator='gpu',
451
+ devices=args.gpu_ids,
452
+ max_epochs=args.max_epochs,
453
+ benchmark=True,
454
+ logger=logger,
455
+ log_every_n_steps=1,
456
+ callbacks=callback_list,
457
+ strategy=DDPStrategy(find_unused_parameters=False)
458
+ )
459
+ trainer.fit(model, dataloader, ckpt_path=args.train_model_file)
460
+
461
+
462
+ def sample(args: DictConfig) -> None:
463
+
464
+ assert args.test_model_file is not None, "test_model_file cannot be None"
465
+ assert args.gpu_ids == 1 or len(args.gpu_ids) == 1, "Only one GPU is supported in test mode"
466
+ dataloader = LightningDataset(args)
467
+ dataloader.setup('test')
468
+ model = ARLDM.load_from_checkpoint(args.test_model_file, args=args, strict=False)
469
+
470
+ predictor = pl.Trainer(
471
+ accelerator='gpu',
472
+ devices=args.gpu_ids,
473
+ max_epochs=-1,
474
+ benchmark=True
475
+ )
476
+ predictions = predictor.predict(model, dataloader)
477
+ images = [elem for sublist in predictions for elem in sublist[0]]
478
+ ori_images = [elem for sublist in predictions for elem in sublist[3]]
479
+ ori_test_images = list()
480
+ if not os.path.exists(args.sample_output_dir):
481
+ try:
482
+ os.mkdir(args.sample_output_dir)
483
+ except:
484
+ pass
485
+
486
+ text_list = [elem for sublist in predictions for elem in sublist[4]]
487
+ ################################
488
+ # print(f"index: {index}")
489
+ num_images = len(images)
490
+ num_groups = (num_images + 4) // 5 # 计算总共需要的组数
491
+
492
+ for g in range(num_groups):
493
+ print('Story {}:'.format(g + 1)) # 打印组号
494
+ start_index = g * 5 # 当前组的起始索引
495
+ end_index = min(start_index + 5, num_images) # 当前组的结束索引
496
+ for i in range(start_index, end_index):
497
+ print(text_list[i]) # 打印对应的文本
498
+ images[i].save(
499
+ os.path.join(args.sample_output_dir, 'group{:02d}_image{:02d}.png'.format(g + 1, i - start_index + 1)))
500
+ # ori_images[i] = ori_images[i]
501
+ ori_images_pil = Image.fromarray(np.uint8(ori_images[i].detach().cpu().squeeze().float().numpy())).convert("RGB")
502
+ ori_test_images.append(ori_images_pil)
503
+ ori_images_pil.save(
504
+ os.path.join('/root/lihui/StoryVisualization/ori_test_images_epoch10', 'group{:02d}_image{:02d}.png'.format(g + 1, i - start_index + 1)))
505
+ # for i, im in enumerate(ori_images):
506
+ # file_path = '/root/lihui/StoryVisualization/ori_test_images/image{}.png'.format(i)
507
+ # cv2.imwrite(file_path, im)
508
+
509
+
510
+ if args.calculate_fid:
511
+ ori = np.array([elem for sublist in predictions for elem in sublist[1]])
512
+ gen = np.array([elem for sublist in predictions for elem in sublist[2]])
513
+ fid = calculate_fid_given_features(ori, gen)
514
+ print('FID: {}'.format(fid))
515
+
516
+
517
+
518
+
519
+
520
+ @hydra.main(config_path=".", config_name="config")
521
+ def main(args: DictConfig) -> None:
522
+ pl.seed_everything(args.seed)
523
+ if args.num_cpu_cores > 0:
524
+ torch.set_num_threads(args.num_cpu_cores)
525
+
526
+ if args.mode == 'train':
527
+ ############################
528
+ train(args)
529
+ elif args.mode == 'sample':
530
+ # dataloader = LightningDataset(args)
531
+ # dataloader.setup('test')
532
+ sample(args)
533
+
534
+
535
+
536
+ if __name__ == '__main__':
537
+ main()
models/blip_override/blip.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import warnings
9
+
10
+ warnings.filterwarnings("ignore")
11
+
12
+ from .vit import VisionTransformer, interpolate_pos_embed
13
+ from .med import BertModel, BertLMHeadModel
14
+ from transformers import BertTokenizer, BertConfig
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+ import os
20
+ from urllib.parse import urlparse
21
+ from timm.models.hub import download_cached_file
22
+
23
+
24
+ class BLIP_Base(nn.Module):
25
+ def __init__(self,
26
+ med_config='models/blip_override/med_config.json',
27
+ image_size=224,
28
+ vit='base',
29
+ vit_grad_ckpt=False,
30
+ vit_ckpt_layer=0,
31
+ ):
32
+ """
33
+ Args:
34
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
35
+ image_size (int): input image size
36
+ vit (str): model size of vision transformer
37
+ """
38
+ super().__init__()
39
+
40
+ self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer)
41
+ self.tokenizer = init_tokenizer()
42
+ med_config = BertConfig.from_json_file(med_config)
43
+ med_config.encoder_width = vision_width
44
+ self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
45
+
46
+ def forward(self, image, text, attention_mask, mode):
47
+ assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
48
+ if mode == 'image':
49
+ # return image features
50
+ image_embeds = self.visual_encoder(image)
51
+ return image_embeds
52
+
53
+ elif mode == 'text':
54
+ # return text features
55
+ text_output = self.text_encoder(text, attention_mask=attention_mask, return_dict=True, mode='text')
56
+ return text_output.last_hidden_state
57
+
58
+ elif mode == 'multimodal':
59
+ # return multimodel features
60
+ image_embeds = self.visual_encoder(image)
61
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
62
+
63
+ text[:, 0] = self.tokenizer.enc_token_id
64
+ output = self.text_encoder(text,
65
+ attention_mask=attention_mask,
66
+ encoder_hidden_states=image_embeds,
67
+ encoder_attention_mask=image_atts,
68
+ return_dict=True,
69
+ )
70
+ return output.last_hidden_state
71
+
72
+
73
+ class BLIP_Decoder(nn.Module):
74
+ def __init__(self,
75
+ med_config='models/blip_override/med_config.json',
76
+ image_size=384,
77
+ vit='base',
78
+ vit_grad_ckpt=False,
79
+ vit_ckpt_layer=0,
80
+ prompt='a picture of ',
81
+ ):
82
+ """
83
+ Args:
84
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
85
+ image_size (int): input image size
86
+ vit (str): model size of vision transformer
87
+ """
88
+ super().__init__()
89
+
90
+ self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer)
91
+ self.tokenizer = init_tokenizer()
92
+ med_config = BertConfig.from_json_file(med_config)
93
+ med_config.encoder_width = vision_width
94
+ self.text_decoder = BertLMHeadModel(config=med_config)
95
+
96
+ self.prompt = prompt
97
+ self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1
98
+
99
+ def forward(self, image, caption):
100
+
101
+ image_embeds = self.visual_encoder(image)
102
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
103
+
104
+ text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(
105
+ image.device)
106
+
107
+ text.input_ids[:, 0] = self.tokenizer.bos_token_id
108
+
109
+ decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
110
+ decoder_targets[:, :self.prompt_length] = -100
111
+
112
+ decoder_output = self.text_decoder(text.input_ids,
113
+ attention_mask=text.attention_mask,
114
+ encoder_hidden_states=image_embeds,
115
+ encoder_attention_mask=image_atts,
116
+ labels=decoder_targets,
117
+ return_dict=True,
118
+ )
119
+ loss_lm = decoder_output.loss
120
+
121
+ return loss_lm
122
+
123
+ def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9,
124
+ repetition_penalty=1.0):
125
+ image_embeds = self.visual_encoder(image)
126
+
127
+ if not sample:
128
+ image_embeds = image_embeds.repeat_interleave(num_beams, dim=0)
129
+
130
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
131
+ model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask": image_atts}
132
+
133
+ prompt = [self.prompt] * image.size(0)
134
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
135
+ input_ids[:, 0] = self.tokenizer.bos_token_id
136
+ input_ids = input_ids[:, :-1]
137
+
138
+ if sample:
139
+ # nucleus sampling
140
+ outputs = self.text_decoder.generate(input_ids=input_ids,
141
+ max_length=max_length,
142
+ min_length=min_length,
143
+ do_sample=True,
144
+ top_p=top_p,
145
+ num_return_sequences=1,
146
+ eos_token_id=self.tokenizer.sep_token_id,
147
+ pad_token_id=self.tokenizer.pad_token_id,
148
+ repetition_penalty=1.1,
149
+ **model_kwargs)
150
+ else:
151
+ # beam search
152
+ outputs = self.text_decoder.generate(input_ids=input_ids,
153
+ max_length=max_length,
154
+ min_length=min_length,
155
+ num_beams=num_beams,
156
+ eos_token_id=self.tokenizer.sep_token_id,
157
+ pad_token_id=self.tokenizer.pad_token_id,
158
+ repetition_penalty=repetition_penalty,
159
+ **model_kwargs)
160
+
161
+ captions = []
162
+ for output in outputs:
163
+ caption = self.tokenizer.decode(output, skip_special_tokens=True)
164
+ captions.append(caption[len(self.prompt):])
165
+ return captions
166
+
167
+
168
+ def blip_decoder(pretrained='', **kwargs):
169
+ model = BLIP_Decoder(**kwargs)
170
+ if pretrained:
171
+ model, msg = load_checkpoint(model, pretrained)
172
+ assert (len(msg.missing_keys) == 0)
173
+ return model
174
+
175
+
176
+ def blip_feature_extractor(pretrained='', **kwargs):
177
+ model = BLIP_Base(**kwargs)
178
+ if pretrained:
179
+ model, msg = load_checkpoint(model, pretrained)
180
+ assert (len(msg.missing_keys) == 0)
181
+ return model
182
+
183
+
184
+ def init_tokenizer():
185
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
186
+ tokenizer.add_special_tokens({'bos_token': '[DEC]'})
187
+ tokenizer.add_special_tokens({'additional_special_tokens': ['[ENC]']})
188
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
189
+ return tokenizer
190
+
191
+
192
+ def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
193
+ assert vit in ['base', 'large'], "vit parameter must be base or large"
194
+ assert use_grad_checkpointing is False, 'grad checkpointing is not supported yet'
195
+ if vit == 'base':
196
+ vision_width = 768
197
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
198
+ num_heads=12, use_grad_checkpointing=use_grad_checkpointing,
199
+ ckpt_layer=ckpt_layer,
200
+ drop_path_rate=0 or drop_path_rate
201
+ )
202
+ elif vit == 'large':
203
+ vision_width = 1024
204
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
205
+ num_heads=16, use_grad_checkpointing=use_grad_checkpointing,
206
+ ckpt_layer=ckpt_layer,
207
+ drop_path_rate=0.1 or drop_path_rate
208
+ )
209
+ return visual_encoder, vision_width
210
+
211
+
212
+ def is_url(url_or_filename):
213
+ parsed = urlparse(url_or_filename)
214
+ return parsed.scheme in ("http", "https")
215
+
216
+
217
+ def load_checkpoint(model, url_or_filename):
218
+ if is_url(url_or_filename):
219
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
220
+ checkpoint = torch.load(cached_file, map_location='cpu')
221
+ elif os.path.isfile(url_or_filename):
222
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
223
+ else:
224
+ raise RuntimeError('checkpoint url or path is invalid')
225
+
226
+ state_dict = checkpoint['model']
227
+
228
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],
229
+ model.visual_encoder)
230
+ if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
231
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
232
+ model.visual_encoder_m)
233
+ for key in model.state_dict().keys():
234
+ if key in state_dict.keys():
235
+ if state_dict[key].shape != model.state_dict()[key].shape:
236
+ del state_dict[key]
237
+
238
+ msg = model.load_state_dict(state_dict, strict=False)
239
+ print('load checkpoint from %s' % url_or_filename)
240
+ return model, msg
models/blip_override/med.py ADDED
@@ -0,0 +1,955 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ '''
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ import torch.nn.functional as F
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.file_utils import (
26
+ ModelOutput,
27
+ )
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ NextSentencePredictorOutput,
35
+ QuestionAnsweringModelOutput,
36
+ SequenceClassifierOutput,
37
+ TokenClassifierOutput,
38
+ )
39
+ from transformers.modeling_utils import (
40
+ PreTrainedModel,
41
+ apply_chunking_to_forward,
42
+ find_pruneable_heads_and_indices,
43
+ prune_linear_layer,
44
+ )
45
+ from transformers.utils import logging
46
+ from transformers.models.bert.configuration_bert import BertConfig
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ class BertEmbeddings(nn.Module):
53
+ """Construct the embeddings from word and position embeddings."""
54
+
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
58
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
59
+
60
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
61
+ # any TensorFlow checkpoint file
62
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
63
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
64
+
65
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
66
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
67
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
68
+
69
+ self.config = config
70
+
71
+ def forward(
72
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
73
+ ):
74
+ if input_ids is not None:
75
+ input_shape = input_ids.size()
76
+ else:
77
+ input_shape = inputs_embeds.size()[:-1]
78
+
79
+ seq_length = input_shape[1]
80
+
81
+ if position_ids is None:
82
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
83
+
84
+ if inputs_embeds is None:
85
+ inputs_embeds = self.word_embeddings(input_ids)
86
+
87
+ embeddings = inputs_embeds
88
+
89
+ if self.position_embedding_type == "absolute":
90
+ position_embeddings = self.position_embeddings(position_ids)
91
+ embeddings += position_embeddings
92
+ embeddings = self.LayerNorm(embeddings)
93
+ embeddings = self.dropout(embeddings)
94
+ return embeddings
95
+
96
+
97
+ class BertSelfAttention(nn.Module):
98
+ def __init__(self, config, is_cross_attention):
99
+ super().__init__()
100
+ self.config = config
101
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
102
+ raise ValueError(
103
+ "The hidden size (%d) is not a multiple of the number of attention "
104
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
105
+ )
106
+
107
+ self.num_attention_heads = config.num_attention_heads
108
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
109
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
110
+
111
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
112
+ if is_cross_attention:
113
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
114
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
115
+ else:
116
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
117
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
118
+
119
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
120
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
121
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
122
+ self.max_position_embeddings = config.max_position_embeddings
123
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
124
+ self.save_attention = False
125
+
126
+ def save_attn_gradients(self, attn_gradients):
127
+ self.attn_gradients = attn_gradients
128
+
129
+ def get_attn_gradients(self):
130
+ return self.attn_gradients
131
+
132
+ def save_attention_map(self, attention_map):
133
+ self.attention_map = attention_map
134
+
135
+ def get_attention_map(self):
136
+ return self.attention_map
137
+
138
+ def transpose_for_scores(self, x):
139
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
140
+ x = x.view(*new_x_shape)
141
+ return x.permute(0, 2, 1, 3)
142
+
143
+ def forward(
144
+ self,
145
+ hidden_states,
146
+ attention_mask=None,
147
+ head_mask=None,
148
+ encoder_hidden_states=None,
149
+ encoder_attention_mask=None,
150
+ past_key_value=None,
151
+ output_attentions=False,
152
+ ):
153
+ mixed_query_layer = self.query(hidden_states)
154
+
155
+ # If this is instantiated as a cross-attention module, the keys
156
+ # and values come from an encoder; the attention mask needs to be
157
+ # such that the encoder's padding tokens are not attended to.
158
+ is_cross_attention = encoder_hidden_states is not None
159
+
160
+ if is_cross_attention:
161
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
162
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
163
+ attention_mask = encoder_attention_mask
164
+ elif past_key_value is not None:
165
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
166
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
167
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
168
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
169
+ else:
170
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
171
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
172
+
173
+ query_layer = self.transpose_for_scores(mixed_query_layer)
174
+
175
+ past_key_value = (key_layer, value_layer)
176
+
177
+ # Take the dot product between "query" and "key" to get the raw attention scores.
178
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
179
+
180
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
181
+ seq_length = hidden_states.size()[1]
182
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
183
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
184
+ distance = position_ids_l - position_ids_r
185
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
186
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
187
+
188
+ if self.position_embedding_type == "relative_key":
189
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
190
+ attention_scores = attention_scores + relative_position_scores
191
+ elif self.position_embedding_type == "relative_key_query":
192
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
193
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
194
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
195
+
196
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
197
+ if attention_mask is not None:
198
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
199
+ attention_scores = attention_scores + attention_mask
200
+
201
+ # Normalize the attention scores to probabilities.
202
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
203
+
204
+ if is_cross_attention and self.save_attention:
205
+ self.save_attention_map(attention_probs)
206
+ attention_probs.register_hook(self.save_attn_gradients)
207
+
208
+ # This is actually dropping out entire tokens to attend to, which might
209
+ # seem a bit unusual, but is taken from the original Transformer paper.
210
+ attention_probs_dropped = self.dropout(attention_probs)
211
+
212
+ # Mask heads if we want to
213
+ if head_mask is not None:
214
+ attention_probs_dropped = attention_probs_dropped * head_mask
215
+
216
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
217
+
218
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
219
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
220
+ context_layer = context_layer.view(*new_context_layer_shape)
221
+
222
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
223
+
224
+ outputs = outputs + (past_key_value,)
225
+ return outputs
226
+
227
+
228
+ class BertSelfOutput(nn.Module):
229
+ def __init__(self, config):
230
+ super().__init__()
231
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
232
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
233
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
234
+
235
+ def forward(self, hidden_states, input_tensor):
236
+ hidden_states = self.dense(hidden_states)
237
+ hidden_states = self.dropout(hidden_states)
238
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
239
+ return hidden_states
240
+
241
+
242
+ class BertAttention(nn.Module):
243
+ def __init__(self, config, is_cross_attention=False):
244
+ super().__init__()
245
+ self.self = BertSelfAttention(config, is_cross_attention)
246
+ self.output = BertSelfOutput(config)
247
+ self.pruned_heads = set()
248
+
249
+ def prune_heads(self, heads):
250
+ if len(heads) == 0:
251
+ return
252
+ heads, index = find_pruneable_heads_and_indices(
253
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
254
+ )
255
+
256
+ # Prune linear layers
257
+ self.self.query = prune_linear_layer(self.self.query, index)
258
+ self.self.key = prune_linear_layer(self.self.key, index)
259
+ self.self.value = prune_linear_layer(self.self.value, index)
260
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
261
+
262
+ # Update hyper params and store pruned heads
263
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
264
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
265
+ self.pruned_heads = self.pruned_heads.union(heads)
266
+
267
+ def forward(
268
+ self,
269
+ hidden_states,
270
+ attention_mask=None,
271
+ head_mask=None,
272
+ encoder_hidden_states=None,
273
+ encoder_attention_mask=None,
274
+ past_key_value=None,
275
+ output_attentions=False,
276
+ ):
277
+ self_outputs = self.self(
278
+ hidden_states,
279
+ attention_mask,
280
+ head_mask,
281
+ encoder_hidden_states,
282
+ encoder_attention_mask,
283
+ past_key_value,
284
+ output_attentions,
285
+ )
286
+ attention_output = self.output(self_outputs[0], hidden_states)
287
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
288
+ return outputs
289
+
290
+
291
+ class BertIntermediate(nn.Module):
292
+ def __init__(self, config):
293
+ super().__init__()
294
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
295
+ if isinstance(config.hidden_act, str):
296
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
297
+ else:
298
+ self.intermediate_act_fn = config.hidden_act
299
+
300
+ def forward(self, hidden_states):
301
+ hidden_states = self.dense(hidden_states)
302
+ hidden_states = self.intermediate_act_fn(hidden_states)
303
+ return hidden_states
304
+
305
+
306
+ class BertOutput(nn.Module):
307
+ def __init__(self, config):
308
+ super().__init__()
309
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
310
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
311
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
312
+
313
+ def forward(self, hidden_states, input_tensor):
314
+ hidden_states = self.dense(hidden_states)
315
+ hidden_states = self.dropout(hidden_states)
316
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
317
+ return hidden_states
318
+
319
+
320
+ class BertLayer(nn.Module):
321
+ def __init__(self, config, layer_num):
322
+ super().__init__()
323
+ self.config = config
324
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
325
+ self.seq_len_dim = 1
326
+ self.attention = BertAttention(config)
327
+ self.layer_num = layer_num
328
+ if self.config.add_cross_attention:
329
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
330
+ self.intermediate = BertIntermediate(config)
331
+ self.output = BertOutput(config)
332
+
333
+ def forward(
334
+ self,
335
+ hidden_states,
336
+ attention_mask=None,
337
+ head_mask=None,
338
+ encoder_hidden_states=None,
339
+ encoder_attention_mask=None,
340
+ past_key_value=None,
341
+ output_attentions=False,
342
+ mode=None,
343
+ ):
344
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
345
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
346
+ self_attention_outputs = self.attention(
347
+ hidden_states,
348
+ attention_mask,
349
+ head_mask,
350
+ output_attentions=output_attentions,
351
+ past_key_value=self_attn_past_key_value,
352
+ )
353
+ attention_output = self_attention_outputs[0]
354
+
355
+ outputs = self_attention_outputs[1:-1]
356
+ present_key_value = self_attention_outputs[-1]
357
+
358
+ if mode=='multimodal':
359
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
360
+
361
+ cross_attention_outputs = self.crossattention(
362
+ attention_output,
363
+ attention_mask,
364
+ head_mask,
365
+ encoder_hidden_states,
366
+ encoder_attention_mask,
367
+ output_attentions=output_attentions,
368
+ )
369
+ attention_output = cross_attention_outputs[0]
370
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
371
+ layer_output = apply_chunking_to_forward(
372
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
373
+ )
374
+ outputs = (layer_output,) + outputs
375
+
376
+ outputs = outputs + (present_key_value,)
377
+
378
+ return outputs
379
+
380
+ def feed_forward_chunk(self, attention_output):
381
+ intermediate_output = self.intermediate(attention_output)
382
+ layer_output = self.output(intermediate_output, attention_output)
383
+ return layer_output
384
+
385
+
386
+ class BertEncoder(nn.Module):
387
+ def __init__(self, config):
388
+ super().__init__()
389
+ self.config = config
390
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
391
+ self.gradient_checkpointing = False
392
+
393
+ def forward(
394
+ self,
395
+ hidden_states,
396
+ attention_mask=None,
397
+ head_mask=None,
398
+ encoder_hidden_states=None,
399
+ encoder_attention_mask=None,
400
+ past_key_values=None,
401
+ use_cache=None,
402
+ output_attentions=False,
403
+ output_hidden_states=False,
404
+ return_dict=True,
405
+ mode='multimodal',
406
+ ):
407
+ all_hidden_states = () if output_hidden_states else None
408
+ all_self_attentions = () if output_attentions else None
409
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
410
+
411
+ next_decoder_cache = () if use_cache else None
412
+
413
+ for i in range(self.config.num_hidden_layers):
414
+ layer_module = self.layer[i]
415
+ if output_hidden_states:
416
+ all_hidden_states = all_hidden_states + (hidden_states,)
417
+
418
+ layer_head_mask = head_mask[i] if head_mask is not None else None
419
+ past_key_value = past_key_values[i] if past_key_values is not None else None
420
+
421
+ if self.gradient_checkpointing and self.training:
422
+
423
+ if use_cache:
424
+ logger.warn(
425
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
426
+ )
427
+ use_cache = False
428
+
429
+ def create_custom_forward(module):
430
+ def custom_forward(*inputs):
431
+ return module(*inputs, past_key_value, output_attentions)
432
+
433
+ return custom_forward
434
+
435
+ layer_outputs = torch.utils.checkpoint.checkpoint(
436
+ create_custom_forward(layer_module),
437
+ hidden_states,
438
+ attention_mask,
439
+ layer_head_mask,
440
+ encoder_hidden_states,
441
+ encoder_attention_mask,
442
+ mode=mode,
443
+ )
444
+ else:
445
+ layer_outputs = layer_module(
446
+ hidden_states,
447
+ attention_mask,
448
+ layer_head_mask,
449
+ encoder_hidden_states,
450
+ encoder_attention_mask,
451
+ past_key_value,
452
+ output_attentions,
453
+ mode=mode,
454
+ )
455
+
456
+ hidden_states = layer_outputs[0]
457
+ if use_cache:
458
+ next_decoder_cache += (layer_outputs[-1],)
459
+ if output_attentions:
460
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
461
+
462
+ if output_hidden_states:
463
+ all_hidden_states = all_hidden_states + (hidden_states,)
464
+
465
+ if not return_dict:
466
+ return tuple(
467
+ v
468
+ for v in [
469
+ hidden_states,
470
+ next_decoder_cache,
471
+ all_hidden_states,
472
+ all_self_attentions,
473
+ all_cross_attentions,
474
+ ]
475
+ if v is not None
476
+ )
477
+ return BaseModelOutputWithPastAndCrossAttentions(
478
+ last_hidden_state=hidden_states,
479
+ past_key_values=next_decoder_cache,
480
+ hidden_states=all_hidden_states,
481
+ attentions=all_self_attentions,
482
+ cross_attentions=all_cross_attentions,
483
+ )
484
+
485
+
486
+ class BertPooler(nn.Module):
487
+ def __init__(self, config):
488
+ super().__init__()
489
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
490
+ self.activation = nn.Tanh()
491
+
492
+ def forward(self, hidden_states):
493
+ # We "pool" the model by simply taking the hidden state corresponding
494
+ # to the first token.
495
+ first_token_tensor = hidden_states[:, 0]
496
+ pooled_output = self.dense(first_token_tensor)
497
+ pooled_output = self.activation(pooled_output)
498
+ return pooled_output
499
+
500
+
501
+ class BertPredictionHeadTransform(nn.Module):
502
+ def __init__(self, config):
503
+ super().__init__()
504
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
505
+ if isinstance(config.hidden_act, str):
506
+ self.transform_act_fn = ACT2FN[config.hidden_act]
507
+ else:
508
+ self.transform_act_fn = config.hidden_act
509
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
510
+
511
+ def forward(self, hidden_states):
512
+ hidden_states = self.dense(hidden_states)
513
+ hidden_states = self.transform_act_fn(hidden_states)
514
+ hidden_states = self.LayerNorm(hidden_states)
515
+ return hidden_states
516
+
517
+
518
+ class BertLMPredictionHead(nn.Module):
519
+ def __init__(self, config):
520
+ super().__init__()
521
+ self.transform = BertPredictionHeadTransform(config)
522
+
523
+ # The output weights are the same as the input embeddings, but there is
524
+ # an output-only bias for each token.
525
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
526
+
527
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
528
+
529
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
530
+ self.decoder.bias = self.bias
531
+
532
+ def forward(self, hidden_states):
533
+ hidden_states = self.transform(hidden_states)
534
+ hidden_states = self.decoder(hidden_states)
535
+ return hidden_states
536
+
537
+
538
+ class BertOnlyMLMHead(nn.Module):
539
+ def __init__(self, config):
540
+ super().__init__()
541
+ self.predictions = BertLMPredictionHead(config)
542
+
543
+ def forward(self, sequence_output):
544
+ prediction_scores = self.predictions(sequence_output)
545
+ return prediction_scores
546
+
547
+
548
+ class BertPreTrainedModel(PreTrainedModel):
549
+ """
550
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
551
+ models.
552
+ """
553
+
554
+ config_class = BertConfig
555
+ base_model_prefix = "bert"
556
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
557
+
558
+ def _init_weights(self, module):
559
+ """ Initialize the weights """
560
+ if isinstance(module, (nn.Linear, nn.Embedding)):
561
+ # Slightly different from the TF version which uses truncated_normal for initialization
562
+ # cf https://github.com/pytorch/pytorch/pull/5617
563
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
564
+ elif isinstance(module, nn.LayerNorm):
565
+ module.bias.data.zero_()
566
+ module.weight.data.fill_(1.0)
567
+ if isinstance(module, nn.Linear) and module.bias is not None:
568
+ module.bias.data.zero_()
569
+
570
+
571
+ class BertModel(BertPreTrainedModel):
572
+ """
573
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
574
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
575
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
576
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
577
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
578
+ input to the forward pass.
579
+ """
580
+
581
+ def __init__(self, config, add_pooling_layer=True):
582
+ super().__init__(config)
583
+ self.config = config
584
+
585
+ self.embeddings = BertEmbeddings(config)
586
+
587
+ self.encoder = BertEncoder(config)
588
+
589
+ self.pooler = BertPooler(config) if add_pooling_layer else None
590
+
591
+ self.init_weights()
592
+
593
+
594
+ def get_input_embeddings(self):
595
+ return self.embeddings.word_embeddings
596
+
597
+ def set_input_embeddings(self, value):
598
+ self.embeddings.word_embeddings = value
599
+
600
+ def _prune_heads(self, heads_to_prune):
601
+ """
602
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
603
+ class PreTrainedModel
604
+ """
605
+ for layer, heads in heads_to_prune.items():
606
+ self.encoder.layer[layer].attention.prune_heads(heads)
607
+
608
+
609
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
610
+ """
611
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
612
+
613
+ Arguments:
614
+ attention_mask (:obj:`torch.Tensor`):
615
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
616
+ input_shape (:obj:`Tuple[int]`):
617
+ The shape of the input to the model.
618
+ device: (:obj:`torch.device`):
619
+ The device of the input to the model.
620
+
621
+ Returns:
622
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
623
+ """
624
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
625
+ # ourselves in which case we just need to make it broadcastable to all heads.
626
+ if attention_mask.dim() == 3:
627
+ extended_attention_mask = attention_mask[:, None, :, :]
628
+ elif attention_mask.dim() == 2:
629
+ # Provided a padding mask of dimensions [batch_size, seq_length]
630
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
631
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
632
+ if is_decoder:
633
+ batch_size, seq_length = input_shape
634
+
635
+ seq_ids = torch.arange(seq_length, device=device)
636
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
637
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
638
+ # causal and attention masks must have same type with pytorch version < 1.3
639
+ causal_mask = causal_mask.to(attention_mask.dtype)
640
+
641
+ if causal_mask.shape[1] < attention_mask.shape[1]:
642
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
643
+ causal_mask = torch.cat(
644
+ [
645
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
646
+ causal_mask,
647
+ ],
648
+ axis=-1,
649
+ )
650
+
651
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
652
+ else:
653
+ extended_attention_mask = attention_mask[:, None, None, :]
654
+ else:
655
+ raise ValueError(
656
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
657
+ input_shape, attention_mask.shape
658
+ )
659
+ )
660
+
661
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
662
+ # masked positions, this operation will create a tensor which is 0.0 for
663
+ # positions we want to attend and -10000.0 for masked positions.
664
+ # Since we are adding it to the raw scores before the softmax, this is
665
+ # effectively the same as removing these entirely.
666
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
667
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
668
+ return extended_attention_mask
669
+
670
+ def forward(
671
+ self,
672
+ input_ids=None,
673
+ attention_mask=None,
674
+ position_ids=None,
675
+ head_mask=None,
676
+ inputs_embeds=None,
677
+ encoder_embeds=None,
678
+ encoder_hidden_states=None,
679
+ encoder_attention_mask=None,
680
+ past_key_values=None,
681
+ use_cache=None,
682
+ output_attentions=None,
683
+ output_hidden_states=None,
684
+ return_dict=None,
685
+ is_decoder=False,
686
+ mode='multimodal',
687
+ ):
688
+ r"""
689
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
690
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
691
+ the model is configured as a decoder.
692
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
693
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
694
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
695
+ - 1 for tokens that are **not masked**,
696
+ - 0 for tokens that are **masked**.
697
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
698
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
699
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
700
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
701
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
702
+ use_cache (:obj:`bool`, `optional`):
703
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
704
+ decoding (see :obj:`past_key_values`).
705
+ """
706
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
707
+ output_hidden_states = (
708
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
709
+ )
710
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
711
+
712
+ if is_decoder:
713
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
714
+ else:
715
+ use_cache = False
716
+
717
+ if input_ids is not None and inputs_embeds is not None:
718
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
719
+ elif input_ids is not None:
720
+ input_shape = input_ids.size()
721
+ batch_size, seq_length = input_shape
722
+ device = input_ids.device
723
+ elif inputs_embeds is not None:
724
+ input_shape = inputs_embeds.size()[:-1]
725
+ batch_size, seq_length = input_shape
726
+ device = inputs_embeds.device
727
+ elif encoder_embeds is not None:
728
+ input_shape = encoder_embeds.size()[:-1]
729
+ batch_size, seq_length = input_shape
730
+ device = encoder_embeds.device
731
+ else:
732
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
733
+
734
+ # past_key_values_length
735
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
736
+
737
+ if attention_mask is None:
738
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
739
+
740
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
741
+ # ourselves in which case we just need to make it broadcastable to all heads.
742
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
743
+ device, is_decoder)
744
+
745
+ # If a 2D or 3D attention mask is provided for the cross-attention
746
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
747
+ if encoder_hidden_states is not None:
748
+ if type(encoder_hidden_states) == list:
749
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
750
+ else:
751
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
752
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
753
+
754
+ if type(encoder_attention_mask) == list:
755
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
756
+ elif encoder_attention_mask is None:
757
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
758
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
759
+ else:
760
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
761
+ else:
762
+ encoder_extended_attention_mask = None
763
+
764
+ # Prepare head mask if needed
765
+ # 1.0 in head_mask indicate we keep the head
766
+ # attention_probs has shape bsz x n_heads x N x N
767
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
768
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
769
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
770
+
771
+ if encoder_embeds is None:
772
+ embedding_output = self.embeddings(
773
+ input_ids=input_ids,
774
+ position_ids=position_ids,
775
+ inputs_embeds=inputs_embeds,
776
+ past_key_values_length=past_key_values_length,
777
+ )
778
+ else:
779
+ embedding_output = encoder_embeds
780
+
781
+ encoder_outputs = self.encoder(
782
+ embedding_output,
783
+ attention_mask=extended_attention_mask,
784
+ head_mask=head_mask,
785
+ encoder_hidden_states=encoder_hidden_states,
786
+ encoder_attention_mask=encoder_extended_attention_mask,
787
+ past_key_values=past_key_values,
788
+ use_cache=use_cache,
789
+ output_attentions=output_attentions,
790
+ output_hidden_states=output_hidden_states,
791
+ return_dict=return_dict,
792
+ mode=mode,
793
+ )
794
+ sequence_output = encoder_outputs[0]
795
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
796
+
797
+ if not return_dict:
798
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
799
+
800
+ return BaseModelOutputWithPoolingAndCrossAttentions(
801
+ last_hidden_state=sequence_output,
802
+ pooler_output=pooled_output,
803
+ past_key_values=encoder_outputs.past_key_values,
804
+ hidden_states=encoder_outputs.hidden_states,
805
+ attentions=encoder_outputs.attentions,
806
+ cross_attentions=encoder_outputs.cross_attentions,
807
+ )
808
+
809
+
810
+
811
+ class BertLMHeadModel(BertPreTrainedModel):
812
+
813
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
814
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
815
+
816
+ def __init__(self, config):
817
+ super().__init__(config)
818
+
819
+ self.bert = BertModel(config, add_pooling_layer=False)
820
+ self.cls = BertOnlyMLMHead(config)
821
+
822
+ self.init_weights()
823
+
824
+ def get_output_embeddings(self):
825
+ return self.cls.predictions.decoder
826
+
827
+ def set_output_embeddings(self, new_embeddings):
828
+ self.cls.predictions.decoder = new_embeddings
829
+
830
+ def forward(
831
+ self,
832
+ input_ids=None,
833
+ attention_mask=None,
834
+ position_ids=None,
835
+ head_mask=None,
836
+ inputs_embeds=None,
837
+ encoder_hidden_states=None,
838
+ encoder_attention_mask=None,
839
+ labels=None,
840
+ past_key_values=None,
841
+ use_cache=None,
842
+ output_attentions=None,
843
+ output_hidden_states=None,
844
+ return_dict=None,
845
+ return_logits=False,
846
+ is_decoder=True,
847
+ reduction='mean',
848
+ mode='multimodal',
849
+ ):
850
+ r"""
851
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
852
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
853
+ the model is configured as a decoder.
854
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
855
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
856
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
857
+ - 1 for tokens that are **not masked**,
858
+ - 0 for tokens that are **masked**.
859
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
860
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
861
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
862
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
863
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
864
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
865
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
866
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
867
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
868
+ use_cache (:obj:`bool`, `optional`):
869
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
870
+ decoding (see :obj:`past_key_values`).
871
+ Returns:
872
+ Example::
873
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
874
+ >>> import torch
875
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
876
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
877
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
878
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
879
+ >>> outputs = model(**inputs)
880
+ >>> prediction_logits = outputs.logits
881
+ """
882
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
883
+ if labels is not None:
884
+ use_cache = False
885
+
886
+ outputs = self.bert(
887
+ input_ids,
888
+ attention_mask=attention_mask,
889
+ position_ids=position_ids,
890
+ head_mask=head_mask,
891
+ inputs_embeds=inputs_embeds,
892
+ encoder_hidden_states=encoder_hidden_states,
893
+ encoder_attention_mask=encoder_attention_mask,
894
+ past_key_values=past_key_values,
895
+ use_cache=use_cache,
896
+ output_attentions=output_attentions,
897
+ output_hidden_states=output_hidden_states,
898
+ return_dict=return_dict,
899
+ is_decoder=is_decoder,
900
+ mode=mode,
901
+ )
902
+
903
+ sequence_output = outputs[0]
904
+ prediction_scores = self.cls(sequence_output)
905
+
906
+ if return_logits:
907
+ return prediction_scores[:, :-1, :].contiguous()
908
+
909
+ lm_loss = None
910
+ if labels is not None:
911
+ # we are doing next-token prediction; shift prediction scores and input ids by one
912
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
913
+ labels = labels[:, 1:].contiguous()
914
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
915
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
916
+ if reduction=='none':
917
+ lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
918
+
919
+ if not return_dict:
920
+ output = (prediction_scores,) + outputs[2:]
921
+ return ((lm_loss,) + output) if lm_loss is not None else output
922
+
923
+ return CausalLMOutputWithCrossAttentions(
924
+ loss=lm_loss,
925
+ logits=prediction_scores,
926
+ past_key_values=outputs.past_key_values,
927
+ hidden_states=outputs.hidden_states,
928
+ attentions=outputs.attentions,
929
+ cross_attentions=outputs.cross_attentions,
930
+ )
931
+
932
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
933
+ input_shape = input_ids.shape
934
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
935
+ if attention_mask is None:
936
+ attention_mask = input_ids.new_ones(input_shape)
937
+
938
+ # cut decoder_input_ids if past is used
939
+ if past is not None:
940
+ input_ids = input_ids[:, -1:]
941
+
942
+ return {
943
+ "input_ids": input_ids,
944
+ "attention_mask": attention_mask,
945
+ "past_key_values": past,
946
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
947
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
948
+ "is_decoder": True,
949
+ }
950
+
951
+ def _reorder_cache(self, past, beam_idx):
952
+ reordered_past = ()
953
+ for layer_past in past:
954
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
955
+ return reordered_past
models/blip_override/med_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 12,
15
+ "num_hidden_layers": 12,
16
+ "pad_token_id": 0,
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 30524,
19
+ "encoder_width": 768,
20
+ "add_cross_attention": true
21
+ }
models/blip_override/vit.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on timm code base
8
+ * https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ '''
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from functools import partial
15
+
16
+ from timm.models.vision_transformer import _cfg, PatchEmbed
17
+ from timm.models.registry import register_model
18
+ from timm.models.layers import trunc_normal_, DropPath
19
+ from timm.models.helpers import named_apply, adapt_input_conv
20
+
21
+
22
+ class Mlp(nn.Module):
23
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
24
+ """
25
+
26
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x):
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
42
+
43
+
44
+ class Attention(nn.Module):
45
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
50
+ self.scale = qk_scale or head_dim ** -0.5
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+ self.attn_gradients = None
56
+ self.attention_map = None
57
+
58
+ def save_attn_gradients(self, attn_gradients):
59
+ self.attn_gradients = attn_gradients
60
+
61
+ def get_attn_gradients(self):
62
+ return self.attn_gradients
63
+
64
+ def save_attention_map(self, attention_map):
65
+ self.attention_map = attention_map
66
+
67
+ def get_attention_map(self):
68
+ return self.attention_map
69
+
70
+ def forward(self, x, register_hook=False):
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
73
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
74
+
75
+ attn = (q @ k.transpose(-2, -1)) * self.scale
76
+ attn = attn.softmax(dim=-1)
77
+ attn = self.attn_drop(attn)
78
+
79
+ if register_hook:
80
+ self.save_attention_map(attn)
81
+ attn.register_hook(self.save_attn_gradients)
82
+
83
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
84
+ x = self.proj(x)
85
+ x = self.proj_drop(x)
86
+ return x
87
+
88
+
89
+ class Block(nn.Module):
90
+
91
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
92
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
93
+ super().__init__()
94
+ self.norm1 = norm_layer(dim)
95
+ self.attn = Attention(
96
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
97
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
98
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
99
+ self.norm2 = norm_layer(dim)
100
+ mlp_hidden_dim = int(dim * mlp_ratio)
101
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
102
+
103
+ def forward(self, x, register_hook=False):
104
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
105
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
106
+ return x
107
+
108
+
109
+ class VisionTransformer(nn.Module):
110
+ """ Vision Transformer
111
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
112
+ https://arxiv.org/abs/2010.11929
113
+ """
114
+
115
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
116
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
117
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
118
+ use_grad_checkpointing=False, ckpt_layer=0):
119
+ """
120
+ Args:
121
+ img_size (int, tuple): input image size
122
+ patch_size (int, tuple): patch size
123
+ in_chans (int): number of input channels
124
+ num_classes (int): number of classes for classification head
125
+ embed_dim (int): embedding dimension
126
+ depth (int): depth of transformer
127
+ num_heads (int): number of attention heads
128
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
129
+ qkv_bias (bool): enable bias for qkv if True
130
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
131
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
132
+ drop_rate (float): dropout rate
133
+ attn_drop_rate (float): attention dropout rate
134
+ drop_path_rate (float): stochastic depth rate
135
+ norm_layer: (nn.Module): normalization layer
136
+ """
137
+ super().__init__()
138
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
139
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
140
+
141
+ self.patch_embed = PatchEmbed(
142
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
143
+
144
+ num_patches = self.patch_embed.num_patches
145
+
146
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
147
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
148
+ self.pos_drop = nn.Dropout(p=drop_rate)
149
+
150
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
151
+ self.blocks = nn.ModuleList([
152
+ Block(
153
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
154
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
155
+ use_grad_checkpointing=(use_grad_checkpointing and i >= depth - ckpt_layer)
156
+ )
157
+ for i in range(depth)])
158
+ self.norm = norm_layer(embed_dim)
159
+
160
+ trunc_normal_(self.pos_embed, std=.02)
161
+ trunc_normal_(self.cls_token, std=.02)
162
+ self.apply(self._init_weights)
163
+
164
+ def _init_weights(self, m):
165
+ if isinstance(m, nn.Linear):
166
+ trunc_normal_(m.weight, std=.02)
167
+ if isinstance(m, nn.Linear) and m.bias is not None:
168
+ nn.init.constant_(m.bias, 0)
169
+ elif isinstance(m, nn.LayerNorm):
170
+ nn.init.constant_(m.bias, 0)
171
+ nn.init.constant_(m.weight, 1.0)
172
+
173
+ @torch.jit.ignore
174
+ def no_weight_decay(self):
175
+ return {'pos_embed', 'cls_token'}
176
+
177
+ def forward(self, x, register_blk=-1):
178
+ B = x.shape[0]
179
+ x = self.patch_embed(x)
180
+
181
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
182
+ x = torch.cat((cls_tokens, x), dim=1)
183
+
184
+ x = x + self.pos_embed[:, :x.size(1), :]
185
+ x = self.pos_drop(x)
186
+
187
+ for i, blk in enumerate(self.blocks):
188
+ x = blk(x, register_blk == i)
189
+ x = self.norm(x)
190
+
191
+ return x
192
+
193
+ @torch.jit.ignore()
194
+ def load_pretrained(self, checkpoint_path, prefix=''):
195
+ _load_weights(self, checkpoint_path, prefix)
196
+
197
+
198
+ @torch.no_grad()
199
+ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
200
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
201
+ """
202
+ import numpy as np
203
+
204
+ def _n2p(w, t=True):
205
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
206
+ w = w.flatten()
207
+ if t:
208
+ if w.ndim == 4:
209
+ w = w.transpose([3, 2, 0, 1])
210
+ elif w.ndim == 3:
211
+ w = w.transpose([2, 0, 1])
212
+ elif w.ndim == 2:
213
+ w = w.transpose([1, 0])
214
+ return torch.from_numpy(w)
215
+
216
+ w = np.load(checkpoint_path)
217
+ if not prefix and 'opt/target/embedding/kernel' in w:
218
+ prefix = 'opt/target/'
219
+
220
+ if hasattr(model.patch_embed, 'backbone'):
221
+ # hybrid
222
+ backbone = model.patch_embed.backbone
223
+ stem_only = not hasattr(backbone, 'stem')
224
+ stem = backbone if stem_only else backbone.stem
225
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
226
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
227
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
228
+ if not stem_only:
229
+ for i, stage in enumerate(backbone.stages):
230
+ for j, block in enumerate(stage.blocks):
231
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
232
+ for r in range(3):
233
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
234
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
235
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
236
+ if block.downsample is not None:
237
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
238
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
239
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
240
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
241
+ else:
242
+ embed_conv_w = adapt_input_conv(
243
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
244
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
245
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
246
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
247
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
248
+ if pos_embed_w.shape != model.pos_embed.shape:
249
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
250
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
251
+ model.pos_embed.copy_(pos_embed_w)
252
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
253
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
254
+ # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
255
+ # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
256
+ # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
257
+ # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
258
+ # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
259
+ # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
260
+ for i, block in enumerate(model.blocks.children()):
261
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
262
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
263
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
264
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
265
+ block.attn.qkv.weight.copy_(torch.cat([
266
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
267
+ block.attn.qkv.bias.copy_(torch.cat([
268
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
269
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
270
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
271
+ for r in range(2):
272
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
273
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
274
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
275
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
276
+
277
+
278
+ def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
279
+ # interpolate position embedding
280
+ embedding_size = pos_embed_checkpoint.shape[-1]
281
+ num_patches = visual_encoder.patch_embed.num_patches
282
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
283
+ # height (== width) for the checkpoint position embedding
284
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
285
+ # height (== width) for the new position embedding
286
+ new_size = int(num_patches ** 0.5)
287
+
288
+ if orig_size != new_size:
289
+ # class_token and dist_token are kept unchanged
290
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
291
+ # only the position tokens are interpolated
292
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
293
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
294
+ pos_tokens = torch.nn.functional.interpolate(
295
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
296
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
297
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
298
+ print('reshape position embedding from %d to %d' % (orig_size ** 2, new_size ** 2))
299
+
300
+ return new_pos_embed
301
+ else:
302
+ return pos_embed_checkpoint
models/diffusers_override/attention.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from dataclasses import dataclass
16
+ from typing import Optional
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.modeling_utils import ModelMixin
24
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
25
+ from diffusers.utils import BaseOutput
26
+ from diffusers.utils.import_utils import is_xformers_available
27
+
28
+
29
+ @dataclass
30
+ class Transformer2DModelOutput(BaseOutput):
31
+ """
32
+ Args:
33
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
34
+ Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
35
+ for the unnoised latent pixels.
36
+ """
37
+
38
+ sample: torch.FloatTensor
39
+
40
+
41
+ if is_xformers_available():
42
+ import xformers
43
+ import xformers.ops
44
+ else:
45
+ xformers = None
46
+
47
+
48
+ class Transformer2DModel(ModelMixin, ConfigMixin):
49
+ """
50
+ Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
51
+ embeddings) inputs.
52
+
53
+ When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
54
+ transformer action. Finally, reshape to image.
55
+
56
+ When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
57
+ embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
58
+ classes of unnoised image.
59
+
60
+ Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
61
+ image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
62
+
63
+ Parameters:
64
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
65
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
66
+ in_channels (`int`, *optional*):
67
+ Pass if the input is continuous. The number of channels in the input and output.
68
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
69
+ dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
70
+ cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
71
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
72
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
73
+ `ImagePositionalEmbeddings`.
74
+ num_vector_embeds (`int`, *optional*):
75
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
76
+ Includes the class for the masked latent pixel.
77
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
78
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
79
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
80
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
81
+ up to but not more than steps than `num_embeds_ada_norm`.
82
+ attention_bias (`bool`, *optional*):
83
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
84
+ """
85
+
86
+ @register_to_config
87
+ def __init__(
88
+ self,
89
+ num_attention_heads: int = 16,
90
+ attention_head_dim: int = 88,
91
+ in_channels: Optional[int] = None,
92
+ num_layers: int = 1,
93
+ dropout: float = 0.0,
94
+ norm_num_groups: int = 32,
95
+ cross_attention_dim: Optional[int] = None,
96
+ attention_bias: bool = False,
97
+ sample_size: Optional[int] = None,
98
+ num_vector_embeds: Optional[int] = None,
99
+ activation_fn: str = "geglu",
100
+ num_embeds_ada_norm: Optional[int] = None,
101
+ ):
102
+ super().__init__()
103
+ self.num_attention_heads = num_attention_heads
104
+ self.attention_head_dim = attention_head_dim
105
+ inner_dim = num_attention_heads * attention_head_dim
106
+
107
+ # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
108
+ # Define whether input is continuous or discrete depending on configuration
109
+ self.is_input_continuous = in_channels is not None
110
+ self.is_input_vectorized = num_vector_embeds is not None
111
+
112
+ if self.is_input_continuous and self.is_input_vectorized:
113
+ raise ValueError(
114
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
115
+ " sure that either `in_channels` or `num_vector_embeds` is None."
116
+ )
117
+ elif not self.is_input_continuous and not self.is_input_vectorized:
118
+ raise ValueError(
119
+ f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
120
+ " sure that either `in_channels` or `num_vector_embeds` is not None."
121
+ )
122
+
123
+ # 2. Define input layers
124
+ if self.is_input_continuous:
125
+ self.in_channels = in_channels
126
+
127
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
128
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
129
+ elif self.is_input_vectorized:
130
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
131
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
132
+
133
+ self.height = sample_size
134
+ self.width = sample_size
135
+ self.num_vector_embeds = num_vector_embeds
136
+ self.num_latent_pixels = self.height * self.width
137
+
138
+ self.latent_image_embedding = ImagePositionalEmbeddings(
139
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
140
+ )
141
+
142
+ # 3. Define transformers blocks
143
+ self.transformer_blocks = nn.ModuleList(
144
+ [
145
+ BasicTransformerBlock(
146
+ inner_dim,
147
+ num_attention_heads,
148
+ attention_head_dim,
149
+ dropout=dropout,
150
+ cross_attention_dim=cross_attention_dim,
151
+ activation_fn=activation_fn,
152
+ num_embeds_ada_norm=num_embeds_ada_norm,
153
+ attention_bias=attention_bias,
154
+ )
155
+ for d in range(num_layers)
156
+ ]
157
+ )
158
+
159
+ # 4. Define output layers
160
+ if self.is_input_continuous:
161
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
162
+ elif self.is_input_vectorized:
163
+ self.norm_out = nn.LayerNorm(inner_dim)
164
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
165
+
166
+ def _set_attention_slice(self, slice_size):
167
+ for block in self.transformer_blocks:
168
+ block._set_attention_slice(slice_size)
169
+
170
+ def forward(self, hidden_states, encoder_hidden_states=None, encoder_attention_mask=None, timestep=None,
171
+ return_dict: bool = True):
172
+ """
173
+ Args:
174
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
175
+ When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
176
+ hidden_states
177
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
178
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
179
+ self-attention.
180
+ encoder_attention_mask ( `torch.LongTensor` of shape `(batch size, context)`, *optional*):
181
+ Attention mask for cross attention layer.
182
+ timestep ( `torch.long`, *optional*):
183
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
184
+ return_dict (`bool`, *optional*, defaults to `True`):
185
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
186
+
187
+ Returns:
188
+ [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
189
+ if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
190
+ tensor.
191
+ """
192
+ # 1. Input
193
+ if self.is_input_continuous:
194
+ batch, channel, height, weight = hidden_states.shape
195
+ residual = hidden_states
196
+ hidden_states = self.norm(hidden_states)
197
+ hidden_states = self.proj_in(hidden_states)
198
+ inner_dim = hidden_states.shape[1]
199
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
200
+ elif self.is_input_vectorized:
201
+ hidden_states = self.latent_image_embedding(hidden_states)
202
+
203
+ # 2. Blocks
204
+ for block in self.transformer_blocks:
205
+ hidden_states = block(hidden_states, context=encoder_hidden_states, mask=encoder_attention_mask,
206
+ timestep=timestep)
207
+
208
+ # 3. Output
209
+ if self.is_input_continuous:
210
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
211
+ hidden_states = self.proj_out(hidden_states)
212
+ output = hidden_states + residual
213
+ elif self.is_input_vectorized:
214
+ hidden_states = self.norm_out(hidden_states)
215
+ logits = self.out(hidden_states)
216
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
217
+ logits = logits.permute(0, 2, 1)
218
+
219
+ # log(p(x_0))
220
+ output = F.log_softmax(logits.double(), dim=1).float()
221
+
222
+ if not return_dict:
223
+ return (output,)
224
+
225
+ return Transformer2DModelOutput(sample=output)
226
+
227
+ def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
228
+ for block in self.transformer_blocks:
229
+ block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
230
+
231
+
232
+ class AttentionBlock(nn.Module):
233
+ """
234
+ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
235
+ to the N-d case.
236
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
237
+ Uses three q, k, v linear layers to compute attention.
238
+
239
+ Parameters:
240
+ channels (`int`): The number of channels in the input and output.
241
+ num_head_channels (`int`, *optional*):
242
+ The number of channels in each head. If None, then `num_heads` = 1.
243
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
244
+ rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
245
+ eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
246
+ """
247
+
248
+ def __init__(
249
+ self,
250
+ channels: int,
251
+ num_head_channels: Optional[int] = None,
252
+ norm_num_groups: int = 32,
253
+ rescale_output_factor: float = 1.0,
254
+ eps: float = 1e-5,
255
+ ):
256
+ super().__init__()
257
+ self.channels = channels
258
+
259
+ self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
260
+ self.num_head_size = num_head_channels
261
+ self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
262
+
263
+ # define q,k,v as linear layers
264
+ self.query = nn.Linear(channels, channels)
265
+ self.key = nn.Linear(channels, channels)
266
+ self.value = nn.Linear(channels, channels)
267
+
268
+ self.rescale_output_factor = rescale_output_factor
269
+ self.proj_attn = nn.Linear(channels, channels, 1)
270
+
271
+ def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
272
+ new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
273
+ # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
274
+ new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
275
+ return new_projection
276
+
277
+ def forward(self, hidden_states):
278
+ residual = hidden_states
279
+ batch, channel, height, width = hidden_states.shape
280
+
281
+ # norm
282
+ hidden_states = self.group_norm(hidden_states)
283
+
284
+ hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
285
+
286
+ # proj to q, k, v
287
+ query_proj = self.query(hidden_states)
288
+ key_proj = self.key(hidden_states)
289
+ value_proj = self.value(hidden_states)
290
+
291
+ # transpose
292
+ query_states = self.transpose_for_scores(query_proj)
293
+ key_states = self.transpose_for_scores(key_proj)
294
+ value_states = self.transpose_for_scores(value_proj)
295
+
296
+ # get scores
297
+ scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
298
+ attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm
299
+ attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
300
+
301
+ # compute attention output
302
+ hidden_states = torch.matmul(attention_probs, value_states)
303
+
304
+ hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
305
+ new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
306
+ hidden_states = hidden_states.view(new_hidden_states_shape)
307
+
308
+ # compute next hidden_states
309
+ hidden_states = self.proj_attn(hidden_states)
310
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
311
+
312
+ # res connect and rescale
313
+ hidden_states = (hidden_states + residual) / self.rescale_output_factor
314
+ return hidden_states
315
+
316
+
317
+ class BasicTransformerBlock(nn.Module):
318
+ r"""
319
+ A basic Transformer block.
320
+
321
+ Parameters:
322
+ dim (`int`): The number of channels in the input and output.
323
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
324
+ attention_head_dim (`int`): The number of channels in each head.
325
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
326
+ cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention.
327
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
328
+ num_embeds_ada_norm (:
329
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
330
+ attention_bias (:
331
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
332
+ """
333
+
334
+ def __init__(
335
+ self,
336
+ dim: int,
337
+ num_attention_heads: int,
338
+ attention_head_dim: int,
339
+ dropout=0.0,
340
+ cross_attention_dim: Optional[int] = None,
341
+ activation_fn: str = "geglu",
342
+ num_embeds_ada_norm: Optional[int] = None,
343
+ attention_bias: bool = False,
344
+ ):
345
+ super().__init__()
346
+ self.attn1 = CrossAttention(
347
+ query_dim=dim,
348
+ heads=num_attention_heads,
349
+ dim_head=attention_head_dim,
350
+ dropout=dropout,
351
+ bias=attention_bias,
352
+ ) # is a self-attention
353
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
354
+ self.attn2 = CrossAttention(
355
+ query_dim=dim,
356
+ cross_attention_dim=cross_attention_dim,
357
+ heads=num_attention_heads,
358
+ dim_head=attention_head_dim,
359
+ dropout=dropout,
360
+ bias=attention_bias,
361
+ ) # is self-attn if context is none
362
+
363
+ # layer norms
364
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
365
+ if self.use_ada_layer_norm:
366
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
367
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
368
+ else:
369
+ self.norm1 = nn.LayerNorm(dim)
370
+ self.norm2 = nn.LayerNorm(dim)
371
+ self.norm3 = nn.LayerNorm(dim)
372
+
373
+ def _set_attention_slice(self, slice_size):
374
+ self.attn1._slice_size = slice_size
375
+ self.attn2._slice_size = slice_size
376
+
377
+ def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
378
+ if not is_xformers_available():
379
+ print("Here is how to install it")
380
+ raise ModuleNotFoundError(
381
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
382
+ " xformers",
383
+ name="xformers",
384
+ )
385
+ elif not torch.cuda.is_available():
386
+ raise ValueError(
387
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
388
+ " available for GPU "
389
+ )
390
+ else:
391
+ try:
392
+ # Make sure we can run the memory efficient attention
393
+ _ = xformers.ops.memory_efficient_attention(
394
+ torch.randn((1, 2, 40), device="cuda"),
395
+ torch.randn((1, 2, 40), device="cuda"),
396
+ torch.randn((1, 2, 40), device="cuda"),
397
+ )
398
+ except Exception as e:
399
+ raise e
400
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
401
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
402
+
403
+ def forward(self, hidden_states, context=None, mask=None, timestep=None):
404
+ # 1. Self-Attention
405
+ norm_hidden_states = (
406
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
407
+ )
408
+ hidden_states = self.attn1(norm_hidden_states) + hidden_states
409
+
410
+ # 2. Cross-Attention
411
+ norm_hidden_states = (
412
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
413
+ )
414
+ hidden_states = self.attn2(norm_hidden_states, context=context, mask=mask) + hidden_states
415
+
416
+ # 3. Feed-forward
417
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
418
+
419
+ return hidden_states
420
+
421
+
422
+ class CrossAttention(nn.Module):
423
+ r"""
424
+ A cross attention layer.
425
+
426
+ Parameters:
427
+ query_dim (`int`): The number of channels in the query.
428
+ cross_attention_dim (`int`, *optional*):
429
+ The number of channels in the context. If not given, defaults to `query_dim`.
430
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
431
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
432
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
433
+ bias (`bool`, *optional*, defaults to False):
434
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
435
+ """
436
+
437
+ def __init__(
438
+ self,
439
+ query_dim: int,
440
+ cross_attention_dim: Optional[int] = None,
441
+ heads: int = 8,
442
+ dim_head: int = 64,
443
+ dropout: float = 0.0,
444
+ bias=False,
445
+ ):
446
+ super().__init__()
447
+ inner_dim = dim_head * heads
448
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
449
+
450
+ self.scale = dim_head ** -0.5
451
+ self.heads = heads
452
+ # for slice_size > 0 the attention score computation
453
+ # is split across the batch axis to save memory
454
+ # You can set slice_size with `set_attention_slice`
455
+ self._slice_size = None
456
+ self._use_memory_efficient_attention_xformers = False
457
+
458
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
459
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
460
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
461
+
462
+ self.to_out = nn.ModuleList([])
463
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
464
+ self.to_out.append(nn.Dropout(dropout))
465
+
466
+ def reshape_heads_to_batch_dim(self, tensor):
467
+ batch_size, seq_len, dim = tensor.shape
468
+ head_size = self.heads
469
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
470
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
471
+ return tensor
472
+
473
+ def reshape_batch_dim_to_heads(self, tensor):
474
+ batch_size, seq_len, dim = tensor.shape
475
+ head_size = self.heads
476
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
477
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
478
+ return tensor
479
+
480
+ def forward(self, hidden_states, context=None, mask=None):
481
+ batch_size, sequence_length, _ = hidden_states.shape
482
+
483
+ query = self.to_q(hidden_states)
484
+ context = context if context is not None else hidden_states
485
+ key = self.to_k(context)
486
+ value = self.to_v(context)
487
+
488
+ dim = query.shape[-1]
489
+
490
+ query = self.reshape_heads_to_batch_dim(query)
491
+ key = self.reshape_heads_to_batch_dim(key)
492
+ value = self.reshape_heads_to_batch_dim(value)
493
+ mask = mask.repeat_interleave(self.heads, dim=0).unsqueeze(1) if mask is not None else None
494
+
495
+ # attention, what we cannot get enough of
496
+ if self._use_memory_efficient_attention_xformers:
497
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value)
498
+ else:
499
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
500
+ hidden_states = self._attention(query, key, value, mask)
501
+ else:
502
+ assert mask is None, "masking is not supported for sliced attention"
503
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
504
+
505
+ # linear proj
506
+ hidden_states = self.to_out[0](hidden_states)
507
+ # dropout
508
+ hidden_states = self.to_out[1](hidden_states)
509
+ return hidden_states
510
+
511
+ def _attention(self, query, key, value, mask):
512
+ # TODO: use baddbmm for better performance
513
+ if query.device.type == "mps":
514
+ # Better performance on mps (~20-25%)
515
+ attention_scores = torch.einsum("b i d, b j d -> b i j", query, key) * self.scale
516
+ else:
517
+ attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
518
+ attention_scores = attention_scores.masked_fill(mask.expand(attention_scores.shape), value=float("-inf")) \
519
+ if mask is not None else attention_scores
520
+ attention_probs = attention_scores.softmax(dim=-1)
521
+ # compute attention output
522
+
523
+ if query.device.type == "mps":
524
+ hidden_states = torch.einsum("b i j, b j d -> b i d", attention_probs, value)
525
+ else:
526
+ hidden_states = torch.matmul(attention_probs, value)
527
+
528
+ # reshape hidden_states
529
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
530
+ return hidden_states
531
+
532
+ def _sliced_attention(self, query, key, value, sequence_length, dim):
533
+ batch_size_attention = query.shape[0]
534
+ hidden_states = torch.zeros(
535
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
536
+ )
537
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
538
+ for i in range(hidden_states.shape[0] // slice_size):
539
+ start_idx = i * slice_size
540
+ end_idx = (i + 1) * slice_size
541
+ if query.device.type == "mps":
542
+ # Better performance on mps (~20-25%)
543
+ attn_slice = (
544
+ torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx])
545
+ * self.scale
546
+ )
547
+ else:
548
+ attn_slice = (
549
+ torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
550
+ ) # TODO: use baddbmm for better performance
551
+ attn_slice = attn_slice.softmax(dim=-1)
552
+ if query.device.type == "mps":
553
+ attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
554
+ else:
555
+ attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
556
+
557
+ hidden_states[start_idx:end_idx] = attn_slice
558
+
559
+ # reshape hidden_states
560
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
561
+ return hidden_states
562
+
563
+ def _memory_efficient_attention_xformers(self, query, key, value):
564
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
565
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
566
+ return hidden_states
567
+
568
+
569
+ class FeedForward(nn.Module):
570
+ r"""
571
+ A feed-forward layer.
572
+
573
+ Parameters:
574
+ dim (`int`): The number of channels in the input.
575
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
576
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
577
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
578
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
579
+ """
580
+
581
+ def __init__(
582
+ self,
583
+ dim: int,
584
+ dim_out: Optional[int] = None,
585
+ mult: int = 4,
586
+ dropout: float = 0.0,
587
+ activation_fn: str = "geglu",
588
+ ):
589
+ super().__init__()
590
+ inner_dim = int(dim * mult)
591
+ dim_out = dim_out if dim_out is not None else dim
592
+
593
+ if activation_fn == "geglu":
594
+ geglu = GEGLU(dim, inner_dim)
595
+ elif activation_fn == "geglu-approximate":
596
+ geglu = ApproximateGELU(dim, inner_dim)
597
+
598
+ self.net = nn.ModuleList([])
599
+ # project in
600
+ self.net.append(geglu)
601
+ # project dropout
602
+ self.net.append(nn.Dropout(dropout))
603
+ # project out
604
+ self.net.append(nn.Linear(inner_dim, dim_out))
605
+
606
+ def forward(self, hidden_states):
607
+ for module in self.net:
608
+ hidden_states = module(hidden_states)
609
+ return hidden_states
610
+
611
+
612
+ # feedforward
613
+ class GEGLU(nn.Module):
614
+ r"""
615
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
616
+
617
+ Parameters:
618
+ dim_in (`int`): The number of channels in the input.
619
+ dim_out (`int`): The number of channels in the output.
620
+ """
621
+
622
+ def __init__(self, dim_in: int, dim_out: int):
623
+ super().__init__()
624
+ self.proj = nn.Linear(dim_in, dim_out * 2)
625
+
626
+ def gelu(self, gate):
627
+ if gate.device.type != "mps":
628
+ return F.gelu(gate)
629
+ # mps: gelu is not implemented for float16
630
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
631
+
632
+ def forward(self, hidden_states):
633
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
634
+ return hidden_states * self.gelu(gate)
635
+
636
+
637
+ class ApproximateGELU(nn.Module):
638
+ """
639
+ The approximate form of Gaussian Error Linear Unit (GELU)
640
+
641
+ For more details, see section 2: https://arxiv.org/abs/1606.08415
642
+ """
643
+
644
+ def __init__(self, dim_in: int, dim_out: int):
645
+ super().__init__()
646
+ self.proj = nn.Linear(dim_in, dim_out)
647
+
648
+ def forward(self, x):
649
+ x = self.proj(x)
650
+ return x * torch.sigmoid(1.702 * x)
651
+
652
+
653
+ class AdaLayerNorm(nn.Module):
654
+ """
655
+ Norm layer modified to incorporate timestep embeddings.
656
+ """
657
+
658
+ def __init__(self, embedding_dim, num_embeddings):
659
+ super().__init__()
660
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
661
+ self.silu = nn.SiLU()
662
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
663
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
664
+
665
+ def forward(self, x, timestep):
666
+ emb = self.linear(self.silu(self.emb(timestep)))
667
+ scale, shift = torch.chunk(emb, 2)
668
+ x = self.norm(x) * (1 + scale) + shift
669
+ return x
models/diffusers_override/unet_2d_blocks.py ADDED
@@ -0,0 +1,1602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+ import torch
16
+ from torch import nn
17
+
18
+ from .attention import AttentionBlock, Transformer2DModel
19
+ from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
20
+
21
+
22
+ def get_down_block(
23
+ down_block_type,
24
+ num_layers,
25
+ in_channels,
26
+ out_channels,
27
+ temb_channels,
28
+ add_downsample,
29
+ resnet_eps,
30
+ resnet_act_fn,
31
+ attn_num_head_channels,
32
+ resnet_groups=None,
33
+ cross_attention_dim=None,
34
+ downsample_padding=None,
35
+ ):
36
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
37
+ if down_block_type == "DownBlock2D":
38
+ return DownBlock2D(
39
+ num_layers=num_layers,
40
+ in_channels=in_channels,
41
+ out_channels=out_channels,
42
+ temb_channels=temb_channels,
43
+ add_downsample=add_downsample,
44
+ resnet_eps=resnet_eps,
45
+ resnet_act_fn=resnet_act_fn,
46
+ resnet_groups=resnet_groups,
47
+ downsample_padding=downsample_padding,
48
+ )
49
+ elif down_block_type == "AttnDownBlock2D":
50
+ return AttnDownBlock2D(
51
+ num_layers=num_layers,
52
+ in_channels=in_channels,
53
+ out_channels=out_channels,
54
+ temb_channels=temb_channels,
55
+ add_downsample=add_downsample,
56
+ resnet_eps=resnet_eps,
57
+ resnet_act_fn=resnet_act_fn,
58
+ resnet_groups=resnet_groups,
59
+ downsample_padding=downsample_padding,
60
+ attn_num_head_channels=attn_num_head_channels,
61
+ )
62
+ elif down_block_type == "CrossAttnDownBlock2D":
63
+ if cross_attention_dim is None:
64
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
65
+ return CrossAttnDownBlock2D(
66
+ num_layers=num_layers,
67
+ in_channels=in_channels,
68
+ out_channels=out_channels,
69
+ temb_channels=temb_channels,
70
+ add_downsample=add_downsample,
71
+ resnet_eps=resnet_eps,
72
+ resnet_act_fn=resnet_act_fn,
73
+ resnet_groups=resnet_groups,
74
+ downsample_padding=downsample_padding,
75
+ cross_attention_dim=cross_attention_dim,
76
+ attn_num_head_channels=attn_num_head_channels,
77
+ )
78
+ elif down_block_type == "SkipDownBlock2D":
79
+ return SkipDownBlock2D(
80
+ num_layers=num_layers,
81
+ in_channels=in_channels,
82
+ out_channels=out_channels,
83
+ temb_channels=temb_channels,
84
+ add_downsample=add_downsample,
85
+ resnet_eps=resnet_eps,
86
+ resnet_act_fn=resnet_act_fn,
87
+ downsample_padding=downsample_padding,
88
+ )
89
+ elif down_block_type == "AttnSkipDownBlock2D":
90
+ return AttnSkipDownBlock2D(
91
+ num_layers=num_layers,
92
+ in_channels=in_channels,
93
+ out_channels=out_channels,
94
+ temb_channels=temb_channels,
95
+ add_downsample=add_downsample,
96
+ resnet_eps=resnet_eps,
97
+ resnet_act_fn=resnet_act_fn,
98
+ downsample_padding=downsample_padding,
99
+ attn_num_head_channels=attn_num_head_channels,
100
+ )
101
+ elif down_block_type == "DownEncoderBlock2D":
102
+ return DownEncoderBlock2D(
103
+ num_layers=num_layers,
104
+ in_channels=in_channels,
105
+ out_channels=out_channels,
106
+ add_downsample=add_downsample,
107
+ resnet_eps=resnet_eps,
108
+ resnet_act_fn=resnet_act_fn,
109
+ resnet_groups=resnet_groups,
110
+ downsample_padding=downsample_padding,
111
+ )
112
+ elif down_block_type == "AttnDownEncoderBlock2D":
113
+ return AttnDownEncoderBlock2D(
114
+ num_layers=num_layers,
115
+ in_channels=in_channels,
116
+ out_channels=out_channels,
117
+ add_downsample=add_downsample,
118
+ resnet_eps=resnet_eps,
119
+ resnet_act_fn=resnet_act_fn,
120
+ resnet_groups=resnet_groups,
121
+ downsample_padding=downsample_padding,
122
+ attn_num_head_channels=attn_num_head_channels,
123
+ )
124
+ raise ValueError(f"{down_block_type} does not exist.")
125
+
126
+
127
+ def get_up_block(
128
+ up_block_type,
129
+ num_layers,
130
+ in_channels,
131
+ out_channels,
132
+ prev_output_channel,
133
+ temb_channels,
134
+ add_upsample,
135
+ resnet_eps,
136
+ resnet_act_fn,
137
+ attn_num_head_channels,
138
+ resnet_groups=None,
139
+ cross_attention_dim=None,
140
+ ):
141
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
142
+ if up_block_type == "UpBlock2D":
143
+ return UpBlock2D(
144
+ num_layers=num_layers,
145
+ in_channels=in_channels,
146
+ out_channels=out_channels,
147
+ prev_output_channel=prev_output_channel,
148
+ temb_channels=temb_channels,
149
+ add_upsample=add_upsample,
150
+ resnet_eps=resnet_eps,
151
+ resnet_act_fn=resnet_act_fn,
152
+ resnet_groups=resnet_groups,
153
+ )
154
+ elif up_block_type == "CrossAttnUpBlock2D":
155
+ if cross_attention_dim is None:
156
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
157
+ return CrossAttnUpBlock2D(
158
+ num_layers=num_layers,
159
+ in_channels=in_channels,
160
+ out_channels=out_channels,
161
+ prev_output_channel=prev_output_channel,
162
+ temb_channels=temb_channels,
163
+ add_upsample=add_upsample,
164
+ resnet_eps=resnet_eps,
165
+ resnet_act_fn=resnet_act_fn,
166
+ resnet_groups=resnet_groups,
167
+ cross_attention_dim=cross_attention_dim,
168
+ attn_num_head_channels=attn_num_head_channels,
169
+ )
170
+ elif up_block_type == "AttnUpBlock2D":
171
+ return AttnUpBlock2D(
172
+ num_layers=num_layers,
173
+ in_channels=in_channels,
174
+ out_channels=out_channels,
175
+ prev_output_channel=prev_output_channel,
176
+ temb_channels=temb_channels,
177
+ add_upsample=add_upsample,
178
+ resnet_eps=resnet_eps,
179
+ resnet_act_fn=resnet_act_fn,
180
+ resnet_groups=resnet_groups,
181
+ attn_num_head_channels=attn_num_head_channels,
182
+ )
183
+ elif up_block_type == "SkipUpBlock2D":
184
+ return SkipUpBlock2D(
185
+ num_layers=num_layers,
186
+ in_channels=in_channels,
187
+ out_channels=out_channels,
188
+ prev_output_channel=prev_output_channel,
189
+ temb_channels=temb_channels,
190
+ add_upsample=add_upsample,
191
+ resnet_eps=resnet_eps,
192
+ resnet_act_fn=resnet_act_fn,
193
+ )
194
+ elif up_block_type == "AttnSkipUpBlock2D":
195
+ return AttnSkipUpBlock2D(
196
+ num_layers=num_layers,
197
+ in_channels=in_channels,
198
+ out_channels=out_channels,
199
+ prev_output_channel=prev_output_channel,
200
+ temb_channels=temb_channels,
201
+ add_upsample=add_upsample,
202
+ resnet_eps=resnet_eps,
203
+ resnet_act_fn=resnet_act_fn,
204
+ attn_num_head_channels=attn_num_head_channels,
205
+ )
206
+ elif up_block_type == "UpDecoderBlock2D":
207
+ return UpDecoderBlock2D(
208
+ num_layers=num_layers,
209
+ in_channels=in_channels,
210
+ out_channels=out_channels,
211
+ add_upsample=add_upsample,
212
+ resnet_eps=resnet_eps,
213
+ resnet_act_fn=resnet_act_fn,
214
+ resnet_groups=resnet_groups,
215
+ )
216
+ elif up_block_type == "AttnUpDecoderBlock2D":
217
+ return AttnUpDecoderBlock2D(
218
+ num_layers=num_layers,
219
+ in_channels=in_channels,
220
+ out_channels=out_channels,
221
+ add_upsample=add_upsample,
222
+ resnet_eps=resnet_eps,
223
+ resnet_act_fn=resnet_act_fn,
224
+ resnet_groups=resnet_groups,
225
+ attn_num_head_channels=attn_num_head_channels,
226
+ )
227
+ raise ValueError(f"{up_block_type} does not exist.")
228
+
229
+
230
+ class UNetMidBlock2D(nn.Module):
231
+ def __init__(
232
+ self,
233
+ in_channels: int,
234
+ temb_channels: int,
235
+ dropout: float = 0.0,
236
+ num_layers: int = 1,
237
+ resnet_eps: float = 1e-6,
238
+ resnet_time_scale_shift: str = "default",
239
+ resnet_act_fn: str = "swish",
240
+ resnet_groups: int = 32,
241
+ resnet_pre_norm: bool = True,
242
+ attn_num_head_channels=1,
243
+ attention_type="default",
244
+ output_scale_factor=1.0,
245
+ **kwargs,
246
+ ):
247
+ super().__init__()
248
+
249
+ self.attention_type = attention_type
250
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
251
+
252
+ # there is always at least one resnet
253
+ resnets = [
254
+ ResnetBlock2D(
255
+ in_channels=in_channels,
256
+ out_channels=in_channels,
257
+ temb_channels=temb_channels,
258
+ eps=resnet_eps,
259
+ groups=resnet_groups,
260
+ dropout=dropout,
261
+ time_embedding_norm=resnet_time_scale_shift,
262
+ non_linearity=resnet_act_fn,
263
+ output_scale_factor=output_scale_factor,
264
+ pre_norm=resnet_pre_norm,
265
+ )
266
+ ]
267
+ attentions = []
268
+
269
+ for _ in range(num_layers):
270
+ attentions.append(
271
+ AttentionBlock(
272
+ in_channels,
273
+ num_head_channels=attn_num_head_channels,
274
+ rescale_output_factor=output_scale_factor,
275
+ eps=resnet_eps,
276
+ norm_num_groups=resnet_groups,
277
+ )
278
+ )
279
+ resnets.append(
280
+ ResnetBlock2D(
281
+ in_channels=in_channels,
282
+ out_channels=in_channels,
283
+ temb_channels=temb_channels,
284
+ eps=resnet_eps,
285
+ groups=resnet_groups,
286
+ dropout=dropout,
287
+ time_embedding_norm=resnet_time_scale_shift,
288
+ non_linearity=resnet_act_fn,
289
+ output_scale_factor=output_scale_factor,
290
+ pre_norm=resnet_pre_norm,
291
+ )
292
+ )
293
+
294
+ self.attentions = nn.ModuleList(attentions)
295
+ self.resnets = nn.ModuleList(resnets)
296
+
297
+ def forward(self, hidden_states, temb=None, encoder_states=None):
298
+ hidden_states = self.resnets[0](hidden_states, temb)
299
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
300
+ if self.attention_type == "default":
301
+ hidden_states = attn(hidden_states)
302
+ else:
303
+ hidden_states = attn(hidden_states, encoder_states)
304
+ hidden_states = resnet(hidden_states, temb)
305
+
306
+ return hidden_states
307
+
308
+
309
+ class UNetMidBlock2DCrossAttn(nn.Module):
310
+ def __init__(
311
+ self,
312
+ in_channels: int,
313
+ temb_channels: int,
314
+ dropout: float = 0.0,
315
+ num_layers: int = 1,
316
+ resnet_eps: float = 1e-6,
317
+ resnet_time_scale_shift: str = "default",
318
+ resnet_act_fn: str = "swish",
319
+ resnet_groups: int = 32,
320
+ resnet_pre_norm: bool = True,
321
+ attn_num_head_channels=1,
322
+ attention_type="default",
323
+ output_scale_factor=1.0,
324
+ cross_attention_dim=1280,
325
+ **kwargs,
326
+ ):
327
+ super().__init__()
328
+
329
+ self.attention_type = attention_type
330
+ self.attn_num_head_channels = attn_num_head_channels
331
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
332
+
333
+ # there is always at least one resnet
334
+ resnets = [
335
+ ResnetBlock2D(
336
+ in_channels=in_channels,
337
+ out_channels=in_channels,
338
+ temb_channels=temb_channels,
339
+ eps=resnet_eps,
340
+ groups=resnet_groups,
341
+ dropout=dropout,
342
+ time_embedding_norm=resnet_time_scale_shift,
343
+ non_linearity=resnet_act_fn,
344
+ output_scale_factor=output_scale_factor,
345
+ pre_norm=resnet_pre_norm,
346
+ )
347
+ ]
348
+ attentions = []
349
+
350
+ for _ in range(num_layers):
351
+ attentions.append(
352
+ Transformer2DModel(
353
+ attn_num_head_channels,
354
+ in_channels // attn_num_head_channels,
355
+ in_channels=in_channels,
356
+ num_layers=1,
357
+ cross_attention_dim=cross_attention_dim,
358
+ norm_num_groups=resnet_groups,
359
+ )
360
+ )
361
+ resnets.append(
362
+ ResnetBlock2D(
363
+ in_channels=in_channels,
364
+ out_channels=in_channels,
365
+ temb_channels=temb_channels,
366
+ eps=resnet_eps,
367
+ groups=resnet_groups,
368
+ dropout=dropout,
369
+ time_embedding_norm=resnet_time_scale_shift,
370
+ non_linearity=resnet_act_fn,
371
+ output_scale_factor=output_scale_factor,
372
+ pre_norm=resnet_pre_norm,
373
+ )
374
+ )
375
+
376
+ self.attentions = nn.ModuleList(attentions)
377
+ self.resnets = nn.ModuleList(resnets)
378
+
379
+ def set_attention_slice(self, slice_size):
380
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
381
+ raise ValueError(
382
+ f"Make sure slice_size {slice_size} is a divisor of "
383
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
384
+ )
385
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
386
+ raise ValueError(
387
+ f"Chunk_size {slice_size} has to be smaller or equal to "
388
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
389
+ )
390
+
391
+ for attn in self.attentions:
392
+ attn._set_attention_slice(slice_size)
393
+
394
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
395
+ for attn in self.attentions:
396
+ attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
397
+
398
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, encoder_attention_mask=None):
399
+ hidden_states = self.resnets[0](hidden_states, temb)
400
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
401
+ hidden_states = attn(hidden_states, encoder_hidden_states, encoder_attention_mask).sample
402
+ hidden_states = resnet(hidden_states, temb)
403
+
404
+ return hidden_states
405
+
406
+
407
+ class AttnDownBlock2D(nn.Module):
408
+ def __init__(
409
+ self,
410
+ in_channels: int,
411
+ out_channels: int,
412
+ temb_channels: int,
413
+ dropout: float = 0.0,
414
+ num_layers: int = 1,
415
+ resnet_eps: float = 1e-6,
416
+ resnet_time_scale_shift: str = "default",
417
+ resnet_act_fn: str = "swish",
418
+ resnet_groups: int = 32,
419
+ resnet_pre_norm: bool = True,
420
+ attn_num_head_channels=1,
421
+ attention_type="default",
422
+ output_scale_factor=1.0,
423
+ downsample_padding=1,
424
+ add_downsample=True,
425
+ ):
426
+ super().__init__()
427
+ resnets = []
428
+ attentions = []
429
+
430
+ self.attention_type = attention_type
431
+
432
+ for i in range(num_layers):
433
+ in_channels = in_channels if i == 0 else out_channels
434
+ resnets.append(
435
+ ResnetBlock2D(
436
+ in_channels=in_channels,
437
+ out_channels=out_channels,
438
+ temb_channels=temb_channels,
439
+ eps=resnet_eps,
440
+ groups=resnet_groups,
441
+ dropout=dropout,
442
+ time_embedding_norm=resnet_time_scale_shift,
443
+ non_linearity=resnet_act_fn,
444
+ output_scale_factor=output_scale_factor,
445
+ pre_norm=resnet_pre_norm,
446
+ )
447
+ )
448
+ attentions.append(
449
+ AttentionBlock(
450
+ out_channels,
451
+ num_head_channels=attn_num_head_channels,
452
+ rescale_output_factor=output_scale_factor,
453
+ eps=resnet_eps,
454
+ norm_num_groups=resnet_groups,
455
+ )
456
+ )
457
+
458
+ self.attentions = nn.ModuleList(attentions)
459
+ self.resnets = nn.ModuleList(resnets)
460
+
461
+ if add_downsample:
462
+ self.downsamplers = nn.ModuleList(
463
+ [
464
+ Downsample2D(
465
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
466
+ )
467
+ ]
468
+ )
469
+ else:
470
+ self.downsamplers = None
471
+
472
+ def forward(self, hidden_states, temb=None):
473
+ output_states = ()
474
+
475
+ for resnet, attn in zip(self.resnets, self.attentions):
476
+ hidden_states = resnet(hidden_states, temb)
477
+ hidden_states = attn(hidden_states)
478
+ output_states += (hidden_states,)
479
+
480
+ if self.downsamplers is not None:
481
+ for downsampler in self.downsamplers:
482
+ hidden_states = downsampler(hidden_states)
483
+
484
+ output_states += (hidden_states,)
485
+
486
+ return hidden_states, output_states
487
+
488
+
489
+ class CrossAttnDownBlock2D(nn.Module):
490
+ def __init__(
491
+ self,
492
+ in_channels: int,
493
+ out_channels: int,
494
+ temb_channels: int,
495
+ dropout: float = 0.0,
496
+ num_layers: int = 1,
497
+ resnet_eps: float = 1e-6,
498
+ resnet_time_scale_shift: str = "default",
499
+ resnet_act_fn: str = "swish",
500
+ resnet_groups: int = 32,
501
+ resnet_pre_norm: bool = True,
502
+ attn_num_head_channels=1,
503
+ cross_attention_dim=1280,
504
+ attention_type="default",
505
+ output_scale_factor=1.0,
506
+ downsample_padding=1,
507
+ add_downsample=True,
508
+ ):
509
+ super().__init__()
510
+ resnets = []
511
+ attentions = []
512
+
513
+ self.attention_type = attention_type
514
+ self.attn_num_head_channels = attn_num_head_channels
515
+
516
+ for i in range(num_layers):
517
+ in_channels = in_channels if i == 0 else out_channels
518
+ resnets.append(
519
+ ResnetBlock2D(
520
+ in_channels=in_channels,
521
+ out_channels=out_channels,
522
+ temb_channels=temb_channels,
523
+ eps=resnet_eps,
524
+ groups=resnet_groups,
525
+ dropout=dropout,
526
+ time_embedding_norm=resnet_time_scale_shift,
527
+ non_linearity=resnet_act_fn,
528
+ output_scale_factor=output_scale_factor,
529
+ pre_norm=resnet_pre_norm,
530
+ )
531
+ )
532
+ attentions.append(
533
+ Transformer2DModel(
534
+ attn_num_head_channels,
535
+ out_channels // attn_num_head_channels,
536
+ in_channels=out_channels,
537
+ num_layers=1,
538
+ cross_attention_dim=cross_attention_dim,
539
+ norm_num_groups=resnet_groups,
540
+ )
541
+ )
542
+ self.attentions = nn.ModuleList(attentions)
543
+ self.resnets = nn.ModuleList(resnets)
544
+
545
+ if add_downsample:
546
+ self.downsamplers = nn.ModuleList(
547
+ [
548
+ Downsample2D(
549
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
550
+ )
551
+ ]
552
+ )
553
+ else:
554
+ self.downsamplers = None
555
+
556
+ self.gradient_checkpointing = False
557
+
558
+ def set_attention_slice(self, slice_size):
559
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
560
+ raise ValueError(
561
+ f"Make sure slice_size {slice_size} is a divisor of "
562
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
563
+ )
564
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
565
+ raise ValueError(
566
+ f"Chunk_size {slice_size} has to be smaller or equal to "
567
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
568
+ )
569
+
570
+ for attn in self.attentions:
571
+ attn._set_attention_slice(slice_size)
572
+
573
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
574
+ for attn in self.attentions:
575
+ attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
576
+
577
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, encoder_attention_mask=None):
578
+ output_states = ()
579
+
580
+ for resnet, attn in zip(self.resnets, self.attentions):
581
+ if self.training and self.gradient_checkpointing:
582
+
583
+ def create_custom_forward(module, return_dict=None):
584
+ def custom_forward(*inputs):
585
+ if return_dict is not None:
586
+ return module(*inputs, return_dict=return_dict)
587
+ else:
588
+ return module(*inputs)
589
+
590
+ return custom_forward
591
+
592
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
593
+ hidden_states = torch.utils.checkpoint.checkpoint(
594
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states,
595
+ encoder_attention_mask
596
+ )[0]
597
+ else:
598
+ hidden_states = resnet(hidden_states, temb)
599
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states,
600
+ encoder_attention_mask=encoder_attention_mask).sample
601
+
602
+ output_states += (hidden_states,)
603
+
604
+ if self.downsamplers is not None:
605
+ for downsampler in self.downsamplers:
606
+ hidden_states = downsampler(hidden_states)
607
+
608
+ output_states += (hidden_states,)
609
+
610
+ return hidden_states, output_states
611
+
612
+
613
+ class DownBlock2D(nn.Module):
614
+ def __init__(
615
+ self,
616
+ in_channels: int,
617
+ out_channels: int,
618
+ temb_channels: int,
619
+ dropout: float = 0.0,
620
+ num_layers: int = 1,
621
+ resnet_eps: float = 1e-6,
622
+ resnet_time_scale_shift: str = "default",
623
+ resnet_act_fn: str = "swish",
624
+ resnet_groups: int = 32,
625
+ resnet_pre_norm: bool = True,
626
+ output_scale_factor=1.0,
627
+ add_downsample=True,
628
+ downsample_padding=1,
629
+ ):
630
+ super().__init__()
631
+ resnets = []
632
+
633
+ for i in range(num_layers):
634
+ in_channels = in_channels if i == 0 else out_channels
635
+ resnets.append(
636
+ ResnetBlock2D(
637
+ in_channels=in_channels,
638
+ out_channels=out_channels,
639
+ temb_channels=temb_channels,
640
+ eps=resnet_eps,
641
+ groups=resnet_groups,
642
+ dropout=dropout,
643
+ time_embedding_norm=resnet_time_scale_shift,
644
+ non_linearity=resnet_act_fn,
645
+ output_scale_factor=output_scale_factor,
646
+ pre_norm=resnet_pre_norm,
647
+ )
648
+ )
649
+
650
+ self.resnets = nn.ModuleList(resnets)
651
+
652
+ if add_downsample:
653
+ self.downsamplers = nn.ModuleList(
654
+ [
655
+ Downsample2D(
656
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
657
+ )
658
+ ]
659
+ )
660
+ else:
661
+ self.downsamplers = None
662
+
663
+ self.gradient_checkpointing = False
664
+
665
+ def forward(self, hidden_states, temb=None):
666
+ output_states = ()
667
+
668
+ for resnet in self.resnets:
669
+ if self.training and self.gradient_checkpointing:
670
+
671
+ def create_custom_forward(module):
672
+ def custom_forward(*inputs):
673
+ return module(*inputs)
674
+
675
+ return custom_forward
676
+
677
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
678
+ else:
679
+ hidden_states = resnet(hidden_states, temb)
680
+
681
+ output_states += (hidden_states,)
682
+
683
+ if self.downsamplers is not None:
684
+ for downsampler in self.downsamplers:
685
+ hidden_states = downsampler(hidden_states)
686
+
687
+ output_states += (hidden_states,)
688
+
689
+ return hidden_states, output_states
690
+
691
+
692
+ class DownEncoderBlock2D(nn.Module):
693
+ def __init__(
694
+ self,
695
+ in_channels: int,
696
+ out_channels: int,
697
+ dropout: float = 0.0,
698
+ num_layers: int = 1,
699
+ resnet_eps: float = 1e-6,
700
+ resnet_time_scale_shift: str = "default",
701
+ resnet_act_fn: str = "swish",
702
+ resnet_groups: int = 32,
703
+ resnet_pre_norm: bool = True,
704
+ output_scale_factor=1.0,
705
+ add_downsample=True,
706
+ downsample_padding=1,
707
+ ):
708
+ super().__init__()
709
+ resnets = []
710
+
711
+ for i in range(num_layers):
712
+ in_channels = in_channels if i == 0 else out_channels
713
+ resnets.append(
714
+ ResnetBlock2D(
715
+ in_channels=in_channels,
716
+ out_channels=out_channels,
717
+ temb_channels=None,
718
+ eps=resnet_eps,
719
+ groups=resnet_groups,
720
+ dropout=dropout,
721
+ time_embedding_norm=resnet_time_scale_shift,
722
+ non_linearity=resnet_act_fn,
723
+ output_scale_factor=output_scale_factor,
724
+ pre_norm=resnet_pre_norm,
725
+ )
726
+ )
727
+
728
+ self.resnets = nn.ModuleList(resnets)
729
+
730
+ if add_downsample:
731
+ self.downsamplers = nn.ModuleList(
732
+ [
733
+ Downsample2D(
734
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
735
+ )
736
+ ]
737
+ )
738
+ else:
739
+ self.downsamplers = None
740
+
741
+ def forward(self, hidden_states):
742
+ for resnet in self.resnets:
743
+ hidden_states = resnet(hidden_states, temb=None)
744
+
745
+ if self.downsamplers is not None:
746
+ for downsampler in self.downsamplers:
747
+ hidden_states = downsampler(hidden_states)
748
+
749
+ return hidden_states
750
+
751
+
752
+ class AttnDownEncoderBlock2D(nn.Module):
753
+ def __init__(
754
+ self,
755
+ in_channels: int,
756
+ out_channels: int,
757
+ dropout: float = 0.0,
758
+ num_layers: int = 1,
759
+ resnet_eps: float = 1e-6,
760
+ resnet_time_scale_shift: str = "default",
761
+ resnet_act_fn: str = "swish",
762
+ resnet_groups: int = 32,
763
+ resnet_pre_norm: bool = True,
764
+ attn_num_head_channels=1,
765
+ output_scale_factor=1.0,
766
+ add_downsample=True,
767
+ downsample_padding=1,
768
+ ):
769
+ super().__init__()
770
+ resnets = []
771
+ attentions = []
772
+
773
+ for i in range(num_layers):
774
+ in_channels = in_channels if i == 0 else out_channels
775
+ resnets.append(
776
+ ResnetBlock2D(
777
+ in_channels=in_channels,
778
+ out_channels=out_channels,
779
+ temb_channels=None,
780
+ eps=resnet_eps,
781
+ groups=resnet_groups,
782
+ dropout=dropout,
783
+ time_embedding_norm=resnet_time_scale_shift,
784
+ non_linearity=resnet_act_fn,
785
+ output_scale_factor=output_scale_factor,
786
+ pre_norm=resnet_pre_norm,
787
+ )
788
+ )
789
+ attentions.append(
790
+ AttentionBlock(
791
+ out_channels,
792
+ num_head_channels=attn_num_head_channels,
793
+ rescale_output_factor=output_scale_factor,
794
+ eps=resnet_eps,
795
+ norm_num_groups=resnet_groups,
796
+ )
797
+ )
798
+
799
+ self.attentions = nn.ModuleList(attentions)
800
+ self.resnets = nn.ModuleList(resnets)
801
+
802
+ if add_downsample:
803
+ self.downsamplers = nn.ModuleList(
804
+ [
805
+ Downsample2D(
806
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
807
+ )
808
+ ]
809
+ )
810
+ else:
811
+ self.downsamplers = None
812
+
813
+ def forward(self, hidden_states):
814
+ for resnet, attn in zip(self.resnets, self.attentions):
815
+ hidden_states = resnet(hidden_states, temb=None)
816
+ hidden_states = attn(hidden_states)
817
+
818
+ if self.downsamplers is not None:
819
+ for downsampler in self.downsamplers:
820
+ hidden_states = downsampler(hidden_states)
821
+
822
+ return hidden_states
823
+
824
+
825
+ class AttnSkipDownBlock2D(nn.Module):
826
+ def __init__(
827
+ self,
828
+ in_channels: int,
829
+ out_channels: int,
830
+ temb_channels: int,
831
+ dropout: float = 0.0,
832
+ num_layers: int = 1,
833
+ resnet_eps: float = 1e-6,
834
+ resnet_time_scale_shift: str = "default",
835
+ resnet_act_fn: str = "swish",
836
+ resnet_pre_norm: bool = True,
837
+ attn_num_head_channels=1,
838
+ attention_type="default",
839
+ output_scale_factor=np.sqrt(2.0),
840
+ downsample_padding=1,
841
+ add_downsample=True,
842
+ ):
843
+ super().__init__()
844
+ self.attentions = nn.ModuleList([])
845
+ self.resnets = nn.ModuleList([])
846
+
847
+ self.attention_type = attention_type
848
+
849
+ for i in range(num_layers):
850
+ in_channels = in_channels if i == 0 else out_channels
851
+ self.resnets.append(
852
+ ResnetBlock2D(
853
+ in_channels=in_channels,
854
+ out_channels=out_channels,
855
+ temb_channels=temb_channels,
856
+ eps=resnet_eps,
857
+ groups=min(in_channels // 4, 32),
858
+ groups_out=min(out_channels // 4, 32),
859
+ dropout=dropout,
860
+ time_embedding_norm=resnet_time_scale_shift,
861
+ non_linearity=resnet_act_fn,
862
+ output_scale_factor=output_scale_factor,
863
+ pre_norm=resnet_pre_norm,
864
+ )
865
+ )
866
+ self.attentions.append(
867
+ AttentionBlock(
868
+ out_channels,
869
+ num_head_channels=attn_num_head_channels,
870
+ rescale_output_factor=output_scale_factor,
871
+ eps=resnet_eps,
872
+ )
873
+ )
874
+
875
+ if add_downsample:
876
+ self.resnet_down = ResnetBlock2D(
877
+ in_channels=out_channels,
878
+ out_channels=out_channels,
879
+ temb_channels=temb_channels,
880
+ eps=resnet_eps,
881
+ groups=min(out_channels // 4, 32),
882
+ dropout=dropout,
883
+ time_embedding_norm=resnet_time_scale_shift,
884
+ non_linearity=resnet_act_fn,
885
+ output_scale_factor=output_scale_factor,
886
+ pre_norm=resnet_pre_norm,
887
+ use_in_shortcut=True,
888
+ down=True,
889
+ kernel="fir",
890
+ )
891
+ self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
892
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
893
+ else:
894
+ self.resnet_down = None
895
+ self.downsamplers = None
896
+ self.skip_conv = None
897
+
898
+ def forward(self, hidden_states, temb=None, skip_sample=None):
899
+ output_states = ()
900
+
901
+ for resnet, attn in zip(self.resnets, self.attentions):
902
+ hidden_states = resnet(hidden_states, temb)
903
+ hidden_states = attn(hidden_states)
904
+ output_states += (hidden_states,)
905
+
906
+ if self.downsamplers is not None:
907
+ hidden_states = self.resnet_down(hidden_states, temb)
908
+ for downsampler in self.downsamplers:
909
+ skip_sample = downsampler(skip_sample)
910
+
911
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
912
+
913
+ output_states += (hidden_states,)
914
+
915
+ return hidden_states, output_states, skip_sample
916
+
917
+
918
+ class SkipDownBlock2D(nn.Module):
919
+ def __init__(
920
+ self,
921
+ in_channels: int,
922
+ out_channels: int,
923
+ temb_channels: int,
924
+ dropout: float = 0.0,
925
+ num_layers: int = 1,
926
+ resnet_eps: float = 1e-6,
927
+ resnet_time_scale_shift: str = "default",
928
+ resnet_act_fn: str = "swish",
929
+ resnet_pre_norm: bool = True,
930
+ output_scale_factor=np.sqrt(2.0),
931
+ add_downsample=True,
932
+ downsample_padding=1,
933
+ ):
934
+ super().__init__()
935
+ self.resnets = nn.ModuleList([])
936
+
937
+ for i in range(num_layers):
938
+ in_channels = in_channels if i == 0 else out_channels
939
+ self.resnets.append(
940
+ ResnetBlock2D(
941
+ in_channels=in_channels,
942
+ out_channels=out_channels,
943
+ temb_channels=temb_channels,
944
+ eps=resnet_eps,
945
+ groups=min(in_channels // 4, 32),
946
+ groups_out=min(out_channels // 4, 32),
947
+ dropout=dropout,
948
+ time_embedding_norm=resnet_time_scale_shift,
949
+ non_linearity=resnet_act_fn,
950
+ output_scale_factor=output_scale_factor,
951
+ pre_norm=resnet_pre_norm,
952
+ )
953
+ )
954
+
955
+ if add_downsample:
956
+ self.resnet_down = ResnetBlock2D(
957
+ in_channels=out_channels,
958
+ out_channels=out_channels,
959
+ temb_channels=temb_channels,
960
+ eps=resnet_eps,
961
+ groups=min(out_channels // 4, 32),
962
+ dropout=dropout,
963
+ time_embedding_norm=resnet_time_scale_shift,
964
+ non_linearity=resnet_act_fn,
965
+ output_scale_factor=output_scale_factor,
966
+ pre_norm=resnet_pre_norm,
967
+ use_in_shortcut=True,
968
+ down=True,
969
+ kernel="fir",
970
+ )
971
+ self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
972
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
973
+ else:
974
+ self.resnet_down = None
975
+ self.downsamplers = None
976
+ self.skip_conv = None
977
+
978
+ def forward(self, hidden_states, temb=None, skip_sample=None):
979
+ output_states = ()
980
+
981
+ for resnet in self.resnets:
982
+ hidden_states = resnet(hidden_states, temb)
983
+ output_states += (hidden_states,)
984
+
985
+ if self.downsamplers is not None:
986
+ hidden_states = self.resnet_down(hidden_states, temb)
987
+ for downsampler in self.downsamplers:
988
+ skip_sample = downsampler(skip_sample)
989
+
990
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
991
+
992
+ output_states += (hidden_states,)
993
+
994
+ return hidden_states, output_states, skip_sample
995
+
996
+
997
+ class AttnUpBlock2D(nn.Module):
998
+ def __init__(
999
+ self,
1000
+ in_channels: int,
1001
+ prev_output_channel: int,
1002
+ out_channels: int,
1003
+ temb_channels: int,
1004
+ dropout: float = 0.0,
1005
+ num_layers: int = 1,
1006
+ resnet_eps: float = 1e-6,
1007
+ resnet_time_scale_shift: str = "default",
1008
+ resnet_act_fn: str = "swish",
1009
+ resnet_groups: int = 32,
1010
+ resnet_pre_norm: bool = True,
1011
+ attention_type="default",
1012
+ attn_num_head_channels=1,
1013
+ output_scale_factor=1.0,
1014
+ add_upsample=True,
1015
+ ):
1016
+ super().__init__()
1017
+ resnets = []
1018
+ attentions = []
1019
+
1020
+ self.attention_type = attention_type
1021
+
1022
+ for i in range(num_layers):
1023
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1024
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1025
+
1026
+ resnets.append(
1027
+ ResnetBlock2D(
1028
+ in_channels=resnet_in_channels + res_skip_channels,
1029
+ out_channels=out_channels,
1030
+ temb_channels=temb_channels,
1031
+ eps=resnet_eps,
1032
+ groups=resnet_groups,
1033
+ dropout=dropout,
1034
+ time_embedding_norm=resnet_time_scale_shift,
1035
+ non_linearity=resnet_act_fn,
1036
+ output_scale_factor=output_scale_factor,
1037
+ pre_norm=resnet_pre_norm,
1038
+ )
1039
+ )
1040
+ attentions.append(
1041
+ AttentionBlock(
1042
+ out_channels,
1043
+ num_head_channels=attn_num_head_channels,
1044
+ rescale_output_factor=output_scale_factor,
1045
+ eps=resnet_eps,
1046
+ norm_num_groups=resnet_groups,
1047
+ )
1048
+ )
1049
+
1050
+ self.attentions = nn.ModuleList(attentions)
1051
+ self.resnets = nn.ModuleList(resnets)
1052
+
1053
+ if add_upsample:
1054
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1055
+ else:
1056
+ self.upsamplers = None
1057
+
1058
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
1059
+ for resnet, attn in zip(self.resnets, self.attentions):
1060
+ # pop res hidden states
1061
+ res_hidden_states = res_hidden_states_tuple[-1]
1062
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1063
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1064
+
1065
+ hidden_states = resnet(hidden_states, temb)
1066
+ hidden_states = attn(hidden_states)
1067
+
1068
+ if self.upsamplers is not None:
1069
+ for upsampler in self.upsamplers:
1070
+ hidden_states = upsampler(hidden_states)
1071
+
1072
+ return hidden_states
1073
+
1074
+
1075
+ class CrossAttnUpBlock2D(nn.Module):
1076
+ def __init__(
1077
+ self,
1078
+ in_channels: int,
1079
+ out_channels: int,
1080
+ prev_output_channel: int,
1081
+ temb_channels: int,
1082
+ dropout: float = 0.0,
1083
+ num_layers: int = 1,
1084
+ resnet_eps: float = 1e-6,
1085
+ resnet_time_scale_shift: str = "default",
1086
+ resnet_act_fn: str = "swish",
1087
+ resnet_groups: int = 32,
1088
+ resnet_pre_norm: bool = True,
1089
+ attn_num_head_channels=1,
1090
+ cross_attention_dim=1280,
1091
+ attention_type="default",
1092
+ output_scale_factor=1.0,
1093
+ add_upsample=True,
1094
+ ):
1095
+ super().__init__()
1096
+ resnets = []
1097
+ attentions = []
1098
+
1099
+ self.attention_type = attention_type
1100
+ self.attn_num_head_channels = attn_num_head_channels
1101
+
1102
+ for i in range(num_layers):
1103
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1104
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1105
+
1106
+ resnets.append(
1107
+ ResnetBlock2D(
1108
+ in_channels=resnet_in_channels + res_skip_channels,
1109
+ out_channels=out_channels,
1110
+ temb_channels=temb_channels,
1111
+ eps=resnet_eps,
1112
+ groups=resnet_groups,
1113
+ dropout=dropout,
1114
+ time_embedding_norm=resnet_time_scale_shift,
1115
+ non_linearity=resnet_act_fn,
1116
+ output_scale_factor=output_scale_factor,
1117
+ pre_norm=resnet_pre_norm,
1118
+ )
1119
+ )
1120
+ attentions.append(
1121
+ Transformer2DModel(
1122
+ attn_num_head_channels,
1123
+ out_channels // attn_num_head_channels,
1124
+ in_channels=out_channels,
1125
+ num_layers=1,
1126
+ cross_attention_dim=cross_attention_dim,
1127
+ norm_num_groups=resnet_groups,
1128
+ )
1129
+ )
1130
+ self.attentions = nn.ModuleList(attentions)
1131
+ self.resnets = nn.ModuleList(resnets)
1132
+
1133
+ if add_upsample:
1134
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1135
+ else:
1136
+ self.upsamplers = None
1137
+
1138
+ self.gradient_checkpointing = False
1139
+
1140
+ def set_attention_slice(self, slice_size):
1141
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
1142
+ raise ValueError(
1143
+ f"Make sure slice_size {slice_size} is a divisor of "
1144
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
1145
+ )
1146
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
1147
+ raise ValueError(
1148
+ f"Chunk_size {slice_size} has to be smaller or equal to "
1149
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
1150
+ )
1151
+
1152
+ for attn in self.attentions:
1153
+ attn._set_attention_slice(slice_size)
1154
+
1155
+ self.gradient_checkpointing = False
1156
+
1157
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
1158
+ for attn in self.attentions:
1159
+ attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
1160
+
1161
+ def forward(
1162
+ self,
1163
+ hidden_states,
1164
+ res_hidden_states_tuple,
1165
+ temb=None,
1166
+ encoder_hidden_states=None,
1167
+ encoder_attention_mask=None,
1168
+ upsample_size=None,
1169
+ ):
1170
+ for resnet, attn in zip(self.resnets, self.attentions):
1171
+ # pop res hidden states
1172
+ res_hidden_states = res_hidden_states_tuple[-1]
1173
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1174
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1175
+
1176
+ if self.training and self.gradient_checkpointing:
1177
+
1178
+ def create_custom_forward(module, return_dict=None):
1179
+ def custom_forward(*inputs):
1180
+ if return_dict is not None:
1181
+ return module(*inputs, return_dict=return_dict)
1182
+ else:
1183
+ return module(*inputs)
1184
+
1185
+ return custom_forward
1186
+
1187
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1188
+ hidden_states = torch.utils.checkpoint.checkpoint(
1189
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states,
1190
+ encoder_attention_mask
1191
+ )[0]
1192
+ else:
1193
+ hidden_states = resnet(hidden_states, temb)
1194
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states,
1195
+ encoder_attention_mask=encoder_attention_mask).sample
1196
+
1197
+ if self.upsamplers is not None:
1198
+ for upsampler in self.upsamplers:
1199
+ hidden_states = upsampler(hidden_states, upsample_size)
1200
+
1201
+ return hidden_states
1202
+
1203
+
1204
+ class UpBlock2D(nn.Module):
1205
+ def __init__(
1206
+ self,
1207
+ in_channels: int,
1208
+ prev_output_channel: int,
1209
+ out_channels: int,
1210
+ temb_channels: int,
1211
+ dropout: float = 0.0,
1212
+ num_layers: int = 1,
1213
+ resnet_eps: float = 1e-6,
1214
+ resnet_time_scale_shift: str = "default",
1215
+ resnet_act_fn: str = "swish",
1216
+ resnet_groups: int = 32,
1217
+ resnet_pre_norm: bool = True,
1218
+ output_scale_factor=1.0,
1219
+ add_upsample=True,
1220
+ ):
1221
+ super().__init__()
1222
+ resnets = []
1223
+
1224
+ for i in range(num_layers):
1225
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1226
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1227
+
1228
+ resnets.append(
1229
+ ResnetBlock2D(
1230
+ in_channels=resnet_in_channels + res_skip_channels,
1231
+ out_channels=out_channels,
1232
+ temb_channels=temb_channels,
1233
+ eps=resnet_eps,
1234
+ groups=resnet_groups,
1235
+ dropout=dropout,
1236
+ time_embedding_norm=resnet_time_scale_shift,
1237
+ non_linearity=resnet_act_fn,
1238
+ output_scale_factor=output_scale_factor,
1239
+ pre_norm=resnet_pre_norm,
1240
+ )
1241
+ )
1242
+
1243
+ self.resnets = nn.ModuleList(resnets)
1244
+
1245
+ if add_upsample:
1246
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1247
+ else:
1248
+ self.upsamplers = None
1249
+
1250
+ self.gradient_checkpointing = False
1251
+
1252
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
1253
+ for resnet in self.resnets:
1254
+ # pop res hidden states
1255
+ res_hidden_states = res_hidden_states_tuple[-1]
1256
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1257
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1258
+
1259
+ if self.training and self.gradient_checkpointing:
1260
+
1261
+ def create_custom_forward(module):
1262
+ def custom_forward(*inputs):
1263
+ return module(*inputs)
1264
+
1265
+ return custom_forward
1266
+
1267
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1268
+ else:
1269
+ hidden_states = resnet(hidden_states, temb)
1270
+
1271
+ if self.upsamplers is not None:
1272
+ for upsampler in self.upsamplers:
1273
+ hidden_states = upsampler(hidden_states, upsample_size)
1274
+
1275
+ return hidden_states
1276
+
1277
+
1278
+ class UpDecoderBlock2D(nn.Module):
1279
+ def __init__(
1280
+ self,
1281
+ in_channels: int,
1282
+ out_channels: int,
1283
+ dropout: float = 0.0,
1284
+ num_layers: int = 1,
1285
+ resnet_eps: float = 1e-6,
1286
+ resnet_time_scale_shift: str = "default",
1287
+ resnet_act_fn: str = "swish",
1288
+ resnet_groups: int = 32,
1289
+ resnet_pre_norm: bool = True,
1290
+ output_scale_factor=1.0,
1291
+ add_upsample=True,
1292
+ ):
1293
+ super().__init__()
1294
+ resnets = []
1295
+
1296
+ for i in range(num_layers):
1297
+ input_channels = in_channels if i == 0 else out_channels
1298
+
1299
+ resnets.append(
1300
+ ResnetBlock2D(
1301
+ in_channels=input_channels,
1302
+ out_channels=out_channels,
1303
+ temb_channels=None,
1304
+ eps=resnet_eps,
1305
+ groups=resnet_groups,
1306
+ dropout=dropout,
1307
+ time_embedding_norm=resnet_time_scale_shift,
1308
+ non_linearity=resnet_act_fn,
1309
+ output_scale_factor=output_scale_factor,
1310
+ pre_norm=resnet_pre_norm,
1311
+ )
1312
+ )
1313
+
1314
+ self.resnets = nn.ModuleList(resnets)
1315
+
1316
+ if add_upsample:
1317
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1318
+ else:
1319
+ self.upsamplers = None
1320
+
1321
+ def forward(self, hidden_states):
1322
+ for resnet in self.resnets:
1323
+ hidden_states = resnet(hidden_states, temb=None)
1324
+
1325
+ if self.upsamplers is not None:
1326
+ for upsampler in self.upsamplers:
1327
+ hidden_states = upsampler(hidden_states)
1328
+
1329
+ return hidden_states
1330
+
1331
+
1332
+ class AttnUpDecoderBlock2D(nn.Module):
1333
+ def __init__(
1334
+ self,
1335
+ in_channels: int,
1336
+ out_channels: int,
1337
+ dropout: float = 0.0,
1338
+ num_layers: int = 1,
1339
+ resnet_eps: float = 1e-6,
1340
+ resnet_time_scale_shift: str = "default",
1341
+ resnet_act_fn: str = "swish",
1342
+ resnet_groups: int = 32,
1343
+ resnet_pre_norm: bool = True,
1344
+ attn_num_head_channels=1,
1345
+ output_scale_factor=1.0,
1346
+ add_upsample=True,
1347
+ ):
1348
+ super().__init__()
1349
+ resnets = []
1350
+ attentions = []
1351
+
1352
+ for i in range(num_layers):
1353
+ input_channels = in_channels if i == 0 else out_channels
1354
+
1355
+ resnets.append(
1356
+ ResnetBlock2D(
1357
+ in_channels=input_channels,
1358
+ out_channels=out_channels,
1359
+ temb_channels=None,
1360
+ eps=resnet_eps,
1361
+ groups=resnet_groups,
1362
+ dropout=dropout,
1363
+ time_embedding_norm=resnet_time_scale_shift,
1364
+ non_linearity=resnet_act_fn,
1365
+ output_scale_factor=output_scale_factor,
1366
+ pre_norm=resnet_pre_norm,
1367
+ )
1368
+ )
1369
+ attentions.append(
1370
+ AttentionBlock(
1371
+ out_channels,
1372
+ num_head_channels=attn_num_head_channels,
1373
+ rescale_output_factor=output_scale_factor,
1374
+ eps=resnet_eps,
1375
+ norm_num_groups=resnet_groups,
1376
+ )
1377
+ )
1378
+
1379
+ self.attentions = nn.ModuleList(attentions)
1380
+ self.resnets = nn.ModuleList(resnets)
1381
+
1382
+ if add_upsample:
1383
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1384
+ else:
1385
+ self.upsamplers = None
1386
+
1387
+ def forward(self, hidden_states):
1388
+ for resnet, attn in zip(self.resnets, self.attentions):
1389
+ hidden_states = resnet(hidden_states, temb=None)
1390
+ hidden_states = attn(hidden_states)
1391
+
1392
+ if self.upsamplers is not None:
1393
+ for upsampler in self.upsamplers:
1394
+ hidden_states = upsampler(hidden_states)
1395
+
1396
+ return hidden_states
1397
+
1398
+
1399
+ class AttnSkipUpBlock2D(nn.Module):
1400
+ def __init__(
1401
+ self,
1402
+ in_channels: int,
1403
+ prev_output_channel: int,
1404
+ out_channels: int,
1405
+ temb_channels: int,
1406
+ dropout: float = 0.0,
1407
+ num_layers: int = 1,
1408
+ resnet_eps: float = 1e-6,
1409
+ resnet_time_scale_shift: str = "default",
1410
+ resnet_act_fn: str = "swish",
1411
+ resnet_pre_norm: bool = True,
1412
+ attn_num_head_channels=1,
1413
+ attention_type="default",
1414
+ output_scale_factor=np.sqrt(2.0),
1415
+ upsample_padding=1,
1416
+ add_upsample=True,
1417
+ ):
1418
+ super().__init__()
1419
+ self.attentions = nn.ModuleList([])
1420
+ self.resnets = nn.ModuleList([])
1421
+
1422
+ self.attention_type = attention_type
1423
+
1424
+ for i in range(num_layers):
1425
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1426
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1427
+
1428
+ self.resnets.append(
1429
+ ResnetBlock2D(
1430
+ in_channels=resnet_in_channels + res_skip_channels,
1431
+ out_channels=out_channels,
1432
+ temb_channels=temb_channels,
1433
+ eps=resnet_eps,
1434
+ groups=min(resnet_in_channels + res_skip_channels // 4, 32),
1435
+ groups_out=min(out_channels // 4, 32),
1436
+ dropout=dropout,
1437
+ time_embedding_norm=resnet_time_scale_shift,
1438
+ non_linearity=resnet_act_fn,
1439
+ output_scale_factor=output_scale_factor,
1440
+ pre_norm=resnet_pre_norm,
1441
+ )
1442
+ )
1443
+
1444
+ self.attentions.append(
1445
+ AttentionBlock(
1446
+ out_channels,
1447
+ num_head_channels=attn_num_head_channels,
1448
+ rescale_output_factor=output_scale_factor,
1449
+ eps=resnet_eps,
1450
+ )
1451
+ )
1452
+
1453
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
1454
+ if add_upsample:
1455
+ self.resnet_up = ResnetBlock2D(
1456
+ in_channels=out_channels,
1457
+ out_channels=out_channels,
1458
+ temb_channels=temb_channels,
1459
+ eps=resnet_eps,
1460
+ groups=min(out_channels // 4, 32),
1461
+ groups_out=min(out_channels // 4, 32),
1462
+ dropout=dropout,
1463
+ time_embedding_norm=resnet_time_scale_shift,
1464
+ non_linearity=resnet_act_fn,
1465
+ output_scale_factor=output_scale_factor,
1466
+ pre_norm=resnet_pre_norm,
1467
+ use_in_shortcut=True,
1468
+ up=True,
1469
+ kernel="fir",
1470
+ )
1471
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1472
+ self.skip_norm = torch.nn.GroupNorm(
1473
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
1474
+ )
1475
+ self.act = nn.SiLU()
1476
+ else:
1477
+ self.resnet_up = None
1478
+ self.skip_conv = None
1479
+ self.skip_norm = None
1480
+ self.act = None
1481
+
1482
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
1483
+ for resnet in self.resnets:
1484
+ # pop res hidden states
1485
+ res_hidden_states = res_hidden_states_tuple[-1]
1486
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1487
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1488
+
1489
+ hidden_states = resnet(hidden_states, temb)
1490
+
1491
+ hidden_states = self.attentions[0](hidden_states)
1492
+
1493
+ if skip_sample is not None:
1494
+ skip_sample = self.upsampler(skip_sample)
1495
+ else:
1496
+ skip_sample = 0
1497
+
1498
+ if self.resnet_up is not None:
1499
+ skip_sample_states = self.skip_norm(hidden_states)
1500
+ skip_sample_states = self.act(skip_sample_states)
1501
+ skip_sample_states = self.skip_conv(skip_sample_states)
1502
+
1503
+ skip_sample = skip_sample + skip_sample_states
1504
+
1505
+ hidden_states = self.resnet_up(hidden_states, temb)
1506
+
1507
+ return hidden_states, skip_sample
1508
+
1509
+
1510
+ class SkipUpBlock2D(nn.Module):
1511
+ def __init__(
1512
+ self,
1513
+ in_channels: int,
1514
+ prev_output_channel: int,
1515
+ out_channels: int,
1516
+ temb_channels: int,
1517
+ dropout: float = 0.0,
1518
+ num_layers: int = 1,
1519
+ resnet_eps: float = 1e-6,
1520
+ resnet_time_scale_shift: str = "default",
1521
+ resnet_act_fn: str = "swish",
1522
+ resnet_pre_norm: bool = True,
1523
+ output_scale_factor=np.sqrt(2.0),
1524
+ add_upsample=True,
1525
+ upsample_padding=1,
1526
+ ):
1527
+ super().__init__()
1528
+ self.resnets = nn.ModuleList([])
1529
+
1530
+ for i in range(num_layers):
1531
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1532
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1533
+
1534
+ self.resnets.append(
1535
+ ResnetBlock2D(
1536
+ in_channels=resnet_in_channels + res_skip_channels,
1537
+ out_channels=out_channels,
1538
+ temb_channels=temb_channels,
1539
+ eps=resnet_eps,
1540
+ groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
1541
+ groups_out=min(out_channels // 4, 32),
1542
+ dropout=dropout,
1543
+ time_embedding_norm=resnet_time_scale_shift,
1544
+ non_linearity=resnet_act_fn,
1545
+ output_scale_factor=output_scale_factor,
1546
+ pre_norm=resnet_pre_norm,
1547
+ )
1548
+ )
1549
+
1550
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
1551
+ if add_upsample:
1552
+ self.resnet_up = ResnetBlock2D(
1553
+ in_channels=out_channels,
1554
+ out_channels=out_channels,
1555
+ temb_channels=temb_channels,
1556
+ eps=resnet_eps,
1557
+ groups=min(out_channels // 4, 32),
1558
+ groups_out=min(out_channels // 4, 32),
1559
+ dropout=dropout,
1560
+ time_embedding_norm=resnet_time_scale_shift,
1561
+ non_linearity=resnet_act_fn,
1562
+ output_scale_factor=output_scale_factor,
1563
+ pre_norm=resnet_pre_norm,
1564
+ use_in_shortcut=True,
1565
+ up=True,
1566
+ kernel="fir",
1567
+ )
1568
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1569
+ self.skip_norm = torch.nn.GroupNorm(
1570
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
1571
+ )
1572
+ self.act = nn.SiLU()
1573
+ else:
1574
+ self.resnet_up = None
1575
+ self.skip_conv = None
1576
+ self.skip_norm = None
1577
+ self.act = None
1578
+
1579
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
1580
+ for resnet in self.resnets:
1581
+ # pop res hidden states
1582
+ res_hidden_states = res_hidden_states_tuple[-1]
1583
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1584
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1585
+
1586
+ hidden_states = resnet(hidden_states, temb)
1587
+
1588
+ if skip_sample is not None:
1589
+ skip_sample = self.upsampler(skip_sample)
1590
+ else:
1591
+ skip_sample = 0
1592
+
1593
+ if self.resnet_up is not None:
1594
+ skip_sample_states = self.skip_norm(hidden_states)
1595
+ skip_sample_states = self.act(skip_sample_states)
1596
+ skip_sample_states = self.skip_conv(skip_sample_states)
1597
+
1598
+ skip_sample = skip_sample + skip_sample_states
1599
+
1600
+ hidden_states = self.resnet_up(hidden_states, temb)
1601
+
1602
+ return hidden_states, skip_sample
models/diffusers_override/unet_2d_condition.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.modeling_utils import ModelMixin
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
25
+ from .unet_2d_blocks import (
26
+ CrossAttnDownBlock2D,
27
+ CrossAttnUpBlock2D,
28
+ DownBlock2D,
29
+ UNetMidBlock2DCrossAttn,
30
+ UpBlock2D,
31
+ get_down_block,
32
+ get_up_block,
33
+ )
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ @dataclass
39
+ class UNet2DConditionOutput(BaseOutput):
40
+ """
41
+ Args:
42
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
43
+ Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
44
+ """
45
+
46
+ sample: torch.FloatTensor
47
+
48
+
49
+ class UNet2DConditionModel(ModelMixin, ConfigMixin):
50
+ r"""
51
+ UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
52
+ and returns sample shaped output.
53
+
54
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
55
+ implements for all the models (such as downloading or saving, etc.)
56
+
57
+ Parameters:
58
+ sample_size (`int`, *optional*): The size of the input sample.
59
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
60
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
61
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
62
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
63
+ Whether to flip the sin to cos in the time embedding.
64
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
65
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
66
+ The tuple of downsample blocks to use.
67
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
68
+ The tuple of upsample blocks to use.
69
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
70
+ The tuple of output channels for each block.
71
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
72
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
73
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
74
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
75
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
76
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
77
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
78
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
79
+ """
80
+
81
+ _supports_gradient_checkpointing = True
82
+
83
+ @register_to_config
84
+ def __init__(
85
+ self,
86
+ sample_size: Optional[int] = None,
87
+ in_channels: int = 4,
88
+ out_channels: int = 4,
89
+ center_input_sample: bool = False,
90
+ flip_sin_to_cos: bool = True,
91
+ freq_shift: int = 0,
92
+ down_block_types: Tuple[str] = (
93
+ "CrossAttnDownBlock2D",
94
+ "CrossAttnDownBlock2D",
95
+ "CrossAttnDownBlock2D",
96
+ "DownBlock2D",
97
+ ),
98
+ up_block_types: Tuple[str] = (
99
+ "UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
100
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
101
+ layers_per_block: int = 2,
102
+ downsample_padding: int = 1,
103
+ mid_block_scale_factor: float = 1,
104
+ act_fn: str = "silu",
105
+ norm_num_groups: int = 32,
106
+ norm_eps: float = 1e-5,
107
+ cross_attention_dim: int = 1280,
108
+ attention_head_dim: int = 8,
109
+ ):
110
+ super().__init__()
111
+
112
+ self.sample_size = sample_size
113
+ time_embed_dim = block_out_channels[0] * 4
114
+
115
+ # input
116
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
117
+
118
+ # time
119
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
120
+ timestep_input_dim = block_out_channels[0]
121
+
122
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
123
+
124
+ self.down_blocks = nn.ModuleList([])
125
+ self.mid_block = None
126
+ self.up_blocks = nn.ModuleList([])
127
+
128
+ # down
129
+ output_channel = block_out_channels[0]
130
+ for i, down_block_type in enumerate(down_block_types):
131
+ input_channel = output_channel
132
+ output_channel = block_out_channels[i]
133
+ is_final_block = i == len(block_out_channels) - 1
134
+
135
+ down_block = get_down_block(
136
+ down_block_type,
137
+ num_layers=layers_per_block,
138
+ in_channels=input_channel,
139
+ out_channels=output_channel,
140
+ temb_channels=time_embed_dim,
141
+ add_downsample=not is_final_block,
142
+ resnet_eps=norm_eps,
143
+ resnet_act_fn=act_fn,
144
+ resnet_groups=norm_num_groups,
145
+ cross_attention_dim=cross_attention_dim,
146
+ attn_num_head_channels=attention_head_dim,
147
+ downsample_padding=downsample_padding,
148
+ )
149
+ self.down_blocks.append(down_block)
150
+
151
+ # mid
152
+ self.mid_block = UNetMidBlock2DCrossAttn(
153
+ in_channels=block_out_channels[-1],
154
+ temb_channels=time_embed_dim,
155
+ resnet_eps=norm_eps,
156
+ resnet_act_fn=act_fn,
157
+ output_scale_factor=mid_block_scale_factor,
158
+ resnet_time_scale_shift="default",
159
+ cross_attention_dim=cross_attention_dim,
160
+ attn_num_head_channels=attention_head_dim,
161
+ resnet_groups=norm_num_groups,
162
+ )
163
+
164
+ # count how many layers upsample the images
165
+ self.num_upsamplers = 0
166
+
167
+ # up
168
+ reversed_block_out_channels = list(reversed(block_out_channels))
169
+ output_channel = reversed_block_out_channels[0]
170
+ for i, up_block_type in enumerate(up_block_types):
171
+ is_final_block = i == len(block_out_channels) - 1
172
+
173
+ prev_output_channel = output_channel
174
+ output_channel = reversed_block_out_channels[i]
175
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
176
+
177
+ # add upsample block for all BUT final layer
178
+ if not is_final_block:
179
+ add_upsample = True
180
+ self.num_upsamplers += 1
181
+ else:
182
+ add_upsample = False
183
+
184
+ up_block = get_up_block(
185
+ up_block_type,
186
+ num_layers=layers_per_block + 1,
187
+ in_channels=input_channel,
188
+ out_channels=output_channel,
189
+ prev_output_channel=prev_output_channel,
190
+ temb_channels=time_embed_dim,
191
+ add_upsample=add_upsample,
192
+ resnet_eps=norm_eps,
193
+ resnet_act_fn=act_fn,
194
+ resnet_groups=norm_num_groups,
195
+ cross_attention_dim=cross_attention_dim,
196
+ attn_num_head_channels=attention_head_dim,
197
+ )
198
+ self.up_blocks.append(up_block)
199
+ prev_output_channel = output_channel
200
+
201
+ # out
202
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
203
+ self.conv_act = nn.SiLU()
204
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
205
+
206
+ def set_attention_slice(self, slice_size):
207
+ if slice_size is not None and self.config.attention_head_dim % slice_size != 0:
208
+ raise ValueError(
209
+ f"Make sure slice_size {slice_size} is a divisor of "
210
+ f"the number of heads used in cross_attention {self.config.attention_head_dim}"
211
+ )
212
+ if slice_size is not None and slice_size > self.config.attention_head_dim:
213
+ raise ValueError(
214
+ f"Chunk_size {slice_size} has to be smaller or equal to "
215
+ f"the number of heads used in cross_attention {self.config.attention_head_dim}"
216
+ )
217
+
218
+ for block in self.down_blocks:
219
+ if hasattr(block, "attentions") and block.attentions is not None:
220
+ block.set_attention_slice(slice_size)
221
+
222
+ self.mid_block.set_attention_slice(slice_size)
223
+
224
+ for block in self.up_blocks:
225
+ if hasattr(block, "attentions") and block.attentions is not None:
226
+ block.set_attention_slice(slice_size)
227
+
228
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
229
+ for block in self.down_blocks:
230
+ if hasattr(block, "attentions") and block.attentions is not None:
231
+ block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
232
+
233
+ self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
234
+
235
+ for block in self.up_blocks:
236
+ if hasattr(block, "attentions") and block.attentions is not None:
237
+ block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
238
+
239
+ def _set_gradient_checkpointing(self, module, value=False):
240
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
241
+ module.gradient_checkpointing = value
242
+
243
+ def forward(
244
+ self,
245
+ sample: torch.FloatTensor,
246
+ timestep: Union[torch.Tensor, float, int],
247
+ encoder_hidden_states: torch.Tensor,
248
+ encoder_attention_mask: torch.Tensor,
249
+ return_dict: bool = True,
250
+ ) -> Union[UNet2DConditionOutput, Tuple]:
251
+ r"""
252
+ Args:
253
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
254
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
255
+ encoder_hidden_states (`torch.FloatTensor`):
256
+ (batch_size, sequence_length, hidden_size) encoder hidden states
257
+ encoder_attention_mask (`torch.FloatTensor`):
258
+ (batch_size, sequence_length) encoder attention mask
259
+ return_dict (`bool`, *optional*, defaults to `True`):
260
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
261
+
262
+ Returns:
263
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
264
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
265
+ returning a tuple, the first element is the sample tensor.
266
+ """
267
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
268
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
269
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
270
+ # on the fly if necessary.
271
+ default_overall_up_factor = 2 ** self.num_upsamplers
272
+
273
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
274
+ forward_upsample_size = False
275
+ upsample_size = None
276
+
277
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
278
+ logger.info("Forward upsample size to force interpolation output size.")
279
+ forward_upsample_size = True
280
+
281
+ # 0. center input if necessary
282
+ if self.config.center_input_sample:
283
+ sample = 2 * sample - 1.0
284
+
285
+ # 1. time
286
+ timesteps = timestep
287
+ if not torch.is_tensor(timesteps):
288
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
289
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
290
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
291
+ timesteps = timesteps[None].to(sample.device)
292
+
293
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
294
+ timesteps = timesteps.expand(sample.shape[0])
295
+
296
+ t_emb = self.time_proj(timesteps)
297
+
298
+ # timesteps does not contain any weights and will always return f32 tensors
299
+ # but time_embedding might actually be running in fp16. so we need to cast here.
300
+ # there might be better ways to encapsulate this.
301
+ t_emb = t_emb.to(dtype=self.dtype)
302
+ emb = self.time_embedding(t_emb)
303
+
304
+ # 2. pre-process
305
+ sample = self.conv_in(sample)
306
+
307
+ # 3. down
308
+ down_block_res_samples = (sample,)
309
+ for downsample_block in self.down_blocks:
310
+ if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
311
+ sample, res_samples = downsample_block(
312
+ hidden_states=sample,
313
+ temb=emb,
314
+ encoder_hidden_states=encoder_hidden_states,
315
+ encoder_attention_mask=encoder_attention_mask,
316
+ )
317
+ else:
318
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
319
+
320
+ down_block_res_samples += res_samples
321
+
322
+ # 4. mid
323
+ sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states,
324
+ encoder_attention_mask=encoder_attention_mask)
325
+
326
+ # 5. up
327
+ for i, upsample_block in enumerate(self.up_blocks):
328
+ is_final_block = i == len(self.up_blocks) - 1
329
+
330
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
331
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
332
+
333
+ # if we have not reached the final block and need to forward the
334
+ # upsample size, we do it here
335
+ if not is_final_block and forward_upsample_size:
336
+ upsample_size = down_block_res_samples[-1].shape[2:]
337
+
338
+ if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
339
+ sample = upsample_block(
340
+ hidden_states=sample,
341
+ temb=emb,
342
+ res_hidden_states_tuple=res_samples,
343
+ encoder_hidden_states=encoder_hidden_states,
344
+ encoder_attention_mask=encoder_attention_mask,
345
+ upsample_size=upsample_size,
346
+ )
347
+ else:
348
+ sample = upsample_block(
349
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
350
+ )
351
+ # 6. post-process
352
+ sample = self.conv_norm_out(sample)
353
+ sample = self.conv_act(sample)
354
+ sample = self.conv_out(sample)
355
+
356
+ if not return_dict:
357
+ return (sample,)
358
+
359
+ return UNet2DConditionOutput(sample=sample)
models/inception.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import models
5
+
6
+ try:
7
+ from torchvision.models.utils import load_state_dict_from_url
8
+ except ImportError:
9
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
10
+
11
+ # Inception weights ported to Pytorch from
12
+ # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13
+ FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
14
+
15
+
16
+ class InceptionV3(nn.Module):
17
+ """Pretrained InceptionV3 network returning feature maps"""
18
+
19
+ # Index of default block of inception to return,
20
+ # corresponds to output of final average pooling
21
+ DEFAULT_BLOCK_INDEX = 3
22
+
23
+ # Maps feature dimensionality to their output blocks indices
24
+ BLOCK_INDEX_BY_DIM = {
25
+ 64: 0, # First max pooling features
26
+ 192: 1, # Second max pooling featurs
27
+ 768: 2, # Pre-aux classifier features
28
+ 2048: 3 # Final average pooling features
29
+ }
30
+
31
+ def __init__(self,
32
+ output_blocks=[DEFAULT_BLOCK_INDEX],
33
+ resize_input=True,
34
+ normalize_input=True,
35
+ requires_grad=False,
36
+ use_fid_inception=True):
37
+ """Build pretrained InceptionV3
38
+
39
+ Parameters
40
+ ----------
41
+ output_blocks : list of int
42
+ Indices of blocks to return features of. Possible values are:
43
+ - 0: corresponds to output of first max pooling
44
+ - 1: corresponds to output of second max pooling
45
+ - 2: corresponds to output which is fed to aux classifier
46
+ - 3: corresponds to output of final average pooling
47
+ resize_input : bool
48
+ If true, bilinearly resizes input to width and height 299 before
49
+ feeding input to model. As the network without fully connected
50
+ layers is fully convolutional, it should be able to handle inputs
51
+ of arbitrary size, so resizing might not be strictly needed
52
+ normalize_input : bool
53
+ If true, scales the input from range (0, 1) to the range the
54
+ pretrained Inception network expects, namely (-1, 1)
55
+ requires_grad : bool
56
+ If true, parameters of the model require gradients. Possibly useful
57
+ for finetuning the network
58
+ use_fid_inception : bool
59
+ If true, uses the pretrained Inception model used in Tensorflow's
60
+ FID implementation. If false, uses the pretrained Inception model
61
+ available in torchvision. The FID Inception model has different
62
+ weights and a slightly different structure from torchvision's
63
+ Inception model. If you want to compute FID scores, you are
64
+ strongly advised to set this parameter to true to get comparable
65
+ results.
66
+ """
67
+ super(InceptionV3, self).__init__()
68
+
69
+ self.resize_input = resize_input
70
+ self.normalize_input = normalize_input
71
+ self.output_blocks = sorted(output_blocks)
72
+ self.last_needed_block = max(output_blocks)
73
+
74
+ assert self.last_needed_block <= 3, \
75
+ 'Last possible output block index is 3'
76
+
77
+ self.blocks = nn.ModuleList()
78
+
79
+ if use_fid_inception:
80
+ inception = fid_inception_v3()
81
+ else:
82
+ inception = models.inception_v3(pretrained=True)
83
+
84
+ # Block 0: input to maxpool1
85
+ block0 = [
86
+ inception.Conv2d_1a_3x3,
87
+ inception.Conv2d_2a_3x3,
88
+ inception.Conv2d_2b_3x3,
89
+ nn.MaxPool2d(kernel_size=3, stride=2)
90
+ ]
91
+ self.blocks.append(nn.Sequential(*block0))
92
+
93
+ # Block 1: maxpool1 to maxpool2
94
+ if self.last_needed_block >= 1:
95
+ block1 = [
96
+ inception.Conv2d_3b_1x1,
97
+ inception.Conv2d_4a_3x3,
98
+ nn.MaxPool2d(kernel_size=3, stride=2)
99
+ ]
100
+ self.blocks.append(nn.Sequential(*block1))
101
+
102
+ # Block 2: maxpool2 to aux classifier
103
+ if self.last_needed_block >= 2:
104
+ block2 = [
105
+ inception.Mixed_5b,
106
+ inception.Mixed_5c,
107
+ inception.Mixed_5d,
108
+ inception.Mixed_6a,
109
+ inception.Mixed_6b,
110
+ inception.Mixed_6c,
111
+ inception.Mixed_6d,
112
+ inception.Mixed_6e,
113
+ ]
114
+ self.blocks.append(nn.Sequential(*block2))
115
+
116
+ # Block 3: aux classifier to final avgpool
117
+ if self.last_needed_block >= 3:
118
+ block3 = [
119
+ inception.Mixed_7a,
120
+ inception.Mixed_7b,
121
+ inception.Mixed_7c,
122
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
123
+ ]
124
+ self.blocks.append(nn.Sequential(*block3))
125
+
126
+ for param in self.parameters():
127
+ param.requires_grad = requires_grad
128
+
129
+ def forward(self, inp):
130
+ """Get Inception feature maps
131
+
132
+ Parameters
133
+ ----------
134
+ inp : torch.autograd.Variable
135
+ Input tensor of shape Bx3xHxW. Values are expected to be in
136
+ range (0, 1)
137
+
138
+ Returns
139
+ -------
140
+ List of torch.autograd.Variable, corresponding to the selected output
141
+ block, sorted ascending by index
142
+ """
143
+ outp = []
144
+ x = inp
145
+
146
+ if self.resize_input:
147
+ x = F.interpolate(x,
148
+ size=(299, 299),
149
+ mode='bilinear',
150
+ align_corners=False)
151
+
152
+ if self.normalize_input:
153
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
154
+
155
+ for idx, block in enumerate(self.blocks):
156
+ x = block(x)
157
+ if idx in self.output_blocks:
158
+ outp.append(x)
159
+
160
+ if idx == self.last_needed_block:
161
+ break
162
+
163
+ return outp
164
+
165
+
166
+ def fid_inception_v3():
167
+ """Build pretrained Inception model for FID computation
168
+
169
+ The Inception model for FID computation uses a different set of weights
170
+ and has a slightly different structure than torchvision's Inception.
171
+
172
+ This method first constructs torchvision's Inception and then patches the
173
+ necessary parts that are different in the FID Inception model.
174
+ """
175
+ inception = models.inception_v3(num_classes=1008,
176
+ aux_logits=False,
177
+ pretrained=False)
178
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
179
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
180
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
181
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
182
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
183
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
184
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
185
+ inception.Mixed_7b = FIDInceptionE_1(1280)
186
+ inception.Mixed_7c = FIDInceptionE_2(2048)
187
+
188
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
189
+ inception.load_state_dict(state_dict)
190
+ return inception
191
+
192
+
193
+ class FIDInceptionA(models.inception.InceptionA):
194
+ """InceptionA block patched for FID computation"""
195
+
196
+ def __init__(self, in_channels, pool_features):
197
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
198
+
199
+ def forward(self, x):
200
+ branch1x1 = self.branch1x1(x)
201
+
202
+ branch5x5 = self.branch5x5_1(x)
203
+ branch5x5 = self.branch5x5_2(branch5x5)
204
+
205
+ branch3x3dbl = self.branch3x3dbl_1(x)
206
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
207
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
208
+
209
+ # Patch: Tensorflow's average pool does not use the padded zero's in
210
+ # its average calculation
211
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
212
+ count_include_pad=False)
213
+ branch_pool = self.branch_pool(branch_pool)
214
+
215
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
216
+ return torch.cat(outputs, 1)
217
+
218
+
219
+ class FIDInceptionC(models.inception.InceptionC):
220
+ """InceptionC block patched for FID computation"""
221
+
222
+ def __init__(self, in_channels, channels_7x7):
223
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
224
+
225
+ def forward(self, x):
226
+ branch1x1 = self.branch1x1(x)
227
+
228
+ branch7x7 = self.branch7x7_1(x)
229
+ branch7x7 = self.branch7x7_2(branch7x7)
230
+ branch7x7 = self.branch7x7_3(branch7x7)
231
+
232
+ branch7x7dbl = self.branch7x7dbl_1(x)
233
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
234
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
235
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
236
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
237
+
238
+ # Patch: Tensorflow's average pool does not use the padded zero's in
239
+ # its average calculation
240
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
241
+ count_include_pad=False)
242
+ branch_pool = self.branch_pool(branch_pool)
243
+
244
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
245
+ return torch.cat(outputs, 1)
246
+
247
+
248
+ class FIDInceptionE_1(models.inception.InceptionE):
249
+ """First InceptionE block patched for FID computation"""
250
+
251
+ def __init__(self, in_channels):
252
+ super(FIDInceptionE_1, self).__init__(in_channels)
253
+
254
+ def forward(self, x):
255
+ branch1x1 = self.branch1x1(x)
256
+
257
+ branch3x3 = self.branch3x3_1(x)
258
+ branch3x3 = [
259
+ self.branch3x3_2a(branch3x3),
260
+ self.branch3x3_2b(branch3x3),
261
+ ]
262
+ branch3x3 = torch.cat(branch3x3, 1)
263
+
264
+ branch3x3dbl = self.branch3x3dbl_1(x)
265
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
266
+ branch3x3dbl = [
267
+ self.branch3x3dbl_3a(branch3x3dbl),
268
+ self.branch3x3dbl_3b(branch3x3dbl),
269
+ ]
270
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
271
+
272
+ # Patch: Tensorflow's average pool does not use the padded zero's in
273
+ # its average calculation
274
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
275
+ count_include_pad=False)
276
+ branch_pool = self.branch_pool(branch_pool)
277
+
278
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
279
+ return torch.cat(outputs, 1)
280
+
281
+
282
+ class FIDInceptionE_2(models.inception.InceptionE):
283
+ """Second InceptionE block patched for FID computation"""
284
+
285
+ def __init__(self, in_channels):
286
+ super(FIDInceptionE_2, self).__init__(in_channels)
287
+
288
+ def forward(self, x):
289
+ branch1x1 = self.branch1x1(x)
290
+
291
+ branch3x3 = self.branch3x3_1(x)
292
+ branch3x3 = [
293
+ self.branch3x3_2a(branch3x3),
294
+ self.branch3x3_2b(branch3x3),
295
+ ]
296
+ branch3x3 = torch.cat(branch3x3, 1)
297
+
298
+ branch3x3dbl = self.branch3x3dbl_1(x)
299
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
300
+ branch3x3dbl = [
301
+ self.branch3x3dbl_3a(branch3x3dbl),
302
+ self.branch3x3dbl_3b(branch3x3dbl),
303
+ ]
304
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
305
+
306
+ # Patch: The FID Inception model uses max pooling instead of average
307
+ # pooling. This is likely an error in this specific Inception
308
+ # implementation, as other Inception models use average pooling here
309
+ # (which matches the description in the paper).
310
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
311
+ branch_pool = self.branch_pool(branch_pool)
312
+
313
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
314
+ return torch.cat(outputs, 1)
v1-5-pruned-emaonly.ckpt → pororo_100.h5 RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cc6cb27103417325ff94f52b7a5d2dde45a7515b25c255d8e396c90014281516
3
- size 4265380512
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b5d47440de7abbbbb2265e1d5ecbc1c5d4d3188434db3988cb13e7ec5fa7549
3
+ size 69568
readme-storyvisualization.md ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 一、基于叙事文本的跨模态序列图像生成模型
2
+
3
+ ## 安装环境
4
+ conda create -n arldm python=3.8
5
+ conda activate arldm
6
+ conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch-lts
7
+ cd /root/lihui/StoryVisualization
8
+ pip install -r requirements.txt
9
+ ## 数据准备
10
+ Download the PororoSV dataset here.
11
+ To accelerate I/O, using the following scrips to convert your downloaded data to HDF5
12
+ python data_script/pororo_hdf5.py
13
+ --data_dir /path/to/pororo_data
14
+ --save_path /path/to/save_hdf5_file
15
+ ## 配置文件config.yaml
16
+
17
+ #device
18
+ mode: sample # train sample
19
+ ckpt_dir: /root/lihui/StoryVisualization/save_ckpt_epoch5_new # checkpoint directory
20
+ run_name: ARLDM # name for this run
21
+
22
+ #train
23
+ train_model_file: /root/lihui/StoryVisualization/save_ckpt_3last50/ARLDM/last.ckpt # model file for resume, none for train from scratch
24
+
25
+ #sample
26
+ test_model_file: /root/lihui/StoryVisualization/save_ckpt_3last50/ARLDM/last.ckpt # model file for test
27
+ sample_output_dir: /root/lihui/StoryVisualization/save_samples_128_epoch50 # output directory
28
+ ## 训练
29
+ 在 config.yaml 中指定您的目录和设备配置并运行:
30
+ python main.py
31
+ ## 采样
32
+ 在 config.yaml 中指定您的目录和设备配置并运行:
33
+ python main.py
34
+ ## 引用
35
+ @article{pan2022synthesizing,
36
+ title={Synthesizing Coherent Story with Auto-Regressive Latent Diffusion Models},
37
+ author={Pan, Xichen and Qin, Pengda and Li, Yuhong and Xue, Hui and Chen, Wenhu},
38
+ journal={arXiv preprint arXiv:2211.10950},
39
+ year={2022}
40
+ }
41
+
42
+
43
+ ### 二、基于Real-ESRGAN的超分算法
44
+ Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data
45
+ [论文]   [项目主页]   [YouTube 视频]   [B站视频]   [Poster]   [PPT]
46
+ Xintao Wang, Liangbin Xie, Chao Dong, Ying Shan
47
+ Tencent ARC Lab; Shenzhen Institutes of Advanced Technology, Chinese Academy of Sciences
48
+ ## 环境
49
+ Python >= 3.7 (推荐使用Anaconda或Miniconda)
50
+ PyTorch >= 1.7
51
+ ## 安装
52
+ 1、直接进入已配好的文件夹
53
+ cd /root/lihui/StoryVisualization/Real-ESRGAN
54
+ 2、或 把项目克隆到本地
55
+ bash git clone https://github.com/xinntao/Real-ESRGAN.git cd Real-ESRGAN
56
+ 3、 安装各种依赖
57
+ ```bash
58
+ 安装 basicsr - https://github.com/xinntao/BasicSR
59
+ #我们使用BasicSR来训练以及推断
60
+ pip install basicsr
61
+ #facexlib和gfpgan是用来增强人脸的
62
+ pip install facexlib pip install gfpgan pip install -r requirements.txt python setup.py develop
63
+ ```
64
+ ## 训练
65
+ 训练好的模型: RealESRGAN_x4plus_anime_6B
66
+ 有关waifu2x的更多信息和对比在anime_model.md中。
67
+ ## 下载模型
68
+ wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P weights
69
+ ## 推断
70
+ python inference_realesrgan.py -n RealESRGAN_x4plus_anime_6B -i inputs
71
+ 结果在results文件夹
72
+ ## BibTeX 引用
73
+ @Article{wang2021realesrgan,
74
+ title={Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data},
75
+ author={Xintao Wang and Liangbin Xie and Chao Dong and Ying Shan},
76
+ journal={arXiv:2107.10833},
77
+ year={2021}
78
+ }
79
+
80
+
81
+ ### 三、基于YOLOv5的目标角色检测算法
82
+ ## 安装
83
+ 克隆 repo,并要求在 Python>=3.7.0 环境中安装 requirements.txt ,且要求 PyTorch>=1.7 。
84
+ git clone https://github.com/ultralytics/yolov5 # clone
85
+ cd /root/lihui/StoryVisualization
86
+ cd yolov5
87
+ pip install -r requirements.txt # install
88
+ ## 转换图片
89
+ cd /root/lihui/StoryVisualization
90
+ python transtoyolo.py
91
+ ## 使用 detect.py 推理
92
+ detect.py 在各种来源上运行推理, 模型 自动从 最新的YOLOv5 release 中下载,并将结果保存到 runs/detect 。
93
+ python detect.py --weights yolov5s.pt --source 0 # webcam
94
+ img.jpg # image
95
+ vid.mp4 # video
96
+ screen # screenshot
97
+ path/ # directory
98
+ list.txt # list of images
99
+ list.streams # list of streams
100
+ 'path/*.jpg' # glob
101
+ 'https://youtu.be/Zgi9g1ksQHc' # YouTube
102
+ 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
103
+ ## 训练
104
+ 最新的 模型 和 数据集 将自动的从 YOLOv5 release 中下载。 YOLOv5n/s/m/l/x 在 V100 GPU 的训练时间为 1/2/4/6/8 天( 多GPU 训练速度更快)。 尽可能使用更大的 --batch-size ,或通过 --batch-size -1 实现 YOLOv5 自动批处理 。下方显示的 batchsize 适用于 V100-16GB。
105
+ python train.py --data xxx.yaml --epochs 500 --weights '' --cfg yolov5l --batch-size 64
106
+ # xx.yaml文件为转换后的数据
107
+
108
+ ## 许可
109
+ YOLOv5 在两种不同的 License 下可用:
110
+ AGPL-3.0 License: 查看 License 文件的详细信息。
111
+ 企业License:在没有 AGPL-3.0 开源要求的情况下为商业产品开发提供更大的灵活性。典型用例是将 Ultralytics 软件和 AI 模型嵌入到商业产品和应用程序中。在以下位置申请企业许可证 Ultralytics 许可 。
112
+
113
+
114
+ ### 四、演示系统
115
+
116
+ ## 指定文件目录并运行:
117
+ cd /root/lihui/StoryVisualization/visualsystem
118
+ python main.py
119
+
120
+
121
+ #
122
+ Your identification has been saved in .
123
+ Your public key has been saved in C:\Users\30254/.ssh/id_ed25519.pub.
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ pytorch_lightning<1.7.0
2
+ lightning-bolts
3
+ transformers==4.24.0
4
+ diffusers==0.7.2
5
+ timm
6
+ ftfy
7
+ hydra-core
8
+ opencv-python
9
+ h5py
10
+ scipy
run.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ python main.py
test.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import h5py
3
+ import copy
4
+ import os
5
+ import random
6
+
7
+ import numpy
8
+ import numpy as np
9
+ from PIL import Image
10
+
11
+
12
+ def gettext(index):
13
+ with h5py.File('/root/lihui/StoryVisualization/pororo.h5', 'r') as h5:
14
+ story = list()
15
+ h5 = h5['test']
16
+ # 读取当前索引处的文本,并使用decode方法将其解码为UTF-8
17
+ texts = h5['text'][index].decode('utf-8').split('|')
18
+ symbol = '\n'
19
+ texts = symbol.join(texts)
20
+ texts = 'Story<' + str(index) + '> :' + '\n' + texts
21
+ print(texts)
22
+ return texts
23
+
24
+
25
+ # for i in range(1000):
26
+ # gettext(i)
27
+
28
+ # 截取前100的数据集
29
+ # ###正确的##############
30
+ # # import h5py
31
+ # # import numpy as np
32
+ # # from PIL import Image
33
+ # #
34
+ # #
35
+ # # # 创建名为“images”的子目录来保存图像
36
+ # # os.makedirs("train_images", exist_ok=True)
37
+ # #
38
+ # # 创建一个h5文件
39
+ # nf = h5py.File('/root/lihui/StoryVisualization/pororo_100.h5', "w")
40
+ # with h5py.File('/root/lihui/StoryVisualization/pororo.h5', 'r') as f:
41
+ # test_group = f['test']
42
+ # texts = np.array(test_group['text'][()])
43
+ # ngroup = nf.create_group('test')
44
+ # ntext = ngroup.create_dataset('text', (100,), dtype=h5py.string_dtype(encoding='utf-8'))
45
+ # for i in range(100):
46
+ # ntext[i]=texts[i]
47
+ # print(f"样本 {i}:")
48
+ # # for j in range(5):
49
+ # # # 创建一个固定的文件名来保存图像
50
+ # # # filename = os.path.join("images", f"image_{i}_{j}.png")
51
+ # # # # 将HDF5文件中的图像数据保存到文件中
52
+ # # # with open(filename, "wb") as img_file:
53
+ # # # img_file.write(test_group[f'image{j}'][i])
54
+ # # # 打印文本信息和文件名
55
+ # # ntext[i]='|'.join(texts[i].decode('utf-8').split('|')[j])
56
+ # # print(f"图像{j}已保存到文件:{filename}")
57
+ # print(ntext[i])
58
+ # nf.close()
59
+
60
+ #保存测试集图像,随机截取视频帧
61
+ with h5py.File(r'C:\Users\zjlab\Desktop\StoryVisualization\pororo.h5', 'r') as h5:
62
+ h5 = h5['test']
63
+
64
+ for index in range(len(h5['text'])): #len(h5['text'])
65
+ # index = int(index + 1)
66
+ # print(index)
67
+ images = list()
68
+ for i in range(5):
69
+ # 从h5文件中读取一组图像和对应的文本。
70
+ im = h5['image{}'.format(i)][index]
71
+ # print(im)
72
+ # pil_img = Image.fromarray(im)
73
+ # # 保存图像
74
+ # pil_img.save(os.path.join('/root/lihui/StoryVisualization/ori_test_images', '{:04d}.png'.format(i)))
75
+ # 对每个图像解码
76
+ im = cv2.imdecode(im, cv2.IMREAD_COLOR)
77
+ # 随机选择一个128像素的图像切片
78
+ idx = random.randint(0, im.shape[0] / 128 - 1)
79
+ # 将切片后的图像加到images列表中
80
+ images.append(im[idx * 128: (idx + 1) * 128])
81
+ # 深拷贝,后续不随images变化
82
+ # ori_images = copy.deepcopy(images)
83
+ # 保存test原始图像
84
+
85
+ # for i, im in enumerate(images):
86
+ # file_path = 'C:/Users/zjlab/Desktop/StoryVisualization/test_images/group{:02d}_image{:02d}.png'.format(
87
+ # index + 1,
88
+ # i + 1)
89
+ # cv2.imwrite(file_path, im)
90
+
91
+ ori_images_pil = Image.fromarray(images[i])#numpy.uint8(images[i].detach().cpu().squeeze().float().numpy())).convert("RGB")
92
+ ori_images_pil.save(
93
+ os.path.join('C:/Users/zjlab/Desktop/StoryVisualization/test_images',
94
+ 'group{:02d}_image{:02d}.png'.format(index + 1,i + 1)))
transtoyolo.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import os
4
+ import numpy as np
5
+ import json
6
+ from glob import glob
7
+ import cv2
8
+ import shutil
9
+ import yaml
10
+ from sklearn.model_selection import train_test_split
11
+ from tqdm import tqdm
12
+
13
+
14
+ # 获取当前路径
15
+ ROOT_DIR = os.getcwd()
16
+
17
+ '''
18
+ 统一图像格式
19
+ '''
20
+ def change_image_format(label_path=ROOT_DIR, suffix='.png'):
21
+ """
22
+ 统一当前文件夹下所有图像的格式,如'.jpg'
23
+ :param suffix: 图像文件后缀
24
+ :param label_path:当前文件路径
25
+ :return:
26
+ """
27
+ externs = ['png', 'jpg', 'JPEG', 'BMP', 'bmp']
28
+ files = list()
29
+ # 获取尾缀在ecterns中的所有图像
30
+ for extern in externs:
31
+ files.extend(glob(label_path + "\\*." + extern))
32
+ # 遍历所有图像,转换图像格式
33
+ for file in files:
34
+ name = ''.join(file.split('.')[:-1])
35
+ file_suffix = file.split('.')[-1]
36
+ if file_suffix != suffix.split('.')[-1]:
37
+ # 重命名为jpg
38
+ new_name = name + suffix
39
+ # 读取图像
40
+ image = cv2.imread(file)
41
+ # 重新存图为jpg格式
42
+ cv2.imwrite(new_name, image)
43
+ # 删除旧图像
44
+ os.remove(file)
45
+
46
+
47
+
48
+ '''
49
+ 读取所有json文件,获取所有的类别
50
+ '''
51
+ def get_all_class(file_list, label_path=ROOT_DIR):
52
+ """
53
+ 从json文件中获取当前数据的所有类别
54
+ :param file_list:当前路径下的所有文件名
55
+ :param label_path:当前文件路径
56
+ :return:
57
+ """
58
+ # 初始化类别列表
59
+ classes = list()
60
+ # 遍历所有json,读取shape中的label值内容,添加到classes
61
+ for filename in tqdm(file_list):
62
+ json_path = os.path.join(label_path, filename + '.json')
63
+ json_file = json.load(open(json_path, "r", encoding="utf-8"))
64
+ for item in json_file["shapes"]:
65
+ label_class = item['label']
66
+ if label_class not in classes:
67
+ classes.append(label_class)
68
+ print('read file done')
69
+ return classes
70
+
71
+
72
+ '''
73
+ 划分训练集、验证机、测试集
74
+ '''
75
+ def split_dataset(label_path, test_size=0.3, isUseTest=False, useNumpyShuffle=False):
76
+ """
77
+ 将文件分为训练集,测试集和验证集
78
+ :param useNumpyShuffle: 使用numpy方法分割数据集
79
+ :param test_size: 分割测试集或验证集的比例
80
+ :param isUseTest: 是否使用测试集,默认为False
81
+ :param label_path:当前文件路径
82
+ :return:
83
+ """
84
+ # 获取所有json
85
+ files = glob(label_path + "\\*.json")
86
+ files = [i.replace("\\", "/").split("/")[-1].split(".json")[0] for i in files]
87
+
88
+ if useNumpyShuffle:
89
+ file_length = len(files)
90
+ index = np.arange(file_length)
91
+ np.random.seed(32)
92
+ np.random.shuffle(index) # 随机划分
93
+
94
+ test_files = None
95
+ # 是否有测试集
96
+ if isUseTest:
97
+ trainval_files, test_files = np.array(files)[index[:int(file_length * (1 - test_size))]], np.array(files)[
98
+ index[int(file_length * (1 - test_size)):]]
99
+ else:
100
+ trainval_files = files
101
+ # 划分训练集和测试集
102
+ train_files, val_files = np.array(trainval_files)[index[:int(len(trainval_files) * (1 - test_size))]], \
103
+ np.array(trainval_files)[index[int(len(trainval_files) * (1 - test_size)):]]
104
+ else:
105
+ test_files = None
106
+ if isUseTest:
107
+ trainval_files, test_files = train_test_split(files, test_size=test_size, random_state=55)
108
+ else:
109
+ trainval_files = files
110
+ train_files, val_files = train_test_split(trainval_files, test_size=test_size, random_state=55)
111
+
112
+ return train_files, val_files, test_files, files
113
+
114
+
115
+ '''
116
+ 生成yolov5的训练、验证、测试集的文件夹
117
+ '''
118
+ def create_save_file(label_path=ROOT_DIR):
119
+ """
120
+ 按照训练时的图像和标注路径创建文件夹
121
+ :param label_path:当前文件路径
122
+ :return:
123
+ """
124
+ # 生成训练集
125
+ train_image = os.path.join(label_path, 'train', 'images')
126
+ if not os.path.exists(train_image):
127
+ os.makedirs(train_image)
128
+ train_label = os.path.join(label_path, 'train', 'labels')
129
+ if not os.path.exists(train_label):
130
+ os.makedirs(train_label)
131
+ # 生成验证集
132
+ val_image = os.path.join(label_path, 'valid', 'images')
133
+ if not os.path.exists(val_image):
134
+ os.makedirs(val_image)
135
+ val_label = os.path.join(label_path, 'valid', 'labels')
136
+ if not os.path.exists(val_label):
137
+ os.makedirs(val_label)
138
+ # 生成测试集
139
+ test_image = os.path.join(label_path, 'test', 'images')
140
+ if not os.path.exists(test_image):
141
+ os.makedirs(test_image)
142
+ test_label = os.path.join(label_path, 'test', 'labels')
143
+ if not os.path.exists(test_label):
144
+ os.makedirs(test_label)
145
+ return train_image, train_label, val_image, val_label, test_image, test_label
146
+
147
+
148
+
149
+ '''
150
+ 转换,根据图像大小,返回box框的中点和高宽信息
151
+ '''
152
+ def convert(size, box):
153
+ # 宽
154
+ dw = 1. / (size[0])
155
+ # 高
156
+ dh = 1. / (size[1])
157
+
158
+ x = (box[0] + box[1]) / 2.0 - 1
159
+ y = (box[2] + box[3]) / 2.0 - 1
160
+ # 宽
161
+ w = box[1] - box[0]
162
+ # 高
163
+ h = box[3] - box[2]
164
+
165
+ x = x * dw
166
+ w = w * dw
167
+ y = y * dh
168
+ h = h * dh
169
+ return x, y, w, h
170
+
171
+
172
+ '''
173
+ 移动图像和标注文件到指定的训练集、验证集和测试集中
174
+ '''
175
+ def push_into_file(file, images, labels, label_path=ROOT_DIR, suffix='.jpg'):
176
+ """
177
+ 最终生成在当前文件夹下的所有文件按image和label分别存在到训练集/验证集/测试集路径的文件夹下
178
+ :param file: 文件名列表
179
+ :param images: 存放images的路径
180
+ :param labels: 存放labels的路径
181
+ :param label_path: 当前文件路径
182
+ :param suffix: 图像文件后缀
183
+ :return:
184
+ """
185
+ # 遍历所有文件
186
+ for filename in file:
187
+ # 图像文件
188
+ image_file = os.path.join(label_path, filename + suffix)
189
+ # 标注文件
190
+ label_file = os.path.join(label_path, filename + '.txt')
191
+ # yolov5存放图像文件夹
192
+ if not os.path.exists(os.path.join(images, filename + suffix)):
193
+ try:
194
+ shutil.move(image_file, images)
195
+ except OSError:
196
+ pass
197
+ # yolov5存放标注文件夹
198
+ if not os.path.exists(os.path.join(labels, filename + suffix)):
199
+ try:
200
+ shutil.move(label_file, labels)
201
+ except OSError:
202
+ pass
203
+
204
+ '''
205
+
206
+ '''
207
+ def json2txt(classes, txt_Name='allfiles', label_path=ROOT_DIR, suffix='.png'):
208
+ """
209
+ 将json文件转化为txt文件,并将json文件存放到指定文件夹
210
+ :param classes: 类别名
211
+ :param txt_Name:txt文件,用来存放所有文件的路径
212
+ :param label_path:当前文件路径
213
+ :param suffix:图像文件后缀
214
+ :return:
215
+ """
216
+ store_json = os.path.join(label_path, 'json')
217
+ if not os.path.exists(store_json):
218
+ os.makedirs(store_json)
219
+
220
+ _, _, _, files = split_dataset(label_path)
221
+ if not os.path.exists(os.path.join(label_path, 'tmp')):
222
+ os.makedirs(os.path.join(label_path, 'tmp'))
223
+
224
+ list_file = open('tmp/%s.txt' % txt_Name, 'w')
225
+ for json_file_ in tqdm(files):
226
+ json_filename = os.path.join(label_path, json_file_ + ".json")
227
+ imagePath = os.path.join(label_path, json_file_ + suffix)
228
+ list_file.write('%s\n' % imagePath)
229
+ out_file = open('%s/%s.txt' % (label_path, json_file_), 'w')
230
+ json_file = json.load(open(json_filename, "r", encoding="utf-8"))
231
+ if os.path.exists(imagePath):
232
+ height, width, channels = cv2.imread(imagePath).shape
233
+ for multi in json_file["shapes"]:
234
+ if len(multi["points"][0]) == 0:
235
+ out_file.write('')
236
+ continue
237
+ points = np.array(multi["points"])
238
+ xmin = min(points[:, 0]) if min(points[:, 0]) > 0 else 0
239
+ xmax = max(points[:, 0]) if max(points[:, 0]) > 0 else 0
240
+ ymin = min(points[:, 1]) if min(points[:, 1]) > 0 else 0
241
+ ymax = max(points[:, 1]) if max(points[:, 1]) > 0 else 0
242
+ label = multi["label"]
243
+ if xmax <= xmin:
244
+ pass
245
+ elif ymax <= ymin:
246
+ pass
247
+ else:
248
+ cls_id = classes.index(label)
249
+ b = (float(xmin), float(xmax), float(ymin), float(ymax))
250
+ bb = convert((width, height), b)
251
+ out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
252
+ # print(json_filename, xmin, ymin, xmax, ymax, cls_id)
253
+ if not os.path.exists(os.path.join(store_json, json_file_ + '.json')):
254
+ try:
255
+ shutil.move(json_filename, store_json)
256
+ except OSError:
257
+ pass
258
+
259
+ '''
260
+ 创建yaml文件
261
+ '''
262
+ def create_yaml(classes, label_path, isUseTest=False):
263
+ nc = len(classes)
264
+ if not isUseTest:
265
+ desired_caps = {
266
+ 'path': label_path,
267
+ 'train': 'train/images',
268
+ 'val': 'valid/images',
269
+ 'nc': nc,
270
+ 'names': classes
271
+ }
272
+ else:
273
+ desired_caps = {
274
+ 'path': label_path,
275
+ 'train': 'train/images',
276
+ 'val': 'valid/images',
277
+ 'test': 'test/images',
278
+ 'nc': nc,
279
+ 'names': classes
280
+ }
281
+ yamlpath = os.path.join(label_path, "data" + ".yaml")
282
+
283
+ # 写入到yaml文件
284
+ with open(yamlpath, "w+", encoding="utf-8") as f:
285
+ for key, val in desired_caps.items():
286
+ yaml.dump({key: val}, f, default_flow_style=False)
287
+
288
+
289
+ # 首先确保当前文件夹下的所有图片统一后缀,如.jpg,如果为其他后缀,将suffix改为对应的后缀,如.png
290
+ def ChangeToYolo5(label_path=r"D:\storydata", suffix='.png', test_size=0.1, isUseTest=False):
291
+ """
292
+ 生成最终标准格式的文件
293
+ :param test_size: 分割测试集或验证集的比例
294
+ :param label_path:当前文件路径
295
+ :param suffix: 文件后缀名
296
+ :param isUseTest: 是否使用测试集
297
+ :return:
298
+ """
299
+ # step1:统一图像格式
300
+ change_image_format(label_path)
301
+ # step2:根据json文件划分训练集、验证集、测试集
302
+ train_files, val_files, test_file, files = split_dataset(label_path, test_size=test_size, isUseTest=isUseTest)
303
+ # step3:根据json文件,获取所有类别
304
+ classes = get_all_class(files)
305
+ # step4:将json文件转化为txt文件,并将json文件存放到指定文件夹
306
+ json2txt(classes)
307
+ # step5:创建yolov5训练所需的yaml文件
308
+ create_yaml(classes, label_path, isUseTest=isUseTest)
309
+ # step6:生成yolov5的训练、验证、测试集的文件夹
310
+ train_image, train_label, val_image, val_label, test_image, test_label = create_save_file(label_path)
311
+ # step7:将所有图像和标注文件,移动到对应的训练集、验证集、测试集
312
+ push_into_file(train_files, train_image, train_label, suffix=suffix) # 将文件移动到训练集文件中
313
+ push_into_file(val_files, val_image, val_label, suffix=suffix) # 将文件移动到验证集文件夹中
314
+ if test_file is not None: # 如果测试集存在,则将文件移动到测试集文件中
315
+ push_into_file(test_file, test_image, test_label, suffix=suffix)
316
+ print('create dataset done')
317
+
318
+
319
+ if __name__ == "__main__":
320
+ ChangeToYolo5()
v1-5-pruned-emaonly.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:6ce0161689b3853acaa03779ec93eafe75a02f4ced659bee03f50797806fa2fa
3
- size 4265146304
 
 
 
 
v1-5-pruned.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1a189f0be69d6106a48548e7626207dddd7042a418dbf372cefd05e0cdba61b6
3
- size 7703324286