Zhouyan248 commited on
Commit
0035a82
1 Parent(s): 9a47fb7

Upload 119 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. Close-up-essence-is-poured-from-bottleKodak-Vision3-50_slow-motion_0000_001.mp4 +0 -0
  3. Close-up_essence_is_poured_from_bottleKodak_Vision.png +0 -0
  4. README.md +65 -12
  5. The-picture-shows-the-beauty-of-the-sea-and-at-the-sam_slow-motion_0000_11301.mp4 +0 -0
  6. The-picture-shows-the-beauty-of-the-sea-and-at-the-sam_slow-motion_0000_6600.mp4 +0 -0
  7. The_picture_shows_the_beauty_of_the_sea.png +0 -0
  8. The_picture_shows_the_beauty_of_the_sea_.jpg +0 -0
  9. __pycache__/download.cpython-310.pyc +0 -0
  10. __pycache__/download.cpython-311.pyc +0 -0
  11. __pycache__/download.cpython-39.pyc +0 -0
  12. __pycache__/utils.cpython-310.pyc +0 -0
  13. __pycache__/utils.cpython-311.pyc +0 -0
  14. __pycache__/utils.cpython-39.pyc +0 -0
  15. app.py +183 -0
  16. configs/sample_i2v.yaml +36 -0
  17. configs/sample_transition.yaml +33 -0
  18. datasets/__pycache__/video_transforms.cpython-311.pyc +0 -0
  19. datasets/__pycache__/video_transforms.cpython-39.pyc +0 -0
  20. datasets/video_transforms.py +472 -0
  21. diffusion/__init__.py +47 -0
  22. diffusion/__pycache__/__init__.cpython-310.pyc +0 -0
  23. diffusion/__pycache__/__init__.cpython-311.pyc +0 -0
  24. diffusion/__pycache__/__init__.cpython-38.pyc +0 -0
  25. diffusion/__pycache__/__init__.cpython-39.pyc +0 -0
  26. diffusion/__pycache__/diffusion_utils.cpython-310.pyc +0 -0
  27. diffusion/__pycache__/diffusion_utils.cpython-311.pyc +0 -0
  28. diffusion/__pycache__/diffusion_utils.cpython-38.pyc +0 -0
  29. diffusion/__pycache__/diffusion_utils.cpython-39.pyc +0 -0
  30. diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc +0 -0
  31. diffusion/__pycache__/gaussian_diffusion.cpython-311.pyc +0 -0
  32. diffusion/__pycache__/gaussian_diffusion.cpython-38.pyc +0 -0
  33. diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc +0 -0
  34. diffusion/__pycache__/respace.cpython-310.pyc +0 -0
  35. diffusion/__pycache__/respace.cpython-311.pyc +0 -0
  36. diffusion/__pycache__/respace.cpython-38.pyc +0 -0
  37. diffusion/__pycache__/respace.cpython-39.pyc +0 -0
  38. diffusion/diffusion_utils.py +88 -0
  39. diffusion/gaussian_diffusion.py +931 -0
  40. diffusion/respace.py +130 -0
  41. diffusion/timestep_sampler.py +150 -0
  42. download.py +44 -0
  43. env.yaml +20 -0
  44. huggingface-i2v/__init__.py +0 -0
  45. huggingface-i2v/requirements.txt +0 -0
  46. image_to_video/__init__.py +221 -0
  47. image_to_video/__pycache__/__init__.cpython-311.pyc +0 -0
  48. input/i2v/Close-up_essence_is_poured_from_bottleKodak_Vision.png +0 -0
  49. input/i2v/The_picture_shows_the_beauty_of_the_sea.png +0 -0
  50. input/i2v/The_picture_shows_the_beauty_of_the_sea_and_at_the_same.png +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ input/transition/1/2-Wide[[:space:]]angle[[:space:]]shot[[:space:]]of[[:space:]]an[[:space:]]alien[[:space:]]planet[[:space:]]with[[:space:]]cherry[[:space:]]blossom[[:space:]]forest-2.png filter=lfs diff=lfs merge=lfs -text
Close-up-essence-is-poured-from-bottleKodak-Vision3-50_slow-motion_0000_001.mp4 ADDED
Binary file (301 kB). View file
 
Close-up_essence_is_poured_from_bottleKodak_Vision.png ADDED
README.md CHANGED
@@ -1,12 +1,65 @@
1
- ---
2
- title: SEINE
3
- emoji: 📚
4
- colorFrom: yellow
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 4.7.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SEINE
2
+ This repository is the official implementation of [SEINE](https://arxiv.org/abs/2310.20700).
3
+
4
+ **[SEINE: Short-to-Long Video Diffusion Model for Generative Transition and Prediction](https://arxiv.org/abs/2310.20700)**
5
+
6
+ [Arxiv Report](https://arxiv.org/abs/2310.20700) | [Project Page](https://vchitect.github.io/SEINE-project/)
7
+
8
+ <img src="seine.gif" width="800">
9
+
10
+
11
+ ## Setups for Inference
12
+
13
+ ### Prepare Environment
14
+ ```
15
+ conda env create -f env.yaml
16
+ conda activate seine
17
+ ```
18
+
19
+ ### Downlaod our model and T2I base model
20
+ Download our model checkpoint from [Google Drive](https://drive.google.com/drive/folders/1cWfeDzKJhpb0m6HA5DoMOH0_ItuUY95b?usp=sharing) and save to directory of ```pre-trained```
21
+
22
+
23
+ Our model is based on Stable diffusion v1.4, you may download [Stable Diffusion v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) to the director of ``` pre-trained ```
24
+
25
+ Now under `./pretrained`, you should be able to see the following:
26
+ ```
27
+ ├── pretrained_models
28
+ │ ├── seine.pt
29
+ │ ├── stable-diffusion-v1-4
30
+ │ │ ├── ...
31
+ └── └── ├── ...
32
+ ├── ...
33
+ ```
34
+
35
+ #### Inference for I2V
36
+ ```python
37
+ python sample_scripts/with_mask_sample.py --config configs/sample_i2v.yaml
38
+ ```
39
+ The generated video will be saved in ```./results/i2v```.
40
+
41
+ #### Inference for Transition
42
+ ```python
43
+ python sample_scripts/with_mask_sample.py --config configs/sample_transition.yaml
44
+ ```
45
+ The generated video will be saved in ```./results/transition```.
46
+
47
+
48
+
49
+ #### More Details
50
+ You can modify ```./configs/sample_mask.yaml``` to change the generation conditions.
51
+ For example,
52
+ ```ckpt``` is used to specify a model checkpoint.
53
+ ```text_prompt``` is used to describe the content of the video.
54
+ ```input_path``` is used to specify the path to the image.
55
+
56
+
57
+ ## BibTeX
58
+ ```bibtex
59
+ @article{chen2023seine,
60
+ title={SEINE: Short-to-Long Video Diffusion Model for Generative Transition and Prediction},
61
+ author={Chen, Xinyuan and Wang, Yaohui and Zhang, Lingjun and Zhuang, Shaobin and Ma, Xin and Yu, Jiashuo and Wang, Yali and Lin, Dahua and Qiao, Yu and Liu, Ziwei},
62
+ journal={arXiv preprint arXiv:2310.20700},
63
+ year={2023}
64
+ }
65
+ ```
The-picture-shows-the-beauty-of-the-sea-and-at-the-sam_slow-motion_0000_11301.mp4 ADDED
Binary file (397 kB). View file
 
The-picture-shows-the-beauty-of-the-sea-and-at-the-sam_slow-motion_0000_6600.mp4 ADDED
Binary file (439 kB). View file
 
The_picture_shows_the_beauty_of_the_sea.png ADDED
The_picture_shows_the_beauty_of_the_sea_.jpg ADDED
__pycache__/download.cpython-310.pyc ADDED
Binary file (1.29 kB). View file
 
__pycache__/download.cpython-311.pyc ADDED
Binary file (1.85 kB). View file
 
__pycache__/download.cpython-39.pyc ADDED
Binary file (1.29 kB). View file
 
__pycache__/utils.cpython-310.pyc ADDED
Binary file (10.4 kB). View file
 
__pycache__/utils.cpython-311.pyc ADDED
Binary file (19.2 kB). View file
 
__pycache__/utils.cpython-39.pyc ADDED
Binary file (10.5 kB). View file
 
app.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from image_to_video import model_i2v_fun, get_input, auto_inpainting, setup_seed
3
+ from omegaconf import OmegaConf
4
+ import torch
5
+ from diffusers.utils.import_utils import is_xformers_available
6
+ import torchvision
7
+ from utils import mask_generation_before
8
+ import os
9
+ import cv2
10
+
11
+ config_path = "/mnt/petrelfs/zhouyan/project/i2v/configs/sample_i2v.yaml"
12
+ args = OmegaConf.load(config_path)
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ # ------- get model ---------------
16
+ # model_i2V = model_i2v_fun()
17
+ # model_i2V.to("cuda")
18
+
19
+ # vae, model, text_encoder, diffusion = model_i2v_fun(args)
20
+ # vae.to(device)
21
+ # model.to(device)
22
+ # text_encoder.to(device)
23
+
24
+ # if args.use_fp16:
25
+ # vae.to(dtype=torch.float16)
26
+ # model.to(dtype=torch.float16)
27
+ # text_encoder.to(dtype=torch.float16)
28
+
29
+ # if args.enable_xformers_memory_efficient_attention and device=="cuda":
30
+ # if is_xformers_available():
31
+ # model.enable_xformers_memory_efficient_attention()
32
+ # else:
33
+ # raise ValueError("xformers is not available. Make sure it is installed correctly")
34
+
35
+
36
+ css = """
37
+ h1 {
38
+ text-align: center;
39
+ }
40
+ #component-0 {
41
+ max-width: 730px;
42
+ margin: auto;
43
+ }
44
+ """
45
+
46
+ def infer(prompt, image_inp, seed_inp, ddim_steps):
47
+ setup_seed(seed_inp)
48
+ args.num_sampling_steps = ddim_steps
49
+ ###先测试Image的返回类型
50
+ print(prompt, seed_inp, ddim_steps, type(image_inp))
51
+ img = cv2.imread(image_inp)
52
+ new_size = [img.shape[0],img.shape[1]]
53
+ # if(img.shape[0]==512 and img.shape[1]==512):
54
+ # args.image_size = [512,512]
55
+ # elif(img.shape[0]==320 and img.shape[1]==512):
56
+ # args.image_size = [320, 512]
57
+ # elif(img.shape[0]==292 and img.shape[1]==512):
58
+ # args.image_size = [292,512]
59
+ # else:
60
+ # raise ValueError("Please enter image of right size")
61
+ # print(args.image_size)
62
+ args.image_size = new_size
63
+
64
+ vae, model, text_encoder, diffusion = model_i2v_fun(args)
65
+ vae.to(device)
66
+ model.to(device)
67
+ text_encoder.to(device)
68
+
69
+ if args.use_fp16:
70
+ vae.to(dtype=torch.float16)
71
+ model.to(dtype=torch.float16)
72
+ text_encoder.to(dtype=torch.float16)
73
+
74
+ if args.enable_xformers_memory_efficient_attention and device=="cuda":
75
+ if is_xformers_available():
76
+ model.enable_xformers_memory_efficient_attention()
77
+ else:
78
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
79
+
80
+
81
+ video_input, reserve_frames = get_input(image_inp, args)
82
+ video_input = video_input.to(device).unsqueeze(0)
83
+ mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device)
84
+ masked_video = video_input * (mask == 0)
85
+ prompt = "tilt up, high quality, stable "
86
+ prompt = prompt + args.additional_prompt
87
+ video_clip = auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,)
88
+ video_ = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
89
+ torchvision.io.write_video(os.path.join(args.save_img_path, prompt+ '.mp4'), video_, fps=8)
90
+
91
+
92
+
93
+ # video = model_i2V(prompt, image_inp, seed_inp, ddim_steps)
94
+
95
+ return os.path.join(args.save_img_path, prompt+ '.mp4')
96
+
97
+
98
+
99
+ def clean():
100
+ # return gr.Image.update(value=None, visible=False), gr.Video.update(value=None)
101
+ return gr.Video.update(value=None)
102
+
103
+
104
+ title = """
105
+ <div style="text-align: center; max-width: 700px; margin: 0 auto;">
106
+ <div
107
+ style="
108
+ display: inline-flex;
109
+ align-items: center;
110
+ gap: 0.8rem;
111
+ font-size: 1.75rem;
112
+ "
113
+ >
114
+ <h1 style="font-weight: 900; margin-bottom: 7px; margin-top: 5px;">
115
+ SEINE: Image-to-Video generation
116
+ </h1>
117
+ </div>
118
+ <p style="margin-bottom: 10px; font-size: 94%">
119
+ Apply SEINE to generate a video
120
+ </p>
121
+ </div>
122
+ """
123
+
124
+
125
+
126
+ with gr.Blocks(css='style.css') as demo:
127
+ gr.Markdown("<font color=red size=10><center>SEINE: Image-to-Video generation</center></font>")
128
+ with gr.Column(elem_id="col-container"):
129
+ # gr.HTML(title)
130
+
131
+ with gr.Row():
132
+ with gr.Column():
133
+ image_inp = gr.Image(type='filepath')
134
+
135
+ with gr.Column():
136
+
137
+ prompt = gr.Textbox(label="Prompt", placeholder="enter prompt", show_label=True, elem_id="prompt-in")
138
+
139
+ with gr.Row():
140
+ # control_task = gr.Dropdown(label="Task", choices=["Text-2-video", "Image-2-video"], value="Text-2-video", multiselect=False, elem_id="controltask-in")
141
+ ddim_steps = gr.Slider(label='Steps', minimum=50, maximum=300, value=250, step=1)
142
+ seed_inp = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, value=250, elem_id="seed-in")
143
+
144
+ # ddim_steps = gr.Slider(label='Steps', minimum=50, maximum=300, value=250, step=1)
145
+
146
+
147
+
148
+ submit_btn = gr.Button("Generate video")
149
+ clean_btn = gr.Button("Clean video")
150
+
151
+ video_out = gr.Video(label="Video result", elem_id="video-output", width = 800)
152
+ inputs = [prompt,image_inp, seed_inp, ddim_steps]
153
+ outputs = [video_out]
154
+ ex = gr.Examples(
155
+ examples = [["/mnt/petrelfs/zhouyan/project/i2v/The_picture_shows_the_beauty_of_the_sea_.jpg","A video of the beauty of the sea",123,50],
156
+ ["/mnt/petrelfs/zhouyan/project/i2v/The_picture_shows_the_beauty_of_the_sea.png","A video of the beauty of the sea",123,50],
157
+ ["/mnt/petrelfs/zhouyan/project/i2v/Close-up_essence_is_poured_from_bottleKodak_Vision.png","A video of close-up essence is poured from bottleKodak Vision",123,50]],
158
+ fn = infer,
159
+ inputs = [image_inp, prompt, seed_inp, ddim_steps],
160
+ outputs=[video_out],
161
+ cache_examples=False
162
+
163
+
164
+ )
165
+ ex.dataset.headers = [""]
166
+ # gr.Markdown("<center>some examples</center>")
167
+ # with gr.Row():
168
+ # gr.Image(value="/mnt/petrelfs/zhouyan/project/i2v/The_picture_shows_the_beauty_of_the_sea_.jpg")
169
+ # gr.Image(value="/mnt/petrelfs/zhouyan/project/i2v/The_picture_shows_the_beauty_of_the_sea.png")
170
+ # gr.Image(value="/mnt/petrelfs/zhouyan/project/i2v/Close-up_essence_is_poured_from_bottleKodak_Vision.png")
171
+ # with gr.Row():
172
+ # gr.Video(value="/mnt/petrelfs/zhouyan/project/i2v/The-picture-shows-the-beauty-of-the-sea-and-at-the-sam_slow-motion_0000_11301.mp4")
173
+ # gr.Video(value="/mnt/petrelfs/zhouyan/project/i2v/The-picture-shows-the-beauty-of-the-sea-and-at-the-sam_slow-motion_0000_6600.mp4")
174
+ # gr.Video(value="/mnt/petrelfs/zhouyan/project/i2v/Close-up-essence-is-poured-from-bottleKodak-Vision3-50_slow-motion_0000_001.mp4")
175
+ # control_task.change(change_task_options, inputs=[control_task], outputs=[canny_opt, hough_opt, normal_opt], queue=False)
176
+ clean_btn.click(clean, inputs=[], outputs=[video_out], queue=False)
177
+ submit_btn.click(infer, inputs, outputs)
178
+ # share_button.click(None, [], [], _js=share_js)
179
+
180
+
181
+ demo.queue(max_size=12).launch(server_name="0.0.0.0",server_port=7861)
182
+
183
+
configs/sample_i2v.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ckpt: "/mnt/petrelfs/share_data/chenxinyuan/code/SEINE-release/pre-trained/seine.pt"
3
+ # save_img_path: "./results/i2v/"
4
+ save_img_path: "/mnt/petrelfs/share_data/zhouyan/gradio_i2v/"
5
+ pretrained_model_path: "pre-trained/stable-diffusion-v1-4/"
6
+
7
+ # model config:
8
+ model: TAVU
9
+ num_frames: 16
10
+ frame_interval: 1
11
+ image_size: [512, 512]
12
+ #image_size: [320, 512]
13
+ # image_size: [512, 512]
14
+
15
+ # model speedup
16
+ use_compile: False
17
+ use_fp16: True
18
+ enable_xformers_memory_efficient_attention: True
19
+ img_path: "/mnt/petrelfs/zhouyan/tmp/last"
20
+ # sample config:
21
+ seed:
22
+ run_time: 13
23
+ cfg_scale: 8.0
24
+ sample_method: 'ddpm'
25
+ num_sampling_steps: 250
26
+ text_prompt: ["slow motion"]
27
+ additional_prompt: ", slow motion."
28
+ negative_prompt: ""
29
+ do_classifier_free_guidance: True
30
+
31
+ # autoregressive config:
32
+ # input_path: "/mnt/petrelfs/zhouyan/tmp/未来上海/WechatIMG9434.jpg"
33
+ input_path: "/mnt/petrelfs/zhouyan/tmp/last"
34
+ researve_frame: 1
35
+ mask_type: "first1"
36
+ use_mask: True
configs/sample_transition.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ckpt: "pre-trained/0020000.pt"
3
+ save_img_path: "./results/transition/"
4
+ pretrained_model_path: "pre-trained/stable-diffusion-v1-4/"
5
+
6
+ # model config:
7
+ model: TAVU
8
+ num_frames: 16
9
+ frame_interval: 1
10
+ #image_size: [240, 560]
11
+ #image_size: [320, 512]
12
+ image_size: [512, 512]
13
+
14
+ # model speedup
15
+ use_compile: False
16
+ use_fp16: True
17
+ enable_xformers_memory_efficient_attention: True
18
+
19
+ # sample config:
20
+ seed:
21
+ run_time: 13
22
+ cfg_scale: 8.0
23
+ sample_method: 'ddpm'
24
+ num_sampling_steps: 250
25
+ text_prompt: ['smooth transition']
26
+ additional_prompt: "smooth transition."
27
+ negative_prompt: ""
28
+ do_classifier_free_guidance: True
29
+
30
+ # autoregressive config:
31
+ input_path: 'input/transition/1'
32
+ mask_type: "onelast1"
33
+ use_mask: True
datasets/__pycache__/video_transforms.cpython-311.pyc ADDED
Binary file (23.3 kB). View file
 
datasets/__pycache__/video_transforms.cpython-39.pyc ADDED
Binary file (14.8 kB). View file
 
datasets/video_transforms.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numbers
4
+ from torchvision.transforms import RandomCrop, RandomResizedCrop
5
+ from PIL import Image
6
+
7
+ def _is_tensor_video_clip(clip):
8
+ if not torch.is_tensor(clip):
9
+ raise TypeError("clip should be Tensor. Got %s" % type(clip))
10
+
11
+ if not clip.ndimension() == 4:
12
+ raise ValueError("clip should be 4D. Got %dD" % clip.dim())
13
+
14
+ return True
15
+
16
+
17
+ def center_crop_arr(pil_image, image_size):
18
+ """
19
+ Center cropping implementation from ADM.
20
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
21
+ """
22
+ while min(*pil_image.size) >= 2 * image_size:
23
+ pil_image = pil_image.resize(
24
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
25
+ )
26
+
27
+ scale = image_size / min(*pil_image.size)
28
+ pil_image = pil_image.resize(
29
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
30
+ )
31
+
32
+ arr = np.array(pil_image)
33
+ crop_y = (arr.shape[0] - image_size) // 2
34
+ crop_x = (arr.shape[1] - image_size) // 2
35
+ return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
36
+
37
+
38
+ def crop(clip, i, j, h, w):
39
+ """
40
+ Args:
41
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
42
+ """
43
+ if len(clip.size()) != 4:
44
+ raise ValueError("clip should be a 4D tensor")
45
+ return clip[..., i : i + h, j : j + w]
46
+
47
+
48
+ def resize(clip, target_size, interpolation_mode):
49
+ if len(target_size) != 2:
50
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
51
+ return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
52
+
53
+ def resize_scale(clip, target_size, interpolation_mode):
54
+ if len(target_size) != 2:
55
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
56
+ H, W = clip.size(-2), clip.size(-1)
57
+ scale_ = target_size[0] / min(H, W)
58
+ return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
59
+
60
+ def resize_with_scale_factor(clip, scale_factor, interpolation_mode):
61
+ return torch.nn.functional.interpolate(clip, scale_factor=scale_factor, mode=interpolation_mode, align_corners=False)
62
+
63
+ def resize_scale_with_height(clip, target_size, interpolation_mode):
64
+ H, W = clip.size(-2), clip.size(-1)
65
+ scale_ = target_size / H
66
+ return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
67
+
68
+ def resize_scale_with_weight(clip, target_size, interpolation_mode):
69
+ H, W = clip.size(-2), clip.size(-1)
70
+ scale_ = target_size / W
71
+ return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
72
+
73
+
74
+ def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
75
+ """
76
+ Do spatial cropping and resizing to the video clip
77
+ Args:
78
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
79
+ i (int): i in (i,j) i.e coordinates of the upper left corner.
80
+ j (int): j in (i,j) i.e coordinates of the upper left corner.
81
+ h (int): Height of the cropped region.
82
+ w (int): Width of the cropped region.
83
+ size (tuple(int, int)): height and width of resized clip
84
+ Returns:
85
+ clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
86
+ """
87
+ if not _is_tensor_video_clip(clip):
88
+ raise ValueError("clip should be a 4D torch.tensor")
89
+ clip = crop(clip, i, j, h, w)
90
+ clip = resize(clip, size, interpolation_mode)
91
+ return clip
92
+
93
+
94
+ def center_crop(clip, crop_size):
95
+ if not _is_tensor_video_clip(clip):
96
+ raise ValueError("clip should be a 4D torch.tensor")
97
+ h, w = clip.size(-2), clip.size(-1)
98
+ # print(clip.shape)
99
+ th, tw = crop_size
100
+ if h < th or w < tw:
101
+ # print(h, w)
102
+ raise ValueError("height {} and width {} must be no smaller than crop_size".format(h, w))
103
+
104
+ i = int(round((h - th) / 2.0))
105
+ j = int(round((w - tw) / 2.0))
106
+ return crop(clip, i, j, th, tw)
107
+
108
+
109
+ def center_crop_using_short_edge(clip):
110
+ if not _is_tensor_video_clip(clip):
111
+ raise ValueError("clip should be a 4D torch.tensor")
112
+ h, w = clip.size(-2), clip.size(-1)
113
+ if h < w:
114
+ th, tw = h, h
115
+ i = 0
116
+ j = int(round((w - tw) / 2.0))
117
+ else:
118
+ th, tw = w, w
119
+ i = int(round((h - th) / 2.0))
120
+ j = 0
121
+ return crop(clip, i, j, th, tw)
122
+
123
+
124
+ def random_shift_crop(clip):
125
+ '''
126
+ Slide along the long edge, with the short edge as crop size
127
+ '''
128
+ if not _is_tensor_video_clip(clip):
129
+ raise ValueError("clip should be a 4D torch.tensor")
130
+ h, w = clip.size(-2), clip.size(-1)
131
+
132
+ if h <= w:
133
+ long_edge = w
134
+ short_edge = h
135
+ else:
136
+ long_edge = h
137
+ short_edge =w
138
+
139
+ th, tw = short_edge, short_edge
140
+
141
+ i = torch.randint(0, h - th + 1, size=(1,)).item()
142
+ j = torch.randint(0, w - tw + 1, size=(1,)).item()
143
+ return crop(clip, i, j, th, tw)
144
+
145
+
146
+ def to_tensor(clip):
147
+ """
148
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
149
+ permute the dimensions of clip tensor
150
+ Args:
151
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
152
+ Return:
153
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
154
+ """
155
+ _is_tensor_video_clip(clip)
156
+ if not clip.dtype == torch.uint8:
157
+ raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
158
+ # return clip.float().permute(3, 0, 1, 2) / 255.0
159
+ return clip.float() / 255.0
160
+
161
+
162
+ def normalize(clip, mean, std, inplace=False):
163
+ """
164
+ Args:
165
+ clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
166
+ mean (tuple): pixel RGB mean. Size is (3)
167
+ std (tuple): pixel standard deviation. Size is (3)
168
+ Returns:
169
+ normalized clip (torch.tensor): Size is (T, C, H, W)
170
+ """
171
+ if not _is_tensor_video_clip(clip):
172
+ raise ValueError("clip should be a 4D torch.tensor")
173
+ if not inplace:
174
+ clip = clip.clone()
175
+ mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
176
+ # print(mean)
177
+ std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
178
+ clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
179
+ return clip
180
+
181
+
182
+ def hflip(clip):
183
+ """
184
+ Args:
185
+ clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
186
+ Returns:
187
+ flipped clip (torch.tensor): Size is (T, C, H, W)
188
+ """
189
+ if not _is_tensor_video_clip(clip):
190
+ raise ValueError("clip should be a 4D torch.tensor")
191
+ return clip.flip(-1)
192
+
193
+
194
+ class RandomCropVideo:
195
+ def __init__(self, size):
196
+ if isinstance(size, numbers.Number):
197
+ self.size = (int(size), int(size))
198
+ else:
199
+ self.size = size
200
+
201
+ def __call__(self, clip):
202
+ """
203
+ Args:
204
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
205
+ Returns:
206
+ torch.tensor: randomly cropped video clip.
207
+ size is (T, C, OH, OW)
208
+ """
209
+ i, j, h, w = self.get_params(clip)
210
+ return crop(clip, i, j, h, w)
211
+
212
+ def get_params(self, clip):
213
+ h, w = clip.shape[-2:]
214
+ th, tw = self.size
215
+
216
+ if h < th or w < tw:
217
+ raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
218
+
219
+ if w == tw and h == th:
220
+ return 0, 0, h, w
221
+
222
+ i = torch.randint(0, h - th + 1, size=(1,)).item()
223
+ j = torch.randint(0, w - tw + 1, size=(1,)).item()
224
+
225
+ return i, j, th, tw
226
+
227
+ def __repr__(self) -> str:
228
+ return f"{self.__class__.__name__}(size={self.size})"
229
+
230
+ class CenterCropResizeVideo:
231
+ '''
232
+ First use the short side for cropping length,
233
+ center crop video, then resize to the specified size
234
+ '''
235
+ def __init__(
236
+ self,
237
+ size,
238
+ interpolation_mode="bilinear",
239
+ ):
240
+ if isinstance(size, tuple):
241
+ if len(size) != 2:
242
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
243
+ self.size = size
244
+ else:
245
+ self.size = (size, size)
246
+
247
+ self.interpolation_mode = interpolation_mode
248
+
249
+
250
+ def __call__(self, clip):
251
+ """
252
+ Args:
253
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
254
+ Returns:
255
+ torch.tensor: scale resized / center cropped video clip.
256
+ size is (T, C, crop_size, crop_size)
257
+ """
258
+ # print(clip.shape)
259
+ clip_center_crop = center_crop_using_short_edge(clip)
260
+ # print(clip_center_crop.shape) 320 512
261
+ clip_center_crop_resize = resize(clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode)
262
+ return clip_center_crop_resize
263
+
264
+ def __repr__(self) -> str:
265
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
266
+
267
+ class WebVideo320512:
268
+ def __init__(
269
+ self,
270
+ size,
271
+ interpolation_mode="bilinear",
272
+ ):
273
+ if isinstance(size, tuple):
274
+ if len(size) != 2:
275
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
276
+ self.size = size
277
+ else:
278
+ self.size = (size, size)
279
+
280
+ self.interpolation_mode = interpolation_mode
281
+
282
+
283
+ def __call__(self, clip):
284
+ """
285
+ Args:
286
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
287
+ Returns:
288
+ torch.tensor: scale resized / center cropped video clip.
289
+ size is (T, C, crop_size, crop_size)
290
+ """
291
+ # add aditional one pixel for avoiding error in center crop
292
+ h, w = clip.size(-2), clip.size(-1)
293
+ # print('before resize', clip.shape)
294
+ if h < 320:
295
+ clip = resize_scale_with_height(clip=clip, target_size=321, interpolation_mode=self.interpolation_mode)
296
+ # print('after h resize', clip.shape)
297
+ if w < 512:
298
+ clip = resize_scale_with_weight(clip=clip, target_size=513, interpolation_mode=self.interpolation_mode)
299
+ # print('after w resize', clip.shape)
300
+ clip_center_crop = center_crop(clip, self.size)
301
+ # print(clip_center_crop.shape)
302
+ return clip_center_crop
303
+
304
+ def __repr__(self) -> str:
305
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
306
+
307
+ class UCFCenterCropVideo:
308
+ '''
309
+ First scale to the specified size in equal proportion to the short edge,
310
+ then center cropping
311
+ '''
312
+ def __init__(
313
+ self,
314
+ size,
315
+ interpolation_mode="bilinear",
316
+ ):
317
+ if isinstance(size, tuple):
318
+ if len(size) != 2:
319
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
320
+ self.size = size
321
+ else:
322
+ self.size = (size, size)
323
+
324
+ self.interpolation_mode = interpolation_mode
325
+
326
+
327
+ def __call__(self, clip):
328
+ """
329
+ Args:
330
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
331
+ Returns:
332
+ torch.tensor: scale resized / center cropped video clip.
333
+ size is (T, C, crop_size, crop_size)
334
+ """
335
+ clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
336
+ clip_center_crop = center_crop(clip_resize, self.size)
337
+ return clip_center_crop
338
+
339
+ def __repr__(self) -> str:
340
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
341
+
342
+
343
+ class CenterCropVideo:
344
+ def __init__(
345
+ self,
346
+ size,
347
+ interpolation_mode="bilinear",
348
+ ):
349
+ if isinstance(size, tuple):
350
+ if len(size) != 2:
351
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
352
+ self.size = size
353
+ else:
354
+ self.size = (size, size)
355
+
356
+ self.interpolation_mode = interpolation_mode
357
+
358
+
359
+ def __call__(self, clip):
360
+ """
361
+ Args:
362
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
363
+ Returns:
364
+ torch.tensor: center cropped video clip.
365
+ size is (T, C, crop_size, crop_size)
366
+ """
367
+ clip_center_crop = center_crop(clip, self.size)
368
+ return clip_center_crop
369
+
370
+ def __repr__(self) -> str:
371
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
372
+
373
+
374
+ class NormalizeVideo:
375
+ """
376
+ Normalize the video clip by mean subtraction and division by standard deviation
377
+ Args:
378
+ mean (3-tuple): pixel RGB mean
379
+ std (3-tuple): pixel RGB standard deviation
380
+ inplace (boolean): whether do in-place normalization
381
+ """
382
+
383
+ def __init__(self, mean, std, inplace=False):
384
+ self.mean = mean
385
+ self.std = std
386
+ self.inplace = inplace
387
+
388
+ def __call__(self, clip):
389
+ """
390
+ Args:
391
+ clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
392
+ """
393
+ return normalize(clip, self.mean, self.std, self.inplace)
394
+
395
+ def __repr__(self) -> str:
396
+ return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
397
+
398
+
399
+ class ToTensorVideo:
400
+ """
401
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
402
+ permute the dimensions of clip tensor
403
+ """
404
+
405
+ def __init__(self):
406
+ pass
407
+
408
+ def __call__(self, clip):
409
+ """
410
+ Args:
411
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
412
+ Return:
413
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
414
+ """
415
+ return to_tensor(clip)
416
+
417
+ def __repr__(self) -> str:
418
+ return self.__class__.__name__
419
+
420
+
421
+ class ResizeVideo():
422
+ '''
423
+ First use the short side for cropping length,
424
+ center crop video, then resize to the specified size
425
+ '''
426
+ def __init__(
427
+ self,
428
+ size,
429
+ interpolation_mode="bilinear",
430
+ ):
431
+ if isinstance(size, tuple):
432
+ if len(size) != 2:
433
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
434
+ self.size = size
435
+ else:
436
+ self.size = (size, size)
437
+
438
+ self.interpolation_mode = interpolation_mode
439
+
440
+
441
+ def __call__(self, clip):
442
+ """
443
+ Args:
444
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
445
+ Returns:
446
+ torch.tensor: scale resized / center cropped video clip.
447
+ size is (T, C, crop_size, crop_size)
448
+ """
449
+ clip_resize = resize(clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
450
+ return clip_resize
451
+
452
+ def __repr__(self) -> str:
453
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
454
+
455
+ # ------------------------------------------------------------
456
+ # --------------------- Sampling ---------------------------
457
+ # ------------------------------------------------------------
458
+ class TemporalRandomCrop(object):
459
+ """Temporally crop the given frame indices at a random location.
460
+
461
+ Args:
462
+ size (int): Desired length of frames will be seen in the model.
463
+ """
464
+
465
+ def __init__(self, size):
466
+ self.size = size
467
+
468
+ def __call__(self, total_frames):
469
+ rand_end = max(0, total_frames - self.size - 1)
470
+ begin_index = random.randint(0, rand_end)
471
+ end_index = min(begin_index + self.size, total_frames)
472
+ return begin_index, end_index
diffusion/__init__.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from . import gaussian_diffusion as gd
7
+ from .respace import SpacedDiffusion, space_timesteps
8
+
9
+
10
+ def create_diffusion(
11
+ timestep_respacing,
12
+ noise_schedule="linear",
13
+ use_kl=False,
14
+ sigma_small=False,
15
+ predict_xstart=False,
16
+ # learn_sigma=True,
17
+ learn_sigma=False, # for unet
18
+ rescale_learned_sigmas=False,
19
+ diffusion_steps=1000
20
+ ):
21
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
22
+ if use_kl:
23
+ loss_type = gd.LossType.RESCALED_KL
24
+ elif rescale_learned_sigmas:
25
+ loss_type = gd.LossType.RESCALED_MSE
26
+ else:
27
+ loss_type = gd.LossType.MSE
28
+ if timestep_respacing is None or timestep_respacing == "":
29
+ timestep_respacing = [diffusion_steps]
30
+ return SpacedDiffusion(
31
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
32
+ betas=betas,
33
+ model_mean_type=(
34
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
35
+ ),
36
+ model_var_type=(
37
+ (
38
+ gd.ModelVarType.FIXED_LARGE
39
+ if not sigma_small
40
+ else gd.ModelVarType.FIXED_SMALL
41
+ )
42
+ if not learn_sigma
43
+ else gd.ModelVarType.LEARNED_RANGE
44
+ ),
45
+ loss_type=loss_type
46
+ # rescale_timesteps=rescale_timesteps,
47
+ )
diffusion/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.01 kB). View file
 
diffusion/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.53 kB). View file
 
diffusion/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (997 Bytes). View file
 
diffusion/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (975 Bytes). View file
 
diffusion/__pycache__/diffusion_utils.cpython-310.pyc ADDED
Binary file (2.83 kB). View file
 
diffusion/__pycache__/diffusion_utils.cpython-311.pyc ADDED
Binary file (4.59 kB). View file
 
diffusion/__pycache__/diffusion_utils.cpython-38.pyc ADDED
Binary file (2.87 kB). View file
 
diffusion/__pycache__/diffusion_utils.cpython-39.pyc ADDED
Binary file (2.83 kB). View file
 
diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc ADDED
Binary file (25 kB). View file
 
diffusion/__pycache__/gaussian_diffusion.cpython-311.pyc ADDED
Binary file (40.5 kB). View file
 
diffusion/__pycache__/gaussian_diffusion.cpython-38.pyc ADDED
Binary file (25 kB). View file
 
diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc ADDED
Binary file (24.9 kB). View file
 
diffusion/__pycache__/respace.cpython-310.pyc ADDED
Binary file (4.98 kB). View file
 
diffusion/__pycache__/respace.cpython-311.pyc ADDED
Binary file (7.78 kB). View file
 
diffusion/__pycache__/respace.cpython-38.pyc ADDED
Binary file (5.06 kB). View file
 
diffusion/__pycache__/respace.cpython-39.pyc ADDED
Binary file (5.07 kB). View file
 
diffusion/diffusion_utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import torch as th
7
+ import numpy as np
8
+
9
+
10
+ def normal_kl(mean1, logvar1, mean2, logvar2):
11
+ """
12
+ Compute the KL divergence between two gaussians.
13
+ Shapes are automatically broadcasted, so batches can be compared to
14
+ scalars, among other use cases.
15
+ """
16
+ tensor = None
17
+ for obj in (mean1, logvar1, mean2, logvar2):
18
+ if isinstance(obj, th.Tensor):
19
+ tensor = obj
20
+ break
21
+ assert tensor is not None, "at least one argument must be a Tensor"
22
+
23
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
24
+ # Tensors, but it does not work for th.exp().
25
+ logvar1, logvar2 = [
26
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27
+ for x in (logvar1, logvar2)
28
+ ]
29
+
30
+ return 0.5 * (
31
+ -1.0
32
+ + logvar2
33
+ - logvar1
34
+ + th.exp(logvar1 - logvar2)
35
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36
+ )
37
+
38
+
39
+ def approx_standard_normal_cdf(x):
40
+ """
41
+ A fast approximation of the cumulative distribution function of the
42
+ standard normal.
43
+ """
44
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45
+
46
+
47
+ def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48
+ """
49
+ Compute the log-likelihood of a continuous Gaussian distribution.
50
+ :param x: the targets
51
+ :param means: the Gaussian mean Tensor.
52
+ :param log_scales: the Gaussian log stddev Tensor.
53
+ :return: a tensor like x of log probabilities (in nats).
54
+ """
55
+ centered_x = x - means
56
+ inv_stdv = th.exp(-log_scales)
57
+ normalized_x = centered_x * inv_stdv
58
+ log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59
+ return log_probs
60
+
61
+
62
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63
+ """
64
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
65
+ given image.
66
+ :param x: the target images. It is assumed that this was uint8 values,
67
+ rescaled to the range [-1, 1].
68
+ :param means: the Gaussian mean Tensor.
69
+ :param log_scales: the Gaussian log stddev Tensor.
70
+ :return: a tensor like x of log probabilities (in nats).
71
+ """
72
+ assert x.shape == means.shape == log_scales.shape
73
+ centered_x = x - means
74
+ inv_stdv = th.exp(-log_scales)
75
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76
+ cdf_plus = approx_standard_normal_cdf(plus_in)
77
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78
+ cdf_min = approx_standard_normal_cdf(min_in)
79
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81
+ cdf_delta = cdf_plus - cdf_min
82
+ log_probs = th.where(
83
+ x < -0.999,
84
+ log_cdf_plus,
85
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86
+ )
87
+ assert log_probs.shape == x.shape
88
+ return log_probs
diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,931 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+
7
+ import math
8
+
9
+ import numpy as np
10
+ import torch as th
11
+ import enum
12
+
13
+ from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
14
+
15
+
16
+ def mean_flat(tensor):
17
+ """
18
+ Take the mean over all non-batch dimensions.
19
+ """
20
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
21
+
22
+
23
+ class ModelMeanType(enum.Enum):
24
+ """
25
+ Which type of output the model predicts.
26
+ """
27
+
28
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
29
+ START_X = enum.auto() # the model predicts x_0
30
+ EPSILON = enum.auto() # the model predicts epsilon
31
+
32
+
33
+ class ModelVarType(enum.Enum):
34
+ """
35
+ What is used as the model's output variance.
36
+ The LEARNED_RANGE option has been added to allow the model to predict
37
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
38
+ """
39
+
40
+ LEARNED = enum.auto()
41
+ FIXED_SMALL = enum.auto()
42
+ FIXED_LARGE = enum.auto()
43
+ LEARNED_RANGE = enum.auto()
44
+
45
+
46
+ class LossType(enum.Enum):
47
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
48
+ RESCALED_MSE = (
49
+ enum.auto()
50
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
51
+ KL = enum.auto() # use the variational lower-bound
52
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
53
+
54
+ def is_vb(self):
55
+ return self == LossType.KL or self == LossType.RESCALED_KL
56
+
57
+
58
+ def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
59
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
60
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
61
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
62
+ return betas
63
+
64
+
65
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
66
+ """
67
+ This is the deprecated API for creating beta schedules.
68
+ See get_named_beta_schedule() for the new library of schedules.
69
+ """
70
+ if beta_schedule == "quad":
71
+ betas = (
72
+ np.linspace(
73
+ beta_start ** 0.5,
74
+ beta_end ** 0.5,
75
+ num_diffusion_timesteps,
76
+ dtype=np.float64,
77
+ )
78
+ ** 2
79
+ )
80
+ elif beta_schedule == "linear":
81
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
82
+ elif beta_schedule == "warmup10":
83
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
84
+ elif beta_schedule == "warmup50":
85
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
86
+ elif beta_schedule == "const":
87
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
88
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
89
+ betas = 1.0 / np.linspace(
90
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
91
+ )
92
+ else:
93
+ raise NotImplementedError(beta_schedule)
94
+ assert betas.shape == (num_diffusion_timesteps,)
95
+ return betas
96
+
97
+
98
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
99
+ """
100
+ Get a pre-defined beta schedule for the given name.
101
+ The beta schedule library consists of beta schedules which remain similar
102
+ in the limit of num_diffusion_timesteps.
103
+ Beta schedules may be added, but should not be removed or changed once
104
+ they are committed to maintain backwards compatibility.
105
+ """
106
+ if schedule_name == "linear":
107
+ # Linear schedule from Ho et al, extended to work for any number of
108
+ # diffusion steps.
109
+ scale = 1000 / num_diffusion_timesteps
110
+ return get_beta_schedule(
111
+ "linear",
112
+ beta_start=scale * 0.0001,
113
+ beta_end=scale * 0.02,
114
+ # diffuser stable diffusion
115
+ # beta_start=scale * 0.00085,
116
+ # beta_end=scale * 0.012,
117
+ num_diffusion_timesteps=num_diffusion_timesteps,
118
+ )
119
+ elif schedule_name == "squaredcos_cap_v2":
120
+ return betas_for_alpha_bar(
121
+ num_diffusion_timesteps,
122
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
123
+ )
124
+ else:
125
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
126
+
127
+
128
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
129
+ """
130
+ Create a beta schedule that discretizes the given alpha_t_bar function,
131
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
132
+ :param num_diffusion_timesteps: the number of betas to produce.
133
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
134
+ produces the cumulative product of (1-beta) up to that
135
+ part of the diffusion process.
136
+ :param max_beta: the maximum beta to use; use values lower than 1 to
137
+ prevent singularities.
138
+ """
139
+ betas = []
140
+ for i in range(num_diffusion_timesteps):
141
+ t1 = i / num_diffusion_timesteps
142
+ t2 = (i + 1) / num_diffusion_timesteps
143
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
144
+ return np.array(betas)
145
+
146
+
147
+ class GaussianDiffusion:
148
+ """
149
+ Utilities for training and sampling diffusion models.
150
+ Original ported from this codebase:
151
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
152
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
153
+ starting at T and going to 1.
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ *,
159
+ betas,
160
+ model_mean_type,
161
+ model_var_type,
162
+ loss_type
163
+ ):
164
+
165
+ self.model_mean_type = model_mean_type
166
+ self.model_var_type = model_var_type
167
+ self.loss_type = loss_type
168
+
169
+ # Use float64 for accuracy.
170
+ betas = np.array(betas, dtype=np.float64)
171
+ self.betas = betas
172
+ assert len(betas.shape) == 1, "betas must be 1-D"
173
+ assert (betas > 0).all() and (betas <= 1).all()
174
+
175
+ self.num_timesteps = int(betas.shape[0])
176
+
177
+ alphas = 1.0 - betas
178
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
179
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
180
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
181
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
182
+
183
+ # calculations for diffusion q(x_t | x_{t-1}) and others
184
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
185
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
186
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
187
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
188
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
189
+
190
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
191
+ self.posterior_variance = (
192
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
193
+ )
194
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
195
+ self.posterior_log_variance_clipped = np.log(
196
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
197
+ ) if len(self.posterior_variance) > 1 else np.array([])
198
+
199
+ self.posterior_mean_coef1 = (
200
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
201
+ )
202
+ self.posterior_mean_coef2 = (
203
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
204
+ )
205
+
206
+ def q_mean_variance(self, x_start, t):
207
+ """
208
+ Get the distribution q(x_t | x_0).
209
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
210
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
211
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
212
+ """
213
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
214
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
215
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
216
+ return mean, variance, log_variance
217
+
218
+ def q_sample(self, x_start, t, noise=None):
219
+ """
220
+ Diffuse the data for a given number of diffusion steps.
221
+ In other words, sample from q(x_t | x_0).
222
+ :param x_start: the initial data batch.
223
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
224
+ :param noise: if specified, the split-out normal noise.
225
+ :return: A noisy version of x_start.
226
+ """
227
+ if noise is None:
228
+ noise = th.randn_like(x_start)
229
+ assert noise.shape == x_start.shape
230
+ return (
231
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
232
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
233
+ )
234
+
235
+ def q_posterior_mean_variance(self, x_start, x_t, t):
236
+ """
237
+ Compute the mean and variance of the diffusion posterior:
238
+ q(x_{t-1} | x_t, x_0)
239
+ """
240
+ assert x_start.shape == x_t.shape
241
+ posterior_mean = (
242
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
243
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
244
+ )
245
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
246
+ posterior_log_variance_clipped = _extract_into_tensor(
247
+ self.posterior_log_variance_clipped, t, x_t.shape
248
+ )
249
+ assert (
250
+ posterior_mean.shape[0]
251
+ == posterior_variance.shape[0]
252
+ == posterior_log_variance_clipped.shape[0]
253
+ == x_start.shape[0]
254
+ )
255
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
256
+
257
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None,
258
+ mask=None, x_start=None, use_concat=False):
259
+ """
260
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
261
+ the initial x, x_0.
262
+ :param model: the model, which takes a signal and a batch of timesteps
263
+ as input.
264
+ :param x: the [N x C x ...] tensor at time t.
265
+ :param t: a 1-D Tensor of timesteps.
266
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
267
+ :param denoised_fn: if not None, a function which applies to the
268
+ x_start prediction before it is used to sample. Applies before
269
+ clip_denoised.
270
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
271
+ pass to the model. This can be used for conditioning.
272
+ :return: a dict with the following keys:
273
+ - 'mean': the model mean output.
274
+ - 'variance': the model variance output.
275
+ - 'log_variance': the log of 'variance'.
276
+ - 'pred_xstart': the prediction for x_0.
277
+ """
278
+ if model_kwargs is None:
279
+ model_kwargs = {}
280
+
281
+ B, F, C = x.shape[:3]
282
+ assert t.shape == (B,)
283
+ if use_concat:
284
+ model_output = model(th.concat([x, mask, x_start], dim=1), t, **model_kwargs)
285
+ else:
286
+ model_output = model(x, t, **model_kwargs)
287
+ try:
288
+ model_output = model_output.sample # for tav unet
289
+ except:
290
+ pass
291
+ # model_output = model(x, t, **model_kwargs)
292
+ if isinstance(model_output, tuple):
293
+ model_output, extra = model_output
294
+ else:
295
+ extra = None
296
+
297
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
298
+ assert model_output.shape == (B, F, C * 2, *x.shape[3:])
299
+ model_output, model_var_values = th.split(model_output, C, dim=2)
300
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
301
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
302
+ # The model_var_values is [-1, 1] for [min_var, max_var].
303
+ frac = (model_var_values + 1) / 2
304
+ model_log_variance = frac * max_log + (1 - frac) * min_log
305
+ model_variance = th.exp(model_log_variance)
306
+ else:
307
+ model_variance, model_log_variance = {
308
+ # for fixedlarge, we set the initial (log-)variance like so
309
+ # to get a better decoder log likelihood.
310
+ ModelVarType.FIXED_LARGE: (
311
+ np.append(self.posterior_variance[1], self.betas[1:]),
312
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
313
+ ),
314
+ ModelVarType.FIXED_SMALL: (
315
+ self.posterior_variance,
316
+ self.posterior_log_variance_clipped,
317
+ ),
318
+ }[self.model_var_type]
319
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
320
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
321
+
322
+ def process_xstart(x):
323
+ if denoised_fn is not None:
324
+ x = denoised_fn(x)
325
+ if clip_denoised:
326
+ return x.clamp(-1, 1)
327
+ return x
328
+
329
+ if self.model_mean_type == ModelMeanType.START_X:
330
+ pred_xstart = process_xstart(model_output)
331
+ else:
332
+ pred_xstart = process_xstart(
333
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
334
+ )
335
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
336
+
337
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
338
+ return {
339
+ "mean": model_mean,
340
+ "variance": model_variance,
341
+ "log_variance": model_log_variance,
342
+ "pred_xstart": pred_xstart,
343
+ "extra": extra,
344
+ }
345
+
346
+ def _predict_xstart_from_eps(self, x_t, t, eps):
347
+ assert x_t.shape == eps.shape
348
+ return (
349
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
350
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
351
+ )
352
+
353
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
354
+ return (
355
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
356
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
357
+
358
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
359
+ """
360
+ Compute the mean for the previous step, given a function cond_fn that
361
+ computes the gradient of a conditional log probability with respect to
362
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
363
+ condition on y.
364
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
365
+ """
366
+ gradient = cond_fn(x, t, **model_kwargs)
367
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
368
+ return new_mean
369
+
370
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
371
+ """
372
+ Compute what the p_mean_variance output would have been, should the
373
+ model's score function be conditioned by cond_fn.
374
+ See condition_mean() for details on cond_fn.
375
+ Unlike condition_mean(), this instead uses the conditioning strategy
376
+ from Song et al (2020).
377
+ """
378
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
379
+
380
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
381
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
382
+
383
+ out = p_mean_var.copy()
384
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
385
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
386
+ return out
387
+
388
+ def p_sample(
389
+ self,
390
+ model,
391
+ x,
392
+ t,
393
+ clip_denoised=True,
394
+ denoised_fn=None,
395
+ cond_fn=None,
396
+ model_kwargs=None,
397
+ mask=None,
398
+ x_start=None,
399
+ use_concat=False
400
+ ):
401
+ """
402
+ Sample x_{t-1} from the model at the given timestep.
403
+ :param model: the model to sample from.
404
+ :param x: the current tensor at x_{t-1}.
405
+ :param t: the value of t, starting at 0 for the first diffusion step.
406
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
407
+ :param denoised_fn: if not None, a function which applies to the
408
+ x_start prediction before it is used to sample.
409
+ :param cond_fn: if not None, this is a gradient function that acts
410
+ similarly to the model.
411
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
412
+ pass to the model. This can be used for conditioning.
413
+ :return: a dict containing the following keys:
414
+ - 'sample': a random sample from the model.
415
+ - 'pred_xstart': a prediction of x_0.
416
+ """
417
+ out = self.p_mean_variance(
418
+ model,
419
+ x,
420
+ t,
421
+ clip_denoised=clip_denoised,
422
+ denoised_fn=denoised_fn,
423
+ model_kwargs=model_kwargs,
424
+ mask=mask,
425
+ x_start=x_start,
426
+ use_concat=use_concat
427
+ )
428
+ noise = th.randn_like(x)
429
+ nonzero_mask = (
430
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
431
+ ) # no noise when t == 0
432
+ if cond_fn is not None:
433
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
434
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
435
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
436
+
437
+ def p_sample_loop(
438
+ self,
439
+ model,
440
+ shape,
441
+ noise=None,
442
+ clip_denoised=True,
443
+ denoised_fn=None,
444
+ cond_fn=None,
445
+ model_kwargs=None,
446
+ device=None,
447
+ progress=False,
448
+ mask=None,
449
+ x_start=None,
450
+ use_concat=False,
451
+ ):
452
+ """
453
+ Generate samples from the model.
454
+ :param model: the model module.
455
+ :param shape: the shape of the samples, (N, C, H, W).
456
+ :param noise: if specified, the noise from the encoder to sample.
457
+ Should be of the same shape as `shape`.
458
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
459
+ :param denoised_fn: if not None, a function which applies to the
460
+ x_start prediction before it is used to sample.
461
+ :param cond_fn: if not None, this is a gradient function that acts
462
+ similarly to the model.
463
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
464
+ pass to the model. This can be used for conditioning.
465
+ :param device: if specified, the device to create the samples on.
466
+ If not specified, use a model parameter's device.
467
+ :param progress: if True, show a tqdm progress bar.
468
+ :return: a non-differentiable batch of samples.
469
+ """
470
+ final = None
471
+ for sample in self.p_sample_loop_progressive(
472
+ model,
473
+ shape,
474
+ noise=noise,
475
+ clip_denoised=clip_denoised,
476
+ denoised_fn=denoised_fn,
477
+ cond_fn=cond_fn,
478
+ model_kwargs=model_kwargs,
479
+ device=device,
480
+ progress=progress,
481
+ mask=mask,
482
+ x_start=x_start,
483
+ use_concat=use_concat
484
+ ):
485
+ final = sample
486
+ return final["sample"]
487
+
488
+ def p_sample_loop_progressive(
489
+ self,
490
+ model,
491
+ shape,
492
+ noise=None,
493
+ clip_denoised=True,
494
+ denoised_fn=None,
495
+ cond_fn=None,
496
+ model_kwargs=None,
497
+ device=None,
498
+ progress=False,
499
+ mask=None,
500
+ x_start=None,
501
+ use_concat=False
502
+ ):
503
+ """
504
+ Generate samples from the model and yield intermediate samples from
505
+ each timestep of diffusion.
506
+ Arguments are the same as p_sample_loop().
507
+ Returns a generator over dicts, where each dict is the return value of
508
+ p_sample().
509
+ """
510
+ if device is None:
511
+ device = next(model.parameters()).device
512
+ assert isinstance(shape, (tuple, list))
513
+ if noise is not None:
514
+ img = noise
515
+ else:
516
+ img = th.randn(*shape, device=device)
517
+ indices = list(range(self.num_timesteps))[::-1]
518
+
519
+ if progress:
520
+ # Lazy import so that we don't depend on tqdm.
521
+ from tqdm.auto import tqdm
522
+
523
+ indices = tqdm(indices)
524
+
525
+ for i in indices:
526
+ t = th.tensor([i] * shape[0], device=device)
527
+ with th.no_grad():
528
+ out = self.p_sample(
529
+ model,
530
+ img,
531
+ t,
532
+ clip_denoised=clip_denoised,
533
+ denoised_fn=denoised_fn,
534
+ cond_fn=cond_fn,
535
+ model_kwargs=model_kwargs,
536
+ mask=mask,
537
+ x_start=x_start,
538
+ use_concat=use_concat
539
+ )
540
+ yield out
541
+ img = out["sample"]
542
+
543
+ def ddim_sample(
544
+ self,
545
+ model,
546
+ x,
547
+ t,
548
+ clip_denoised=True,
549
+ denoised_fn=None,
550
+ cond_fn=None,
551
+ model_kwargs=None,
552
+ eta=0.0,
553
+ mask=None,
554
+ x_start=None,
555
+ use_concat=False
556
+ ):
557
+ """
558
+ Sample x_{t-1} from the model using DDIM.
559
+ Same usage as p_sample().
560
+ """
561
+ out = self.p_mean_variance(
562
+ model,
563
+ x,
564
+ t,
565
+ clip_denoised=clip_denoised,
566
+ denoised_fn=denoised_fn,
567
+ model_kwargs=model_kwargs,
568
+ mask=mask,
569
+ x_start=x_start,
570
+ use_concat=use_concat
571
+ )
572
+ if cond_fn is not None:
573
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
574
+
575
+ # Usually our model outputs epsilon, but we re-derive it
576
+ # in case we used x_start or x_prev prediction.
577
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
578
+
579
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
580
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
581
+ sigma = (
582
+ eta
583
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
584
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
585
+ )
586
+ # Equation 12.
587
+ noise = th.randn_like(x)
588
+ mean_pred = (
589
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
590
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
591
+ )
592
+ nonzero_mask = (
593
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
594
+ ) # no noise when t == 0
595
+ sample = mean_pred + nonzero_mask * sigma * noise
596
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
597
+
598
+ def ddim_reverse_sample(
599
+ self,
600
+ model,
601
+ x,
602
+ t,
603
+ clip_denoised=True,
604
+ denoised_fn=None,
605
+ cond_fn=None,
606
+ model_kwargs=None,
607
+ eta=0.0,
608
+ ):
609
+ """
610
+ Sample x_{t+1} from the model using DDIM reverse ODE.
611
+ """
612
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
613
+ out = self.p_mean_variance(
614
+ model,
615
+ x,
616
+ t,
617
+ clip_denoised=clip_denoised,
618
+ denoised_fn=denoised_fn,
619
+ model_kwargs=model_kwargs,
620
+ )
621
+ if cond_fn is not None:
622
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
623
+ # Usually our model outputs epsilon, but we re-derive it
624
+ # in case we used x_start or x_prev prediction.
625
+ eps = (
626
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
627
+ - out["pred_xstart"]
628
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
629
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
630
+
631
+ # Equation 12. reversed
632
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
633
+
634
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
635
+
636
+ def ddim_sample_loop(
637
+ self,
638
+ model,
639
+ shape,
640
+ noise=None,
641
+ clip_denoised=True,
642
+ denoised_fn=None,
643
+ cond_fn=None,
644
+ model_kwargs=None,
645
+ device=None,
646
+ progress=False,
647
+ eta=0.0,
648
+ mask=None,
649
+ x_start=None,
650
+ use_concat=False
651
+ ):
652
+ """
653
+ Generate samples from the model using DDIM.
654
+ Same usage as p_sample_loop().
655
+ """
656
+ final = None
657
+ for sample in self.ddim_sample_loop_progressive(
658
+ model,
659
+ shape,
660
+ noise=noise,
661
+ clip_denoised=clip_denoised,
662
+ denoised_fn=denoised_fn,
663
+ cond_fn=cond_fn,
664
+ model_kwargs=model_kwargs,
665
+ device=device,
666
+ progress=progress,
667
+ eta=eta,
668
+ mask=mask,
669
+ x_start=x_start,
670
+ use_concat=use_concat
671
+ ):
672
+ final = sample
673
+ return final["sample"]
674
+
675
+ def ddim_sample_loop_progressive(
676
+ self,
677
+ model,
678
+ shape,
679
+ noise=None,
680
+ clip_denoised=True,
681
+ denoised_fn=None,
682
+ cond_fn=None,
683
+ model_kwargs=None,
684
+ device=None,
685
+ progress=False,
686
+ eta=0.0,
687
+ mask=None,
688
+ x_start=None,
689
+ use_concat=False
690
+ ):
691
+ """
692
+ Use DDIM to sample from the model and yield intermediate samples from
693
+ each timestep of DDIM.
694
+ Same usage as p_sample_loop_progressive().
695
+ """
696
+ if device is None:
697
+ device = next(model.parameters()).device
698
+ assert isinstance(shape, (tuple, list))
699
+ if noise is not None:
700
+ img = noise
701
+ else:
702
+ img = th.randn(*shape, device=device)
703
+ indices = list(range(self.num_timesteps))[::-1]
704
+
705
+ if progress:
706
+ # Lazy import so that we don't depend on tqdm.
707
+ from tqdm.auto import tqdm
708
+
709
+ indices = tqdm(indices)
710
+
711
+ for i in indices:
712
+ t = th.tensor([i] * shape[0], device=device)
713
+ with th.no_grad():
714
+ out = self.ddim_sample(
715
+ model,
716
+ img,
717
+ t,
718
+ clip_denoised=clip_denoised,
719
+ denoised_fn=denoised_fn,
720
+ cond_fn=cond_fn,
721
+ model_kwargs=model_kwargs,
722
+ eta=eta,
723
+ mask=mask,
724
+ x_start=x_start,
725
+ use_concat=use_concat
726
+ )
727
+ yield out
728
+ img = out["sample"]
729
+
730
+ def _vb_terms_bpd(
731
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
732
+ ):
733
+ """
734
+ Get a term for the variational lower-bound.
735
+ The resulting units are bits (rather than nats, as one might expect).
736
+ This allows for comparison to other papers.
737
+ :return: a dict with the following keys:
738
+ - 'output': a shape [N] tensor of NLLs or KLs.
739
+ - 'pred_xstart': the x_0 predictions.
740
+ """
741
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
742
+ x_start=x_start, x_t=x_t, t=t
743
+ )
744
+ out = self.p_mean_variance(
745
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
746
+ )
747
+ kl = normal_kl(
748
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
749
+ )
750
+ kl = mean_flat(kl) / np.log(2.0)
751
+
752
+ decoder_nll = -discretized_gaussian_log_likelihood(
753
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
754
+ )
755
+ assert decoder_nll.shape == x_start.shape
756
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
757
+
758
+ # At the first timestep return the decoder NLL,
759
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
760
+ output = th.where((t == 0), decoder_nll, kl)
761
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
762
+
763
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, use_mask=False):
764
+ """
765
+ Compute training losses for a single timestep.
766
+ :param model: the model to evaluate loss on.
767
+ :param x_start: the [N x C x ...] tensor of inputs.
768
+ :param t: a batch of timestep indices.
769
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
770
+ pass to the model. This can be used for conditioning.
771
+ :param noise: if specified, the specific Gaussian noise to try to remove.
772
+ :return: a dict with the key "loss" containing a tensor of shape [N].
773
+ Some mean or variance settings may also have other keys.
774
+ """
775
+ if model_kwargs is None:
776
+ model_kwargs = {}
777
+ if noise is None:
778
+ noise = th.randn_like(x_start)
779
+ x_t = self.q_sample(x_start, t, noise=noise)
780
+ if use_mask:
781
+ x_t = th.cat([x_t[:, :4], x_start[:, 4:]], dim=1)
782
+ terms = {}
783
+
784
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
785
+ terms["loss"] = self._vb_terms_bpd(
786
+ model=model,
787
+ x_start=x_start,
788
+ x_t=x_t,
789
+ t=t,
790
+ clip_denoised=False,
791
+ model_kwargs=model_kwargs,
792
+ )["output"]
793
+ if self.loss_type == LossType.RESCALED_KL:
794
+ terms["loss"] *= self.num_timesteps
795
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
796
+ model_output = model(x_t, t, **model_kwargs)
797
+ try:
798
+ # model_output = model(x_t, t, **model_kwargs).sample
799
+ model_output = model_output.sample # for tav unet
800
+ except:
801
+ pass
802
+ # model_output = model(x_t, t, **model_kwargs)
803
+
804
+ if self.model_var_type in [
805
+ ModelVarType.LEARNED,
806
+ ModelVarType.LEARNED_RANGE,
807
+ ]:
808
+ B, F, C = x_t.shape[:3]
809
+ assert model_output.shape == (B, F, C * 2, *x_t.shape[3:])
810
+ model_output, model_var_values = th.split(model_output, C, dim=2)
811
+ # Learn the variance using the variational bound, but don't let
812
+ # it affect our mean prediction.
813
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=2)
814
+ terms["vb"] = self._vb_terms_bpd(
815
+ model=lambda *args, r=frozen_out: r,
816
+ x_start=x_start,
817
+ x_t=x_t,
818
+ t=t,
819
+ clip_denoised=False,
820
+ )["output"]
821
+ if self.loss_type == LossType.RESCALED_MSE:
822
+ # Divide by 1000 for equivalence with initial implementation.
823
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
824
+ terms["vb"] *= self.num_timesteps / 1000.0
825
+
826
+ target = {
827
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
828
+ x_start=x_start, x_t=x_t, t=t
829
+ )[0],
830
+ ModelMeanType.START_X: x_start,
831
+ ModelMeanType.EPSILON: noise,
832
+ }[self.model_mean_type]
833
+ # assert model_output.shape == target.shape == x_start.shape
834
+ if use_mask:
835
+ terms["mse"] = mean_flat((target[:,:4] - model_output) ** 2)
836
+ else:
837
+ terms["mse"] = mean_flat((target - model_output) ** 2)
838
+ if "vb" in terms:
839
+ terms["loss"] = terms["mse"] + terms["vb"]
840
+ else:
841
+ terms["loss"] = terms["mse"]
842
+ else:
843
+ raise NotImplementedError(self.loss_type)
844
+
845
+ return terms
846
+
847
+ def _prior_bpd(self, x_start):
848
+ """
849
+ Get the prior KL term for the variational lower-bound, measured in
850
+ bits-per-dim.
851
+ This term can't be optimized, as it only depends on the encoder.
852
+ :param x_start: the [N x C x ...] tensor of inputs.
853
+ :return: a batch of [N] KL values (in bits), one per batch element.
854
+ """
855
+ batch_size = x_start.shape[0]
856
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
857
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
858
+ kl_prior = normal_kl(
859
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
860
+ )
861
+ return mean_flat(kl_prior) / np.log(2.0)
862
+
863
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
864
+ """
865
+ Compute the entire variational lower-bound, measured in bits-per-dim,
866
+ as well as other related quantities.
867
+ :param model: the model to evaluate loss on.
868
+ :param x_start: the [N x C x ...] tensor of inputs.
869
+ :param clip_denoised: if True, clip denoised samples.
870
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
871
+ pass to the model. This can be used for conditioning.
872
+ :return: a dict containing the following keys:
873
+ - total_bpd: the total variational lower-bound, per batch element.
874
+ - prior_bpd: the prior term in the lower-bound.
875
+ - vb: an [N x T] tensor of terms in the lower-bound.
876
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
877
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
878
+ """
879
+ device = x_start.device
880
+ batch_size = x_start.shape[0]
881
+
882
+ vb = []
883
+ xstart_mse = []
884
+ mse = []
885
+ for t in list(range(self.num_timesteps))[::-1]:
886
+ t_batch = th.tensor([t] * batch_size, device=device)
887
+ noise = th.randn_like(x_start)
888
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
889
+ # Calculate VLB term at the current timestep
890
+ with th.no_grad():
891
+ out = self._vb_terms_bpd(
892
+ model,
893
+ x_start=x_start,
894
+ x_t=x_t,
895
+ t=t_batch,
896
+ clip_denoised=clip_denoised,
897
+ model_kwargs=model_kwargs,
898
+ )
899
+ vb.append(out["output"])
900
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
901
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
902
+ mse.append(mean_flat((eps - noise) ** 2))
903
+
904
+ vb = th.stack(vb, dim=1)
905
+ xstart_mse = th.stack(xstart_mse, dim=1)
906
+ mse = th.stack(mse, dim=1)
907
+
908
+ prior_bpd = self._prior_bpd(x_start)
909
+ total_bpd = vb.sum(dim=1) + prior_bpd
910
+ return {
911
+ "total_bpd": total_bpd,
912
+ "prior_bpd": prior_bpd,
913
+ "vb": vb,
914
+ "xstart_mse": xstart_mse,
915
+ "mse": mse,
916
+ }
917
+
918
+
919
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
920
+ """
921
+ Extract values from a 1-D numpy array for a batch of indices.
922
+ :param arr: the 1-D numpy array.
923
+ :param timesteps: a tensor of indices into the array to extract.
924
+ :param broadcast_shape: a larger shape of K dimensions with the batch
925
+ dimension equal to the length of timesteps.
926
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
927
+ """
928
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
929
+ while len(res.shape) < len(broadcast_shape):
930
+ res = res[..., None]
931
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
diffusion/respace.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+ import torch
6
+ import numpy as np
7
+ import torch as th
8
+
9
+ from .gaussian_diffusion import GaussianDiffusion
10
+
11
+
12
+ def space_timesteps(num_timesteps, section_counts):
13
+ """
14
+ Create a list of timesteps to use from an original diffusion process,
15
+ given the number of timesteps we want to take from equally-sized portions
16
+ of the original process.
17
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
18
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
19
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
20
+ If the stride is a string starting with "ddim", then the fixed striding
21
+ from the DDIM paper is used, and only one section is allowed.
22
+ :param num_timesteps: the number of diffusion steps in the original
23
+ process to divide up.
24
+ :param section_counts: either a list of numbers, or a string containing
25
+ comma-separated numbers, indicating the step count
26
+ per section. As a special case, use "ddimN" where N
27
+ is a number of steps to use the striding from the
28
+ DDIM paper.
29
+ :return: a set of diffusion steps from the original process to use.
30
+ """
31
+ if isinstance(section_counts, str):
32
+ if section_counts.startswith("ddim"):
33
+ desired_count = int(section_counts[len("ddim") :])
34
+ for i in range(1, num_timesteps):
35
+ if len(range(0, num_timesteps, i)) == desired_count:
36
+ return set(range(0, num_timesteps, i))
37
+ raise ValueError(
38
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
39
+ )
40
+ section_counts = [int(x) for x in section_counts.split(",")]
41
+ size_per = num_timesteps // len(section_counts)
42
+ extra = num_timesteps % len(section_counts)
43
+ start_idx = 0
44
+ all_steps = []
45
+ for i, section_count in enumerate(section_counts):
46
+ size = size_per + (1 if i < extra else 0)
47
+ if size < section_count:
48
+ raise ValueError(
49
+ f"cannot divide section of {size} steps into {section_count}"
50
+ )
51
+ if section_count <= 1:
52
+ frac_stride = 1
53
+ else:
54
+ frac_stride = (size - 1) / (section_count - 1)
55
+ cur_idx = 0.0
56
+ taken_steps = []
57
+ for _ in range(section_count):
58
+ taken_steps.append(start_idx + round(cur_idx))
59
+ cur_idx += frac_stride
60
+ all_steps += taken_steps
61
+ start_idx += size
62
+ return set(all_steps)
63
+
64
+
65
+ class SpacedDiffusion(GaussianDiffusion):
66
+ """
67
+ A diffusion process which can skip steps in a base diffusion process.
68
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
69
+ original diffusion process to retain.
70
+ :param kwargs: the kwargs to create the base diffusion process.
71
+ """
72
+
73
+ def __init__(self, use_timesteps, **kwargs):
74
+ self.use_timesteps = set(use_timesteps)
75
+ self.timestep_map = []
76
+ self.original_num_steps = len(kwargs["betas"])
77
+
78
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79
+ last_alpha_cumprod = 1.0
80
+ new_betas = []
81
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82
+ if i in self.use_timesteps:
83
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84
+ last_alpha_cumprod = alpha_cumprod
85
+ self.timestep_map.append(i)
86
+ kwargs["betas"] = np.array(new_betas)
87
+ super().__init__(**kwargs)
88
+
89
+ def p_mean_variance(
90
+ self, model, *args, **kwargs
91
+ ): # pylint: disable=signature-differs
92
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93
+
94
+ # @torch.compile
95
+ def training_losses(
96
+ self, model, *args, **kwargs
97
+ ): # pylint: disable=signature-differs
98
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
99
+
100
+ def condition_mean(self, cond_fn, *args, **kwargs):
101
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
102
+
103
+ def condition_score(self, cond_fn, *args, **kwargs):
104
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
105
+
106
+ def _wrap_model(self, model):
107
+ if isinstance(model, _WrappedModel):
108
+ return model
109
+ return _WrappedModel(
110
+ model, self.timestep_map, self.original_num_steps
111
+ )
112
+
113
+ def _scale_timesteps(self, t):
114
+ # Scaling is done by the wrapped model.
115
+ return t
116
+
117
+
118
+ class _WrappedModel:
119
+ def __init__(self, model, timestep_map, original_num_steps):
120
+ self.model = model
121
+ self.timestep_map = timestep_map
122
+ # self.rescale_timesteps = rescale_timesteps
123
+ self.original_num_steps = original_num_steps
124
+
125
+ def __call__(self, x, ts, **kwargs):
126
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
127
+ new_ts = map_tensor[ts]
128
+ # if self.rescale_timesteps:
129
+ # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
130
+ return self.model(x, new_ts, **kwargs)
diffusion/timestep_sampler.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from abc import ABC, abstractmethod
7
+
8
+ import numpy as np
9
+ import torch as th
10
+ import torch.distributed as dist
11
+
12
+
13
+ def create_named_schedule_sampler(name, diffusion):
14
+ """
15
+ Create a ScheduleSampler from a library of pre-defined samplers.
16
+ :param name: the name of the sampler.
17
+ :param diffusion: the diffusion object to sample for.
18
+ """
19
+ if name == "uniform":
20
+ return UniformSampler(diffusion)
21
+ elif name == "loss-second-moment":
22
+ return LossSecondMomentResampler(diffusion)
23
+ else:
24
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
25
+
26
+
27
+ class ScheduleSampler(ABC):
28
+ """
29
+ A distribution over timesteps in the diffusion process, intended to reduce
30
+ variance of the objective.
31
+ By default, samplers perform unbiased importance sampling, in which the
32
+ objective's mean is unchanged.
33
+ However, subclasses may override sample() to change how the resampled
34
+ terms are reweighted, allowing for actual changes in the objective.
35
+ """
36
+
37
+ @abstractmethod
38
+ def weights(self):
39
+ """
40
+ Get a numpy array of weights, one per diffusion step.
41
+ The weights needn't be normalized, but must be positive.
42
+ """
43
+
44
+ def sample(self, batch_size, device):
45
+ """
46
+ Importance-sample timesteps for a batch.
47
+ :param batch_size: the number of timesteps.
48
+ :param device: the torch device to save to.
49
+ :return: a tuple (timesteps, weights):
50
+ - timesteps: a tensor of timestep indices.
51
+ - weights: a tensor of weights to scale the resulting losses.
52
+ """
53
+ w = self.weights()
54
+ p = w / np.sum(w)
55
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56
+ indices = th.from_numpy(indices_np).long().to(device)
57
+ weights_np = 1 / (len(p) * p[indices_np])
58
+ weights = th.from_numpy(weights_np).float().to(device)
59
+ return indices, weights
60
+
61
+
62
+ class UniformSampler(ScheduleSampler):
63
+ def __init__(self, diffusion):
64
+ self.diffusion = diffusion
65
+ self._weights = np.ones([diffusion.num_timesteps])
66
+
67
+ def weights(self):
68
+ return self._weights
69
+
70
+
71
+ class LossAwareSampler(ScheduleSampler):
72
+ def update_with_local_losses(self, local_ts, local_losses):
73
+ """
74
+ Update the reweighting using losses from a model.
75
+ Call this method from each rank with a batch of timesteps and the
76
+ corresponding losses for each of those timesteps.
77
+ This method will perform synchronization to make sure all of the ranks
78
+ maintain the exact same reweighting.
79
+ :param local_ts: an integer Tensor of timesteps.
80
+ :param local_losses: a 1D Tensor of losses.
81
+ """
82
+ batch_sizes = [
83
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
84
+ for _ in range(dist.get_world_size())
85
+ ]
86
+ dist.all_gather(
87
+ batch_sizes,
88
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
89
+ )
90
+
91
+ # Pad all_gather batches to be the maximum batch size.
92
+ batch_sizes = [x.item() for x in batch_sizes]
93
+ max_bs = max(batch_sizes)
94
+
95
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
96
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
97
+ dist.all_gather(timestep_batches, local_ts)
98
+ dist.all_gather(loss_batches, local_losses)
99
+ timesteps = [
100
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
101
+ ]
102
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
103
+ self.update_with_all_losses(timesteps, losses)
104
+
105
+ @abstractmethod
106
+ def update_with_all_losses(self, ts, losses):
107
+ """
108
+ Update the reweighting using losses from a model.
109
+ Sub-classes should override this method to update the reweighting
110
+ using losses from the model.
111
+ This method directly updates the reweighting without synchronizing
112
+ between workers. It is called by update_with_local_losses from all
113
+ ranks with identical arguments. Thus, it should have deterministic
114
+ behavior to maintain state across workers.
115
+ :param ts: a list of int timesteps.
116
+ :param losses: a list of float losses, one per timestep.
117
+ """
118
+
119
+
120
+ class LossSecondMomentResampler(LossAwareSampler):
121
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
122
+ self.diffusion = diffusion
123
+ self.history_per_term = history_per_term
124
+ self.uniform_prob = uniform_prob
125
+ self._loss_history = np.zeros(
126
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
127
+ )
128
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
129
+
130
+ def weights(self):
131
+ if not self._warmed_up():
132
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
133
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
134
+ weights /= np.sum(weights)
135
+ weights *= 1 - self.uniform_prob
136
+ weights += self.uniform_prob / len(weights)
137
+ return weights
138
+
139
+ def update_with_all_losses(self, ts, losses):
140
+ for t, loss in zip(ts, losses):
141
+ if self._loss_counts[t] == self.history_per_term:
142
+ # Shift out the oldest loss term.
143
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
144
+ self._loss_history[t, -1] = loss
145
+ else:
146
+ self._loss_history[t, self._loss_counts[t]] = loss
147
+ self._loss_counts[t] += 1
148
+
149
+ def _warmed_up(self):
150
+ return (self._loss_counts == self.history_per_term).all()
download.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Functions for downloading pre-trained DiT models
9
+ """
10
+ from torchvision.datasets.utils import download_url
11
+ import torch
12
+ import os
13
+
14
+
15
+
16
+ def find_model(model_name):
17
+
18
+ checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
19
+
20
+ if "ema" in checkpoint: # supports checkpoints from train.py
21
+ print('Ema existing!')
22
+ checkpoint = checkpoint["ema"]
23
+ return checkpoint
24
+
25
+
26
+ def download_model(model_name):
27
+ """
28
+ Downloads a pre-trained DiT model from the web.
29
+ """
30
+ assert model_name in pretrained_models
31
+ local_path = f'pretrained_models/{model_name}'
32
+ if not os.path.isfile(local_path):
33
+ os.makedirs('pretrained_models', exist_ok=True)
34
+ web_path = f'https://dl.fbaipublicfiles.com/DiT/models/{model_name}'
35
+ download_url(web_path, 'pretrained_models')
36
+ model = torch.load(local_path, map_location=lambda storage, loc: storage)
37
+ return model
38
+
39
+
40
+ if __name__ == "__main__":
41
+ # Download all DiT checkpoints
42
+ for model in pretrained_models:
43
+ download_model(model)
44
+ print('Done.')
env.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: seine
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - conda-forge
6
+ - defaults
7
+ dependencies:
8
+ - python=3.9.16
9
+ - pytorch=2.0.1
10
+ - pytorch-cuda=11.7
11
+ - torchvision=0.15.2
12
+ - pip
13
+ - pip:
14
+ - decord==0.6.0
15
+ - diffusers==0.15.0
16
+ - imageio==2.29.0
17
+ - transformers==4.29.2
18
+ - xformers==0.0.20
19
+ - einops
20
+ - omegaconf
huggingface-i2v/__init__.py ADDED
File without changes
huggingface-i2v/requirements.txt ADDED
File without changes
image_to_video/__init__.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import math
4
+ import docx
5
+ try:
6
+ import utils
7
+
8
+ from diffusion import create_diffusion
9
+ from download import find_model
10
+ except:
11
+ # sys.path.append(os.getcwd())
12
+ sys.path.append(os.path.split(sys.path[0])[0])
13
+ # sys.path[0]
14
+ # os.path.split(sys.path[0])
15
+
16
+
17
+ import utils
18
+
19
+ from diffusion import create_diffusion
20
+ from download import find_model
21
+
22
+ import torch
23
+ torch.backends.cuda.matmul.allow_tf32 = True
24
+ torch.backends.cudnn.allow_tf32 = True
25
+ import argparse
26
+ import torchvision
27
+
28
+ from einops import rearrange
29
+ from models import get_models
30
+ from torchvision.utils import save_image
31
+ from diffusers.models import AutoencoderKL
32
+ from models.clip import TextEmbedder
33
+ from omegaconf import OmegaConf
34
+ from PIL import Image
35
+ import numpy as np
36
+ from torchvision import transforms
37
+ sys.path.append("..")
38
+ from datasets import video_transforms
39
+ from utils import mask_generation_before
40
+ from natsort import natsorted
41
+ from diffusers.utils.import_utils import is_xformers_available
42
+
43
+ config_path = "/mnt/petrelfs/zhouyan/project/i2v/configs/sample_i2v.yaml"
44
+ args = OmegaConf.load(config_path)
45
+ device = "cuda" if torch.cuda.is_available() else "cpu"
46
+ print(args)
47
+
48
+ def model_i2v_fun(args):
49
+ if args.seed:
50
+ torch.manual_seed(args.seed)
51
+ torch.set_grad_enabled(False)
52
+ if args.ckpt is None:
53
+ raise ValueError("Please specify a checkpoint path using --ckpt <path>")
54
+ latent_h = args.image_size[0] // 8
55
+ latent_w = args.image_size[1] // 8
56
+ args.image_h = args.image_size[0]
57
+ args.image_w = args.image_size[1]
58
+ args.latent_h = latent_h
59
+ args.latent_w = latent_w
60
+ print("loading model")
61
+ model = get_models(args).to(device)
62
+
63
+ if args.use_compile:
64
+ model = torch.compile(model)
65
+ ckpt_path = args.ckpt
66
+ state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['ema']
67
+ model.load_state_dict(state_dict)
68
+
69
+ print('loading success')
70
+
71
+ model.eval()
72
+ pretrained_model_path = args.pretrained_model_path
73
+ diffusion = create_diffusion(str(args.num_sampling_steps))
74
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
75
+ text_encoder = TextEmbedder(pretrained_model_path).to(device)
76
+ # if args.use_fp16:
77
+ # print('Warning: using half precision for inference')
78
+ # vae.to(dtype=torch.float16)
79
+ # model.to(dtype=torch.float16)
80
+ # text_encoder.to(dtype=torch.float16)
81
+
82
+ return vae, model, text_encoder, diffusion
83
+
84
+
85
+ def auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,):
86
+ b,f,c,h,w=video_input.shape
87
+ latent_h = args.image_size[0] // 8
88
+ latent_w = args.image_size[1] // 8
89
+
90
+ # prepare inputs
91
+ if args.use_fp16:
92
+ z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, dtype=torch.float16, device=device) # b,c,f,h,w
93
+ masked_video = masked_video.to(dtype=torch.float16)
94
+ mask = mask.to(dtype=torch.float16)
95
+ else:
96
+ z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, device=device) # b,c,f,h,w
97
+
98
+
99
+ masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous()
100
+ masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215)
101
+ masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous()
102
+ mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1)
103
+
104
+ # classifier_free_guidance
105
+ if args.do_classifier_free_guidance:
106
+ masked_video = torch.cat([masked_video] * 2)
107
+ mask = torch.cat([mask] * 2)
108
+ z = torch.cat([z] * 2)
109
+ prompt_all = [prompt] + [args.negative_prompt]
110
+
111
+ else:
112
+ masked_video = masked_video
113
+ mask = mask
114
+ z = z
115
+ prompt_all = [prompt]
116
+
117
+ text_prompt = text_encoder(text_prompts=prompt_all, train=False)
118
+ model_kwargs = dict(encoder_hidden_states=text_prompt,
119
+ class_labels=None,
120
+ cfg_scale=args.cfg_scale,
121
+ use_fp16=args.use_fp16,) # tav unet
122
+
123
+ # Sample images:
124
+ if args.sample_method == 'ddim':
125
+ samples = diffusion.ddim_sample_loop(
126
+ model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, \
127
+ mask=mask, x_start=masked_video, use_concat=args.use_mask
128
+ )
129
+ elif args.sample_method == 'ddpm':
130
+ samples = diffusion.p_sample_loop(
131
+ model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, \
132
+ mask=mask, x_start=masked_video, use_concat=args.use_mask
133
+ )
134
+ samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32]
135
+ if args.use_fp16:
136
+ samples = samples.to(dtype=torch.float16)
137
+
138
+ video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32]
139
+ video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256]
140
+ return video_clip
141
+
142
+ def get_input(path,args):
143
+ input_path = path
144
+ # input_path = args.input_path
145
+ transform_video = transforms.Compose([
146
+ video_transforms.ToTensorVideo(), # TCHW
147
+ video_transforms.ResizeVideo((args.image_h, args.image_w)),
148
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
149
+ ])
150
+ temporal_sample_func = video_transforms.TemporalRandomCrop(args.num_frames * args.frame_interval)
151
+ if input_path is not None:
152
+ print(f'loading video from {input_path}')
153
+ if os.path.isdir(input_path):
154
+ file_list = os.listdir(input_path)
155
+ video_frames = []
156
+ if args.mask_type.startswith('onelast'):
157
+ num = int(args.mask_type.split('onelast')[-1])
158
+ # get first and last frame
159
+ first_frame_path = os.path.join(input_path, natsorted(file_list)[0])
160
+ last_frame_path = os.path.join(input_path, natsorted(file_list)[-1])
161
+ first_frame = torch.as_tensor(np.array(Image.open(first_frame_path), dtype=np.uint8, copy=True)).unsqueeze(0)
162
+ last_frame = torch.as_tensor(np.array(Image.open(last_frame_path), dtype=np.uint8, copy=True)).unsqueeze(0)
163
+ for i in range(num):
164
+ video_frames.append(first_frame)
165
+ # add zeros to frames
166
+ num_zeros = args.num_frames-2*num
167
+ for i in range(num_zeros):
168
+ zeros = torch.zeros_like(first_frame)
169
+ video_frames.append(zeros)
170
+ for i in range(num):
171
+ video_frames.append(last_frame)
172
+ n = 0
173
+ video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w
174
+ video_frames = transform_video(video_frames)
175
+ else:
176
+ for file in file_list:
177
+ if file.endswith('jpg') or file.endswith('png'):
178
+ image = torch.as_tensor(np.array(Image.open(os.path.join(input_path,file)), dtype=np.uint8, copy=True)).unsqueeze(0)
179
+ video_frames.append(image)
180
+ else:
181
+ continue
182
+ n = 0
183
+ video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w
184
+ video_frames = transform_video(video_frames)
185
+ return video_frames, n
186
+ elif os.path.isfile(input_path):
187
+ _, full_file_name = os.path.split(input_path)
188
+ file_name, extention = os.path.splitext(full_file_name)
189
+ if extention == '.jpg' or extention == '.png':
190
+ # raise TypeError('a single image is not supported yet!!')
191
+ print("reading video from a image")
192
+ video_frames = []
193
+ num = int(args.mask_type.split('first')[-1])
194
+ first_frame = torch.as_tensor(np.array(Image.open(input_path), dtype=np.uint8, copy=True)).unsqueeze(0)
195
+ for i in range(num):
196
+ video_frames.append(first_frame)
197
+ num_zeros = args.num_frames-num
198
+ for i in range(num_zeros):
199
+ zeros = torch.zeros_like(first_frame)
200
+ video_frames.append(zeros)
201
+ n = 0
202
+ video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w
203
+ video_frames = transform_video(video_frames)
204
+ return video_frames, n
205
+ else:
206
+ raise TypeError(f'{extention} is not supported !!')
207
+ else:
208
+ raise ValueError('Please check your path input!!')
209
+ else:
210
+ # raise ValueError('Need to give a video or some images')
211
+ print('given video is None, using text to video')
212
+ video_frames = torch.zeros(16,3,args.latent_h,args.latent_w,dtype=torch.uint8)
213
+ args.mask_type = 'all'
214
+ video_frames = transform_video(video_frames)
215
+ n = 0
216
+ return video_frames, n
217
+
218
+ def setup_seed(seed):
219
+ torch.manual_seed(seed)
220
+ torch.cuda.manual_seed_all(seed)
221
+
image_to_video/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (13.4 kB). View file
 
input/i2v/Close-up_essence_is_poured_from_bottleKodak_Vision.png ADDED
input/i2v/The_picture_shows_the_beauty_of_the_sea.png ADDED
input/i2v/The_picture_shows_the_beauty_of_the_sea_and_at_the_same.png ADDED