Anshul Nasery commited on
Commit
44f2ca8
1 Parent(s): 93b2b48

Demo commit

Browse files
data/README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluation Preprocessing
2
+
3
+ ## MSR-VTT
4
+ Download the bounding box annotations for MSR-VTT from [here](https://drive.google.com/file/d/1OQvoR5zkohz5GpZxT0-fN1CPY9LjKT6y/view?usp=sharing). This is a pickle file with a dictionary. Each dictionary element has the video id, caption, subject of the caption and a sequence of bounding boxes. These were generated using `get_fg_obj.py`.
5
+ You can also download the videos from MSR-VTT from [this link](https://cove.thecvf.com/datasets/839). The [StyleGAN-v repo](https://github.com/universome/stylegan-v/blob/master/src/scripts/convert_videos_to_frames.py) is used to pre-process and convert the dataset into frames.
6
+
7
+ ### Pre-processing
8
+ Our pre-processing pipeline is described here. We first extract the subject of the caption using Spacy. Then this subject is fed into Owl-ViT to obtain bounding boxes. If there are 0 bounding boxes corresponding to a subject, we use the next caption from the dataset. If there is atleast one bounding box, we interpolate bounding boxes for the missing frames linearly.
9
+
10
+ ## ssv2-ST
11
+ Similar pre-processing is done for this dataset, except that a larger OwL-ViT model is used, and the first noun chunk is extracted instead of the subject. The former significantly slows down the pre-processing. The dataset downloading is a bit complex, you need to follow the instructions [here](https://github.com/MikeWangWZHL/Paxion#dataset-setup). Download the dataset and run `generate_ssv2_st.py`.
12
+
13
+ ## Interactive Motion Control - IMC
14
+ We generate bounding boxes for this dataset using the `generate_imc.py` file. The prompts are in `custom_prompts.csv` and `filtered_prompts.csv`.
data/custom_prompts.csv ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ A butterfly resting on a flower.,Butterfly,Small,Square,Stationary
2
+ A koala clinging to a eucalyptus tree.,Koala,Medium,Rectangle Vertical,Stationary
3
+ A peacock displaying its feathers in a garden.,Peacock,Medium,Rectangle Horizontal,Stationary
4
+ A frog sitting on a lily pad in a pond.,Frog,Small,Square,Stationary
5
+ A deer standing in a snowy field.,Deer,Medium,Rectangle Horizontal,Stationary
6
+ A horse grazing in a meadow.,Horse,Medium,Rectangle Horizontal,Stationary
7
+ A squirrel holding an acorn in a park.,Squirrel,Small,Rectangle Vertical,Stationary
8
+ A parrot perched on a branch in the rainforest.,Parrot,Small,Rectangle Vertical,Stationary
9
+ A fox sitting in a forest clearing.,Fox,Medium,Rectangle Horizontal,Stationary
10
+ A swan floating gracefully on a lake.,Swan,Medium,Rectangle Horizontal,Stationary
11
+ A panda munching bamboo in a bamboo forest.,Panda,Large,Rectangle Vertical,Stationary
12
+ A hummingbird hovering near a flower.,Hummingbird,Small,Rectangle Vertical,Stationary
13
+ A penguin standing on an iceberg.,Penguin,Medium,Rectangle Vertical,Stationary
14
+ A lion lying in the savanna grass.,Lion,Large,Rectangle Horizontal,Stationary
15
+ An owl perched silently in a tree at night.,Owl,Medium,Rectangle Vertical,Stationary
16
+ A goat standing on a rocky hillside.,Goat,Medium,Rectangle Horizontal,Stationary
17
+ A dolphin just breaking the ocean surface.,Dolphin,Medium,Rectangle Horizontal,Stationary
18
+ A camel resting in a desert landscape.,Camel,Large,Rectangle Horizontal,Stationary
19
+ A kangaroo standing in the Australian outback.,Kangaroo,Medium,Rectangle Vertical,Stationary
20
+ An eagle sitting atop a mountain cliff.,Eagle,Small,Rectangle Vertical,Stationary
21
+ An ancient clock tower in a historic city square.,Clock Tower,Large,Rectangle Vertical,Stationary
22
+ A rustic wooden bridge over a tranquil stream.,Wooden Bridge,Medium,Rectangle Horizontal,Stationary
23
+ A grand piano in an elegant concert hall.,Grand Piano,Medium,Rectangle Horizontal,Stationary
24
+ A vintage car parked in front of a classic diner.,Vintage Car,Medium,Rectangle Horizontal,Stationary
25
+ A majestic lighthouse on a rocky coastline.,Lighthouse,Large,Rectangle Vertical,Stationary
26
+ A colorful hot air balloon tethered to the ground.,Hot Air Balloon,Large,Rectangle Vertical,Stationary
27
+ A medieval castle overlooking a scenic valley.,Castle,Large,Rectangle Horizontal,Stationary
28
+ A traditional windmill in a field of tulips.,Windmill,Large,Rectangle Vertical,Stationary
29
+ An intricate sculpture in a modern art museum.,Sculpture,Medium,Rectangle Vertical,Stationary
30
+ A red British telephone box on a city street.,Telephone Box,Medium,Rectangle Vertical,Stationary
31
+ A classic steam train stationed at an old railway platform.,Steam Train,Large,Rectangle Horizontal,Stationary
32
+ An old-fashioned street lamp on a foggy night.,Street Lamp,Medium,Rectangle Vertical,Stationary
33
+ A snow-covered cabin in a winter landscape.,Cabin,Medium,Rectangle Horizontal,Stationary
34
+ A beautifully crafted sundial in a botanical garden.,Sundial,Small,Square,Stationary
35
+ An ornate fountain in a public park.,Fountain,Medium,Square,Stationary
36
+ A weathered rowboat on a peaceful lakeshore.,Rowboat,Small,Rectangle Horizontal,Stationary
37
+ A detailed mural on the side of an urban building.,Mural,Large,Rectangle Horizontal,Stationary
38
+ A historical monument in a busy city center.,Monument,Large,Rectangle Vertical,Stationary
39
+ A charming gazebo in a lush garden.,Gazebo,Medium,Rectangle Horizontal,Stationary
40
+ A striking skyscraper against a city skyline.,Skyscraper,Large,Rectangle Vertical,Stationary
41
+ A cheetah sprinting across the savanna.,Cheetah,Medium,Rectangle Horizontal,Left to right
42
+ A school of fish swimming in a coral reef.,School of Fish,Medium,Rectangle Horizontal,Zig-zag
43
+ A hummingbird darting around a flower garden.,Hummingbird,Small,Square,Zig-zag
44
+ A horse galloping through a meadow.,Horse,Large,Rectangle Horizontal,Left to right
45
+ A squirrel scampering up a tree trunk.,Squirrel,Small,Rectangle Vertical,Up to down
46
+ A flock of geese flying in a V-formation.,Flock of Geese,Large,Rectangle Horizontal,Left to right
47
+ A rabbit hopping through a grassy field.,Rabbit,Small,Rectangle Horizontal,Zig-zag
48
+ A dolphin leaping out of the ocean waves.,Dolphin,Medium,Rectangle Horizontal,Up to down
49
+ A bee buzzing around a blooming sunflower.,Bee,Small,Square,Zig-zag
50
+ A butterfly fluttering over a meadow of wildflowers.,Butterfly,Small,Square,Zig-zag
51
+ A kangaroo bounding across the Australian outback.,Kangaroo,Medium,Rectangle Horizontal,Left to right
52
+ A hawk soaring in the sky above a mountain range.,Hawk,Medium,Rectangle Horizontal,Left to right
53
+ A spider spinning a web in the morning light.,Spider,Small,Square,Zig-zag
54
+ A snake slithering through a tropical rainforest.,Snake,Medium,Rectangle Horizontal,Zig-zag
55
+ A dog running to catch a frisbee in a park.,Dog,Medium,Rectangle Horizontal,Left to right
56
+ A cat playfully chasing a ball of yarn.,Cat,Small,Square,Zig-zag
57
+ A herd of elephants migrating across the African plains.,Herd of Elephants,Large,Rectangle Horizontal,Left to right
58
+ A bat flitting through the night sky.,Bat,Small,Rectangle Horizontal,Zig-zag
59
+ A group of penguins waddling on an Antarctic ice sheet.,Group of Penguins,Medium,Rectangle Horizontal,Left to right
60
+ A monkey swinging from branch to branch in the jungle.,Monkey,Medium,Rectangle Horizontal,Zig-zag
61
+ A woodpecker climbing up a tree trunk.,Woodpecker,Small,Rectangle Vertical,Up to down
62
+ A squirrel descending a tree after gathering nuts.,Squirrel,Small,Rectangle Vertical,Up to down
63
+ A snake slithering down a rocky hill.,Snake,Medium,Rectangle Vertical,Up to down
64
+ A bird diving towards the water to catch fish.,Bird,Small,Rectangle Vertical,Up to down
65
+ A monkey climbing up a vine in the rainforest.,Monkey,Medium,Rectangle Vertical,Up to down
66
+ A cat jumping down from a high fence.,Cat,Small,Rectangle Vertical,Up to down
67
+ An eagle descending from the sky to its nest.,Eagle,Medium,Rectangle Vertical,Up to down
68
+ A frog leaping up to catch a fly.,Frog,Small,Rectangle Vertical,Up to down
69
+ A spider descending on its web from a branch.,Spider,Small,Rectangle Vertical,Up to down
70
+ A mountain goat scaling a steep cliff.,Mountain Goat,Medium,Rectangle Vertical,Up to down
71
+ A koala climbing up a eucalyptus tree.,Koala,Small,Rectangle Vertical,Up to down
72
+ A bear climbing down a tree after spotting a threat.,Bear,Large,Rectangle Vertical,Up to down
73
+ A parrot flying upwards towards the treetops.,Parrot,Small,Rectangle Vertical,Up to down
74
+ A squirrel jumping from one tree to another,Squirrel,Small,Rectangle Vertical,Up to down
75
+ A bat swooping down from a cave's ceiling.,Bat,Small,Rectangle Vertical,Up to down
76
+ A duck diving underwater in search of food.,Duck,Medium,Rectangle Vertical,Up to down
77
+ A kangaroo hopping down a gentle slope.,Kangaroo,Medium,Rectangle Vertical,Up to down
78
+ A rabbit burrowing downwards into its warren.,Rabbit,Small,Rectangle Vertical,Up to down
79
+ A raccoon climbing up a city lamppost.,Raccoon,Medium,Rectangle Vertical,Up to down
80
+ An owl swooping down on its prey during the night.,Owl,Medium,Rectangle Vertical,Up to down
81
+ A train chugging along a mountainous landscape.,Train,Large,Rectangle Horizontal,Left to right
82
+ A hot air balloon drifting across a clear sky.,Hot Air Balloon,Large,Rectangle Vertical,Up to down
83
+ A sailboat gliding over the ocean waves.,Sailboat,Medium,Rectangle Horizontal,Left to right
84
+ A vintage car cruising down a coastal highway.,Vintage Car,Medium,Rectangle Horizontal,Left to right
85
+ A windmill turning its blades in a gentle breeze.,Windmill,Large,Rectangle Vertical,Left to right
86
+ A Ferris wheel rotating at a lively carnival.,Ferris Wheel,Large,Rectangle Vertical,Left to right
87
+ A satellite orbiting Earth in outer space.,Satellite,Small,Rectangle Horizontal,Left to right
88
+ A red double-decker bus moving through London streets.,Double-Decker Bus,Large,Rectangle Vertical,Left to right
89
+ A bicycle rider pedaling through a city park.,Bicycle Rider,Medium,Rectangle Horizontal,Left to right
90
+ A skateboarder performing tricks at a skate park.,Skateboarder,Small,Rectangle Horizontal,Left to right
91
+ A gondola floating down a Venetian canal.,Gondola,Medium,Rectangle Horizontal,Left to right
92
+ A jet plane flying high in the sky.,Jet Plane,Large,Rectangle Horizontal,Left to right
93
+ A helicopter hovering above a cityscape.,Helicopter,Medium,Rectangle Horizontal,Left to right
94
+ A roller coaster looping in an amusement park.,Roller Coaster,Large,Rectangle Horizontal,Left to right
95
+ A leaf falling gently from a tree.,Leaf,Small,Rectangle Vertical,Up to down
96
+ A mechanical clock's hands ticking forward.,Clock,Medium,Square,Left to right
97
+ A streetcar trundling down tracks in a historic district.,Streetcar,Medium,Rectangle Horizontal,Left to right
98
+ A rocket launching into space from a launchpad.,Rocket,Large,Rectangle Vertical,Up to down
99
+ A paper plane gliding in the air.,Paper Plane,Small,Rectangle Horizontal,Left to right
100
+ An escalator carrying people up in a shopping mall.,Escalator,Large,Rectangle Vertical,Up to down
data/filtered_prompts.txt ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ A woodpecker climbing up a tree trunk.
2
+ A woodpecker climbing up a tree trunk.
3
+ A woodpecker climbing up a tree trunk.
4
+ A squirrel descending a tree after gathering nuts.
5
+ A squirrel descending a tree after gathering nuts.
6
+ A squirrel descending a tree after gathering nuts.
7
+ A bird diving towards the water to catch fish.
8
+ A bird diving towards the water to catch fish.
9
+ A bird diving towards the water to catch fish.
10
+ A frog leaping up to catch a fly.
11
+ A frog leaping up to catch a fly.
12
+ A frog leaping up to catch a fly.
13
+ A spider descending on its web from a branch.
14
+ A spider descending on its web from a branch.
15
+ A spider descending on its web from a branch.
16
+ A parrot flying upwards towards the treetops.
17
+ A parrot flying upwards towards the treetops.
18
+ A parrot flying upwards towards the treetops.
19
+ A squirrel jumping from one tree to another
20
+ A squirrel jumping from one tree to another
21
+ A squirrel jumping from one tree to another
22
+ A rabbit burrowing downwards into its warren.
23
+ A rabbit burrowing downwards into its warren.
24
+ A rabbit burrowing downwards into its warren.
25
+ A satellite orbiting Earth in outer space.
26
+ A satellite orbiting Earth in outer space.
27
+ A satellite orbiting Earth in outer space.
28
+ A skateboarder performing tricks at a skate park.
29
+ A skateboarder performing tricks at a skate park.
30
+ A skateboarder performing tricks at a skate park.
31
+ A leaf falling gently from a tree.
32
+ A leaf falling gently from a tree.
33
+ A leaf falling gently from a tree.
34
+ A paper plane gliding in the air.
35
+ A paper plane gliding in the air.
36
+ A bear climbing down a tree after spotting a threat.
37
+ A bear climbing down a tree after spotting a threat.
38
+ A bear climbing down a tree after spotting a threat.
39
+ A duck diving underwater in search of food.
40
+ A duck diving underwater in search of food.
41
+ A duck diving underwater in search of food.
42
+ A kangaroo hopping down a gentle slope.
43
+ A kangaroo hopping down a gentle slope.
44
+ A kangaroo hopping down a gentle slope.
45
+ An owl swooping down on its prey during the night.
46
+ An owl swooping down on its prey during the night.
47
+ An owl swooping down on its prey during the night.
48
+ A hot air balloon drifting across a clear sky.
49
+ A hot air balloon drifting across a clear sky.
50
+ A hot air balloon drifting across a clear sky.
51
+ A sailboat gliding over the ocean waves.
52
+ A sailboat gliding over the ocean waves.
53
+ A sailboat gliding over the ocean waves.
54
+ A vintage car cruising down a coastal highway.
55
+ A vintage car cruising down a coastal highway.
56
+ A vintage car cruising down a coastal highway.
57
+ A red double-decker bus moving through London streets.
58
+ A red double-decker bus moving through London streets.
59
+ A red double-decker bus moving through London streets.
60
+ A jet plane flying high in the sky.
61
+ A jet plane flying high in the sky.
62
+ A jet plane flying high in the sky.
63
+ A helicopter hovering above a cityscape.
64
+ A helicopter hovering above a cityscape.
65
+ A helicopter hovering above a cityscape.
66
+ A roller coaster looping in an amusement park.
67
+ A roller coaster looping in an amusement park.
68
+ A roller coaster looping in an amusement park.
69
+ A streetcar trundling down tracks in a historic district.
70
+ A streetcar trundling down tracks in a historic district.
71
+ A streetcar trundling down tracks in a historic district.
72
+ A rocket launching into space from a launchpad.
73
+ A rocket launching into space from a launchpad.
74
+ A rocket launching into space from a launchpad.
75
+ A deer standing in a snowy field.
76
+ A deer standing in a snowy field.
77
+ A deer standing in a snowy field.
78
+ A horse grazing in a meadow.
79
+ A horse grazing in a meadow.
80
+ A horse grazing in a meadow.
81
+ A fox sitting in a forest clearing.
82
+ A fox sitting in a forest clearing.
83
+ A fox sitting in a forest clearing.
84
+ A swan floating gracefully on a lake.
85
+ A swan floating gracefully on a lake.
86
+ A swan floating gracefully on a lake.
87
+ A panda munching bamboo in a bamboo forest.
88
+ A panda munching bamboo in a bamboo forest.
89
+ A panda munching bamboo in a bamboo forest.
90
+ A penguin standing on an iceberg.
91
+ A penguin standing on an iceberg.
92
+ A penguin standing on an iceberg.
93
+ A lion lying in the savanna grass.
94
+ A lion lying in the savanna grass.
95
+ A lion lying in the savanna grass.
96
+ An owl perched silently in a tree at night.
97
+ An owl perched silently in a tree at night.
98
+ An owl perched silently in a tree at night.
99
+ A dolphin just breaking the ocean surface.
100
+ A dolphin just breaking the ocean surface.
101
+ A dolphin just breaking the ocean surface.
102
+ A camel resting in a desert landscape.
103
+ A camel resting in a desert landscape.
104
+ A camel resting in a desert landscape.
105
+ A kangaroo standing in the Australian outback.
106
+ A kangaroo standing in the Australian outback.
107
+ A kangaroo standing in the Australian outback.
108
+ A grand piano in an elegant concert hall.
109
+ A grand piano in an elegant concert hall.
110
+ A grand piano in an elegant concert hall.
111
+ A vintage car parked in front of a classic diner.
112
+ A vintage car parked in front of a classic diner.
113
+ A vintage car parked in front of a classic diner.
114
+ A colorful hot air balloon tethered to the ground.
115
+ A colorful hot air balloon tethered to the ground.
116
+ A colorful hot air balloon tethered to the ground.
117
+ A red British telephone box on a city street.
118
+ A red British telephone box on a city street.
119
+ A red British telephone box on a city street.
120
+ A classic steam train stationed at an old railway platform.
121
+ A classic steam train stationed at an old railway platform.
122
+ A classic steam train stationed at an old railway platform.
123
+ An old-fashioned street lamp on a foggy night.
124
+ An old-fashioned street lamp on a foggy night.
125
+ An old-fashioned street lamp on a foggy night.
data/generate_imc.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import random
4
+ import csv
5
+ import pickle
6
+ import tqdm
7
+
8
+ def clamp(x, min_val, max_val):
9
+ return int(max(min(x, max_val), min_val))
10
+
11
+ def generate_moving_frames_simpler(canvas_size, num_frames, aspect_ratio, bounding_box_size, motion_type, up_to_down_strict=False, keep_in_frame=True):
12
+ # Mapping size to bounding box dimensions
13
+ size_mapping = {'Small': 0.25, 'Medium': 0.3, 'Large': 0.3}
14
+ aspect_ratio_mapping = {'Rectangle Vertical': (1.33, 1), 'Rectangle Horizontal': (1, 1.33), 'Square': (1, 1)}
15
+
16
+ # Calculate bounding box dimensions
17
+ ratio = aspect_ratio_mapping[aspect_ratio]
18
+ box_height = int(canvas_size[0] * size_mapping[bounding_box_size] * ratio[0])
19
+ box_width = int(canvas_size[1] * size_mapping[bounding_box_size] * ratio[1])
20
+
21
+ x_init_pos = [0.1 * canvas_size[1], 0.25 * canvas_size[1], 0.45*canvas_size[1], 0.7 * canvas_size[1]]
22
+ y_init_pos = [0.1 * canvas_size[0], 0.25 * canvas_size[0], 0.45*canvas_size[0], 0.7 * canvas_size[0]]
23
+
24
+ speed_dir = random.choice([-1,1]) # random.randint(1, 3)*4
25
+ # print('-'*20)
26
+ # print(motion_type)
27
+ if 'up' in motion_type.lower():
28
+ # Freedom in horizontal init
29
+ # Vertical init depends on upward or downward motion
30
+ pos_x = random.choice(x_init_pos) + random.randint(int(-0.01 * canvas_size[1]), int(0.01 * canvas_size[1]))
31
+ if up_to_down_strict == 'up':
32
+ # pos_y = np.random.choice(y_init_pos[2:]) + random.randint(int(-0.01 * canvas_size[1]), int(0.01 * canvas_size[1]))
33
+ speed_dir = -1.
34
+ elif up_to_down_strict == 'down':
35
+ # pos_y = np.random.choice(y_init_pos[2:]) + random.randint(int(-0.01 * canvas_size[1]), int(0.01 * canvas_size[1]))
36
+ speed_dir = 1.
37
+ # y_end_max = canvas_size[0] - box_height
38
+
39
+ # else:
40
+ if speed_dir == 1.:
41
+ pos_y = np.random.choice(y_init_pos[:2]) + random.randint(int(-0.01 * canvas_size[0]), int(0.01 * canvas_size[0]))
42
+ y_end_max = canvas_size[0] - box_height
43
+ else:
44
+ pos_y = np.random.choice(y_init_pos[2:]) + random.randint(int(-0.01 * canvas_size[0]), int(0.01 * canvas_size[0]))
45
+ y_end_max = box_height
46
+ max_speed = np.abs(y_end_max - pos_y) / num_frames
47
+
48
+ speed = random.randint(2, 4)*4
49
+ speed = min(speed, max_speed)
50
+ speed = speed_dir * speed
51
+ elif 'left' in motion_type.lower():
52
+ # Freedom in vertical init
53
+ # Horizontal init depends on upward or downward motion
54
+ pos_y = random.choice(y_init_pos) + random.randint(int(-0.01 * canvas_size[0]), int(0.01 * canvas_size[0]))
55
+ if up_to_down_strict:
56
+ speed_dir = 1.
57
+
58
+ if speed_dir == 1.:
59
+ pos_x = np.random.choice(x_init_pos[:2]) + random.randint(int(-0.01 * canvas_size[1]), int(0.01 * canvas_size[1]))
60
+ x_end_max = canvas_size[1] - box_width
61
+ else:
62
+ pos_x = np.random.choice(x_init_pos[2:]) + random.randint(int(-0.01 * canvas_size[1]), int(0.01 * canvas_size[1]))
63
+ x_end_max = box_width
64
+ max_speed = np.abs(x_end_max - pos_x) / num_frames
65
+
66
+ speed = random.randint(2, 4)*4
67
+ speed = min(speed, max_speed)
68
+ speed = speed_dir * speed
69
+
70
+ else:
71
+ speed_dir_y = random.choice([-1,1])
72
+ if speed_dir == 1.:
73
+ pos_x = np.random.choice(x_init_pos[:2]) + random.randint(int(-0.01 * canvas_size[1]), int(0.01 * canvas_size[1]))
74
+ x_end_max = canvas_size[1] - box_width
75
+ else:
76
+ pos_x = np.random.choice(x_init_pos[2:]) + random.randint(int(-0.01 * canvas_size[1]), int(0.01 * canvas_size[1]))
77
+ x_end_max = box_width
78
+
79
+ if speed_dir_y == 1.:
80
+ pos_y = np.random.choice(y_init_pos[:2]) + random.randint(int(-0.01 * canvas_size[0]), int(0.01 * canvas_size[0]))
81
+ y_end_max = canvas_size[0] - box_height
82
+ else:
83
+ pos_y = np.random.choice(y_init_pos[2:]) + random.randint(int(-0.01 * canvas_size[0]), int(0.01 * canvas_size[0]))
84
+ y_end_max = box_height
85
+ max_speed_x = np.abs(x_end_max - pos_x) / num_frames
86
+ max_speed_y = np.abs(y_end_max - pos_y) / num_frames
87
+ speed_x = random.randint(2, 4)*4
88
+ speed_y = random.randint(2, 4)*4
89
+ speed_x = min(speed_x, max_speed_x)
90
+ speed_y = min(speed_y, max_speed_y)
91
+ speed_x, speed_y = (speed_dir * speed_x, speed_dir_y * speed_y)
92
+
93
+ frames = []
94
+
95
+
96
+ for _ in range(num_frames):
97
+ canvas = np.zeros(canvas_size)
98
+
99
+ # Determine movement direction and apply movement
100
+ if motion_type == "Left to right":
101
+ pos_x = (pos_x + speed) # % (canvas_size[1] - box_width)
102
+ pos_y = pos_y + random.randint(int(-0.01 * canvas_size[0]), int(0.01 * canvas_size[0]))
103
+ elif motion_type == "Up to down":
104
+ pos_y = (pos_y + speed) # % (canvas_size[0] - box_height)
105
+ pos_x = pos_x + random.randint(int(-0.01 * canvas_size[1]), int(0.01 * canvas_size[1]))
106
+ elif motion_type == "Zig-zag":
107
+ # Zig-zag motion alternates between horizontal and vertical movement
108
+ if _ % 2 == 0:
109
+ pos_x = (pos_x + speed_x) # % (canvas_size[1] - box_width)
110
+ else:
111
+ pos_y = (pos_y + speed_y) # % (canvas_size[0] - box_height)
112
+ canvas[clamp(pos_y, 0, canvas_size[0]):clamp(pos_y + box_height, 0, canvas_size[0]),
113
+ clamp(pos_x, 0, canvas_size[1]):clamp(pos_x + box_width, 0, canvas_size[1])] = 1
114
+
115
+ # Add frame to the list
116
+ frames.append(canvas)
117
+
118
+ return frames
119
+
120
+
121
+ def generate_stationary_frames_simpler(canvas_size, num_frames, aspect_ratio, bounding_box_size):
122
+ # Mapping size to bounding box dimensions
123
+ size_mapping = {'Small': 0.25, 'Medium': 0.3, 'Large': 0.3}
124
+ aspect_ratio_mapping = {'Rectangle Vertical': (1.33, 1), 'Rectangle Horizontal': (1, 1.33), 'Square': (1, 1)}
125
+
126
+ # Calculate bounding box dimensions
127
+ ratio = aspect_ratio_mapping[aspect_ratio]
128
+ box_height = int(canvas_size[0] * size_mapping[bounding_box_size] * ratio[0])
129
+ box_width = int(canvas_size[1] * size_mapping[bounding_box_size] * ratio[1])
130
+
131
+ x_init_pos = [0.1 * canvas_size[1], 0.25 * canvas_size[1], 0.45*canvas_size[1], 0.7 * canvas_size[1]]
132
+ y_init_pos = [0.1 * canvas_size[0], 0.25 * canvas_size[0], 0.45*canvas_size[0], 0.7 * canvas_size[0]]
133
+
134
+ pos_x = np.random.choice(x_init_pos) + random.randint(int(-0.01 * canvas_size[1]), int(0.01 * canvas_size[1]))
135
+ pos_y = np.random.choice(y_init_pos) + random.randint(int(-0.01 * canvas_size[0]), int(0.01 * canvas_size[0]))
136
+ # Initialize frames
137
+ frames = []
138
+ for _ in range(num_frames):
139
+ canvas = np.zeros(canvas_size)
140
+
141
+ # Determine movement direction and apply movement
142
+ pos_y = pos_y + random.randint(int(-0.01 * canvas_size[0]), int(0.01 * canvas_size[0]))
143
+ pos_x = pos_x + random.randint(int(-0.01 * canvas_size[1]), int(0.01 * canvas_size[1]))
144
+
145
+ canvas[clamp(pos_y, 0, canvas_size[0]):clamp(pos_y + box_height, 0, canvas_size[0]),
146
+ clamp(pos_x, 0, canvas_size[1]):clamp(pos_x + box_width, 0, canvas_size[1])] = 1
147
+
148
+ # Add frame to the list
149
+ frames.append(canvas)
150
+
151
+
152
+ return frames
153
+
154
+
155
+ input_file_path = "custom_prompts.csv"
156
+ output_file_path = "custom_prompts_with_bb.pkl"
157
+ num_videos_per_prompt = 3
158
+ video_id = 1100
159
+ all_records = []
160
+ frames_per_prompts = 3
161
+ num_frames = 16
162
+ with open('filtered_prompts.txt') as f:
163
+ GOOD_PROMPTS = set([x.strip() for x in f.readlines()])
164
+ with open(input_file_path, "r") as f:
165
+ data = csv.reader(f)
166
+ for row in tqdm.tqdm(data):
167
+ prompt = row[0]
168
+ prompt = prompt.replace('herd of', '').replace('group of', '').replace('flock of', '').replace('school of', '').replace('escalator', 'elevator')
169
+ subject = row[1].lower().replace('herd of', '').replace('group of', '').replace('flock of', '').replace('school of', '').replace('escalator', 'elevator')
170
+ if prompt not in GOOD_PROMPTS:
171
+ continue
172
+ canvas_size = (224, 224)
173
+ frames = []
174
+ if row[-1] == "Stationary":
175
+ for _ in range(frames_per_prompts):
176
+ frames.append(generate_stationary_frames_simpler(canvas_size, num_frames, row[3], row[2]))
177
+ else:
178
+ for _ in range(frames_per_prompts):
179
+ up_to_down_strict = False
180
+ if "up" in prompt.lower() or 'ascending' in prompt.lower():
181
+ up_to_down_strict = 'up'
182
+ elif "down" in prompt.lower() or 'descending' in prompt.lower():
183
+ up_to_down_strict = 'down'
184
+ else:
185
+ up_to_down_strict = False
186
+ frames.append(generate_moving_frames_simpler(canvas_size, num_frames, row[3], row[2], row[4], up_to_down_strict))
187
+
188
+ for i in range(frames_per_prompts):
189
+ record_dict = {"video_id": video_id, "prompt": prompt, "frames": frames[i], "subject": row[1], "motion": row[4], "aspect_ratio": row[3], "bounding_box_size": row[2]}
190
+ all_records.append(record_dict)
191
+ video_id += 1
192
+ print(f"Wrote {len(all_records)} records to {output_file_path}")
193
+ with open(output_file_path, "wb") as f:
194
+ pickle.dump(all_records, f)
195
+
196
+
data/generate_ssv2_st.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #@title Get bounding boxes for the subject
2
+ from transformers import pipeline
3
+ from moviepy.editor import VideoFileClip
4
+ from PIL import Image
5
+ import os
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ import tqdm
9
+ import pickle
10
+ import torch
11
+
12
+ checkpoint = "google/owlvit-large-patch14"
13
+ detector = pipeline(model=checkpoint, task="zero-shot-object-detection", cache_dir="/coc/pskynet4/yashjain/", device='cuda:0')
14
+
15
+
16
+ # from transformers import Owlv2Processor, Owlv2ForObjectDetection
17
+
18
+ # processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
19
+ # model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")
20
+
21
+ # def owl_inference(image, text):
22
+ # inputs = inputs = processor(text=text, images=image, return_tensors="pt")
23
+ # outputs = model(**inputs)
24
+ # target_sizes = torch.Tensor([image.size[::-1]])
25
+ # results = processor.post_process_object_detection(outputs=outputs, threshold=0.1, target_sizes=target_sizes)
26
+ # return results[0]['boxes']
27
+
28
+ def find_surrounding_masks(mask_presence):
29
+ # Finds the indices of the surrounding masks for each gap
30
+ gap_info = []
31
+ start = None
32
+
33
+ for i, present in enumerate(mask_presence):
34
+ if present and start is not None:
35
+ end = i
36
+ gap_info.append((start, end))
37
+ start = None
38
+ elif not present and start is None and i > 0:
39
+ start = i - 1
40
+
41
+ # Handle the special case where the gap is at the end
42
+ if start is not None:
43
+ gap_info.append((start, len(mask_presence)))
44
+
45
+ return gap_info
46
+
47
+ def copy_edge_masks(mask_list, mask_presence):
48
+ if not mask_presence[-1]:
49
+ # Find the last present mask and copy it to the end
50
+ for i in reversed(range(len(mask_presence))):
51
+ if mask_presence[i]:
52
+ mask_list[i+1:] = [mask_list[i]] * (len(mask_presence) - i - 1)
53
+ break
54
+
55
+ def interpolate_masks(mask_list, mask_presence):
56
+ # Ensure the mask list and mask presence list are the same length
57
+ assert len(mask_list) == len(mask_presence), "Mask list and presence list must have the same length."
58
+
59
+ # Copy edge masks if there are gaps at the start or end
60
+ # copy_edge_masks(mask_list, mask_presence)
61
+
62
+ # Find surrounding masks for gaps
63
+ gap_info = find_surrounding_masks(mask_presence)
64
+
65
+ # Interpolate the masks in the gaps
66
+ for start, end in gap_info:
67
+ end = min(end, len(mask_list)-1)
68
+ num_steps = end - start - 1
69
+ prev_mask = mask_list[start]
70
+ next_mask = mask_list[end]
71
+ step = (next_mask - prev_mask) / (num_steps + 1)
72
+ interpolated_masks = [(prev_mask + step * (i + 1)).round().astype(int) for i in range(num_steps)]
73
+ mask_list[start + 1:end] = interpolated_masks
74
+
75
+ return mask_list
76
+
77
+ def get_bounding_boxes(clip_path, subject):
78
+ # Read video from the path
79
+ clip = VideoFileClip(clip_path)
80
+ all_bboxes = []
81
+ bbox_present = []
82
+
83
+ num_bb = 0
84
+
85
+ for fidx,frame in enumerate(clip.iter_frames()):
86
+ if fidx > 24: break
87
+
88
+ frame = Image.fromarray(frame)
89
+
90
+ predictions = detector(
91
+ frame,
92
+ candidate_labels=[subject,],
93
+ )
94
+ try:
95
+
96
+ bbox = predictions[0]["box"]
97
+
98
+ bbox = (bbox["xmin"], bbox["ymin"], bbox["xmax"], bbox["ymax"])
99
+
100
+ # Get a zeros array of the same size as the frame
101
+ canvas = np.zeros(frame.size[::-1])
102
+ # Draw the bounding box on the canvas
103
+ canvas[bbox[1]:bbox[3], bbox[0]:bbox[2]] = 1
104
+ # Add the canvas to the list of bounding boxes
105
+ all_bboxes.append(canvas)
106
+ bbox_present.append(True)
107
+ num_bb += 1
108
+ except Exception as e:
109
+
110
+ # Append an empty canvas, we will interpolate later
111
+ all_bboxes.append(np.zeros(frame.size[::-1]))
112
+ bbox_present.append(False)
113
+ continue
114
+
115
+ # Design decision
116
+ interpolated_masks = interpolate_masks(all_bboxes, bbox_present)
117
+ return interpolated_masks, num_bb
118
+
119
+ import json
120
+ BASE_DIR = '/scr/clips_downsampled_5fps_downsized_224x224'
121
+ annotations = json.load(open('/gscratch/sewoong/anasery/datasets/ssv2/datasets/SSv2/ssv2_label_ssv2_template/ssv2_ret_label_val_small_filtered.json', 'r'))
122
+
123
+ records_with_masks = []
124
+ ridx = 0
125
+ for idx,record in tqdm.tqdm(enumerate(annotations)):
126
+ video_id = record['video']
127
+ print(f"{record['caption']} - {record['nouns']}")
128
+ # for video_id in video_ids:
129
+ new_record = record.copy()
130
+ new_record['video'] = video_id.replace('webm', 'mp4')
131
+ all_masks = []
132
+ all_num_bb = []
133
+ for subject in record['nouns']:
134
+ masks, num_bb = get_bounding_boxes(clip_path=os.path.join(BASE_DIR, video_id.replace('webm', 'mp4')), subject=subject)
135
+ all_masks.append(masks)
136
+ all_num_bb.append(num_bb)
137
+ try:
138
+ print(f"{record['video']} , subj - {record['nouns']}, bb - {all_num_bb}")
139
+ except:
140
+ continue
141
+ new_record['masks'] = all_masks
142
+ records_with_masks.append(new_record)
143
+ ridx += 1
144
+
145
+ if ridx % 100 == 0:
146
+ with open(f'/gscratch/sewoong/anasery/datasets/ssv2/datasets/SSv2/SSv2_label_with_two_obj_masks.pkl', 'wb') as f:
147
+ pickle.dump(records_with_masks, f)
148
+
149
+ with open(f'/gscratch/sewoong/anasery/datasets/ssv2/datasets/SSv2/SSv2_label_with_two_obj_masks.pkl', 'wb') as f:
150
+ pickle.dump(records_with_masks, f)
data/get_bbox.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #@title Get bounding boxes for the subject
2
+ from transformers import pipeline
3
+ from moviepy.editor import VideoFileClip
4
+ from PIL import Image
5
+ import os
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ import tqdm
9
+ import pickle
10
+ import torch
11
+
12
+ checkpoint = "google/owlvit-large-patch14"
13
+ detector = pipeline(model=checkpoint, task="zero-shot-object-detection", device='cuda:0')
14
+
15
+
16
+ def get_bounding_boxes(clip_path, subject):
17
+ # Read video from the path
18
+ clip = VideoFileClip(clip_path)
19
+ all_bboxes = []
20
+ bbox_present = []
21
+
22
+ num_bb = 0
23
+ for fidx,frame in enumerate(clip.iter_frames()):
24
+ frame = Image.fromarray(frame)
25
+
26
+ predictions = detector(
27
+ frame,
28
+ candidate_labels=[subject,],
29
+ )
30
+ try:
31
+
32
+ bbox = predictions[0]["box"]
33
+
34
+ bbox = (bbox["xmin"], bbox["ymin"], bbox["xmax"], bbox["ymax"])
35
+
36
+ # Get a zeros array of the same size as the frame
37
+ canvas = np.zeros(frame.size[::-1])
38
+ # Draw the bounding box on the canvas
39
+ canvas[bbox[1]:bbox[3], bbox[0]:bbox[2]] = 1
40
+ # Add the canvas to the list of bounding boxes
41
+ all_bboxes.append(canvas)
42
+ bbox_present.append(True)
43
+ num_bb += 1
44
+ except Exception as e:
45
+
46
+ # Append an empty canvas, we will interpolate later
47
+ all_bboxes.append(np.zeros(frame.size[::-1]))
48
+ bbox_present.append(False)
49
+ continue
50
+ return all_bboxes, num_bb
51
+
52
+ import pickle as pkl
53
+ dir_path = '/your/result/path'
54
+
55
+ video_filename = '2_of_40_2.mp4'
56
+ output_bbox = []
57
+ with open("/ssv2dataset/path.pkl", "rb") as f:
58
+ data = pkl.load(f)
59
+ dataset_size = len(data)
60
+ failed_cnt = 0
61
+ for i, d in tqdm.tqdm(enumerate(data)):
62
+ try:
63
+ # print(f"{d['subject']} || {d['caption']} || {d['video']}")
64
+ filename = d['video'].split('.')[0]
65
+ video_path = os.path.join(dir_path, filename, video_filename)
66
+ fg_object = d['subject']
67
+ masks, num_bb = get_bounding_boxes(video_path, fg_object)
68
+
69
+ output_bbox.append({
70
+ 'caption': d['caption'],
71
+ 'video': d['video'],
72
+ 'subject': d['subject'],
73
+ 'mask': masks,
74
+ 'num_bb': num_bb
75
+ })
76
+ # print(num_bb)
77
+ except:
78
+ print(f"Missed #{i} with Caption: {d['caption']}")
79
+ failed_cnt += 1
80
+
81
+ with open(f"/output/path/iou_eval/ssv2_modelscope_{video_filename.split('.')[0]}_bbox-v2.pkl", "wb") as f:
82
+ pkl.dump(output_bbox, f)
83
+
84
+ print(f"Failed: {failed_cnt} out of {dataset_size}")
env.yaml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: llmgd-3
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=main
7
+ - _openmp_mutex=5.1=1_gnu
8
+ - blas=1.0=mkl
9
+ - brotlipy=0.7.0=py310h7f8727e_1002
10
+ - bzip2=1.0.8=h7b6447c_0
11
+ - ca-certificates=2023.08.22=h06a4308_0
12
+ - certifi=2023.7.22=py310h06a4308_0
13
+ - cffi=1.15.1=py310h5eee18b_3
14
+ - charset-normalizer=2.0.4=pyhd3eb1b0_0
15
+ - cryptography=41.0.3=py310hdda0065_0
16
+ - cudatoolkit=11.3.1=h2bc3f7f_2
17
+ - ffmpeg=4.3=hf484d3e_0
18
+ - freetype=2.12.1=h4a9f257_0
19
+ - giflib=5.2.1=h5eee18b_3
20
+ - gmp=6.2.1=h295c915_3
21
+ - gnutls=3.6.15=he1e5248_0
22
+ - idna=3.4=py310h06a4308_0
23
+ - intel-openmp=2023.1.0=hdb19cb5_46305
24
+ - jpeg=9e=h5eee18b_1
25
+ - lame=3.100=h7b6447c_0
26
+ - lcms2=2.12=h3be6417_0
27
+ - ld_impl_linux-64=2.38=h1181459_1
28
+ - lerc=3.0=h295c915_0
29
+ - libdeflate=1.17=h5eee18b_1
30
+ - libffi=3.4.4=h6a678d5_0
31
+ - libgcc-ng=11.2.0=h1234567_1
32
+ - libgomp=11.2.0=h1234567_1
33
+ - libiconv=1.16=h7f8727e_2
34
+ - libidn2=2.3.4=h5eee18b_0
35
+ - libpng=1.6.39=h5eee18b_0
36
+ - libstdcxx-ng=11.2.0=h1234567_1
37
+ - libtasn1=4.19.0=h5eee18b_0
38
+ - libtiff=4.5.1=h6a678d5_0
39
+ - libunistring=0.9.10=h27cfd23_0
40
+ - libuuid=1.41.5=h5eee18b_0
41
+ - libwebp=1.3.2=h11a3e52_0
42
+ - libwebp-base=1.3.2=h5eee18b_0
43
+ - lz4-c=1.9.4=h6a678d5_0
44
+ - mkl=2023.1.0=h213fc3f_46343
45
+ - mkl-service=2.4.0=py310h5eee18b_1
46
+ - mkl_fft=1.3.8=py310h5eee18b_0
47
+ - mkl_random=1.2.4=py310hdb19cb5_0
48
+ - ncurses=6.4=h6a678d5_0
49
+ - nettle=3.7.3=hbbd107a_1
50
+ - numpy=1.26.0=py310h5f9d8c6_0
51
+ - numpy-base=1.26.0=py310hb5e798b_0
52
+ - openh264=2.1.1=h4ff587b_0
53
+ - openjpeg=2.4.0=h3ad879b_0
54
+ - openssl=3.0.11=h7f8727e_2
55
+ - pillow=10.0.1=py310ha6cbd5a_0
56
+ - pip=23.3=py310h06a4308_0
57
+ - pycparser=2.21=pyhd3eb1b0_0
58
+ - pyopenssl=23.2.0=py310h06a4308_0
59
+ - pysocks=1.7.1=py310h06a4308_0
60
+ - python=3.10.13=h955ad1f_0
61
+ - pytorch=1.12.1=py3.10_cuda11.3_cudnn8.3.2_0
62
+ - pytorch-mutex=1.0=cuda
63
+ - readline=8.2=h5eee18b_0
64
+ - requests=2.31.0=py310h06a4308_0
65
+ - setuptools=68.0.0=py310h06a4308_0
66
+ - sqlite=3.41.2=h5eee18b_0
67
+ - tbb=2021.8.0=hdb19cb5_0
68
+ - tk=8.6.12=h1ccaba5_0
69
+ - torchaudio=0.12.1=py310_cu113
70
+ - torchvision=0.13.1=py310_cu113
71
+ - typing_extensions=4.7.1=py310h06a4308_0
72
+ - tzdata=2023c=h04d1e81_0
73
+ - urllib3=1.26.16=py310h06a4308_0
74
+ - wheel=0.41.2=py310h06a4308_0
75
+ - xz=5.4.2=h5eee18b_0
76
+ - zlib=1.2.13=h5eee18b_0
77
+ - zstd=1.5.5=hc292b87_0
78
+ - pip:
79
+ - accelerate==0.23.0
80
+ - accelerator==2023.7.18.dev1
81
+ - av==10.0.0
82
+ - beautifulsoup4==4.12.2
83
+ - bottle==0.12.25
84
+ - diffusers==0.21.4
85
+ - filelock==3.12.4
86
+ - fsspec==2023.9.2
87
+ - gdown==4.7.1
88
+ - huggingface-hub==0.17.3
89
+ - imageio==2.31.5
90
+ - importlib-metadata==6.8.0
91
+ - opencv-python==4.8.1.78
92
+ - packaging==23.2
93
+ - psutil==5.9.6
94
+ - pyyaml==6.0.1
95
+ - regex==2023.10.3
96
+ - safetensors==0.4.0
97
+ - setproctitle==1.3.3
98
+ - six==1.16.0
99
+ - soupsieve==2.5
100
+ - tokenizers==0.14.1
101
+ - tqdm==4.66.1
102
+ - transformers==4.34.1
103
+ - waitress==2.1.2
104
+ - zipp==3.17.0
105
+ - moviepy
106
+ - gradio
107
+ prefix: /home/yasjain/miniconda3/envs/llmgd-3
src/app.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import numpy as np
4
+ from gradio_utils import *
5
+
6
+ def image_mod(image):
7
+ return image.rotate(45)
8
+
9
+ import os
10
+
11
+ import sys
12
+ sys.path.insert(1, os.path.join(sys.path[0], '..'))
13
+
14
+
15
+ import cv2
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.functional as F
19
+
20
+
21
+
22
+
23
+ from models.pipelines import TextToVideoSDPipelineSpatialAware
24
+
25
+
26
+
27
+ NUM_POINTS = 3
28
+ NUM_FRAMES = 24
29
+ LARGE_BOX_SIZE = 256
30
+
31
+
32
+ def generate_video(pipe, overall_prompt, latents, get_latents=False, num_frames=24, num_inference_steps=50, fg_masks=None,
33
+ fg_masked_latents=None, frozen_steps=0, frozen_prompt=None, custom_attention_mask=None, fg_prompt=None):
34
+
35
+ video_frames = pipe(overall_prompt, num_frames=num_frames, latents=latents, num_inference_steps=num_inference_steps, frozen_mask=fg_masks,
36
+ frozen_steps=frozen_steps, latents_all_input=fg_masked_latents, frozen_prompt=frozen_prompt, custom_attention_mask=custom_attention_mask, fg_prompt=fg_prompt,
37
+ make_attention_mask_2d=True, attention_mask_block_diagonal=True, height=320, width=576 ).frames
38
+ if get_latents:
39
+ video_latents = pipe(overall_prompt, num_frames=num_frames, latents=latents, num_inference_steps=num_inference_steps, output_type="latent").frames
40
+ return video_frames, video_latents
41
+
42
+ return video_frames
43
+
44
+
45
+ # def generate_bb(prompt, fg_object, aspect_ratio, size, trajectory):
46
+
47
+ # if len(trajectory['layers']) < NUM_POINTS:
48
+ # raise ValueError
49
+ # final_canvas = torch.zeros((NUM_FRAMES,320,576))
50
+
51
+ # bbox_size_x = LARGE_BOX_SIZE if size == "large" else int(LARGE_BOX_SIZE * 0.75) if size == "medium" else LARGE_BOX_SIZE//2
52
+ # bbox_size_y = bbox_size_x if aspect_ratio == "square" else int(bbox_size_x * 0.75) if aspect_ratio == "horizontal" else int(bbox_size_x * 1.25)
53
+
54
+ # bbox_coords = []
55
+ # # TODO add checks for trajectory
56
+ # for t in trajectory['layers']:
57
+ # bbox_coords.append([int(t.sum(axis=-2).argmax()*576/800), int(t.sum(axis=-1)[140:460].argmax())])
58
+ # bbox_coords = np.array(bbox_coords)
59
+ # # Make a list of length 24
60
+ # # Each element is a list of length 2
61
+ # # First element is the x coordinate of the bbox
62
+ # # Second element is a set of y coordinates of the bbox
63
+ # new_bbox_coords = [np.zeros(2,) for i in range(NUM_FRAMES)]
64
+ # divisor = int(NUM_FRAMES / (NUM_POINTS-1))
65
+ # for i in range(NUM_POINTS-1):
66
+ # new_bbox_coords[i*divisor] = bbox_coords[i]
67
+ # new_bbox_coords[-1] = bbox_coords[-1]
68
+
69
+ # # Linearly interpolate in the middle
70
+ # for i in range(NUM_POINTS-1):
71
+ # for j in range(1,divisor):
72
+ # new_bbox_coords[i*divisor+j][1] = int((bbox_coords[i][0] * (divisor-j) + bbox_coords[(i+1)][0] * j) / divisor)
73
+ # new_bbox_coords[i*divisor+j][0] = int((bbox_coords[i][1] * (divisor-j) + bbox_coords[(i+1)][1] * j) / divisor)
74
+
75
+ # for i in range(NUM_FRAMES):
76
+ # x = int(new_bbox_coords[i][0])
77
+ # y = int(new_bbox_coords[i][1])
78
+ # final_canvas[i,int(x-bbox_size_x/2):int(x+bbox_size_x/2), int(y-bbox_size_y/2):int(y+bbox_size_y/2)] = 1
79
+
80
+ # torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
+ # try:
82
+ # pipe = TextToVideoSDPipelineSpatialAware.from_pretrained(
83
+ # "cerspense/zeroscope_v2_576w", torch_dtype=torch.float, variant="fp32").to(torch_device)
84
+ # except:
85
+ # pipe = TextToVideoSDPipelineSpatialAware.from_pretrained(
86
+ # "cerspense/zeroscope_v2_576w", torch_dtype=torch.float, variant="fp32").to(torch_device)
87
+
88
+ # fg_masks = F.interpolate(final_canvas.unsqueeze(1), size=(40,72), mode="nearest").to(torch_device)
89
+
90
+ # # Save fg_masks as images
91
+ # for i in range(NUM_FRAMES):
92
+ # cv2.imwrite(f"./fg_masks/frame_{i:04d}.png", fg_masks[i,0].cpu().numpy()*255)
93
+
94
+
95
+
96
+ # seed = 2
97
+ # random_latents = torch.randn([1, 4, NUM_FRAMES, 40, 72], generator=torch.Generator().manual_seed(seed)).to(torch_device)
98
+ # overall_prompt = f"A realistic lively {prompt}"
99
+ # video_frames = generate_video(pipe, overall_prompt, random_latents, get_latents=False, num_frames=NUM_FRAMES, num_inference_steps=40,
100
+ # fg_masks=fg_masks, fg_masked_latents=None, frozen_steps=2, frozen_prompt=None, fg_prompt=fg_object)
101
+
102
+ # return create_video(video_frames,fps=8, type="final")
103
+
104
+
105
+ def interpolate_points(points, target_length):
106
+ print(points)
107
+ if len(points) == target_length:
108
+ return points
109
+ elif len(points) > target_length:
110
+ # Subsample the points uniformly
111
+ indices = np.round(np.linspace(0, len(points) - 1, target_length)).astype(int)
112
+ return [points[i] for i in indices]
113
+ else:
114
+ # Linearly interpolate to get more points
115
+ interpolated_points = []
116
+ num_points_to_add = target_length - len(points)
117
+ points_added_per_segment = num_points_to_add // (len(points) - 1)
118
+
119
+ for i in range(len(points) - 1):
120
+ start, end = points[i], points[i + 1]
121
+ interpolated_points.append(start)
122
+ for j in range(1, points_added_per_segment + 1):
123
+ fraction = j / (points_added_per_segment + 1)
124
+ new_point = np.round(start + fraction * (end - start))
125
+ interpolated_points.append(new_point)
126
+
127
+ # Add the last point
128
+ interpolated_points.append(points[-1])
129
+
130
+ # If there are still not enough points, add extras at the end
131
+ while len(interpolated_points) < target_length:
132
+ interpolated_points.append(points[-1])
133
+
134
+ return interpolated_points
135
+
136
+
137
+ torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
138
+
139
+
140
+ try:
141
+ pipe = TextToVideoSDPipelineSpatialAware.from_pretrained(
142
+ "cerspense/zeroscope_v2_576w", torch_dtype=torch.float, variant="fp32").to(torch_device)
143
+ except:
144
+ pipe = TextToVideoSDPipelineSpatialAware.from_pretrained(
145
+ "cerspense/zeroscope_v2_576w", torch_dtype=torch.float, variant="fp32").to(torch_device)
146
+
147
+
148
+ def generate_bb(prompt, fg_object, aspect_ratio, size, motion_direction, trajectory):
149
+
150
+ # if len(trajectory['layers']) < NUM_POINTS:
151
+ # raise ValueError
152
+ final_canvas = torch.zeros((NUM_FRAMES,320//8,576//8))
153
+
154
+ bbox_size_x = LARGE_BOX_SIZE if size == "large" else int(LARGE_BOX_SIZE * 0.75) if size == "medium" else LARGE_BOX_SIZE//2
155
+ bbox_size_y = bbox_size_x if aspect_ratio == "square" else int(bbox_size_x * 1.33) if aspect_ratio == "horizontal" else int(bbox_size_x * 0.75)
156
+
157
+ bbox_coords = []
158
+
159
+ image = trajectory['composite']
160
+ print(image.shape)
161
+
162
+ image = cv2.resize(image,(576, 320))
163
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
164
+ _, thresh = cv2.threshold(gray, 30, 255, cv2.THRESH_BINARY_INV)
165
+ contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
166
+
167
+
168
+ # Process each contour
169
+ bbox_points = []
170
+ for contour in contours:
171
+ # You can approximate the contour to reduce the number of points
172
+ epsilon = 0.01 * cv2.arcLength(contour, True)
173
+ approx = cv2.approxPolyDP(contour, epsilon, True)
174
+
175
+ # Extracting and printing coordinates
176
+ for point in approx:
177
+ y, x = point.ravel()
178
+ if x in range(1,319) and y in range(1,575):
179
+ bbox_points.append([x,y])
180
+
181
+ if motion_direction in ['l2r', 'r2l']:
182
+ sorted_points = sorted(bbox_points, key=lambda x: x[1], reverse=motion_direction=="r2l")
183
+ else:
184
+ sorted_points = sorted(bbox_points, key=lambda x: x[0], reverse=motion_direction=="d2u")
185
+ target_length = 24
186
+ final_points = interpolate_points(np.array(sorted_points), target_length)
187
+
188
+ # Remember to reverse the co-ordinates
189
+ for i in range(NUM_FRAMES):
190
+ x = int(final_points[i][0])
191
+ y = int(final_points[i][1])
192
+ # Added Padding
193
+ final_canvas[i, max(int(x-bbox_size_x/2),16) // 8:min(int(x+bbox_size_x/2), 304)// 8,
194
+ max(int(y-bbox_size_y/2),16)// 8:min(int(y+bbox_size_y/2),560)// 8] = 1
195
+
196
+
197
+ torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
198
+ fg_masks = final_canvas.unsqueeze(1).to(torch_device)
199
+ # # Save fg_masks as images
200
+ for i in range(NUM_FRAMES):
201
+ cv2.imwrite(f"./fg_masks/frame_{i:04d}.png", fg_masks[i,0].cpu().numpy()*255)
202
+
203
+ seed = 2
204
+ random_latents = torch.randn([1, 4, NUM_FRAMES, 40, 72], generator=torch.Generator().manual_seed(seed)).to(torch_device)
205
+ overall_prompt = f"A realistic lively {prompt}"
206
+ video_frames = generate_video(pipe, overall_prompt, random_latents, get_latents=False, num_frames=NUM_FRAMES, num_inference_steps=40,
207
+ fg_masks=fg_masks, fg_masked_latents=None, frozen_steps=2, frozen_prompt=None, fg_prompt=fg_object)
208
+
209
+ return create_video(video_frames,fps=8, type="final")
210
+
211
+
212
+
213
+ demo = gr.Interface(
214
+ fn=generate_bb,
215
+ inputs=["text", "text", gr.Radio(choices=["square", "horizontal", "vertical"]), gr.Radio(choices=["small", "medium", "large"]), gr.Radio(choices=["l2r", "r2l", "u2d", "d2u"]),
216
+ gr.Paint(value={'background':np.zeros((320,576)), 'layers': [], 'composite': np.zeros((320,576))},type="numpy", image_mode="RGB", height=320, width=576)],
217
+ outputs=gr.Video(),
218
+ )
219
+
220
+
221
+ if __name__ == "__main__":
222
+ demo.launch(share=True)
src/app_modelscope.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.pipelines import TextToVideoSDPipelineSpatialAware
2
+ import torch.nn.functional as F
3
+ import torch
4
+ import cv2
5
+ import sys
6
+ import gradio as gr
7
+ import os
8
+ import numpy as np
9
+ from gradio_utils import *
10
+
11
+
12
+ def image_mod(image):
13
+ return image.rotate(45)
14
+
15
+
16
+ sys.path.insert(1, os.path.join(sys.path[0], '..'))
17
+
18
+
19
+ NUM_POINTS = 3
20
+ NUM_FRAMES = 16
21
+ LARGE_BOX_SIZE = 176
22
+
23
+
24
+ def generate_video(pipe, overall_prompt, latents, get_latents=False, num_frames=24, num_inference_steps=50, fg_masks=None,
25
+ fg_masked_latents=None, frozen_steps=0, frozen_prompt=None, custom_attention_mask=None, fg_prompt=None):
26
+
27
+ video_frames = pipe(overall_prompt, num_frames=num_frames, latents=latents, num_inference_steps=num_inference_steps, frozen_mask=fg_masks,
28
+ frozen_steps=frozen_steps, latents_all_input=fg_masked_latents, frozen_prompt=frozen_prompt, custom_attention_mask=custom_attention_mask, fg_prompt=fg_prompt,
29
+ make_attention_mask_2d=True, attention_mask_block_diagonal=True, height=256, width=256).frames
30
+ if get_latents:
31
+ video_latents = pipe(overall_prompt, num_frames=num_frames, latents=latents,
32
+ num_inference_steps=num_inference_steps, output_type="latent").frames
33
+ return video_frames, video_latents
34
+
35
+ return video_frames
36
+
37
+
38
+ # def generate_bb(prompt, fg_object, aspect_ratio, size, trajectory):
39
+
40
+ # if len(trajectory['layers']) < NUM_POINTS:
41
+ # raise ValueError
42
+ # final_canvas = torch.zeros((NUM_FRAMES,320,576))
43
+
44
+ # bbox_size_x = LARGE_BOX_SIZE if size == "large" else int(LARGE_BOX_SIZE * 0.75) if size == "medium" else LARGE_BOX_SIZE//2
45
+ # bbox_size_y = bbox_size_x if aspect_ratio == "square" else int(bbox_size_x * 0.75) if aspect_ratio == "horizontal" else int(bbox_size_x * 1.25)
46
+
47
+ # bbox_coords = []
48
+ # # TODO add checks for trajectory
49
+ # for t in trajectory['layers']:
50
+ # bbox_coords.append([int(t.sum(axis=-2).argmax()*576/800), int(t.sum(axis=-1)[140:460].argmax())])
51
+ # bbox_coords = np.array(bbox_coords)
52
+ # # Make a list of length 24
53
+ # # Each element is a list of length 2
54
+ # # First element is the x coordinate of the bbox
55
+ # # Second element is a set of y coordinates of the bbox
56
+ # new_bbox_coords = [np.zeros(2,) for i in range(NUM_FRAMES)]
57
+ # divisor = int(NUM_FRAMES / (NUM_POINTS-1))
58
+ # for i in range(NUM_POINTS-1):
59
+ # new_bbox_coords[i*divisor] = bbox_coords[i]
60
+ # new_bbox_coords[-1] = bbox_coords[-1]
61
+
62
+ # # Linearly interpolate in the middle
63
+ # for i in range(NUM_POINTS-1):
64
+ # for j in range(1,divisor):
65
+ # new_bbox_coords[i*divisor+j][1] = int((bbox_coords[i][0] * (divisor-j) + bbox_coords[(i+1)][0] * j) / divisor)
66
+ # new_bbox_coords[i*divisor+j][0] = int((bbox_coords[i][1] * (divisor-j) + bbox_coords[(i+1)][1] * j) / divisor)
67
+
68
+ # for i in range(NUM_FRAMES):
69
+ # x = int(new_bbox_coords[i][0])
70
+ # y = int(new_bbox_coords[i][1])
71
+ # final_canvas[i,int(x-bbox_size_x/2):int(x+bbox_size_x/2), int(y-bbox_size_y/2):int(y+bbox_size_y/2)] = 1
72
+
73
+ # torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
+ # try:
75
+ # pipe = TextToVideoSDPipelineSpatialAware.from_pretrained(
76
+ # "cerspense/zeroscope_v2_576w", torch_dtype=torch.float, variant="fp32").to(torch_device)
77
+ # except:
78
+ # pipe = TextToVideoSDPipelineSpatialAware.from_pretrained(
79
+ # "cerspense/zeroscope_v2_576w", torch_dtype=torch.float, variant="fp32").to(torch_device)
80
+
81
+ # fg_masks = F.interpolate(final_canvas.unsqueeze(1), size=(40,72), mode="nearest").to(torch_device)
82
+
83
+ # # Save fg_masks as images
84
+ # for i in range(NUM_FRAMES):
85
+ # cv2.imwrite(f"./fg_masks/frame_{i:04d}.png", fg_masks[i,0].cpu().numpy()*255)
86
+
87
+
88
+ # seed = 2
89
+ # random_latents = torch.randn([1, 4, NUM_FRAMES, 40, 72], generator=torch.Generator().manual_seed(seed)).to(torch_device)
90
+ # overall_prompt = f"A realistic lively {prompt}"
91
+ # video_frames = generate_video(pipe, overall_prompt, random_latents, get_latents=False, num_frames=NUM_FRAMES, num_inference_steps=40,
92
+ # fg_masks=fg_masks, fg_masked_latents=None, frozen_steps=2, frozen_prompt=None, fg_prompt=fg_object)
93
+
94
+ # return create_video(video_frames,fps=8, type="final")
95
+
96
+
97
+ def interpolate_points(points, target_length):
98
+ print(points)
99
+ if len(points) == target_length:
100
+ return points
101
+ elif len(points) > target_length:
102
+ # Subsample the points uniformly
103
+ indices = np.round(np.linspace(
104
+ 0, len(points) - 1, target_length)).astype(int)
105
+ return [points[i] for i in indices]
106
+ else:
107
+ # Linearly interpolate to get more points
108
+ interpolated_points = []
109
+ num_points_to_add = target_length - len(points)
110
+ points_added_per_segment = num_points_to_add // (len(points) - 1)
111
+
112
+ for i in range(len(points) - 1):
113
+ start, end = points[i], points[i + 1]
114
+ interpolated_points.append(start)
115
+ for j in range(1, points_added_per_segment + 1):
116
+ fraction = j / (points_added_per_segment + 1)
117
+ new_point = np.round(start + fraction * (end - start))
118
+ interpolated_points.append(new_point)
119
+
120
+ # Add the last point
121
+ interpolated_points.append(points[-1])
122
+
123
+ # If there are still not enough points, add extras at the end
124
+ while len(interpolated_points) < target_length:
125
+ interpolated_points.append(points[-1])
126
+
127
+ return interpolated_points
128
+
129
+
130
+ torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
131
+
132
+
133
+ try:
134
+ pipe = TextToVideoSDPipelineSpatialAware.from_pretrained(
135
+ "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float, variant="fp32").to(torch_device)
136
+ except:
137
+ pipe = TextToVideoSDPipelineSpatialAware.from_pretrained(
138
+ "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float, variant="fp32").to(torch_device)
139
+
140
+
141
+ def generate_bb(prompt, fg_object, aspect_ratio, size, motion_direction, seed, peekaboo_steps, trajectory):
142
+
143
+ if not set(fg_object.split()).issubset(set(prompt.split())):
144
+ raise gr.Error("Foreground object should be present in the video prompt")
145
+ # if len(trajectory['layers']) < NUM_POINTS:
146
+ # raise ValueError
147
+ final_canvas = torch.zeros((NUM_FRAMES, 256//8, 256//8))
148
+
149
+ bbox_size_x = LARGE_BOX_SIZE if size == "large" else int(
150
+ LARGE_BOX_SIZE * 0.75) if size == "medium" else LARGE_BOX_SIZE//2
151
+ bbox_size_y = bbox_size_x if aspect_ratio == "square" else int(
152
+ bbox_size_x * 1.33) if aspect_ratio == "horizontal" else int(bbox_size_x * 0.75)
153
+
154
+ bbox_coords = []
155
+
156
+ image = trajectory['composite']
157
+ print(image.shape)
158
+
159
+ image = cv2.resize(image, (256, 256))
160
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
161
+ _, thresh = cv2.threshold(gray, 30, 255, cv2.THRESH_BINARY_INV)
162
+ contours, _ = cv2.findContours(
163
+ thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
164
+
165
+ # Process each contour
166
+ bbox_points = []
167
+ for contour in contours:
168
+ # You can approximate the contour to reduce the number of points
169
+ epsilon = 0.01 * cv2.arcLength(contour, True)
170
+ approx = cv2.approxPolyDP(contour, epsilon, True)
171
+
172
+ # Extracting and printing coordinates
173
+ for point in approx:
174
+ y, x = point.ravel()
175
+ if x in range(1, 255) and y in range(1, 255):
176
+ # bbox_points.append([min(max(x, 32), 256-32),min(max(y, 32), 256-32)])
177
+ bbox_points.append([min(max(x, 0), 256), min(max(y, 0), 256)])
178
+
179
+ if motion_direction in ['Left to Right', 'Right to Left']:
180
+ sorted_points = sorted(
181
+ bbox_points, key=lambda x: x[1], reverse=motion_direction == "Right to Left")
182
+ else:
183
+ sorted_points = sorted(
184
+ bbox_points, key=lambda x: x[0], reverse=motion_direction == "Down to Up")
185
+ target_length = NUM_FRAMES
186
+ final_points = interpolate_points(np.array(sorted_points), target_length)
187
+
188
+ # Remember to reverse the co-ordinates
189
+ for i in range(NUM_FRAMES):
190
+ x = int(final_points[i][0])
191
+ y = int(final_points[i][1])
192
+ # Added Padding
193
+ final_canvas[i, max(int(x-bbox_size_x/2), 0) // 8:min(int(x+bbox_size_x/2), 256) // 8,
194
+ max(int(y-bbox_size_y/2), 0) // 8:min(int(y+bbox_size_y/2), 256) // 8] = 1
195
+
196
+ torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
197
+ fg_masks = final_canvas.unsqueeze(1).to(torch_device)
198
+ # # Save fg_masks as images
199
+ for i in range(NUM_FRAMES):
200
+ cv2.imwrite(f"./fg_masks/frame_{i:04d}.png",
201
+ fg_masks[i, 0].cpu().numpy()*255)
202
+
203
+ seed = seed
204
+ random_latents = torch.randn([1, 4, NUM_FRAMES, 32, 32], generator=torch.Generator(
205
+ ).manual_seed(seed)).to(torch_device)
206
+ overall_prompt = f"{prompt} , high quality"
207
+ video_frames = generate_video(pipe, overall_prompt, random_latents, get_latents=False, num_frames=NUM_FRAMES, num_inference_steps=40,
208
+ fg_masks=fg_masks, fg_masked_latents=None, frozen_steps=int(peekaboo_steps), frozen_prompt=None, fg_prompt=fg_object)
209
+ video_frames_original = generate_video(pipe, overall_prompt, random_latents, get_latents=False, num_frames=NUM_FRAMES, num_inference_steps=40,
210
+ fg_masks=None, fg_masked_latents=None, frozen_steps=0, frozen_prompt=None, fg_prompt=None)
211
+
212
+ return create_video(video_frames_original, fps=8, type="modelscope"), create_video(video_frames, fps=8, type="final")
213
+
214
+
215
+ instructions_md = """
216
+ ## Usage Instructions
217
+ - **Video Prompt**: Enter a brief description of the scene you want to generate.
218
+ - **Foreground Object**: Specify the main object in the video.
219
+ - **Aspect Ratio**: Choose the aspect ratio for the bounding box.
220
+ - **Size of the Bounding Box**: Select how large the foreground object should be.
221
+ - **Trajectory of the Bounding Box**: Draw the trajectory of the bounding box.
222
+ - **Motion Direction**: Indicate the direction of movement for the object.
223
+ - **Geek Settings**: Advanced settings for fine-tuning (optional).
224
+ - **Generate Video**: Click the button to create your video.
225
+
226
+ Feel free to experiment with different settings to see how they affect the output!
227
+ """
228
+
229
+ with gr.Blocks() as demo:
230
+ gr.Markdown("""
231
+ # Peekaboo Demo
232
+ """)
233
+ with gr.Row():
234
+ video_1 = gr.Video(label="Original Modelscope Video")
235
+ video_2 = gr.Video(label="Peekaboo Video")
236
+
237
+
238
+ with gr.Accordion(label="Usage Instructions", open=False):
239
+ gr.Markdown(instructions_md)
240
+ with gr.Group("User Input"):
241
+ txt_1 = gr.Textbox(lines=1, label="Video Prompt", value="Darth Vader surfing on some waves")
242
+ txt_2 = gr.Textbox(lines=1, label="Foreground Object in the Video Prompt", value="Darth Vader")
243
+ aspect_ratio = gr.Radio(choices=["square", "horizontal", "vertical"], label="Aspect Ratio", value="vertical")
244
+ trajectory = gr.Paint(value={'background': np.zeros((256, 256)), 'layers': [], 'composite': np.zeros((256, 256))}, type="numpy", image_mode="RGB", height=256, width=256, label="Trajectory of the Bounding Box")
245
+ size = gr.Radio(choices=["small", "medium", "large"], label="Size of the Bounding Box", value="medium")
246
+ motion_direction = gr.Radio(choices=["Left to Right", "Right to Left", "Up to Down", "Down to Up"], label="Motion Direction", value="Left to Right")
247
+
248
+ with gr.Accordion(label="Geek settings", open=False):
249
+ with gr.Group():
250
+ seed = gr.Slider(0, 10, step=1., value=2, label="Seed")
251
+ peekaboo_steps = gr.Slider(0, 20, step=1., value=2, label="Number of Peekaboo Steps")
252
+
253
+
254
+ btn = gr.Button(value="Generate Video")
255
+
256
+ btn.click(generate_bb, inputs=[txt_1, txt_2, aspect_ratio, size, motion_direction, seed, peekaboo_steps, trajectory], outputs=[video_1, video_2])
257
+
258
+
259
+
260
+
261
+ if __name__ == "__main__":
262
+ demo.launch(share=True)
src/generation.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import sys
4
+ sys.path.insert(1, os.path.join(sys.path[0], '..'))
5
+
6
+ import warnings
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import tqdm
11
+ import torch
12
+ import torch.nn.functional as F
13
+ import torchvision.io as vision_io
14
+
15
+
16
+
17
+ from models.pipelines import TextToVideoSDPipelineSpatialAware
18
+ from diffusers.utils import export_to_video
19
+ from PIL import Image
20
+ import torchvision
21
+
22
+
23
+
24
+ import warnings
25
+ warnings.filterwarnings("ignore")
26
+
27
+ OUTPUT_PATH = "/scr/demo"
28
+
29
+ def generate_video(pipe, overall_prompt, latents, get_latents=False, num_frames=24, num_inference_steps=50, fg_masks=None,
30
+ fg_masked_latents=None, frozen_steps=0, frozen_prompt=None, custom_attention_mask=None, fg_prompt=None):
31
+
32
+ video_frames = pipe(overall_prompt, num_frames=num_frames, latents=latents, num_inference_steps=num_inference_steps, frozen_mask=fg_masks,
33
+ frozen_steps=frozen_steps, latents_all_input=fg_masked_latents, frozen_prompt=frozen_prompt, custom_attention_mask=custom_attention_mask, fg_prompt=fg_prompt,
34
+ make_attention_mask_2d=True, attention_mask_block_diagonal=True, height=320, width=576 ).frames
35
+ if get_latents:
36
+ video_latents = pipe(overall_prompt, num_frames=num_frames, latents=latents, num_inference_steps=num_inference_steps, output_type="latent").frames
37
+ return video_frames, video_latents
38
+
39
+ return video_frames
40
+
41
+ def save_frames(path):
42
+ video, audio, video_info = vision_io.read_video(f"{path}.mp4", pts_unit='sec')
43
+
44
+ # Number of frames
45
+ num_frames = video.size(0)
46
+
47
+ # Save each frame
48
+ os.makedirs(f"{path}", exist_ok=True)
49
+ for i in range(num_frames):
50
+ frame = video[i, :, :, :].numpy()
51
+ # Convert from C x H x W to H x W x C and from torch tensor to PIL Image
52
+ # frame = frame.permute(1, 2, 0).numpy()
53
+ img = Image.fromarray(frame.astype('uint8'))
54
+ img.save(f"{path}/frame_{i:04d}.png")
55
+
56
+ if __name__ == "__main__":
57
+ # Example usage
58
+ num_frames = 24
59
+ save_path = "video"
60
+ torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
+ random_latents = torch.randn([1, 4, num_frames, 40, 72], generator=torch.Generator().manual_seed(2)).to(torch_device)
62
+
63
+ try:
64
+ pipe = TextToVideoSDPipelineSpatialAware.from_pretrained(
65
+ "cerspense/zeroscope_v2_576w", torch_dtype=torch.float, variant="fp32").to(torch_device)
66
+ except:
67
+ pipe = TextToVideoSDPipelineSpatialAware.from_pretrained(
68
+ "cerspense/zeroscope_v2_576w", torch_dtype=torch.float, variant="fp32").to(torch_device)
69
+
70
+ # Generate video
71
+
72
+
73
+ bbox_mask = torch.zeros([24, 1, 40, 72], device=torch_device)
74
+ bbox_mask_2 = torch.zeros([24, 1, 40, 72], device=torch_device)
75
+
76
+
77
+ x_start = [10 + (i % 3) for i in range(num_frames)] # Simulating slight movement in x
78
+ x_end = [30 + (i % 3) for i in range(num_frames)] # Simulating slight movement in x
79
+ y_start = [10 for _ in range(num_frames)] # Static y start as the bear is seated/standing
80
+ y_end = [25 for _ in range(num_frames)] # Static y end, considering the size of the guitar
81
+
82
+ # Populate the bbox_mask tensor with ones where the bounding box is located
83
+ for i in range(num_frames):
84
+ bbox_mask[i, :, x_start[i]:x_end[i], y_start[i]:y_end[i]] = 1
85
+ bbox_mask_2[i, :, x_start[i]:x_end[i], 72-y_end[i]:72-y_start[i]] = 1
86
+
87
+ # fg_masks = bbox_mask
88
+ fg_masks = [bbox_mask, bbox_mask_2]
89
+
90
+
91
+
92
+ frozen_prompt = None
93
+ fg_masked_latents = None
94
+ fg_objects = []
95
+ prompts = []
96
+ prompts = [
97
+ (["cat", "goldfish bowl"], "A cat curiously staring at a goldfish bowl on a sunny windowsill."),
98
+ (["Superman", "Batman"], "Superman and Batman standing side by side in a heroic pose against a city skyline."),
99
+ (["rose", "daisy"], "A rose and a daisy in a small vase on a rustic wooden table."),
100
+ (["Harry Potter", "Hermione Granger"], "Harry Potter and Hermione Granger studying a magical map."),
101
+ (["butterfly", "dragonfly"], "A butterfly and a dragonfly resting on a leaf in a vibrant garden."),
102
+ (["teddy bear", "toy train"], "A teddy bear and a toy train on a child's playmat in a brightly lit room."),
103
+ (["frog", "turtle"], "A frog and a turtle sitting on a lily pad in a serene pond."),
104
+ (["Mickey Mouse", "Donald Duck"], "Mickey Mouse and Donald Duck enjoying a day at the beach, building a sandcastle."),
105
+ (["penguin", "seal"], "A penguin and a seal lounging on an iceberg in the Antarctic."),
106
+ (["lion", "zebra"], "A lion and a zebra peacefully drinking water from the same pond in the savannah.")
107
+ ]
108
+
109
+ for fg_object, overall_prompt in prompts:
110
+ os.makedirs(f"{OUTPUT_PATH}/{save_path}/{overall_prompt}-mask", exist_ok=True)
111
+ try:
112
+ for i in range(num_frames):
113
+ torchvision.utils.save_image(fg_masks[0][i,0], f"{OUTPUT_PATH}/{save_path}/{overall_prompt}-mask/frame_{i:04d}_0.png")
114
+ torchvision.utils.save_image(fg_masks[1][i,0], f"{OUTPUT_PATH}/{save_path}/{overall_prompt}-mask/frame_{i:04d}_1.png")
115
+ except:
116
+ pass
117
+ print(fg_object, overall_prompt)
118
+ seed = 2
119
+ random_latents = torch.randn([1, 4, num_frames, 40, 72], generator=torch.Generator().manual_seed(seed)).to(torch_device)
120
+ for num_inference_steps in range(40,50,10):
121
+ for frozen_steps in [0, 1, 2]:
122
+ video_frames = generate_video(pipe, overall_prompt, random_latents, get_latents=False, num_frames=num_frames, num_inference_steps=num_inference_steps,
123
+ fg_masks=fg_masks, fg_masked_latents=fg_masked_latents, frozen_steps=frozen_steps, frozen_prompt=frozen_prompt, fg_prompt=fg_object)
124
+ # Save video frames
125
+ os.makedirs(f"{OUTPUT_PATH}/{save_path}/{overall_prompt}", exist_ok=True)
126
+ video_path = export_to_video(video_frames, f"{OUTPUT_PATH}/{save_path}/{overall_prompt}/{frozen_steps}_of_{num_inference_steps}_{seed}_masked.mp4")
127
+ save_frames(f"{OUTPUT_PATH}/{save_path}/{overall_prompt}/{frozen_steps}_of_{num_inference_steps}_{seed}_masked")
128
+
src/gradio_utils.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image
6
+ from moviepy.editor import *
7
+
8
+ def get_frames(video_in):
9
+ frames = []
10
+ #resize the video
11
+ clip = VideoFileClip(video_in)
12
+
13
+ #check fps
14
+ if clip.fps > 30:
15
+ print("vide rate is over 30, resetting to 30")
16
+ clip_resized = clip.resize(height=512)
17
+ clip_resized.write_videofile("video_resized.mp4", fps=30)
18
+ else:
19
+ print("video rate is OK")
20
+ clip_resized = clip.resize(height=512)
21
+ clip_resized.write_videofile("video_resized.mp4", fps=clip.fps)
22
+
23
+ print("video resized to 512 height")
24
+
25
+ # Opens the Video file with CV2
26
+ cap= cv2.VideoCapture("video_resized.mp4")
27
+
28
+ fps = cap.get(cv2.CAP_PROP_FPS)
29
+ print("video fps: " + str(fps))
30
+ i=0
31
+ while(cap.isOpened()):
32
+ ret, frame = cap.read()
33
+ if ret == False:
34
+ break
35
+ cv2.imwrite('kang'+str(i)+'.jpg',frame)
36
+ frames.append('kang'+str(i)+'.jpg')
37
+ i+=1
38
+
39
+ cap.release()
40
+ cv2.destroyAllWindows()
41
+ print("broke the video into frames")
42
+
43
+ return frames, fps
44
+
45
+
46
+ def convert(gif):
47
+ if gif != None:
48
+ clip = VideoFileClip(gif.name)
49
+ clip.write_videofile("my_gif_video.mp4")
50
+ return "my_gif_video.mp4"
51
+ else:
52
+ pass
53
+
54
+
55
+ def create_video(frames, fps, type):
56
+ print("building video result")
57
+ clip = ImageSequenceClip(frames, fps=fps)
58
+ clip.write_videofile(type + "_result.mp4", fps=fps)
59
+
60
+ return type + "_result.mp4"
src/image_generation.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # Modified from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py
4
+ from transformers import pipeline
5
+ import torchvision
6
+ from PIL import Image
7
+ from models.t2i_pipeline import StableDiffusionPipelineSpatialAware
8
+ import torchvision.io as vision_io
9
+ import torch.nn.functional as F
10
+ import torch
11
+ import tqdm
12
+ import numpy as np
13
+ import cv2
14
+ import warnings
15
+ import time
16
+ import tempfile
17
+ import argparse
18
+ import glob
19
+ import multiprocessing as mp
20
+ import os
21
+ import random
22
+
23
+ # fmt: off
24
+ import sys
25
+ sys.path.insert(1, os.path.join(sys.path[0], '..'))
26
+ # fmt: on
27
+
28
+
29
+ warnings.filterwarnings("ignore")
30
+
31
+ # constants
32
+ WINDOW_NAME = "demo"
33
+
34
+
35
+ def generate_image(pipe, overall_prompt, latents, get_latents=False, num_inference_steps=50, fg_masks=None,
36
+ fg_masked_latents=None, frozen_steps=0, frozen_prompt=None, custom_attention_mask=None, fg_prompt=None):
37
+ '''
38
+ Main function that calls the image diffusion model
39
+ latent: input_noise from where it starts the generation
40
+ get_latents: if True, returns the latents for each frame
41
+ '''
42
+
43
+ image = pipe(overall_prompt, latents=latents, num_inference_steps=num_inference_steps, frozen_mask=fg_masks,
44
+ frozen_steps=frozen_steps, latents_all_input=fg_masked_latents, frozen_prompt=frozen_prompt, custom_attention_mask=custom_attention_mask, output_type='pil',
45
+ fg_prompt=fg_prompt, make_attention_mask_2d=True, attention_mask_block_diagonal=True).images[0]
46
+ torch.save(image, "img.pt")
47
+
48
+ if get_latents:
49
+ video_latents = pipe(overall_prompt, latents=latents,
50
+ num_inference_steps=num_inference_steps, output_type="latent").images
51
+ torch.save(video_latents, "img_latents.pt")
52
+ return image, video_latents
53
+
54
+ return image
55
+
56
+
57
+ def save_frames(path):
58
+ video, audio, video_info = vision_io.read_video(
59
+ f"demo3/{path}.mp4", pts_unit='sec')
60
+
61
+ # Number of frames
62
+ num_frames = video.size(0)
63
+
64
+ # Save each frame
65
+ os.makedirs(f"demo3/{path}", exist_ok=True)
66
+ for i in range(num_frames):
67
+ frame = video[i, :, :, :].numpy()
68
+ # Convert from C x H x W to H x W x C and from torch tensor to PIL Image
69
+ # frame = frame.permute(1, 2, 0).numpy()
70
+ img = Image.fromarray(frame.astype('uint8'))
71
+ img.save(f"demo3/{path}/frame_{i:04d}.png")
72
+
73
+
74
+ def create_boxes():
75
+ img_width = 96
76
+ img_height = 96
77
+
78
+ # initialize bboxes list
79
+ sbboxes = []
80
+
81
+ # object dimensions
82
+ for object_size in [20, 30, 40, 50, 60]:
83
+ obj_width, obj_height = object_size, object_size
84
+
85
+ # starting position
86
+ start_x = 3
87
+ start_y = 4
88
+
89
+ # calculate total size occupied by the objects in the grid
90
+ total_obj_width = 3 * obj_width
91
+ total_obj_height = 3 * obj_height
92
+
93
+ # determine horizontal and vertical spacings
94
+ spacing_horizontal = (img_width - total_obj_width - start_x) // 2
95
+ spacing_vertical = (img_height - total_obj_height - start_y) // 2
96
+
97
+ for i in range(3):
98
+ for j in range(3):
99
+ x_start = start_x + i * (obj_width + spacing_horizontal)
100
+ y_start = start_y + j * (obj_height + spacing_vertical)
101
+ # Corrected to img_width to include the last pixel
102
+ x_end = min(x_start + obj_width, img_width)
103
+ # Corrected to img_height to include the last pixel
104
+ y_end = min(y_start + obj_height, img_height)
105
+ sbboxes.append([x_start, y_start, x_end, y_end])
106
+
107
+ mask_id = 0
108
+ masks_list = []
109
+
110
+ for sbbox in sbboxes:
111
+ smask = torch.zeros(1, 1, 96, 96)
112
+ smask[0, 0, sbbox[1]:sbbox[3], sbbox[0]:sbbox[2]] = 1.0
113
+ masks_list.append(smask)
114
+ # torchvision.utils.save_image(smask, f"{SAVE_DIR}/masks/mask_{mask_id}.png") # save masks as images
115
+ mask_id += 1
116
+
117
+ return masks_list
118
+
119
+
120
+ def objects_list():
121
+ objects_settings = [
122
+ ("apple", "on a table"),
123
+ ("ball", "in a park"),
124
+ ("cat", "on a couch"),
125
+ ("dog", "in a backyard"),
126
+ ("elephant", "in a jungle"),
127
+ ("fountain pen", "on a desk"),
128
+ ("guitar", "on a stage"),
129
+ ("helicopter", "in the sky"),
130
+ ("island", "in the sea"),
131
+ ("jar", "on a shelf"),
132
+ ("kite", "in the sky"),
133
+ ("lamp", "in a room"),
134
+ ("motorbike", "on a road"),
135
+ ("notebook", "on a table"),
136
+ ("owl", "on a tree"),
137
+ ("piano", "in a hall"),
138
+ ("queen", "in a castle"),
139
+ ("robot", "in a lab"),
140
+ ("snake", "in a forest"),
141
+ ("tent", "in the mountains"),
142
+ ("umbrella", "on a beach"),
143
+ ("violin", "in an orchestra"),
144
+ ("wheel", "in a garage"),
145
+ ("xylophone", "in a music class"),
146
+ ("yacht", "in a marina"),
147
+ ("zebra", "in a savannah"),
148
+ ("aeroplane", "in the clouds"),
149
+ ("bridge", "over a river"),
150
+ ("computer", "in an office"),
151
+ ("dragon", "in a cave"),
152
+ ("egg", "in a nest"),
153
+ ("flower", "in a garden"),
154
+ ("globe", "in a library"),
155
+ ("hat", "on a rack"),
156
+ ("ice cube", "in a glass"),
157
+ ("jewelry", "in a box"),
158
+ ("kangaroo", "in a desert"),
159
+ ("lion", "in a den"),
160
+ ("mug", "on a counter"),
161
+ ("nest", "on a branch"),
162
+ ("octopus", "in the ocean"),
163
+ ("parrot", "in a rainforest"),
164
+ ("quilt", "on a bed"),
165
+ ("rose", "in a vase"),
166
+ ("ship", "in a dock"),
167
+ ("train", "on the tracks"),
168
+ ("utensils", "in a kitchen"),
169
+ ("vase", "on a window sill"),
170
+ ("watch", "in a store"),
171
+ ("x-ray", "in a hospital"),
172
+ ("yarn", "in a basket"),
173
+ ("zeppelin", "above a city"),
174
+ ]
175
+ objects_settings.extend([
176
+ ("muffin", "on a bakery shelf"),
177
+ ("notebook", "on a student's desk"),
178
+ ("owl", "in a tree"),
179
+ ("piano", "in a concert hall"),
180
+ ("quill", "on parchment"),
181
+ ("robot", "in a factory"),
182
+ ("snake", "in the grass"),
183
+ ("telescope", "in an observatory"),
184
+ ("umbrella", "at the beach"),
185
+ ("violin", "in an orchestra"),
186
+ ("whale", "in the ocean"),
187
+ ("xylophone", "in a music store"),
188
+ ("yacht", "in a marina"),
189
+ ("zebra", "on a savanna"),
190
+
191
+ # Kitchen items
192
+ ("spoon", "in a drawer"),
193
+ ("plate", "in a cupboard"),
194
+ ("cup", "on a shelf"),
195
+ ("frying pan", "on a stove"),
196
+ ("jar", "in the refrigerator"),
197
+
198
+ # Office items
199
+ ("computer", "in an office"),
200
+ ("printer", "by a desk"),
201
+ ("chair", "around a conference table"),
202
+ ("lamp", "on a workbench"),
203
+ ("calendar", "on a wall"),
204
+
205
+ # Outdoor items
206
+ ("bicycle", "on a street"),
207
+ ("tent", "in a campsite"),
208
+ ("fire", "in a fireplace"),
209
+ ("mountain", "in the distance"),
210
+ ("river", "through the woods"),
211
+
212
+
213
+ # and so on ...
214
+ ])
215
+
216
+ # To expedite the generation, you can combine themes and objects:
217
+
218
+ themes = [
219
+ ("wild animals", ["tiger", "lion", "cheetah",
220
+ "giraffe", "hippopotamus"], "in the wild"),
221
+ ("household items", ["sofa", "tv", "clock",
222
+ "vase", "photo frame"], "in a living room"),
223
+ ("clothes", ["shirt", "pants", "shoes",
224
+ "hat", "jacket"], "in a wardrobe"),
225
+ ("musical instruments", ["drum", "trumpet",
226
+ "harp", "saxophone", "tuba"], "in a band"),
227
+ ("cosmic entities", ["planet", "star",
228
+ "comet", "nebula", "asteroid"], "in space"),
229
+ # ... add more themes
230
+ ]
231
+
232
+ # Using the themes to extend our list
233
+ for theme_name, theme_objects, theme_location in themes:
234
+ for theme_object in theme_objects:
235
+ objects_settings.append((theme_object, theme_location))
236
+
237
+ # Sports equipment
238
+ objects_settings.extend([
239
+ ("basketball", "on a court"),
240
+ ("golf ball", "on a golf course"),
241
+ ("tennis racket", "on a tennis court"),
242
+ ("baseball bat", "in a stadium"),
243
+ ("hockey stick", "on an ice rink"),
244
+ ("football", "on a field"),
245
+ ("skateboard", "in a skatepark"),
246
+ ("boxing gloves", "in a boxing ring"),
247
+ ("ski", "on a snowy slope"),
248
+ ("surfboard", "on a beach shore"),
249
+ ])
250
+
251
+ # Toys and games
252
+ objects_settings.extend([
253
+ ("teddy bear", "on a child's bed"),
254
+ ("doll", "in a toy store"),
255
+ ("toy car", "on a carpet"),
256
+ ("board game", "on a table"),
257
+ ("yo-yo", "in a child's hand"),
258
+ ("kite", "in the sky on a windy day"),
259
+ ("Lego bricks", "on a construction table"),
260
+ ("jigsaw puzzle", "partially completed"),
261
+ ("rubik's cube", "on a shelf"),
262
+ ("action figure", "on display"),
263
+ ])
264
+
265
+ # Transportation
266
+ objects_settings.extend([
267
+ ("bus", "at a bus stop"),
268
+ ("motorcycle", "on a road"),
269
+ ("helicopter", "landing on a pad"),
270
+ ("scooter", "on a sidewalk"),
271
+ ("train", "at a station"),
272
+ ("bicycle", "parked by a post"),
273
+ ("boat", "in a harbor"),
274
+ ("tractor", "on a farm"),
275
+ ("airplane", "taking off from a runway"),
276
+ ("submarine", "below sea level"),
277
+ ])
278
+
279
+ # Medieval theme
280
+ objects_settings.extend([
281
+ ("castle", "on a hilltop"),
282
+ ("knight", "riding a horse"),
283
+ ("bow and arrow", "in an archery range"),
284
+ ("crown", "in a treasure chest"),
285
+ ("dragon", "flying over mountains"),
286
+ ("shield", "next to a warrior"),
287
+ ("dagger", "on a wooden table"),
288
+ ("torch", "lighting a dark corridor"),
289
+ ("scroll", "sealed with wax"),
290
+ ("cauldron", "with bubbling potion"),
291
+ ])
292
+
293
+ # Modern technology
294
+ objects_settings.extend([
295
+ ("smartphone", "on a charger"),
296
+ ("laptop", "in a cafe"),
297
+ ("headphones", "around a neck"),
298
+ ("camera", "on a tripod"),
299
+ ("drone", "flying over a park"),
300
+ ("USB stick", "plugged into a computer"),
301
+ ("watch", "on a wrist"),
302
+ ("microphone", "on a podcast desk"),
303
+ ("tablet", "with a digital pen"),
304
+ ("VR headset", "ready for gaming"),
305
+ ])
306
+
307
+ # Nature
308
+ objects_settings.extend([
309
+ ("tree", "in a forest"),
310
+ ("flower", "in a garden"),
311
+ ("mountain", "on a horizon"),
312
+ ("cloud", "in a blue sky"),
313
+ ("waterfall", "in a scenic location"),
314
+ ("beach", "next to an ocean"),
315
+ ("cactus", "in a desert"),
316
+ ("volcano", "erupting with lava"),
317
+ ("coral", "under the sea"),
318
+ ("moon", "in a night sky"),
319
+ ])
320
+
321
+ prompts = [f"A {obj} {setting}" for obj, setting in objects_settings]
322
+
323
+ return objects_settings
324
+
325
+
326
+ if __name__ == "__main__":
327
+ SAVE_DIR = "/scr/image/"
328
+ save_path = "img43-att_mask"
329
+ torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
330
+ random_latents = torch.randn(
331
+ [1, 4, 96, 96], generator=torch.Generator().manual_seed(1)).to(torch_device)
332
+
333
+ try:
334
+ pipe = StableDiffusionPipelineSpatialAware.from_pretrained(
335
+ "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float, variant="fp32", cache_dir="/gscratch/scrubbed/anasery/").to(torch_device)
336
+ except:
337
+ pipe = StableDiffusionPipelineSpatialAware.from_pretrained(
338
+ "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float, variant="fp32").to(torch_device)
339
+
340
+ fg_object = "apple" # fg object stores the object to be masked
341
+ # overall prompt stores the prompt
342
+ overall_prompt = f"An {fg_object} on plate"
343
+ os.makedirs(f"{SAVE_DIR}/{overall_prompt}", exist_ok=True)
344
+
345
+ masks_list = create_boxes()
346
+
347
+ # torch.save(f"{overall_prompt}+masked", "prompt.pt")
348
+ obj_settings = objects_list() # 166
349
+ for obj_setting in obj_settings[120:]:
350
+ fg_object = obj_setting[0]
351
+ overall_prompt = f"A {obj_setting[0]} {obj_setting[1]}"
352
+ print(overall_prompt)
353
+
354
+ # randomly select 10 numbers from range len of masks_list
355
+ selected_mask_ids = random.sample(range(len(masks_list)), 3)
356
+ for mask_id in selected_mask_ids:
357
+ os.makedirs(
358
+ f"{SAVE_DIR}/{overall_prompt}/mask{mask_id}", exist_ok=True)
359
+ torchvision.utils.save_image(
360
+ masks_list[mask_id][0][0], f"{SAVE_DIR}/{overall_prompt}/mask{mask_id}/mask.png")
361
+ for frozen_steps in range(0, 5):
362
+ img = generate_image(pipe, overall_prompt, random_latents, get_latents=False, num_inference_steps=50, fg_masks=masks_list[mask_id].to(
363
+ torch_device), fg_masked_latents=None, frozen_steps=frozen_steps, frozen_prompt=None, fg_prompt=fg_object)
364
+
365
+ img.save(
366
+ f"{SAVE_DIR}/{overall_prompt}/mask{mask_id}/{frozen_steps}.png")
src/make_image_grid.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import os
3
+ import re
4
+ import tqdm
5
+
6
+
7
+ def create_grid(directory, save_path=None):
8
+ # Get list of image files in the directory
9
+ files = sorted([f for f in os.listdir(directory) if f.endswith('.png')])
10
+
11
+ # Assuming all images are the same size
12
+ img_sample = Image.open(os.path.join(directory, files[0]))
13
+ width, height = img_sample.size
14
+
15
+ # Calculate grid dimensions for 24 images
16
+ grid_width = width * 6 # 6 images in each row
17
+ grid_height = height * 4 # 4 rows
18
+
19
+ # Create new image for the grid
20
+ grid_img = Image.new('RGB', (grid_width, grid_height))
21
+
22
+ for idx, file in enumerate(files):
23
+ img = Image.open(os.path.join(directory, file))
24
+ x = idx % 6 * width # 6 images in each row
25
+ y = idx // 6 * height # 4 rows
26
+ grid_img.paste(img, (x, y))
27
+
28
+ if save_path:
29
+ grid_img.save(f'{save_path}/{directory.split("/")[-1]}_grid.png')
30
+ else:
31
+ grid_img.save(f'{directory}_grid.png')
32
+
33
+
34
+ def list_subdirectories(parent_directory):
35
+ # Regex pattern to match the subdirectory naming convention
36
+ pattern = re.compile(r"\d+_of_\d+_masked1")
37
+
38
+ # List all subdirectories
39
+ subdirs = [d for d in os.listdir(parent_directory) if os.path.isdir(os.path.join(parent_directory, d))]
40
+
41
+ # Filter subdirectories based on naming convention
42
+ matching_subdirs = [d for d in subdirs if pattern.match(d)]
43
+
44
+ return matching_subdirs
45
+
46
+ # List of directories
47
+
48
+ # for prompt in ["A cat walking in a park", "A dog running in a park", " A wooden barrel drifting on a river", "A kite flying in the sky"]:
49
+ # for prompt in ["A car driving on the road"]:
50
+ # try:
51
+ # directories = list_subdirectories(f"demo4/video41-att_mask/{prompt}")
52
+ # except FileNotFoundError:
53
+ # print(f"Directory not found: {prompt}")
54
+ # continue
55
+ # os.makedirs(f"demo4/{prompt}/consolidated_grids", exist_ok=True)
56
+ # for directory in tqdm.tqdm(directories):
57
+ # create_grid(os.path.join(f"demo4/video41-att_mask/{prompt}", directory), save_path=f"demo4/{prompt}/consolidated_grids")
src/models/__init__.py ADDED
File without changes
src/models/attention.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
21
+ from diffusers.models.activations import get_activation
22
+ from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings
23
+ from diffusers.models.lora import LoRACompatibleLinear
24
+
25
+ from .attention_processor import Attention
26
+
27
+ import math
28
+
29
+ @maybe_allow_in_graph
30
+ class GatedSelfAttentionDense(nn.Module):
31
+ def __init__(self, query_dim, context_dim, n_heads, d_head):
32
+ super().__init__()
33
+
34
+ # we need a linear projection since we need cat visual feature and obj feature
35
+ self.linear = nn.Linear(context_dim, query_dim)
36
+
37
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
38
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
39
+
40
+ self.norm1 = nn.LayerNorm(query_dim)
41
+ self.norm2 = nn.LayerNorm(query_dim)
42
+
43
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
44
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
45
+
46
+ self.enabled = True
47
+
48
+ def forward(self, x, objs):
49
+ if not self.enabled:
50
+ return x
51
+
52
+ n_visual = x.shape[1]
53
+ objs = self.linear(objs)
54
+
55
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
56
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
57
+
58
+ return x
59
+
60
+
61
+ @maybe_allow_in_graph
62
+ class BasicTransformerBlock(nn.Module):
63
+ r"""
64
+ A basic Transformer block.
65
+
66
+ Parameters:
67
+ dim (`int`): The number of channels in the input and output.
68
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
69
+ attention_head_dim (`int`): The number of channels in each head.
70
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
71
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
72
+ only_cross_attention (`bool`, *optional*):
73
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
74
+ double_self_attention (`bool`, *optional*):
75
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
76
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
77
+ num_embeds_ada_norm (:
78
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
79
+ attention_bias (:
80
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ dim: int,
86
+ num_attention_heads: int,
87
+ attention_head_dim: int,
88
+ dropout=0.0,
89
+ cross_attention_dim: Optional[int] = None,
90
+ activation_fn: str = "geglu",
91
+ num_embeds_ada_norm: Optional[int] = None,
92
+ attention_bias: bool = False,
93
+ only_cross_attention: bool = False,
94
+ double_self_attention: bool = False,
95
+ upcast_attention: bool = False,
96
+ norm_elementwise_affine: bool = True,
97
+ norm_type: str = "layer_norm",
98
+ final_dropout: bool = False,
99
+ attention_type: str = "default",
100
+ ):
101
+ super().__init__()
102
+ self.only_cross_attention = only_cross_attention
103
+
104
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
105
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
106
+
107
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
108
+ raise ValueError(
109
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
110
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
111
+ )
112
+
113
+ # Define 3 blocks. Each block has its own normalization layer.
114
+ # 1. Self-Attn
115
+ if self.use_ada_layer_norm:
116
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
117
+ elif self.use_ada_layer_norm_zero:
118
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
119
+ else:
120
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
121
+ self.attn1 = Attention(
122
+ query_dim=dim,
123
+ heads=num_attention_heads,
124
+ dim_head=attention_head_dim,
125
+ dropout=dropout,
126
+ bias=attention_bias,
127
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
128
+ upcast_attention=upcast_attention,
129
+ )
130
+
131
+ # 2. Cross-Attn
132
+ if cross_attention_dim is not None or double_self_attention:
133
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
134
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
135
+ # the second cross attention block.
136
+ self.norm2 = (
137
+ AdaLayerNorm(dim, num_embeds_ada_norm)
138
+ if self.use_ada_layer_norm
139
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
140
+ )
141
+ self.attn2 = Attention(
142
+ query_dim=dim,
143
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
144
+ heads=num_attention_heads,
145
+ dim_head=attention_head_dim,
146
+ dropout=dropout,
147
+ bias=attention_bias,
148
+ upcast_attention=upcast_attention,
149
+ ) # is self-attn if encoder_hidden_states is none
150
+ else:
151
+ self.norm2 = None
152
+ self.attn2 = None
153
+
154
+ # 3. Feed-forward
155
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
156
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
157
+
158
+ # 4. Fuser
159
+ if attention_type == "gated" or attention_type == "gated-text-image":
160
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
161
+
162
+ # let chunk size default to None
163
+ self._chunk_size = None
164
+ self._chunk_dim = 0
165
+
166
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
167
+ # Sets chunk feed-forward
168
+ self._chunk_size = chunk_size
169
+ self._chunk_dim = dim
170
+
171
+ def forward(
172
+ self,
173
+ hidden_states: torch.FloatTensor,
174
+ attention_mask: Optional[torch.FloatTensor] = None,
175
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
176
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
177
+ timestep: Optional[torch.LongTensor] = None,
178
+ cross_attention_kwargs: Dict[str, Any] = None,
179
+ class_labels: Optional[torch.LongTensor] = None,
180
+ **kwargs,
181
+ ):
182
+ # Notice that normalization is always applied before the real computation in the following blocks.
183
+
184
+ if attention_mask is not None and not isinstance(attention_mask, list):
185
+ if attention_mask is not None and hidden_states.shape[1] != attention_mask.shape[-1]:
186
+ tmp = attention_mask.clone()
187
+ scale_factor = int(math.sqrt(attention_mask.shape[-1] // hidden_states.shape[1]))
188
+ try:
189
+ tmp = tmp.reshape(tmp.shape[0], 40, 72)
190
+ except:
191
+ try:
192
+ tmp = tmp.reshape(tmp.shape[0], 32, 32) # MSR-VTT
193
+ except:
194
+ tmp = tmp.reshape(tmp.shape[0], 96, 96)
195
+ tmp = tmp[:, ::scale_factor, ::scale_factor]
196
+ tmp = tmp.reshape(tmp.shape[0], 1, -1)
197
+ attention_mask = tmp
198
+
199
+ if attention_mask is not None:
200
+ tmp = attention_mask.clone()
201
+ tmp = tmp.view(tmp.shape[0], -1,1)/(-10000)
202
+ tmp = (1-tmp)
203
+ orig_attn_mask = attention_mask.clone()
204
+ else:
205
+ # tmp = 0
206
+ tmp =1
207
+ orig_attn_mask = None
208
+
209
+ if attention_mask is not None and 'make_2d_attention_mask' in kwargs and kwargs['make_2d_attention_mask'] == True:
210
+ # We broadcast and take element wise AND. Note that addition is equivalent to AND here, since we are dealing with -10000 and 0.
211
+ attention_mask_2d = attention_mask + attention_mask.permute(0,2,1)
212
+ # Get it back to original range. This step is optional tbh
213
+ attention_mask_2d = torch.where(attention_mask_2d < 0., -10000, 0).type(attention_mask.dtype)
214
+
215
+ if 'block_diagonal_attention' in kwargs and kwargs['block_diagonal_attention'] == True:
216
+ tmp_attention = torch.where(attention_mask < 0., 0., -10000.) # allow background
217
+ tmp_attention = tmp_attention + tmp_attention.permute(0,2,1)
218
+ tmp_attention = torch.where(tmp_attention < 0., -10000, 0)
219
+ attention_mask_2d = attention_mask_2d * tmp_attention
220
+ attention_mask_2d = torch.where(attention_mask_2d.abs() < 1.,0., -10000.).type(attention_mask.dtype)
221
+ attention_mask = attention_mask_2d
222
+
223
+
224
+ # Multiple objects
225
+ elif attention_mask is not None and isinstance(attention_mask, list):
226
+ if hidden_states.shape[1] != attention_mask[0].shape[-1]:
227
+ new_attention_mask = []
228
+ for attn_mask in attention_mask:
229
+ tmp = attn_mask.clone()
230
+ scale_factor = int(math.sqrt(attn_mask.shape[-1] // hidden_states.shape[1]))
231
+ try:
232
+ tmp = tmp.reshape(tmp.shape[0], 40, 72)
233
+ except:
234
+ tmp = tmp.reshape(tmp.shape[0], 32, 32)
235
+ tmp = tmp[:, ::scale_factor, ::scale_factor]
236
+ tmp = tmp.reshape(tmp.shape[0], 1, -1)
237
+ new_attention_mask.append(tmp)
238
+ attention_mask = new_attention_mask
239
+
240
+ orig_attn_mask = []
241
+ for attn_mask in attention_mask:
242
+ tmp = attn_mask.clone()
243
+
244
+ tmp = tmp.view(tmp.shape[0], -1,1)/(-10000)
245
+ tmp = (1-tmp)
246
+
247
+ orig_attn_mask.append(attn_mask.clone())
248
+
249
+
250
+ if 'make_2d_attention_mask' in kwargs and kwargs['make_2d_attention_mask'] == True:
251
+ # We broadcast and take element wise AND. Note that addition is equivalent to AND here, since we are dealing with -10000 and 0.
252
+ attn_mask_2d = []
253
+ for attn_mask in attention_mask:
254
+ attention_mask_2d = attn_mask + attn_mask.permute(0,2,1)
255
+ # Get it back to original range. This step is optional tbh
256
+ attention_mask_2d = torch.where(attention_mask_2d < 0., -10000, 0).type(attn_mask.dtype)
257
+ attn_mask_2d.append(attention_mask_2d)
258
+ attention_mask_2d = torch.prod(torch.stack(attn_mask_2d, dim=0), dim=0)
259
+ attention_mask_2d = torch.where(attention_mask_2d.abs() < 1.,0., -10000.).type(attn_mask.dtype)
260
+ if 'block_diagonal_attention' in kwargs and kwargs['block_diagonal_attention'] == True:
261
+ tmp_attention = torch.where(torch.prod(torch.stack(attention_mask,dim=0),dim=0).abs() < 1., -10000., 0.) # Check this well
262
+ tmp_attention = tmp_attention + tmp_attention.permute(0,2,1)
263
+ tmp_attention = torch.where(tmp_attention < 0., -10000, 0)
264
+ attention_mask_2d = attention_mask_2d * tmp_attention
265
+ attention_mask_2d = torch.where(attention_mask_2d.abs() < 1.,0., -10000.).type(attention_mask_2d.dtype)
266
+ attention_mask = attention_mask_2d
267
+
268
+ else:
269
+ tmp = 1
270
+ orig_attn_mask = None
271
+
272
+ if self.use_ada_layer_norm:
273
+ norm_hidden_states = self.norm1(hidden_states, timestep)
274
+ elif self.use_ada_layer_norm_zero:
275
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
276
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
277
+ )
278
+ else:
279
+ norm_hidden_states = self.norm1(hidden_states)
280
+
281
+ # 1. Retrieve lora scale.
282
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
283
+
284
+ # 2. Prepare GLIGEN inputs
285
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
286
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
287
+
288
+
289
+ # breakpoint()
290
+
291
+ ## self-attention amongst fg
292
+ attn_output = self.attn1(
293
+ norm_hidden_states, # + tmp,
294
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
295
+ attention_mask=attention_mask,
296
+ **cross_attention_kwargs,
297
+ )
298
+
299
+
300
+ if self.use_ada_layer_norm_zero:
301
+ attn_output = gate_msa.unsqueeze(1) * attn_output
302
+ hidden_states = attn_output + hidden_states
303
+
304
+ if attention_mask is not None:
305
+ tmp = 1-tmp
306
+
307
+ # 2.5 GLIGEN Control
308
+ if gligen_kwargs is not None:
309
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
310
+ # 2.5 ends
311
+
312
+ # 3. Cross-Attention
313
+ if self.attn2 is not None:
314
+ norm_hidden_states = (
315
+ self.norm2(hidden_states*tmp, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states*tmp)
316
+ )
317
+
318
+
319
+ if encoder_attention_mask is None:
320
+ attn_output = self.attn2(
321
+ norm_hidden_states,
322
+ encoder_hidden_states=encoder_hidden_states,
323
+ attention_mask=encoder_attention_mask,
324
+ **cross_attention_kwargs,
325
+ )
326
+
327
+ if encoder_attention_mask is not None: # Encoder attention mask is not None
328
+
329
+ if 'block_diagonal_attention' in kwargs and kwargs['block_diagonal_attention'] == True:
330
+
331
+ if not isinstance(orig_attn_mask, list):
332
+ orig_attn_mask = torch.where(orig_attn_mask < 0., 0., -10000.).type(orig_attn_mask.dtype).to(orig_attn_mask.device)
333
+ encoder_attention_mask_2d = encoder_attention_mask + orig_attn_mask.permute(0,2,1)
334
+ encoder_attention_mask_2d = torch.where(encoder_attention_mask_2d < 0., -10000, 0).type(encoder_attention_mask.dtype)
335
+
336
+ inverted_encoder_attention_mask = torch.where(encoder_attention_mask < 0., 0., -10000.).type(encoder_attention_mask.dtype)
337
+ inverted_encoder_attention_mask[:,:,0] = -10000 # CLS token
338
+
339
+ inverted_orig_mask = torch.where(orig_attn_mask < 0., 0., -10000.).type(orig_attn_mask.dtype)
340
+ inverted_encoder_attention_mask_2d = inverted_encoder_attention_mask + inverted_orig_mask.permute(0,2,1)
341
+
342
+ encoder_attention_mask_2d = encoder_attention_mask_2d * inverted_encoder_attention_mask_2d
343
+ encoder_attention_mask_2d = torch.where(encoder_attention_mask_2d.abs() < 1.,0., -10000.).type(encoder_attention_mask.dtype)
344
+
345
+ encoder_attention_mask = encoder_attention_mask_2d
346
+ else:
347
+ orig_attn_mask = [torch.where(orig_attn_mask_ < 0., 0., -10000.).type(orig_attn_mask_.dtype).to(orig_attn_mask_.device) for orig_attn_mask_ in orig_attn_mask]
348
+ encoder_attention_mask_2d = [encoder_attention_mask_ + orig_attn_mask_.permute(0,2,1) for encoder_attention_mask_, orig_attn_mask_ in zip(encoder_attention_mask, orig_attn_mask)]
349
+ encoder_attention_mask_2d = [torch.where(encoder_attention_mask_2d_ < 0., -10000, 0).type(encoder_attention_mask_2d_.dtype) for encoder_attention_mask_2d_ in encoder_attention_mask_2d]
350
+
351
+ inverted_encoder_attention_mask = torch.where(torch.sum(torch.stack(encoder_attention_mask, dim=0),dim=0) < 0., 0., -10000.).type(encoder_attention_mask[0].dtype)
352
+ inverted_encoder_attention_mask[:,:,0] = -10000 # CLS token
353
+
354
+ inverted_orig_mask = torch.where(torch.sum(torch.stack(orig_attn_mask,dim=0),dim=0) < 0., 0., -10000.).type(orig_attn_mask[0].dtype)
355
+ inverted_encoder_attention_mask_2d = inverted_encoder_attention_mask + inverted_orig_mask.permute(0,2,1)
356
+
357
+ encoder_attention_mask_2d = torch.where(torch.sum(torch.stack(encoder_attention_mask_2d, dim=0), dim=0) < 0., -10000., 0.)
358
+ encoder_attention_mask_2d = encoder_attention_mask_2d * inverted_encoder_attention_mask_2d
359
+ encoder_attention_mask_2d = torch.where(encoder_attention_mask_2d.abs() < 1.,0., -10000.).type(encoder_attention_mask[0].dtype)
360
+
361
+ encoder_attention_mask = encoder_attention_mask_2d
362
+
363
+ norm_hidden_states = (
364
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
365
+ )
366
+ ## cross-attention amongst bg
367
+ attn_output = self.attn2(
368
+ norm_hidden_states,
369
+ encoder_hidden_states=encoder_hidden_states,
370
+ attention_mask=encoder_attention_mask,
371
+ **cross_attention_kwargs,
372
+ )
373
+
374
+ del encoder_attention_mask_2d, inverted_encoder_attention_mask, inverted_encoder_attention_mask_2d, inverted_orig_mask, orig_attn_mask, attention_mask_2d, tmp_attention
375
+ torch.cuda.empty_cache()
376
+
377
+ hidden_states = attn_output + hidden_states
378
+
379
+ else:
380
+ norm_hidden_states2 = (
381
+ self.norm2(hidden_states*(1-tmp), timestep) if self.use_ada_layer_norm else self.norm2(hidden_states*(1-tmp))
382
+ )
383
+ encoder_attention_mask2 = torch.where(encoder_attention_mask < 0., 0., -10000.).type(encoder_attention_mask.dtype).to(encoder_attention_mask.device)
384
+ encoder_attention_mask2[:, :, 0] = -10000
385
+ attn_output2 = self.attn2(
386
+ norm_hidden_states2,
387
+ encoder_hidden_states=encoder_hidden_states,
388
+ attention_mask=encoder_attention_mask2,
389
+ **cross_attention_kwargs,
390
+ )
391
+
392
+ hidden_states = attn_output*tmp + attn_output2*(1-tmp)+ hidden_states
393
+ else:
394
+ hidden_states = attn_output*tmp + hidden_states
395
+
396
+ # 4. Feed-forward
397
+ norm_hidden_states = self.norm3(hidden_states)
398
+
399
+ if self.use_ada_layer_norm_zero:
400
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
401
+
402
+ if self._chunk_size is not None:
403
+ # "feed_forward_chunk_size" can be used to save memory
404
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
405
+ raise ValueError(
406
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
407
+ )
408
+
409
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
410
+ ff_output = torch.cat(
411
+ [
412
+ self.ff(hid_slice, scale=lora_scale)
413
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
414
+ ],
415
+ dim=self._chunk_dim,
416
+ )
417
+ else:
418
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
419
+
420
+ if self.use_ada_layer_norm_zero:
421
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
422
+
423
+ hidden_states = ff_output + hidden_states
424
+
425
+ return hidden_states
426
+
427
+ class FeedForward(nn.Module):
428
+ r"""
429
+ A feed-forward layer.
430
+
431
+ Parameters:
432
+ dim (`int`): The number of channels in the input.
433
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
434
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
435
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
436
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
437
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
438
+ """
439
+
440
+ def __init__(
441
+ self,
442
+ dim: int,
443
+ dim_out: Optional[int] = None,
444
+ mult: int = 4,
445
+ dropout: float = 0.0,
446
+ activation_fn: str = "geglu",
447
+ final_dropout: bool = False,
448
+ ):
449
+ super().__init__()
450
+ inner_dim = int(dim * mult)
451
+ dim_out = dim_out if dim_out is not None else dim
452
+
453
+ if activation_fn == "gelu":
454
+ act_fn = GELU(dim, inner_dim)
455
+ if activation_fn == "gelu-approximate":
456
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
457
+ elif activation_fn == "geglu":
458
+ act_fn = GEGLU(dim, inner_dim)
459
+ elif activation_fn == "geglu-approximate":
460
+ act_fn = ApproximateGELU(dim, inner_dim)
461
+
462
+ self.net = nn.ModuleList([])
463
+ # project in
464
+ self.net.append(act_fn)
465
+ # project dropout
466
+ self.net.append(nn.Dropout(dropout))
467
+ # project out
468
+ self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
469
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
470
+ if final_dropout:
471
+ self.net.append(nn.Dropout(dropout))
472
+
473
+ def forward(self, hidden_states, scale: float = 1.0):
474
+ for module in self.net:
475
+ if isinstance(module, (LoRACompatibleLinear, GEGLU)):
476
+ hidden_states = module(hidden_states, scale)
477
+ else:
478
+ hidden_states = module(hidden_states)
479
+ return hidden_states
480
+
481
+
482
+ class GELU(nn.Module):
483
+ r"""
484
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
485
+ """
486
+
487
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
488
+ super().__init__()
489
+ self.proj = nn.Linear(dim_in, dim_out)
490
+ self.approximate = approximate
491
+
492
+ def gelu(self, gate):
493
+ if gate.device.type != "mps":
494
+ return F.gelu(gate, approximate=self.approximate)
495
+ # mps: gelu is not implemented for float16
496
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
497
+
498
+ def forward(self, hidden_states):
499
+ hidden_states = self.proj(hidden_states)
500
+ hidden_states = self.gelu(hidden_states)
501
+ return hidden_states
502
+
503
+
504
+ class GEGLU(nn.Module):
505
+ r"""
506
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
507
+
508
+ Parameters:
509
+ dim_in (`int`): The number of channels in the input.
510
+ dim_out (`int`): The number of channels in the output.
511
+ """
512
+
513
+ def __init__(self, dim_in: int, dim_out: int):
514
+ super().__init__()
515
+ self.proj = LoRACompatibleLinear(dim_in, dim_out * 2)
516
+
517
+ def gelu(self, gate):
518
+ if gate.device.type != "mps":
519
+ return F.gelu(gate)
520
+ # mps: gelu is not implemented for float16
521
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
522
+
523
+ def forward(self, hidden_states, scale: float = 1.0):
524
+ hidden_states, gate = self.proj(hidden_states, scale).chunk(2, dim=-1)
525
+ return hidden_states * self.gelu(gate)
526
+
527
+
528
+ class ApproximateGELU(nn.Module):
529
+ """
530
+ The approximate form of Gaussian Error Linear Unit (GELU)
531
+
532
+ For more details, see section 2: https://arxiv.org/abs/1606.08415
533
+ """
534
+
535
+ def __init__(self, dim_in: int, dim_out: int):
536
+ super().__init__()
537
+ self.proj = nn.Linear(dim_in, dim_out)
538
+
539
+ def forward(self, x):
540
+ x = self.proj(x)
541
+ return x * torch.sigmoid(1.702 * x)
542
+
543
+
544
+ class AdaLayerNorm(nn.Module):
545
+ """
546
+ Norm layer modified to incorporate timestep embeddings.
547
+ """
548
+
549
+ def __init__(self, embedding_dim, num_embeddings):
550
+ super().__init__()
551
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
552
+ self.silu = nn.SiLU()
553
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
554
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
555
+
556
+ def forward(self, x, timestep):
557
+ emb = self.linear(self.silu(self.emb(timestep)))
558
+ scale, shift = torch.chunk(emb, 2)
559
+ x = self.norm(x) * (1 + scale) + shift
560
+ return x
561
+
562
+
563
+ class AdaLayerNormZero(nn.Module):
564
+ """
565
+ Norm layer adaptive layer norm zero (adaLN-Zero).
566
+ """
567
+
568
+ def __init__(self, embedding_dim, num_embeddings):
569
+ super().__init__()
570
+
571
+ self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
572
+
573
+ self.silu = nn.SiLU()
574
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
575
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
576
+
577
+ def forward(self, x, timestep, class_labels, hidden_dtype=None):
578
+ emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
579
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
580
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
581
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
582
+
583
+
584
+ class AdaGroupNorm(nn.Module):
585
+ """
586
+ GroupNorm layer modified to incorporate timestep embeddings.
587
+ """
588
+
589
+ def __init__(
590
+ self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
591
+ ):
592
+ super().__init__()
593
+ self.num_groups = num_groups
594
+ self.eps = eps
595
+
596
+ if act_fn is None:
597
+ self.act = None
598
+ else:
599
+ self.act = get_activation(act_fn)
600
+
601
+ self.linear = nn.Linear(embedding_dim, out_dim * 2)
602
+
603
+ def forward(self, x, emb):
604
+ if self.act:
605
+ emb = self.act(emb)
606
+ emb = self.linear(emb)
607
+ emb = emb[:, :, None, None]
608
+ scale, shift = emb.chunk(2, dim=1)
609
+
610
+ x = F.group_norm(x, self.num_groups, eps=self.eps)
611
+ x = x * (1 + scale) + shift
612
+ return x
src/models/attention_processor.py ADDED
@@ -0,0 +1,1662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from importlib import import_module
15
+ from typing import Callable, Optional, Union
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.utils import deprecate, logging
22
+ from diffusers.utils.import_utils import is_xformers_available
23
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
24
+ from diffusers.models.lora import LoRACompatibleLinear, LoRALinearLayer
25
+ import torchvision
26
+ import math
27
+
28
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
+
30
+
31
+ if is_xformers_available():
32
+ import xformers
33
+ import xformers.ops
34
+ else:
35
+ xformers = None
36
+
37
+
38
+ @maybe_allow_in_graph
39
+ class Attention(nn.Module):
40
+ r"""
41
+ A cross attention layer.
42
+
43
+ Parameters:
44
+ query_dim (`int`): The number of channels in the query.
45
+ cross_attention_dim (`int`, *optional*):
46
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
47
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
48
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
49
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
50
+ bias (`bool`, *optional*, defaults to False):
51
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ query_dim: int,
57
+ cross_attention_dim: Optional[int] = None,
58
+ heads: int = 8,
59
+ dim_head: int = 64,
60
+ dropout: float = 0.0,
61
+ bias=False,
62
+ upcast_attention: bool = False,
63
+ upcast_softmax: bool = False,
64
+ cross_attention_norm: Optional[str] = None,
65
+ cross_attention_norm_num_groups: int = 32,
66
+ added_kv_proj_dim: Optional[int] = None,
67
+ norm_num_groups: Optional[int] = None,
68
+ spatial_norm_dim: Optional[int] = None,
69
+ out_bias: bool = True,
70
+ scale_qk: bool = True,
71
+ only_cross_attention: bool = False,
72
+ eps: float = 1e-5,
73
+ rescale_output_factor: float = 1.0,
74
+ residual_connection: bool = False,
75
+ _from_deprecated_attn_block=False,
76
+ processor: Optional["AttnProcessor"] = None,
77
+ ):
78
+ super().__init__()
79
+ self.inner_dim = dim_head * heads
80
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
81
+ self.upcast_attention = upcast_attention
82
+ self.upcast_softmax = upcast_softmax
83
+ self.rescale_output_factor = rescale_output_factor
84
+ self.residual_connection = residual_connection
85
+ self.dropout = dropout
86
+
87
+ # we make use of this private variable to know whether this class is loaded
88
+ # with an deprecated state dict so that we can convert it on the fly
89
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
90
+
91
+ self.scale_qk = scale_qk
92
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
93
+
94
+ self.heads = heads
95
+ # for slice_size > 0 the attention score computation
96
+ # is split across the batch axis to save memory
97
+ # You can set slice_size with `set_attention_slice`
98
+ self.sliceable_head_dim = heads
99
+
100
+ self.added_kv_proj_dim = added_kv_proj_dim
101
+ self.only_cross_attention = only_cross_attention
102
+
103
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
104
+ raise ValueError(
105
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
106
+ )
107
+
108
+ if norm_num_groups is not None:
109
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
110
+ else:
111
+ self.group_norm = None
112
+
113
+ if spatial_norm_dim is not None:
114
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
115
+ else:
116
+ self.spatial_norm = None
117
+
118
+ if cross_attention_norm is None:
119
+ self.norm_cross = None
120
+ elif cross_attention_norm == "layer_norm":
121
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
122
+ elif cross_attention_norm == "group_norm":
123
+ if self.added_kv_proj_dim is not None:
124
+ # The given `encoder_hidden_states` are initially of shape
125
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
126
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
127
+ # before the projection, so we need to use `added_kv_proj_dim` as
128
+ # the number of channels for the group norm.
129
+ norm_cross_num_channels = added_kv_proj_dim
130
+ else:
131
+ norm_cross_num_channels = self.cross_attention_dim
132
+
133
+ self.norm_cross = nn.GroupNorm(
134
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
135
+ )
136
+ else:
137
+ raise ValueError(
138
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
139
+ )
140
+
141
+ self.to_q = LoRACompatibleLinear(query_dim, self.inner_dim, bias=bias)
142
+
143
+ if not self.only_cross_attention:
144
+ # only relevant for the `AddedKVProcessor` classes
145
+ self.to_k = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias)
146
+ self.to_v = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias)
147
+ else:
148
+ self.to_k = None
149
+ self.to_v = None
150
+
151
+ if self.added_kv_proj_dim is not None:
152
+ self.add_k_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim)
153
+ self.add_v_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim)
154
+
155
+ self.to_out = nn.ModuleList([])
156
+ self.to_out.append(LoRACompatibleLinear(self.inner_dim, query_dim, bias=out_bias))
157
+ self.to_out.append(nn.Dropout(dropout))
158
+
159
+ # set attention processor
160
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
161
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
162
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
163
+ if processor is None:
164
+ processor = (
165
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
166
+ )
167
+ self.set_processor(processor)
168
+
169
+ def set_use_memory_efficient_attention_xformers(
170
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
171
+ ):
172
+ is_lora = hasattr(self, "processor") and isinstance(
173
+ self.processor,
174
+ LORA_ATTENTION_PROCESSORS,
175
+ )
176
+ is_custom_diffusion = hasattr(self, "processor") and isinstance(
177
+ self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
178
+ )
179
+ is_added_kv_processor = hasattr(self, "processor") and isinstance(
180
+ self.processor,
181
+ (
182
+ AttnAddedKVProcessor,
183
+ AttnAddedKVProcessor2_0,
184
+ SlicedAttnAddedKVProcessor,
185
+ XFormersAttnAddedKVProcessor,
186
+ LoRAAttnAddedKVProcessor,
187
+ ),
188
+ )
189
+
190
+ if use_memory_efficient_attention_xformers:
191
+ if is_added_kv_processor and (is_lora or is_custom_diffusion):
192
+ raise NotImplementedError(
193
+ f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}"
194
+ )
195
+ if not is_xformers_available():
196
+ raise ModuleNotFoundError(
197
+ (
198
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
199
+ " xformers"
200
+ ),
201
+ name="xformers",
202
+ )
203
+ elif not torch.cuda.is_available():
204
+ raise ValueError(
205
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
206
+ " only available for GPU "
207
+ )
208
+ else:
209
+ try:
210
+ # Make sure we can run the memory efficient attention
211
+ _ = xformers.ops.memory_efficient_attention(
212
+ torch.randn((1, 2, 40), device="cuda"),
213
+ torch.randn((1, 2, 40), device="cuda"),
214
+ torch.randn((1, 2, 40), device="cuda"),
215
+ )
216
+ except Exception as e:
217
+ raise e
218
+
219
+ if is_lora:
220
+ # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
221
+ # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
222
+ processor = LoRAXFormersAttnProcessor(
223
+ hidden_size=self.processor.hidden_size,
224
+ cross_attention_dim=self.processor.cross_attention_dim,
225
+ rank=self.processor.rank,
226
+ attention_op=attention_op,
227
+ )
228
+ processor.load_state_dict(self.processor.state_dict())
229
+ processor.to(self.processor.to_q_lora.up.weight.device)
230
+ elif is_custom_diffusion:
231
+ processor = CustomDiffusionXFormersAttnProcessor(
232
+ train_kv=self.processor.train_kv,
233
+ train_q_out=self.processor.train_q_out,
234
+ hidden_size=self.processor.hidden_size,
235
+ cross_attention_dim=self.processor.cross_attention_dim,
236
+ attention_op=attention_op,
237
+ )
238
+ processor.load_state_dict(self.processor.state_dict())
239
+ if hasattr(self.processor, "to_k_custom_diffusion"):
240
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
241
+ elif is_added_kv_processor:
242
+ # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
243
+ # which uses this type of cross attention ONLY because the attention mask of format
244
+ # [0, ..., -10.000, ..., 0, ...,] is not supported
245
+ # throw warning
246
+ logger.info(
247
+ "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
248
+ )
249
+ processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
250
+ else:
251
+ processor = XFormersAttnProcessor(attention_op=attention_op)
252
+ else:
253
+ if is_lora:
254
+ attn_processor_class = (
255
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
256
+ )
257
+ processor = attn_processor_class(
258
+ hidden_size=self.processor.hidden_size,
259
+ cross_attention_dim=self.processor.cross_attention_dim,
260
+ rank=self.processor.rank,
261
+ )
262
+ processor.load_state_dict(self.processor.state_dict())
263
+ processor.to(self.processor.to_q_lora.up.weight.device)
264
+ elif is_custom_diffusion:
265
+ processor = CustomDiffusionAttnProcessor(
266
+ train_kv=self.processor.train_kv,
267
+ train_q_out=self.processor.train_q_out,
268
+ hidden_size=self.processor.hidden_size,
269
+ cross_attention_dim=self.processor.cross_attention_dim,
270
+ )
271
+ processor.load_state_dict(self.processor.state_dict())
272
+ if hasattr(self.processor, "to_k_custom_diffusion"):
273
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
274
+ else:
275
+ # set attention processor
276
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
277
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
278
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
279
+ processor = (
280
+ AttnProcessor2_0()
281
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
282
+ else AttnProcessor()
283
+ )
284
+
285
+ self.set_processor(processor)
286
+
287
+ def set_attention_slice(self, slice_size):
288
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
289
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
290
+
291
+ if slice_size is not None and self.added_kv_proj_dim is not None:
292
+ processor = SlicedAttnAddedKVProcessor(slice_size)
293
+ elif slice_size is not None:
294
+ processor = SlicedAttnProcessor(slice_size)
295
+ elif self.added_kv_proj_dim is not None:
296
+ processor = AttnAddedKVProcessor()
297
+ else:
298
+ # set attention processor
299
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
300
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
301
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
302
+ processor = (
303
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
304
+ )
305
+
306
+ self.set_processor(processor)
307
+
308
+ def set_processor(self, processor: "AttnProcessor", _remove_lora=False):
309
+ if hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
310
+ deprecate(
311
+ "set_processor to offload LoRA",
312
+ "0.26.0",
313
+ "In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
314
+ )
315
+ # TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete
316
+ # We need to remove all LoRA layers
317
+ # Don't forget to remove ALL `_remove_lora` from the codebase
318
+ for module in self.modules():
319
+ if hasattr(module, "set_lora_layer"):
320
+ module.set_lora_layer(None)
321
+
322
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
323
+ # pop `processor` from `self._modules`
324
+ if (
325
+ hasattr(self, "processor")
326
+ and isinstance(self.processor, torch.nn.Module)
327
+ and not isinstance(processor, torch.nn.Module)
328
+ ):
329
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
330
+ self._modules.pop("processor")
331
+
332
+ self.processor = processor
333
+
334
+ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
335
+ if not return_deprecated_lora:
336
+ return self.processor
337
+
338
+ # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
339
+ # serialization format for LoRA Attention Processors. It should be deleted once the integration
340
+ # with PEFT is completed.
341
+ is_lora_activated = {
342
+ name: module.lora_layer is not None
343
+ for name, module in self.named_modules()
344
+ if hasattr(module, "lora_layer")
345
+ }
346
+
347
+ # 1. if no layer has a LoRA activated we can return the processor as usual
348
+ if not any(is_lora_activated.values()):
349
+ return self.processor
350
+
351
+ # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
352
+ is_lora_activated.pop("add_k_proj", None)
353
+ is_lora_activated.pop("add_v_proj", None)
354
+ # 2. else it is not posssible that only some layers have LoRA activated
355
+ if not all(is_lora_activated.values()):
356
+ raise ValueError(
357
+ f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
358
+ )
359
+
360
+ # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
361
+ non_lora_processor_cls_name = self.processor.__class__.__name__
362
+ lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
363
+
364
+ hidden_size = self.inner_dim
365
+
366
+ # now create a LoRA attention processor from the LoRA layers
367
+ if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
368
+ kwargs = {
369
+ "cross_attention_dim": self.cross_attention_dim,
370
+ "rank": self.to_q.lora_layer.rank,
371
+ "network_alpha": self.to_q.lora_layer.network_alpha,
372
+ "q_rank": self.to_q.lora_layer.rank,
373
+ "q_hidden_size": self.to_q.lora_layer.out_features,
374
+ "k_rank": self.to_k.lora_layer.rank,
375
+ "k_hidden_size": self.to_k.lora_layer.out_features,
376
+ "v_rank": self.to_v.lora_layer.rank,
377
+ "v_hidden_size": self.to_v.lora_layer.out_features,
378
+ "out_rank": self.to_out[0].lora_layer.rank,
379
+ "out_hidden_size": self.to_out[0].lora_layer.out_features,
380
+ }
381
+
382
+ if hasattr(self.processor, "attention_op"):
383
+ kwargs["attention_op"] = self.prcoessor.attention_op
384
+
385
+ lora_processor = lora_processor_cls(hidden_size, **kwargs)
386
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
387
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
388
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
389
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
390
+ elif lora_processor_cls == LoRAAttnAddedKVProcessor:
391
+ lora_processor = lora_processor_cls(
392
+ hidden_size,
393
+ cross_attention_dim=self.add_k_proj.weight.shape[0],
394
+ rank=self.to_q.lora_layer.rank,
395
+ network_alpha=self.to_q.lora_layer.network_alpha,
396
+ )
397
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
398
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
399
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
400
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
401
+
402
+ # only save if used
403
+ if self.add_k_proj.lora_layer is not None:
404
+ lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
405
+ lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
406
+ else:
407
+ lora_processor.add_k_proj_lora = None
408
+ lora_processor.add_v_proj_lora = None
409
+ else:
410
+ raise ValueError(f"{lora_processor_cls} does not exist.")
411
+
412
+ return lora_processor
413
+
414
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
415
+ # The `Attention` class can call different attention processors / attention functions
416
+ # here we simply pass along all tensors to the selected processor class
417
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
418
+ return self.processor(
419
+ self,
420
+ hidden_states,
421
+ encoder_hidden_states=encoder_hidden_states,
422
+ attention_mask=attention_mask,
423
+ **cross_attention_kwargs,
424
+ )
425
+
426
+ def batch_to_head_dim(self, tensor):
427
+ head_size = self.heads
428
+ batch_size, seq_len, dim = tensor.shape
429
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
430
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
431
+ return tensor
432
+
433
+ def head_to_batch_dim(self, tensor, out_dim=3):
434
+ head_size = self.heads
435
+ batch_size, seq_len, dim = tensor.shape
436
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
437
+ tensor = tensor.permute(0, 2, 1, 3)
438
+
439
+ if out_dim == 3:
440
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
441
+
442
+ return tensor
443
+
444
+ def get_attention_scores(self, query, key, attention_mask=None):
445
+ dtype = query.dtype
446
+ if self.upcast_attention:
447
+ query = query.float()
448
+ key = key.float()
449
+
450
+ if attention_mask is None:
451
+ baddbmm_input = torch.empty(
452
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
453
+ )
454
+ beta = 0
455
+ else:
456
+ # breakpoint()
457
+ baddbmm_input = attention_mask
458
+ beta = 1
459
+
460
+ attention_scores = torch.baddbmm(
461
+ baddbmm_input,
462
+ query,
463
+ key.transpose(-1, -2),
464
+ beta=beta,
465
+ alpha=self.scale,
466
+ )
467
+ del baddbmm_input
468
+
469
+ if self.upcast_softmax:
470
+ attention_scores = attention_scores.float()
471
+
472
+ attention_probs = attention_scores.softmax(dim=-1)
473
+ del attention_scores
474
+
475
+ attention_probs = attention_probs.to(dtype)
476
+
477
+ return attention_probs
478
+
479
+ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
480
+ if batch_size is None:
481
+ deprecate(
482
+ "batch_size=None",
483
+ "0.22.0",
484
+ (
485
+ "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
486
+ " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
487
+ " `prepare_attention_mask` when preparing the attention_mask."
488
+ ),
489
+ )
490
+ batch_size = 1
491
+
492
+ head_size = self.heads
493
+ if attention_mask is None:
494
+ return attention_mask
495
+
496
+ current_length: int = attention_mask.shape[-1]
497
+ if current_length != target_length:
498
+ if attention_mask.device.type == "mps":
499
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
500
+ # Instead, we can manually construct the padding tensor.
501
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
502
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
503
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
504
+ else:
505
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
506
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
507
+ # remaining_length: int = target_length - current_length
508
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
509
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
510
+
511
+ if out_dim == 3:
512
+ if attention_mask.shape[0] < batch_size * head_size:
513
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
514
+ elif out_dim == 4:
515
+ attention_mask = attention_mask.unsqueeze(1)
516
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
517
+
518
+ return attention_mask
519
+
520
+ def norm_encoder_hidden_states(self, encoder_hidden_states):
521
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
522
+
523
+ if isinstance(self.norm_cross, nn.LayerNorm):
524
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
525
+ elif isinstance(self.norm_cross, nn.GroupNorm):
526
+ # Group norm norms along the channels dimension and expects
527
+ # input to be in the shape of (N, C, *). In this case, we want
528
+ # to norm along the hidden dimension, so we need to move
529
+ # (batch_size, sequence_length, hidden_size) ->
530
+ # (batch_size, hidden_size, sequence_length)
531
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
532
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
533
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
534
+ else:
535
+ assert False
536
+
537
+ return encoder_hidden_states
538
+
539
+
540
+ class AttnProcessor:
541
+ r"""
542
+ Default processor for performing attention-related computations.
543
+ """
544
+
545
+ def __call__(
546
+ self,
547
+ attn: Attention,
548
+ hidden_states,
549
+ encoder_hidden_states=None,
550
+ attention_mask=None,
551
+ temb=None,
552
+ scale=1.0,
553
+ ):
554
+ residual = hidden_states
555
+
556
+ if attn.spatial_norm is not None:
557
+ hidden_states = attn.spatial_norm(hidden_states, temb)
558
+
559
+ input_ndim = hidden_states.ndim
560
+
561
+ if input_ndim == 4:
562
+ batch_size, channel, height, width = hidden_states.shape
563
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
564
+
565
+ batch_size, sequence_length, _ = (
566
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
567
+ )
568
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
569
+
570
+ if attn.group_norm is not None:
571
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
572
+
573
+ query = attn.to_q(hidden_states, scale=scale)
574
+
575
+ if encoder_hidden_states is None:
576
+ encoder_hidden_states = hidden_states
577
+ elif attn.norm_cross:
578
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
579
+
580
+ key = attn.to_k(encoder_hidden_states, scale=scale)
581
+ value = attn.to_v(encoder_hidden_states, scale=scale)
582
+
583
+ query = attn.head_to_batch_dim(query)
584
+ key = attn.head_to_batch_dim(key)
585
+ value = attn.head_to_batch_dim(value)
586
+
587
+ # try:
588
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
589
+ # except Exception as e:
590
+ # breakpoint()
591
+ hidden_states = torch.bmm(attention_probs, value)
592
+ hidden_states = attn.batch_to_head_dim(hidden_states)
593
+
594
+
595
+ # linear proj
596
+ hidden_states = attn.to_out[0](hidden_states, scale=scale)
597
+ # dropout
598
+ hidden_states = attn.to_out[1](hidden_states)
599
+
600
+ if input_ndim == 4:
601
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
602
+
603
+ if attn.residual_connection:
604
+ hidden_states = hidden_states + residual
605
+
606
+ hidden_states = hidden_states / attn.rescale_output_factor
607
+
608
+ return hidden_states
609
+
610
+
611
+ class CustomDiffusionAttnProcessor(nn.Module):
612
+ r"""
613
+ Processor for implementing attention for the Custom Diffusion method.
614
+
615
+ Args:
616
+ train_kv (`bool`, defaults to `True`):
617
+ Whether to newly train the key and value matrices corresponding to the text features.
618
+ train_q_out (`bool`, defaults to `True`):
619
+ Whether to newly train query matrices corresponding to the latent image features.
620
+ hidden_size (`int`, *optional*, defaults to `None`):
621
+ The hidden size of the attention layer.
622
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
623
+ The number of channels in the `encoder_hidden_states`.
624
+ out_bias (`bool`, defaults to `True`):
625
+ Whether to include the bias parameter in `train_q_out`.
626
+ dropout (`float`, *optional*, defaults to 0.0):
627
+ The dropout probability to use.
628
+ """
629
+
630
+ def __init__(
631
+ self,
632
+ train_kv=True,
633
+ train_q_out=True,
634
+ hidden_size=None,
635
+ cross_attention_dim=None,
636
+ out_bias=True,
637
+ dropout=0.0,
638
+ ):
639
+ super().__init__()
640
+ self.train_kv = train_kv
641
+ self.train_q_out = train_q_out
642
+
643
+ self.hidden_size = hidden_size
644
+ self.cross_attention_dim = cross_attention_dim
645
+
646
+ # `_custom_diffusion` id for easy serialization and loading.
647
+ if self.train_kv:
648
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
649
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
650
+ if self.train_q_out:
651
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
652
+ self.to_out_custom_diffusion = nn.ModuleList([])
653
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
654
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
655
+
656
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
657
+ batch_size, sequence_length, _ = hidden_states.shape
658
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
659
+ if self.train_q_out:
660
+ query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
661
+ else:
662
+ query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
663
+
664
+ if encoder_hidden_states is None:
665
+ crossattn = False
666
+ encoder_hidden_states = hidden_states
667
+ else:
668
+ crossattn = True
669
+ if attn.norm_cross:
670
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
671
+
672
+ if self.train_kv:
673
+ key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
674
+ value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
675
+ key = key.to(attn.to_q.weight.dtype)
676
+ value = value.to(attn.to_q.weight.dtype)
677
+ else:
678
+ key = attn.to_k(encoder_hidden_states)
679
+ value = attn.to_v(encoder_hidden_states)
680
+
681
+ if crossattn:
682
+ detach = torch.ones_like(key)
683
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
684
+ key = detach * key + (1 - detach) * key.detach()
685
+ value = detach * value + (1 - detach) * value.detach()
686
+
687
+ query = attn.head_to_batch_dim(query)
688
+ key = attn.head_to_batch_dim(key)
689
+ value = attn.head_to_batch_dim(value)
690
+
691
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
692
+ hidden_states = torch.bmm(attention_probs, value)
693
+ hidden_states = attn.batch_to_head_dim(hidden_states)
694
+
695
+ if self.train_q_out:
696
+ # linear proj
697
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
698
+ # dropout
699
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
700
+ else:
701
+ # linear proj
702
+ hidden_states = attn.to_out[0](hidden_states)
703
+ # dropout
704
+ hidden_states = attn.to_out[1](hidden_states)
705
+
706
+ return hidden_states
707
+
708
+
709
+ class AttnAddedKVProcessor:
710
+ r"""
711
+ Processor for performing attention-related computations with extra learnable key and value matrices for the text
712
+ encoder.
713
+ """
714
+
715
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
716
+ residual = hidden_states
717
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
718
+ batch_size, sequence_length, _ = hidden_states.shape
719
+
720
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
721
+
722
+ if encoder_hidden_states is None:
723
+ encoder_hidden_states = hidden_states
724
+ elif attn.norm_cross:
725
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
726
+
727
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
728
+
729
+ query = attn.to_q(hidden_states, scale=scale)
730
+ query = attn.head_to_batch_dim(query)
731
+
732
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, scale=scale)
733
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, scale=scale)
734
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
735
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
736
+
737
+ if not attn.only_cross_attention:
738
+ key = attn.to_k(hidden_states, scale=scale)
739
+ value = attn.to_v(hidden_states, scale=scale)
740
+ key = attn.head_to_batch_dim(key)
741
+ value = attn.head_to_batch_dim(value)
742
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
743
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
744
+ else:
745
+ key = encoder_hidden_states_key_proj
746
+ value = encoder_hidden_states_value_proj
747
+
748
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
749
+ hidden_states = torch.bmm(attention_probs, value)
750
+ hidden_states = attn.batch_to_head_dim(hidden_states)
751
+
752
+ # linear proj
753
+ hidden_states = attn.to_out[0](hidden_states, scale=scale)
754
+ # dropout
755
+ hidden_states = attn.to_out[1](hidden_states)
756
+
757
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
758
+ hidden_states = hidden_states + residual
759
+
760
+ return hidden_states
761
+
762
+
763
+ class AttnAddedKVProcessor2_0:
764
+ r"""
765
+ Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
766
+ learnable key and value matrices for the text encoder.
767
+ """
768
+
769
+ def __init__(self):
770
+ if not hasattr(F, "scaled_dot_product_attention"):
771
+ raise ImportError(
772
+ "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
773
+ )
774
+
775
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
776
+ residual = hidden_states
777
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
778
+ batch_size, sequence_length, _ = hidden_states.shape
779
+
780
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
781
+
782
+ if encoder_hidden_states is None:
783
+ encoder_hidden_states = hidden_states
784
+ elif attn.norm_cross:
785
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
786
+
787
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
788
+
789
+ query = attn.to_q(hidden_states, scale=scale)
790
+ query = attn.head_to_batch_dim(query, out_dim=4)
791
+
792
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
793
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
794
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
795
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
796
+
797
+ if not attn.only_cross_attention:
798
+ key = attn.to_k(hidden_states, scale=scale)
799
+ value = attn.to_v(hidden_states, scale=scale)
800
+ key = attn.head_to_batch_dim(key, out_dim=4)
801
+ value = attn.head_to_batch_dim(value, out_dim=4)
802
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
803
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
804
+ else:
805
+ key = encoder_hidden_states_key_proj
806
+ value = encoder_hidden_states_value_proj
807
+
808
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
809
+ # TODO: add support for attn.scale when we move to Torch 2.1
810
+ hidden_states = F.scaled_dot_product_attention(
811
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
812
+ )
813
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
814
+
815
+ # linear proj
816
+ hidden_states = attn.to_out[0](hidden_states, scale=scale)
817
+ # dropout
818
+ hidden_states = attn.to_out[1](hidden_states)
819
+
820
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
821
+ hidden_states = hidden_states + residual
822
+
823
+ return hidden_states
824
+
825
+
826
+ class XFormersAttnAddedKVProcessor:
827
+ r"""
828
+ Processor for implementing memory efficient attention using xFormers.
829
+
830
+ Args:
831
+ attention_op (`Callable`, *optional*, defaults to `None`):
832
+ The base
833
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
834
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
835
+ operator.
836
+ """
837
+
838
+ def __init__(self, attention_op: Optional[Callable] = None):
839
+ self.attention_op = attention_op
840
+
841
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
842
+ residual = hidden_states
843
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
844
+ batch_size, sequence_length, _ = hidden_states.shape
845
+
846
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
847
+
848
+ if encoder_hidden_states is None:
849
+ encoder_hidden_states = hidden_states
850
+ elif attn.norm_cross:
851
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
852
+
853
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
854
+
855
+ query = attn.to_q(hidden_states)
856
+ query = attn.head_to_batch_dim(query)
857
+
858
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
859
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
860
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
861
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
862
+
863
+ if not attn.only_cross_attention:
864
+ key = attn.to_k(hidden_states)
865
+ value = attn.to_v(hidden_states)
866
+ key = attn.head_to_batch_dim(key)
867
+ value = attn.head_to_batch_dim(value)
868
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
869
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
870
+ else:
871
+ key = encoder_hidden_states_key_proj
872
+ value = encoder_hidden_states_value_proj
873
+
874
+ hidden_states = xformers.ops.memory_efficient_attention(
875
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
876
+ )
877
+ hidden_states = hidden_states.to(query.dtype)
878
+ hidden_states = attn.batch_to_head_dim(hidden_states)
879
+
880
+ # linear proj
881
+ hidden_states = attn.to_out[0](hidden_states)
882
+ # dropout
883
+ hidden_states = attn.to_out[1](hidden_states)
884
+
885
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
886
+ hidden_states = hidden_states + residual
887
+
888
+ return hidden_states
889
+
890
+
891
+ class XFormersAttnProcessor:
892
+ r"""
893
+ Processor for implementing memory efficient attention using xFormers.
894
+
895
+ Args:
896
+ attention_op (`Callable`, *optional*, defaults to `None`):
897
+ The base
898
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
899
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
900
+ operator.
901
+ """
902
+
903
+ def __init__(self, attention_op: Optional[Callable] = None):
904
+ self.attention_op = attention_op
905
+
906
+ def __call__(
907
+ self,
908
+ attn: Attention,
909
+ hidden_states: torch.FloatTensor,
910
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
911
+ attention_mask: Optional[torch.FloatTensor] = None,
912
+ temb: Optional[torch.FloatTensor] = None,
913
+ scale: float = 1.0,
914
+ ):
915
+ residual = hidden_states
916
+
917
+ if attn.spatial_norm is not None:
918
+ hidden_states = attn.spatial_norm(hidden_states, temb)
919
+
920
+ input_ndim = hidden_states.ndim
921
+
922
+ if input_ndim == 4:
923
+ batch_size, channel, height, width = hidden_states.shape
924
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
925
+
926
+ batch_size, key_tokens, _ = (
927
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
928
+ )
929
+
930
+ attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
931
+ if attention_mask is not None:
932
+ # expand our mask's singleton query_tokens dimension:
933
+ # [batch*heads, 1, key_tokens] ->
934
+ # [batch*heads, query_tokens, key_tokens]
935
+ # so that it can be added as a bias onto the attention scores that xformers computes:
936
+ # [batch*heads, query_tokens, key_tokens]
937
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
938
+ _, query_tokens, _ = hidden_states.shape
939
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
940
+
941
+ if attn.group_norm is not None:
942
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
943
+
944
+ query = attn.to_q(hidden_states, scale=scale)
945
+
946
+ if encoder_hidden_states is None:
947
+ encoder_hidden_states = hidden_states
948
+ elif attn.norm_cross:
949
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
950
+
951
+ key = attn.to_k(encoder_hidden_states, scale=scale)
952
+ value = attn.to_v(encoder_hidden_states, scale=scale)
953
+
954
+ query = attn.head_to_batch_dim(query).contiguous()
955
+ key = attn.head_to_batch_dim(key).contiguous()
956
+ value = attn.head_to_batch_dim(value).contiguous()
957
+
958
+ hidden_states = xformers.ops.memory_efficient_attention(
959
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
960
+ )
961
+ hidden_states = hidden_states.to(query.dtype)
962
+ hidden_states = attn.batch_to_head_dim(hidden_states)
963
+
964
+ # linear proj
965
+ hidden_states = attn.to_out[0](hidden_states, scale=scale)
966
+ # dropout
967
+ hidden_states = attn.to_out[1](hidden_states)
968
+
969
+ if input_ndim == 4:
970
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
971
+
972
+ if attn.residual_connection:
973
+ hidden_states = hidden_states + residual
974
+
975
+ hidden_states = hidden_states / attn.rescale_output_factor
976
+
977
+ return hidden_states
978
+
979
+
980
+ class AttnProcessor2_0:
981
+ r"""
982
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
983
+ """
984
+
985
+ def __init__(self):
986
+ if not hasattr(F, "scaled_dot_product_attention"):
987
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
988
+
989
+ def __call__(
990
+ self,
991
+ attn: Attention,
992
+ hidden_states,
993
+ encoder_hidden_states=None,
994
+ attention_mask=None,
995
+ temb=None,
996
+ scale: float = 1.0,
997
+ ):
998
+ residual = hidden_states
999
+
1000
+ if attn.spatial_norm is not None:
1001
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1002
+
1003
+ input_ndim = hidden_states.ndim
1004
+
1005
+ if input_ndim == 4:
1006
+ batch_size, channel, height, width = hidden_states.shape
1007
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1008
+
1009
+ batch_size, sequence_length, _ = (
1010
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1011
+ )
1012
+
1013
+ if attention_mask is not None:
1014
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1015
+ # scaled_dot_product_attention expects attention_mask shape to be
1016
+ # (batch, heads, source_length, target_length)
1017
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1018
+
1019
+ if attn.group_norm is not None:
1020
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1021
+
1022
+ query = attn.to_q(hidden_states, scale=scale)
1023
+
1024
+ if encoder_hidden_states is None:
1025
+ encoder_hidden_states = hidden_states
1026
+ elif attn.norm_cross:
1027
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1028
+
1029
+ key = attn.to_k(encoder_hidden_states, scale=scale)
1030
+ value = attn.to_v(encoder_hidden_states, scale=scale)
1031
+
1032
+ inner_dim = key.shape[-1]
1033
+ head_dim = inner_dim // attn.heads
1034
+
1035
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1036
+
1037
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1038
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1039
+
1040
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1041
+ # TODO: add support for attn.scale when we move to Torch 2.1
1042
+ hidden_states = F.scaled_dot_product_attention(
1043
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1044
+ )
1045
+
1046
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1047
+ hidden_states = hidden_states.to(query.dtype)
1048
+
1049
+ # linear proj
1050
+ hidden_states = attn.to_out[0](hidden_states, scale=scale)
1051
+ # dropout
1052
+ hidden_states = attn.to_out[1](hidden_states)
1053
+
1054
+ if input_ndim == 4:
1055
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1056
+
1057
+ if attn.residual_connection:
1058
+ hidden_states = hidden_states + residual
1059
+
1060
+ hidden_states = hidden_states / attn.rescale_output_factor
1061
+
1062
+ return hidden_states
1063
+
1064
+
1065
+ class CustomDiffusionXFormersAttnProcessor(nn.Module):
1066
+ r"""
1067
+ Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
1068
+
1069
+ Args:
1070
+ train_kv (`bool`, defaults to `True`):
1071
+ Whether to newly train the key and value matrices corresponding to the text features.
1072
+ train_q_out (`bool`, defaults to `True`):
1073
+ Whether to newly train query matrices corresponding to the latent image features.
1074
+ hidden_size (`int`, *optional*, defaults to `None`):
1075
+ The hidden size of the attention layer.
1076
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
1077
+ The number of channels in the `encoder_hidden_states`.
1078
+ out_bias (`bool`, defaults to `True`):
1079
+ Whether to include the bias parameter in `train_q_out`.
1080
+ dropout (`float`, *optional*, defaults to 0.0):
1081
+ The dropout probability to use.
1082
+ attention_op (`Callable`, *optional*, defaults to `None`):
1083
+ The base
1084
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
1085
+ as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
1086
+ """
1087
+
1088
+ def __init__(
1089
+ self,
1090
+ train_kv=True,
1091
+ train_q_out=False,
1092
+ hidden_size=None,
1093
+ cross_attention_dim=None,
1094
+ out_bias=True,
1095
+ dropout=0.0,
1096
+ attention_op: Optional[Callable] = None,
1097
+ ):
1098
+ super().__init__()
1099
+ self.train_kv = train_kv
1100
+ self.train_q_out = train_q_out
1101
+
1102
+ self.hidden_size = hidden_size
1103
+ self.cross_attention_dim = cross_attention_dim
1104
+ self.attention_op = attention_op
1105
+
1106
+ # `_custom_diffusion` id for easy serialization and loading.
1107
+ if self.train_kv:
1108
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1109
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1110
+ if self.train_q_out:
1111
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
1112
+ self.to_out_custom_diffusion = nn.ModuleList([])
1113
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
1114
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
1115
+
1116
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
1117
+ batch_size, sequence_length, _ = (
1118
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1119
+ )
1120
+
1121
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1122
+
1123
+ if self.train_q_out:
1124
+ query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
1125
+ else:
1126
+ query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
1127
+
1128
+ if encoder_hidden_states is None:
1129
+ crossattn = False
1130
+ encoder_hidden_states = hidden_states
1131
+ else:
1132
+ crossattn = True
1133
+ if attn.norm_cross:
1134
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1135
+
1136
+ if self.train_kv:
1137
+ key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
1138
+ value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
1139
+ key = key.to(attn.to_q.weight.dtype)
1140
+ value = value.to(attn.to_q.weight.dtype)
1141
+ else:
1142
+ key = attn.to_k(encoder_hidden_states)
1143
+ value = attn.to_v(encoder_hidden_states)
1144
+
1145
+ if crossattn:
1146
+ detach = torch.ones_like(key)
1147
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
1148
+ key = detach * key + (1 - detach) * key.detach()
1149
+ value = detach * value + (1 - detach) * value.detach()
1150
+
1151
+ query = attn.head_to_batch_dim(query).contiguous()
1152
+ key = attn.head_to_batch_dim(key).contiguous()
1153
+ value = attn.head_to_batch_dim(value).contiguous()
1154
+
1155
+ hidden_states = xformers.ops.memory_efficient_attention(
1156
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1157
+ )
1158
+ hidden_states = hidden_states.to(query.dtype)
1159
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1160
+
1161
+ if self.train_q_out:
1162
+ # linear proj
1163
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
1164
+ # dropout
1165
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
1166
+ else:
1167
+ # linear proj
1168
+ hidden_states = attn.to_out[0](hidden_states)
1169
+ # dropout
1170
+ hidden_states = attn.to_out[1](hidden_states)
1171
+ return hidden_states
1172
+
1173
+
1174
+ class SlicedAttnProcessor:
1175
+ r"""
1176
+ Processor for implementing sliced attention.
1177
+
1178
+ Args:
1179
+ slice_size (`int`, *optional*):
1180
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1181
+ `attention_head_dim` must be a multiple of the `slice_size`.
1182
+ """
1183
+
1184
+ def __init__(self, slice_size):
1185
+ self.slice_size = slice_size
1186
+
1187
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
1188
+ residual = hidden_states
1189
+
1190
+ input_ndim = hidden_states.ndim
1191
+
1192
+ if input_ndim == 4:
1193
+ batch_size, channel, height, width = hidden_states.shape
1194
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1195
+
1196
+ batch_size, sequence_length, _ = (
1197
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1198
+ )
1199
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1200
+
1201
+ if attn.group_norm is not None:
1202
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1203
+
1204
+ query = attn.to_q(hidden_states)
1205
+ dim = query.shape[-1]
1206
+ query = attn.head_to_batch_dim(query)
1207
+
1208
+ if encoder_hidden_states is None:
1209
+ encoder_hidden_states = hidden_states
1210
+ elif attn.norm_cross:
1211
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1212
+
1213
+ key = attn.to_k(encoder_hidden_states)
1214
+ value = attn.to_v(encoder_hidden_states)
1215
+ key = attn.head_to_batch_dim(key)
1216
+ value = attn.head_to_batch_dim(value)
1217
+
1218
+ batch_size_attention, query_tokens, _ = query.shape
1219
+ hidden_states = torch.zeros(
1220
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
1221
+ )
1222
+
1223
+ for i in range(batch_size_attention // self.slice_size):
1224
+ start_idx = i * self.slice_size
1225
+ end_idx = (i + 1) * self.slice_size
1226
+
1227
+ query_slice = query[start_idx:end_idx]
1228
+ key_slice = key[start_idx:end_idx]
1229
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
1230
+
1231
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
1232
+
1233
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
1234
+
1235
+ hidden_states[start_idx:end_idx] = attn_slice
1236
+
1237
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1238
+
1239
+ # linear proj
1240
+ hidden_states = attn.to_out[0](hidden_states)
1241
+ # dropout
1242
+ hidden_states = attn.to_out[1](hidden_states)
1243
+
1244
+ if input_ndim == 4:
1245
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1246
+
1247
+ if attn.residual_connection:
1248
+ hidden_states = hidden_states + residual
1249
+
1250
+ hidden_states = hidden_states / attn.rescale_output_factor
1251
+
1252
+ return hidden_states
1253
+
1254
+
1255
+ class SlicedAttnAddedKVProcessor:
1256
+ r"""
1257
+ Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
1258
+
1259
+ Args:
1260
+ slice_size (`int`, *optional*):
1261
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1262
+ `attention_head_dim` must be a multiple of the `slice_size`.
1263
+ """
1264
+
1265
+ def __init__(self, slice_size):
1266
+ self.slice_size = slice_size
1267
+
1268
+ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
1269
+ residual = hidden_states
1270
+
1271
+ if attn.spatial_norm is not None:
1272
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1273
+
1274
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
1275
+
1276
+ batch_size, sequence_length, _ = hidden_states.shape
1277
+
1278
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1279
+
1280
+ if encoder_hidden_states is None:
1281
+ encoder_hidden_states = hidden_states
1282
+ elif attn.norm_cross:
1283
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1284
+
1285
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1286
+
1287
+ query = attn.to_q(hidden_states)
1288
+ dim = query.shape[-1]
1289
+ query = attn.head_to_batch_dim(query)
1290
+
1291
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1292
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1293
+
1294
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
1295
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
1296
+
1297
+ if not attn.only_cross_attention:
1298
+ key = attn.to_k(hidden_states)
1299
+ value = attn.to_v(hidden_states)
1300
+ key = attn.head_to_batch_dim(key)
1301
+ value = attn.head_to_batch_dim(value)
1302
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1303
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1304
+ else:
1305
+ key = encoder_hidden_states_key_proj
1306
+ value = encoder_hidden_states_value_proj
1307
+
1308
+ batch_size_attention, query_tokens, _ = query.shape
1309
+ hidden_states = torch.zeros(
1310
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
1311
+ )
1312
+
1313
+ for i in range(batch_size_attention // self.slice_size):
1314
+ start_idx = i * self.slice_size
1315
+ end_idx = (i + 1) * self.slice_size
1316
+
1317
+ query_slice = query[start_idx:end_idx]
1318
+ key_slice = key[start_idx:end_idx]
1319
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
1320
+
1321
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
1322
+
1323
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
1324
+
1325
+ hidden_states[start_idx:end_idx] = attn_slice
1326
+
1327
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1328
+
1329
+ # linear proj
1330
+ hidden_states = attn.to_out[0](hidden_states)
1331
+ # dropout
1332
+ hidden_states = attn.to_out[1](hidden_states)
1333
+
1334
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
1335
+ hidden_states = hidden_states + residual
1336
+
1337
+ return hidden_states
1338
+
1339
+
1340
+ class SpatialNorm(nn.Module):
1341
+ """
1342
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002
1343
+ """
1344
+
1345
+ def __init__(
1346
+ self,
1347
+ f_channels,
1348
+ zq_channels,
1349
+ ):
1350
+ super().__init__()
1351
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
1352
+ self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
1353
+ self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
1354
+
1355
+ def forward(self, f, zq):
1356
+ f_size = f.shape[-2:]
1357
+ zq = F.interpolate(zq, size=f_size, mode="nearest")
1358
+ norm_f = self.norm_layer(f)
1359
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
1360
+ return new_f
1361
+
1362
+
1363
+ ## Deprecated
1364
+ class LoRAAttnProcessor(nn.Module):
1365
+ r"""
1366
+ Processor for implementing the LoRA attention mechanism.
1367
+
1368
+ Args:
1369
+ hidden_size (`int`, *optional*):
1370
+ The hidden size of the attention layer.
1371
+ cross_attention_dim (`int`, *optional*):
1372
+ The number of channels in the `encoder_hidden_states`.
1373
+ rank (`int`, defaults to 4):
1374
+ The dimension of the LoRA update matrices.
1375
+ network_alpha (`int`, *optional*):
1376
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1377
+ """
1378
+
1379
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs):
1380
+ super().__init__()
1381
+
1382
+ self.hidden_size = hidden_size
1383
+ self.cross_attention_dim = cross_attention_dim
1384
+ self.rank = rank
1385
+
1386
+ q_rank = kwargs.pop("q_rank", None)
1387
+ q_hidden_size = kwargs.pop("q_hidden_size", None)
1388
+ q_rank = q_rank if q_rank is not None else rank
1389
+ q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
1390
+
1391
+ v_rank = kwargs.pop("v_rank", None)
1392
+ v_hidden_size = kwargs.pop("v_hidden_size", None)
1393
+ v_rank = v_rank if v_rank is not None else rank
1394
+ v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
1395
+
1396
+ out_rank = kwargs.pop("out_rank", None)
1397
+ out_hidden_size = kwargs.pop("out_hidden_size", None)
1398
+ out_rank = out_rank if out_rank is not None else rank
1399
+ out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
1400
+
1401
+ self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
1402
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1403
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
1404
+ self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
1405
+
1406
+ def __call__(self, attn: Attention, hidden_states, *args, **kwargs):
1407
+ self_cls_name = self.__class__.__name__
1408
+ deprecate(
1409
+ self_cls_name,
1410
+ "0.26.0",
1411
+ (
1412
+ f"Make sure use {self_cls_name[4:]} instead by setting"
1413
+ "LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
1414
+ " `LoraLoaderMixin.load_lora_weights`"
1415
+ ),
1416
+ )
1417
+ attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
1418
+ attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
1419
+ attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
1420
+ attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
1421
+
1422
+ attn._modules.pop("processor")
1423
+ attn.processor = AttnProcessor()
1424
+ return attn.processor(attn, hidden_states, *args, **kwargs)
1425
+
1426
+
1427
+ class LoRAAttnProcessor2_0(nn.Module):
1428
+ r"""
1429
+ Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
1430
+ attention.
1431
+
1432
+ Args:
1433
+ hidden_size (`int`):
1434
+ The hidden size of the attention layer.
1435
+ cross_attention_dim (`int`, *optional*):
1436
+ The number of channels in the `encoder_hidden_states`.
1437
+ rank (`int`, defaults to 4):
1438
+ The dimension of the LoRA update matrices.
1439
+ network_alpha (`int`, *optional*):
1440
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1441
+ """
1442
+
1443
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs):
1444
+ super().__init__()
1445
+ if not hasattr(F, "scaled_dot_product_attention"):
1446
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1447
+
1448
+ self.hidden_size = hidden_size
1449
+ self.cross_attention_dim = cross_attention_dim
1450
+ self.rank = rank
1451
+
1452
+ q_rank = kwargs.pop("q_rank", None)
1453
+ q_hidden_size = kwargs.pop("q_hidden_size", None)
1454
+ q_rank = q_rank if q_rank is not None else rank
1455
+ q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
1456
+
1457
+ v_rank = kwargs.pop("v_rank", None)
1458
+ v_hidden_size = kwargs.pop("v_hidden_size", None)
1459
+ v_rank = v_rank if v_rank is not None else rank
1460
+ v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
1461
+
1462
+ out_rank = kwargs.pop("out_rank", None)
1463
+ out_hidden_size = kwargs.pop("out_hidden_size", None)
1464
+ out_rank = out_rank if out_rank is not None else rank
1465
+ out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
1466
+
1467
+ self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
1468
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1469
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
1470
+ self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
1471
+
1472
+ def __call__(self, attn: Attention, hidden_states, *args, **kwargs):
1473
+ self_cls_name = self.__class__.__name__
1474
+ deprecate(
1475
+ self_cls_name,
1476
+ "0.26.0",
1477
+ (
1478
+ f"Make sure use {self_cls_name[4:]} instead by setting"
1479
+ "LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
1480
+ " `LoraLoaderMixin.load_lora_weights`"
1481
+ ),
1482
+ )
1483
+ attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
1484
+ attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
1485
+ attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
1486
+ attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
1487
+
1488
+ attn._modules.pop("processor")
1489
+ attn.processor = AttnProcessor2_0()
1490
+ return attn.processor(attn, hidden_states, *args, **kwargs)
1491
+
1492
+
1493
+ class LoRAXFormersAttnProcessor(nn.Module):
1494
+ r"""
1495
+ Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
1496
+
1497
+ Args:
1498
+ hidden_size (`int`, *optional*):
1499
+ The hidden size of the attention layer.
1500
+ cross_attention_dim (`int`, *optional*):
1501
+ The number of channels in the `encoder_hidden_states`.
1502
+ rank (`int`, defaults to 4):
1503
+ The dimension of the LoRA update matrices.
1504
+ attention_op (`Callable`, *optional*, defaults to `None`):
1505
+ The base
1506
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1507
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1508
+ operator.
1509
+ network_alpha (`int`, *optional*):
1510
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1511
+
1512
+ """
1513
+
1514
+ def __init__(
1515
+ self,
1516
+ hidden_size,
1517
+ cross_attention_dim,
1518
+ rank=4,
1519
+ attention_op: Optional[Callable] = None,
1520
+ network_alpha=None,
1521
+ **kwargs,
1522
+ ):
1523
+ super().__init__()
1524
+
1525
+ self.hidden_size = hidden_size
1526
+ self.cross_attention_dim = cross_attention_dim
1527
+ self.rank = rank
1528
+ self.attention_op = attention_op
1529
+
1530
+ q_rank = kwargs.pop("q_rank", None)
1531
+ q_hidden_size = kwargs.pop("q_hidden_size", None)
1532
+ q_rank = q_rank if q_rank is not None else rank
1533
+ q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
1534
+
1535
+ v_rank = kwargs.pop("v_rank", None)
1536
+ v_hidden_size = kwargs.pop("v_hidden_size", None)
1537
+ v_rank = v_rank if v_rank is not None else rank
1538
+ v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
1539
+
1540
+ out_rank = kwargs.pop("out_rank", None)
1541
+ out_hidden_size = kwargs.pop("out_hidden_size", None)
1542
+ out_rank = out_rank if out_rank is not None else rank
1543
+ out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
1544
+
1545
+ self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
1546
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1547
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
1548
+ self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
1549
+
1550
+ def __call__(self, attn: Attention, hidden_states, *args, **kwargs):
1551
+ self_cls_name = self.__class__.__name__
1552
+ deprecate(
1553
+ self_cls_name,
1554
+ "0.26.0",
1555
+ (
1556
+ f"Make sure use {self_cls_name[4:]} instead by setting"
1557
+ "LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
1558
+ " `LoraLoaderMixin.load_lora_weights`"
1559
+ ),
1560
+ )
1561
+ attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
1562
+ attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
1563
+ attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
1564
+ attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
1565
+
1566
+ attn._modules.pop("processor")
1567
+ attn.processor = XFormersAttnProcessor()
1568
+ return attn.processor(attn, hidden_states, *args, **kwargs)
1569
+
1570
+
1571
+ class LoRAAttnAddedKVProcessor(nn.Module):
1572
+ r"""
1573
+ Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
1574
+ encoder.
1575
+
1576
+ Args:
1577
+ hidden_size (`int`, *optional*):
1578
+ The hidden size of the attention layer.
1579
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
1580
+ The number of channels in the `encoder_hidden_states`.
1581
+ rank (`int`, defaults to 4):
1582
+ The dimension of the LoRA update matrices.
1583
+
1584
+ """
1585
+
1586
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
1587
+ super().__init__()
1588
+
1589
+ self.hidden_size = hidden_size
1590
+ self.cross_attention_dim = cross_attention_dim
1591
+ self.rank = rank
1592
+
1593
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1594
+ self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1595
+ self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1596
+ self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1597
+ self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1598
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1599
+
1600
+ def __call__(self, attn: Attention, hidden_states, *args, **kwargs):
1601
+ self_cls_name = self.__class__.__name__
1602
+ deprecate(
1603
+ self_cls_name,
1604
+ "0.26.0",
1605
+ (
1606
+ f"Make sure use {self_cls_name[4:]} instead by setting"
1607
+ "LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
1608
+ " `LoraLoaderMixin.load_lora_weights`"
1609
+ ),
1610
+ )
1611
+ attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
1612
+ attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
1613
+ attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
1614
+ attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
1615
+
1616
+ attn._modules.pop("processor")
1617
+ attn.processor = AttnAddedKVProcessor()
1618
+ return attn.processor(attn, hidden_states, *args, **kwargs)
1619
+
1620
+
1621
+ LORA_ATTENTION_PROCESSORS = (
1622
+ LoRAAttnProcessor,
1623
+ LoRAAttnProcessor2_0,
1624
+ LoRAXFormersAttnProcessor,
1625
+ LoRAAttnAddedKVProcessor,
1626
+ )
1627
+
1628
+ ADDED_KV_ATTENTION_PROCESSORS = (
1629
+ AttnAddedKVProcessor,
1630
+ SlicedAttnAddedKVProcessor,
1631
+ AttnAddedKVProcessor2_0,
1632
+ XFormersAttnAddedKVProcessor,
1633
+ LoRAAttnAddedKVProcessor,
1634
+ )
1635
+
1636
+ CROSS_ATTENTION_PROCESSORS = (
1637
+ AttnProcessor,
1638
+ AttnProcessor2_0,
1639
+ XFormersAttnProcessor,
1640
+ SlicedAttnProcessor,
1641
+ LoRAAttnProcessor,
1642
+ LoRAAttnProcessor2_0,
1643
+ LoRAXFormersAttnProcessor,
1644
+ )
1645
+
1646
+ AttentionProcessor = Union[
1647
+ AttnProcessor,
1648
+ AttnProcessor2_0,
1649
+ XFormersAttnProcessor,
1650
+ SlicedAttnProcessor,
1651
+ AttnAddedKVProcessor,
1652
+ SlicedAttnAddedKVProcessor,
1653
+ AttnAddedKVProcessor2_0,
1654
+ XFormersAttnAddedKVProcessor,
1655
+ CustomDiffusionAttnProcessor,
1656
+ CustomDiffusionXFormersAttnProcessor,
1657
+ # depraceted
1658
+ LoRAAttnProcessor,
1659
+ LoRAAttnProcessor2_0,
1660
+ LoRAXFormersAttnProcessor,
1661
+ LoRAAttnAddedKVProcessor,
1662
+ ]
src/models/pipelines.py ADDED
@@ -0,0 +1,1414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import CLIPTextModel, CLIPTokenizer
21
+
22
+ from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
23
+ from diffusers.models import AutoencoderKL
24
+ from .unet_3d_condition import UNet3DConditionModel
25
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
26
+ from diffusers.schedulers import KarrasDiffusionSchedulers
27
+ from diffusers.utils import (
28
+ deprecate,
29
+ logging,
30
+ replace_example_docstring,
31
+ )
32
+ from diffusers.utils.torch_utils import randn_tensor
33
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
34
+ # from . import TextToVideoSDPipelineOutput
35
+
36
+
37
+ from dataclasses import dataclass
38
+ from typing import List, Union
39
+ from typing import Optional, Callable, Dict, Any
40
+
41
+ import numpy as np
42
+ import torch
43
+ from diffusers.utils import (
44
+ BaseOutput,
45
+ )
46
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
47
+
48
+
49
+ @dataclass
50
+ class TextToVideoSDPipelineOutput(BaseOutput):
51
+ """
52
+ Output class for text-to-video pipelines.
53
+
54
+ Args:
55
+ frames (`List[np.ndarray]` or `torch.FloatTensor`)
56
+ List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as
57
+ a `torch` tensor. The length of the list denotes the video length (the number of frames).
58
+ """
59
+
60
+ frames: Union[List[np.ndarray], torch.FloatTensor]
61
+
62
+
63
+ def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
64
+ # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
65
+ # reshape to ncfhw
66
+ mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
67
+ std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
68
+ # unnormalize back to [0,1]
69
+ video = video.mul_(std).add_(mean)
70
+ video.clamp_(0, 1)
71
+ # prepare the final outputs
72
+ i, c, f, h, w = video.shape
73
+ images = video.permute(2, 3, 0, 4, 1).reshape(
74
+ f, h, i * w, c
75
+ ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
76
+ images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
77
+ images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
78
+ return images
79
+
80
+ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
81
+ r"""
82
+ Pipeline for text-to-video generation.
83
+
84
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
85
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
86
+
87
+ Args:
88
+ vae ([`AutoencoderKL`]):
89
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
90
+ text_encoder ([`CLIPTextModel`]):
91
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
92
+ tokenizer (`CLIPTokenizer`):
93
+ A [`~transformers.CLIPTokenizer`] to tokenize text.
94
+ unet ([`UNet3DConditionModel`]):
95
+ A [`UNet3DConditionModel`] to denoise the encoded video latents.
96
+ scheduler ([`SchedulerMixin`]):
97
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
98
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
99
+ """
100
+ model_cpu_offload_seq = "text_encoder->unet->vae"
101
+
102
+ def __init__(
103
+ self,
104
+ vae: AutoencoderKL,
105
+ text_encoder: CLIPTextModel,
106
+ tokenizer: CLIPTokenizer,
107
+ unet: UNet3DConditionModel,
108
+ scheduler: KarrasDiffusionSchedulers,
109
+ ):
110
+ super().__init__()
111
+
112
+ self.register_modules(
113
+ vae=vae,
114
+ text_encoder=text_encoder,
115
+ tokenizer=tokenizer,
116
+ unet=unet,
117
+ scheduler=scheduler,
118
+ )
119
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
120
+
121
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
122
+ def enable_vae_slicing(self):
123
+ r"""
124
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
125
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
126
+ """
127
+ self.vae.enable_slicing()
128
+
129
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
130
+ def disable_vae_slicing(self):
131
+ r"""
132
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
133
+ computing decoding in one step.
134
+ """
135
+ self.vae.disable_slicing()
136
+
137
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
138
+ def enable_vae_tiling(self):
139
+ r"""
140
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
141
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
142
+ processing larger images.
143
+ """
144
+ self.vae.enable_tiling()
145
+
146
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
147
+ def disable_vae_tiling(self):
148
+ r"""
149
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
150
+ computing decoding in one step.
151
+ """
152
+ self.vae.disable_tiling()
153
+
154
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
155
+ def _encode_prompt(
156
+ self,
157
+ prompt,
158
+ device,
159
+ num_images_per_prompt,
160
+ do_classifier_free_guidance,
161
+ negative_prompt=None,
162
+ prompt_embeds: Optional[torch.FloatTensor] = None,
163
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
164
+ lora_scale: Optional[float] = None,
165
+ ):
166
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
167
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
168
+
169
+ prompt_embeds_tuple = self.encode_prompt(
170
+ prompt=prompt,
171
+ device=device,
172
+ num_images_per_prompt=num_images_per_prompt,
173
+ do_classifier_free_guidance=do_classifier_free_guidance,
174
+ negative_prompt=negative_prompt,
175
+ prompt_embeds=prompt_embeds,
176
+ negative_prompt_embeds=negative_prompt_embeds,
177
+ lora_scale=lora_scale,
178
+ )
179
+
180
+ # concatenate for backwards comp
181
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
182
+
183
+ return prompt_embeds
184
+
185
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
186
+ def encode_prompt(
187
+ self,
188
+ prompt,
189
+ device,
190
+ num_images_per_prompt,
191
+ do_classifier_free_guidance,
192
+ negative_prompt=None,
193
+ prompt_embeds: Optional[torch.FloatTensor] = None,
194
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
195
+ lora_scale: Optional[float] = None,
196
+ ):
197
+ r"""
198
+ Encodes the prompt into text encoder hidden states.
199
+
200
+ Args:
201
+ prompt (`str` or `List[str]`, *optional*):
202
+ prompt to be encoded
203
+ device: (`torch.device`):
204
+ torch device
205
+ num_images_per_prompt (`int`):
206
+ number of images that should be generated per prompt
207
+ do_classifier_free_guidance (`bool`):
208
+ whether to use classifier free guidance or not
209
+ negative_prompt (`str` or `List[str]`, *optional*):
210
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
211
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
212
+ less than `1`).
213
+ prompt_embeds (`torch.FloatTensor`, *optional*):
214
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
215
+ provided, text embeddings will be generated from `prompt` input argument.
216
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
217
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
218
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
219
+ argument.
220
+ lora_scale (`float`, *optional*):
221
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
222
+ """
223
+ # set lora scale so that monkey patched LoRA
224
+ # function of text encoder can correctly access it
225
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
226
+ self._lora_scale = lora_scale
227
+
228
+ # dynamically adjust the LoRA scale
229
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
230
+
231
+ if prompt is not None and isinstance(prompt, str):
232
+ batch_size = 1
233
+ elif prompt is not None and isinstance(prompt, list):
234
+ batch_size = len(prompt)
235
+ else:
236
+ batch_size = prompt_embeds.shape[0]
237
+
238
+ if prompt_embeds is None:
239
+ # textual inversion: procecss multi-vector tokens if necessary
240
+ if isinstance(self, TextualInversionLoaderMixin):
241
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
242
+
243
+ text_inputs = self.tokenizer(
244
+ prompt,
245
+ padding="max_length",
246
+ max_length=self.tokenizer.model_max_length,
247
+ truncation=True,
248
+ return_tensors="pt",
249
+ )
250
+ text_input_ids = text_inputs.input_ids
251
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
252
+
253
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
254
+ text_input_ids, untruncated_ids
255
+ ):
256
+ removed_text = self.tokenizer.batch_decode(
257
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
258
+ )
259
+ # logger.warning(
260
+ # "The following part of your input was truncated because CLIP can only handle sequences up to"
261
+ # f" {self.tokenizer.model_max_length} tokens: {removed_text}"
262
+ # )
263
+
264
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
265
+ attention_mask = text_inputs.attention_mask.to(device)
266
+ else:
267
+ attention_mask = None
268
+
269
+ prompt_embeds = self.text_encoder(
270
+ text_input_ids.to(device),
271
+ attention_mask=attention_mask,
272
+ )
273
+ prompt_embeds = prompt_embeds[0]
274
+
275
+ if self.text_encoder is not None:
276
+ prompt_embeds_dtype = self.text_encoder.dtype
277
+ elif self.unet is not None:
278
+ prompt_embeds_dtype = self.unet.dtype
279
+ else:
280
+ prompt_embeds_dtype = prompt_embeds.dtype
281
+
282
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
283
+
284
+ bs_embed, seq_len, _ = prompt_embeds.shape
285
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
286
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
287
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
288
+
289
+ # get unconditional embeddings for classifier free guidance
290
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
291
+ uncond_tokens: List[str]
292
+ if negative_prompt is None:
293
+ uncond_tokens = [""] * batch_size
294
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
295
+ raise TypeError(
296
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
297
+ f" {type(prompt)}."
298
+ )
299
+ elif isinstance(negative_prompt, str):
300
+ uncond_tokens = [negative_prompt]
301
+ elif batch_size != len(negative_prompt):
302
+ raise ValueError(
303
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
304
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
305
+ " the batch size of `prompt`."
306
+ )
307
+ else:
308
+ uncond_tokens = negative_prompt
309
+
310
+ # textual inversion: procecss multi-vector tokens if necessary
311
+ if isinstance(self, TextualInversionLoaderMixin):
312
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
313
+
314
+ max_length = prompt_embeds.shape[1]
315
+ uncond_input = self.tokenizer(
316
+ uncond_tokens,
317
+ padding="max_length",
318
+ max_length=max_length,
319
+ truncation=True,
320
+ return_tensors="pt",
321
+ )
322
+
323
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
324
+ attention_mask = uncond_input.attention_mask.to(device)
325
+ else:
326
+ attention_mask = None
327
+
328
+ negative_prompt_embeds = self.text_encoder(
329
+ uncond_input.input_ids.to(device),
330
+ attention_mask=attention_mask,
331
+ )
332
+ negative_prompt_embeds = negative_prompt_embeds[0]
333
+
334
+ if do_classifier_free_guidance:
335
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
336
+ seq_len = negative_prompt_embeds.shape[1]
337
+
338
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
339
+
340
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
341
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
342
+
343
+ return prompt_embeds, negative_prompt_embeds
344
+
345
+ def decode_latents(self, latents):
346
+ latents = 1 / self.vae.config.scaling_factor * latents
347
+
348
+ batch_size, channels, num_frames, height, width = latents.shape
349
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
350
+
351
+ image = self.vae.decode(latents).sample
352
+ video = (
353
+ image[None, :]
354
+ .reshape(
355
+ (
356
+ batch_size,
357
+ num_frames,
358
+ -1,
359
+ )
360
+ + image.shape[2:]
361
+ )
362
+ .permute(0, 2, 1, 3, 4)
363
+ )
364
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
365
+ video = video.float()
366
+ return video
367
+
368
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
369
+ def prepare_extra_step_kwargs(self, generator, eta):
370
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
371
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
372
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
373
+ # and should be between [0, 1]
374
+
375
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
376
+ extra_step_kwargs = {}
377
+ if accepts_eta:
378
+ extra_step_kwargs["eta"] = eta
379
+
380
+ # check if the scheduler accepts generator
381
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
382
+ if accepts_generator:
383
+ extra_step_kwargs["generator"] = generator
384
+ return extra_step_kwargs
385
+
386
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
387
+ def check_inputs(
388
+ self,
389
+ prompt,
390
+ height,
391
+ width,
392
+ callback_steps,
393
+ negative_prompt=None,
394
+ prompt_embeds=None,
395
+ negative_prompt_embeds=None,
396
+ ):
397
+ if height % 8 != 0 or width % 8 != 0:
398
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
399
+
400
+ if (callback_steps is None) or (
401
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
402
+ ):
403
+ raise ValueError(
404
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
405
+ f" {type(callback_steps)}."
406
+ )
407
+
408
+ if prompt is not None and prompt_embeds is not None:
409
+ raise ValueError(
410
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
411
+ " only forward one of the two."
412
+ )
413
+ elif prompt is None and prompt_embeds is None:
414
+ raise ValueError(
415
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
416
+ )
417
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
418
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
419
+
420
+ if negative_prompt is not None and negative_prompt_embeds is not None:
421
+ raise ValueError(
422
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
423
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
424
+ )
425
+
426
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
427
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
428
+ raise ValueError(
429
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
430
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
431
+ f" {negative_prompt_embeds.shape}."
432
+ )
433
+
434
+ def prepare_latents(
435
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
436
+ ):
437
+ shape = (
438
+ batch_size,
439
+ num_channels_latents,
440
+ num_frames,
441
+ height // self.vae_scale_factor,
442
+ width // self.vae_scale_factor,
443
+ )
444
+ if isinstance(generator, list) and len(generator) != batch_size:
445
+ raise ValueError(
446
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
447
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
448
+ )
449
+
450
+ if latents is None:
451
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
452
+ else:
453
+ latents = latents.to(device)
454
+
455
+ # scale the initial noise by the standard deviation required by the scheduler
456
+ latents = latents * self.scheduler.init_noise_sigma
457
+ return latents
458
+
459
+ @torch.no_grad()
460
+ # @replace_example_docstring(EXAMPLE_DOC_STRING)
461
+ def __call__(
462
+ self,
463
+ prompt: Union[str, List[str]] = None,
464
+ height: Optional[int] = None,
465
+ width: Optional[int] = None,
466
+ num_frames: int = 16,
467
+ num_inference_steps: int = 50,
468
+ guidance_scale: float = 9.0,
469
+ negative_prompt: Optional[Union[str, List[str]]] = None,
470
+ eta: float = 0.0,
471
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
472
+ latents: Optional[torch.FloatTensor] = None,
473
+ prompt_embeds: Optional[torch.FloatTensor] = None,
474
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
475
+ output_type: Optional[str] = "np",
476
+ return_dict: bool = True,
477
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
478
+ callback_steps: int = 1,
479
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
480
+ ):
481
+ r"""
482
+ The call function to the pipeline for generation.
483
+
484
+ Args:
485
+ prompt (`str` or `List[str]`, *optional*):
486
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
487
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
488
+ The height in pixels of the generated video.
489
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
490
+ The width in pixels of the generated video.
491
+ num_frames (`int`, *optional*, defaults to 16):
492
+ The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
493
+ amounts to 2 seconds of video.
494
+ num_inference_steps (`int`, *optional*, defaults to 50):
495
+ The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
496
+ expense of slower inference.
497
+ guidance_scale (`float`, *optional*, defaults to 7.5):
498
+ A higher guidance scale value encourages the model to generate images closely linked to the text
499
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
500
+ negative_prompt (`str` or `List[str]`, *optional*):
501
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
502
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
503
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
504
+ The number of images to generate per prompt.
505
+ eta (`float`, *optional*, defaults to 0.0):
506
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
507
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
508
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
509
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
510
+ generation deterministic.
511
+ latents (`torch.FloatTensor`, *optional*):
512
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
513
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
514
+ tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
515
+ `(batch_size, num_channel, num_frames, height, width)`.
516
+ prompt_embeds (`torch.FloatTensor`, *optional*):
517
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
518
+ provided, text embeddings are generated from the `prompt` input argument.
519
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
520
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
521
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
522
+ output_type (`str`, *optional*, defaults to `"np"`):
523
+ The output format of the generated video. Choose between `torch.FloatTensor` or `np.array`.
524
+ return_dict (`bool`, *optional*, defaults to `True`):
525
+ Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
526
+ of a plain tuple.
527
+ callback (`Callable`, *optional*):
528
+ A function that calls every `callback_steps` steps during inference. The function is called with the
529
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
530
+ callback_steps (`int`, *optional*, defaults to 1):
531
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
532
+ every step.
533
+ cross_attention_kwargs (`dict`, *optional*):
534
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
535
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
536
+
537
+ Examples:
538
+
539
+ Returns:
540
+ [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`:
541
+ If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
542
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
543
+ """
544
+ # 0. Default height and width to unet
545
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
546
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
547
+
548
+ num_images_per_prompt = 1
549
+
550
+ # 1. Check inputs. Raise error if not correct
551
+ self.check_inputs(
552
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
553
+ )
554
+
555
+ # 2. Define call parameters
556
+ if prompt is not None and isinstance(prompt, str):
557
+ batch_size = 1
558
+ elif prompt is not None and isinstance(prompt, list):
559
+ batch_size = len(prompt)
560
+ else:
561
+ batch_size = prompt_embeds.shape[0]
562
+
563
+ device = self._execution_device
564
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
565
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
566
+ # corresponds to doing no classifier free guidance.
567
+ do_classifier_free_guidance = guidance_scale > 1.0
568
+
569
+ # 3. Encode input prompt
570
+ text_encoder_lora_scale = (
571
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
572
+ )
573
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
574
+ prompt,
575
+ device,
576
+ num_images_per_prompt,
577
+ do_classifier_free_guidance,
578
+ negative_prompt,
579
+ prompt_embeds=prompt_embeds,
580
+ negative_prompt_embeds=negative_prompt_embeds,
581
+ lora_scale=text_encoder_lora_scale,
582
+ )
583
+ # For classifier free guidance, we need to do two forward passes.
584
+ # Here we concatenate the unconditional and text embeddings into a single batch
585
+ # to avoid doing two forward passes
586
+ if do_classifier_free_guidance:
587
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
588
+
589
+ # 4. Prepare timesteps
590
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
591
+ timesteps = self.scheduler.timesteps
592
+
593
+ # 5. Prepare latent variables
594
+ num_channels_latents = self.unet.config.in_channels
595
+ latents = self.prepare_latents(
596
+ batch_size * num_images_per_prompt,
597
+ num_channels_latents,
598
+ num_frames,
599
+ height,
600
+ width,
601
+ prompt_embeds.dtype,
602
+ device,
603
+ generator,
604
+ latents,
605
+ )
606
+
607
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
608
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
609
+
610
+ # 7. Denoising loop
611
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
612
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
613
+ for i, t in enumerate(timesteps):
614
+ # expand the latents if we are doing classifier free guidance
615
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
616
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
617
+
618
+ # predict the noise residual
619
+ noise_pred = self.unet(
620
+ latent_model_input,
621
+ t,
622
+ encoder_hidden_states=prompt_embeds,
623
+ cross_attention_kwargs=cross_attention_kwargs,
624
+ return_dict=False,
625
+ )[0]
626
+
627
+ # perform guidance
628
+ if do_classifier_free_guidance:
629
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
630
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
631
+
632
+ # reshape latents
633
+ bsz, channel, frames, width, height = latents.shape
634
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
635
+ noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
636
+
637
+ # compute the previous noisy sample x_t -> x_t-1
638
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
639
+
640
+ # reshape latents back
641
+ latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4)
642
+
643
+ # call the callback, if provided
644
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
645
+ progress_bar.update()
646
+ if callback is not None and i % callback_steps == 0:
647
+ callback(i, t, latents)
648
+
649
+ if output_type == "latent":
650
+ return TextToVideoSDPipelineOutput(frames=latents)
651
+
652
+ video_tensor = self.decode_latents(latents)
653
+
654
+ if output_type == "pt":
655
+ video = video_tensor
656
+ else:
657
+ video = tensor2vid(video_tensor)
658
+
659
+ # Offload all models
660
+ self.maybe_free_model_hooks()
661
+
662
+ if not return_dict:
663
+ return (video,)
664
+
665
+ return TextToVideoSDPipelineOutput(frames=video)
666
+
667
+
668
+
669
+ class TextToVideoSDPipelineSpatialAware(TextToVideoSDPipeline):
670
+ def __init__(
671
+ self,
672
+ vae,
673
+ text_encoder,
674
+ tokenizer,
675
+ unet,
676
+ scheduler,
677
+ ):
678
+ # print(f"Initializing this pipeline with {type(vae)}, {type(unet)}")
679
+ unet_new = UNet3DConditionModel()
680
+ unet_new.load_state_dict(unet.state_dict())
681
+ super().__init__(vae, text_encoder, tokenizer, unet_new, scheduler)
682
+
683
+ def _encode_prompt(
684
+ self,
685
+ prompt,
686
+ device,
687
+ num_images_per_prompt,
688
+ do_classifier_free_guidance,
689
+ negative_prompt=None,
690
+ prompt_embeds: Optional[torch.FloatTensor] = None,
691
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
692
+ lora_scale: Optional[float] = None,
693
+ fg_prompt: Optional[str] = None,
694
+ num_frames: int = 16,
695
+ ):
696
+ r"""
697
+ Encodes the prompt into text encoder hidden states.
698
+
699
+ Args:
700
+ prompt (`str` or `List[str]`, *optional*):
701
+ prompt to be encoded
702
+ device: (`torch.device`):
703
+ torch device
704
+ num_images_per_prompt (`int`):
705
+ number of images that should be generated per prompt
706
+ do_classifier_free_guidance (`bool`):
707
+ whether to use classifier free guidance or not
708
+ negative_prompt (`str` or `List[str]`, *optional*):
709
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
710
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
711
+ less than `1`).
712
+ prompt_embeds (`torch.FloatTensor`, *optional*):
713
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
714
+ provided, text embeddings will be generated from `prompt` input argument.
715
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
716
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
717
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
718
+ argument.
719
+ lora_scale (`float`, *optional*):
720
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
721
+ """
722
+ # set lora scale so that monkey patched LoRA
723
+ # function of text encoder can correctly access it
724
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
725
+ self._lora_scale = lora_scale
726
+
727
+ if prompt is not None and isinstance(prompt, str):
728
+ batch_size = 1
729
+ elif prompt is not None and isinstance(prompt, list):
730
+ batch_size = len(prompt)
731
+ else:
732
+ batch_size = prompt_embeds.shape[0]
733
+
734
+ if prompt_embeds is None:
735
+ # textual inversion: procecss multi-vector tokens if necessary
736
+ if isinstance(self, TextualInversionLoaderMixin):
737
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
738
+
739
+ text_inputs = self.tokenizer(
740
+ prompt,
741
+ padding="max_length",
742
+ max_length=self.tokenizer.model_max_length,
743
+ truncation=True,
744
+ return_tensors="pt",
745
+ )
746
+ text_input_ids = text_inputs.input_ids
747
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
748
+
749
+
750
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
751
+ text_input_ids, untruncated_ids
752
+ ):
753
+ removed_text = self.tokenizer.batch_decode(
754
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
755
+ )
756
+ # logger.warning(
757
+ # "The following part of your input was truncated because CLIP can only handle sequences up to"
758
+ # f" {self.tokenizer.model_max_length} tokens: {removed_text}"
759
+ # )
760
+
761
+
762
+
763
+ if fg_prompt is not None:
764
+ if not isinstance(fg_prompt, list):
765
+ fg_text_inputs = self.tokenizer(
766
+ fg_prompt,
767
+ # padding="max_length",
768
+ # max_length=self.tokenizer.model_max_length,
769
+ # truncation=True,
770
+ return_tensors="pt",
771
+ )
772
+ # breakpoint()
773
+ fg_text_input_ids = fg_text_inputs.input_ids
774
+
775
+ # remove first and last token
776
+ fg_text_input_ids = fg_text_input_ids[:,:-1]
777
+
778
+ # remove common tokens in fg_text_input_ids from text_input_ids
779
+ batch_size = text_input_ids.shape[0]
780
+ # Create a mask that is True wherever a token in text_input_ids matches a token in fg_text_input_ids
781
+ mask = (text_input_ids.unsqueeze(-1) == fg_text_input_ids.unsqueeze(1)).any(dim=-1)
782
+ # print(mask)
783
+ # breakpoint()
784
+ # Get the values from text_input_ids that are not in fg_text_input_ids
785
+ encoder_attention_mask = ~mask
786
+ encoder_attention_mask = encoder_attention_mask.repeat((2,1))
787
+ encoder_attention_mask = encoder_attention_mask.repeat((num_frames,1)) # To account for videos
788
+ encoder_attention_mask = encoder_attention_mask.to(device)
789
+
790
+ # text_input_ids_filtered = text_input_ids[~mask].view(1, -1)
791
+ # text_input_ids_filtered will now contain the values from text_input_ids that aren't in fg_text_input_ids
792
+ else:
793
+ encoder_attention_mask = []
794
+ for fg_prompt_i in fg_prompt:
795
+ fg_text_inputs = self.tokenizer(
796
+ fg_prompt_i,
797
+ return_tensors="pt",)
798
+ fg_text_input_ids = fg_text_inputs.input_ids
799
+
800
+ # remove first and last token
801
+ fg_text_input_ids = fg_text_input_ids[:,:-1]
802
+
803
+ # remove common tokens in fg_text_input_ids from text_input_ids
804
+ batch_size = text_input_ids.shape[0]
805
+ # Create a mask that is True wherever a token in text_input_ids matches a token in fg_text_input_ids
806
+ mask = (text_input_ids.unsqueeze(-1) == fg_text_input_ids.unsqueeze(1)).any(dim=-1)
807
+ encoder_attention_mask_i = ~mask
808
+ encoder_attention_mask_i = encoder_attention_mask_i.repeat((2,1))
809
+ encoder_attention_mask_i = encoder_attention_mask_i.repeat((num_frames,1)) # To account for videos
810
+ encoder_attention_mask_i = encoder_attention_mask_i.to(device)
811
+ encoder_attention_mask.append(encoder_attention_mask_i)
812
+
813
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
814
+ attention_mask = text_inputs.attention_mask.to(device)
815
+ else:
816
+ attention_mask = None
817
+
818
+ prompt_embeds = self.text_encoder(
819
+ text_input_ids.to(device),
820
+ attention_mask=attention_mask,
821
+ # attention_mask=encoder_attention_mask[:batch_size] if fg_prompt is not None else None,
822
+ )
823
+ prompt_embeds = prompt_embeds[0]
824
+
825
+ if self.text_encoder is not None:
826
+ prompt_embeds_dtype = self.text_encoder.dtype
827
+ elif self.unet is not None:
828
+ prompt_embeds_dtype = self.unet.dtype
829
+ else:
830
+ prompt_embeds_dtype = prompt_embeds.dtype
831
+
832
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
833
+
834
+ bs_embed, seq_len, _ = prompt_embeds.shape
835
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
836
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
837
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
838
+
839
+ # get unconditional embeddings for classifier free guidance
840
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
841
+ uncond_tokens: List[str]
842
+ if negative_prompt is None:
843
+ uncond_tokens = [""] * batch_size
844
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
845
+ raise TypeError(
846
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
847
+ f" {type(prompt)}."
848
+ )
849
+ elif isinstance(negative_prompt, str):
850
+ uncond_tokens = [negative_prompt]
851
+ elif batch_size != len(negative_prompt):
852
+ raise ValueError(
853
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
854
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
855
+ " the batch size of `prompt`."
856
+ )
857
+ else:
858
+ uncond_tokens = negative_prompt
859
+
860
+ # textual inversion: procecss multi-vector tokens if necessary
861
+ if isinstance(self, TextualInversionLoaderMixin):
862
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
863
+
864
+ max_length = prompt_embeds.shape[1]
865
+ uncond_input = self.tokenizer(
866
+ uncond_tokens,
867
+ padding="max_length",
868
+ max_length=max_length,
869
+ truncation=True,
870
+ return_tensors="pt",
871
+ )
872
+
873
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
874
+ attention_mask = uncond_input.attention_mask.to(device)
875
+ else:
876
+ attention_mask = None
877
+
878
+ negative_prompt_embeds = self.text_encoder(
879
+ uncond_input.input_ids.to(device),
880
+ attention_mask=attention_mask,
881
+ )
882
+ negative_prompt_embeds = negative_prompt_embeds[0]
883
+
884
+ if do_classifier_free_guidance:
885
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
886
+ seq_len = negative_prompt_embeds.shape[1]
887
+
888
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
889
+
890
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
891
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
892
+
893
+ # For classifier free guidance, we need to do two forward passes.
894
+ # Here we concatenate the unconditional and text embeddings into a single batch
895
+ # to avoid doing two forward passes
896
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
897
+
898
+ if fg_prompt is not None:
899
+ return prompt_embeds, encoder_attention_mask
900
+ return prompt_embeds, None
901
+
902
+
903
+ @torch.no_grad()
904
+ def __call__(
905
+ self,
906
+ prompt: Union[str, List[str]] = None,
907
+ height: Optional[int] = None,
908
+ width: Optional[int] = None,
909
+ num_frames: int = 16,
910
+ num_inference_steps: int = 50,
911
+ guidance_scale: float = 9.0,
912
+ negative_prompt: Optional[Union[str, List[str]]] = None,
913
+ eta: float = 0.0,
914
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
915
+ latents: Optional[torch.FloatTensor] = None,
916
+ prompt_embeds: Optional[torch.FloatTensor] = None,
917
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
918
+ output_type: Optional[str] = "np",
919
+ return_dict: bool = True,
920
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
921
+ callback_steps: int = 1,
922
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
923
+ frozen_mask: Optional[torch.FloatTensor] = None,
924
+ frozen_steps: Optional[int] = None,
925
+ frozen_text_mask: Optional[torch.FloatTensor] = None,
926
+ frozen_prompt: Optional[Union[str, List[str]]] = None,
927
+ custom_attention_mask: Optional[torch.FloatTensor] = None,
928
+ latents_all_input: Optional[torch.FloatTensor] = None,
929
+ fg_prompt: Optional[torch.FloatTensor]=None,
930
+ make_attention_mask_2d=False,
931
+ attention_mask_block_diagonal=False,):
932
+ r"""
933
+ The call function to the pipeline for generation.
934
+
935
+ Args:
936
+ prompt (`str` or `List[str]`, *optional*):
937
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
938
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
939
+ The height in pixels of the generated video.
940
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
941
+ The width in pixels of the generated video.
942
+ num_frames (`int`, *optional*, defaults to 16):
943
+ The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
944
+ amounts to 2 seconds of video.
945
+ num_inference_steps (`int`, *optional*, defaults to 50):
946
+ The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
947
+ expense of slower inference.
948
+ guidance_scale (`float`, *optional*, defaults to 7.5):
949
+ A higher guidance scale value encourages the model to generate images closely linked to the text
950
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
951
+ negative_prompt (`str` or `List[str]`, *optional*):
952
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
953
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
954
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
955
+ The number of images to generate per prompt.
956
+ eta (`float`, *optional*, defaults to 0.0):
957
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
958
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
959
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
960
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
961
+ generation deterministic.
962
+ latents (`torch.FloatTensor`, *optional*):
963
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
964
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
965
+ tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
966
+ `(batch_size, num_channel, num_frames, height, width)`.
967
+ prompt_embeds (`torch.FloatTensor`, *optional*):
968
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
969
+ provided, text embeddings are generated from the `prompt` input argument.
970
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
971
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
972
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
973
+ output_type (`str`, *optional*, defaults to `"np"`):
974
+ The output format of the generated video. Choose between `torch.FloatTensor` or `np.array`.
975
+ return_dict (`bool`, *optional*, defaults to `True`):
976
+ Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
977
+ of a plain tuple.
978
+ callback (`Callable`, *optional*):
979
+ A function that calls every `callback_steps` steps during inference. The function is called with the
980
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
981
+ callback_steps (`int`, *optional*, defaults to 1):
982
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
983
+ every step.
984
+ cross_attention_kwargs (`dict`, *optional*):
985
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
986
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
987
+
988
+ Examples:
989
+
990
+ Returns:
991
+ [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`:
992
+ If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
993
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
994
+ """
995
+ # 0. Default height and width to unet
996
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
997
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
998
+
999
+ num_images_per_prompt = 1
1000
+
1001
+ # 1. Check inputs. Raise error if not correct
1002
+ self.check_inputs(
1003
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
1004
+ )
1005
+
1006
+ # 2. Define call parameters
1007
+ if prompt is not None and isinstance(prompt, str):
1008
+ batch_size = 1
1009
+ elif prompt is not None and isinstance(prompt, list):
1010
+ batch_size = len(prompt)
1011
+ else:
1012
+ batch_size = prompt_embeds.shape[0]
1013
+
1014
+ device = self._execution_device
1015
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1016
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1017
+ # corresponds to doing no classifier free guidance.
1018
+ do_classifier_free_guidance = guidance_scale > 1.0
1019
+
1020
+ # 3. Encode input prompt
1021
+ text_encoder_lora_scale = (
1022
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
1023
+ )
1024
+ prompt_embeds, custom_attention_mask = self._encode_prompt(
1025
+ prompt,
1026
+ device,
1027
+ num_images_per_prompt,
1028
+ do_classifier_free_guidance,
1029
+ negative_prompt,
1030
+ prompt_embeds=prompt_embeds,
1031
+ negative_prompt_embeds=negative_prompt_embeds,
1032
+ lora_scale=text_encoder_lora_scale,
1033
+ fg_prompt=fg_prompt,
1034
+ num_frames=num_frames,
1035
+ )
1036
+ if frozen_prompt is not None: # freeze the prompt
1037
+ prompt_embeds, _ = self._encode_prompt(
1038
+ frozen_prompt,
1039
+ device,
1040
+ num_images_per_prompt,
1041
+ do_classifier_free_guidance,
1042
+ negative_prompt,
1043
+ prompt_embeds=None,
1044
+ negative_prompt_embeds=None,
1045
+ lora_scale=text_encoder_lora_scale,
1046
+ )
1047
+ # if frozen_prompt is not None: # TODO see why different length of prompt and frozen_prompt causes error
1048
+ # frozen_prompt_embeds = self._encode_prompt(
1049
+ # frozen_prompt,
1050
+ # device,
1051
+ # num_images_per_prompt,
1052
+ # do_classifier_free_guidance,
1053
+ # negative_prompt,
1054
+ # prompt_embeds=None,
1055
+ # negative_prompt_embeds=None,
1056
+ # lora_scale=text_encoder_lora_scale,)
1057
+ # # breakpoint()
1058
+ # else:
1059
+ # frozen_prompt_embeds = None
1060
+
1061
+ # 4. Prepare timesteps
1062
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1063
+ timesteps = self.scheduler.timesteps
1064
+
1065
+ # 5. Prepare latent variables
1066
+ num_channels_latents = self.unet.config.in_channels
1067
+ latents = self.prepare_latents(
1068
+ batch_size * num_images_per_prompt,
1069
+ num_channels_latents,
1070
+ num_frames,
1071
+ height,
1072
+ width,
1073
+ prompt_embeds.dtype,
1074
+ device,
1075
+ generator,
1076
+ latents,
1077
+ )
1078
+
1079
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1080
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1081
+ if frozen_mask is not None:
1082
+ if not isinstance(frozen_mask, list):
1083
+ attention_mask = frozen_mask.clone()
1084
+ attention_mask = attention_mask.view(attention_mask.shape[0], -1).repeat((2,1)).to(frozen_mask.device) # For video
1085
+
1086
+ attention_mask = attention_mask.bool()
1087
+
1088
+ else:
1089
+ attention_mask = []
1090
+ for frozen_mask_i in frozen_mask:
1091
+ attention_mask_i = frozen_mask_i.clone()
1092
+ attention_mask_i = attention_mask_i.view(attention_mask_i.shape[0], -1).repeat((2,1)).to(frozen_mask_i.device)
1093
+ attention_mask_i = attention_mask_i.bool()
1094
+ attention_mask.append(attention_mask_i)
1095
+
1096
+ # if make_attention_mask_2d:
1097
+ # # This converts attention mask into (num_frames*2, num_pixels, num_pixels)
1098
+ # attention_mask = attention_mask.unsqueeze(1)
1099
+ # tmp_mask = attention_mask.permute(0,2,1) # 32, 1024, 1
1100
+ # # The following line makes attention mask to have a block of ones
1101
+ # attention_mask_2d = torch.bitwise_and(attention_mask, tmp_mask)
1102
+ # if attention_mask_block_diagonal:
1103
+ # tmp_mask = ~attention_mask
1104
+ # # We now get ones where background attends to background
1105
+ # tmp_mask_2 = tmp_mask.permute(0, 2, 1)
1106
+ # tmp_mask = torch.bitwise_and(tmp_mask, tmp_mask_2)
1107
+ # # We now get a block diagonal structure
1108
+ # attention_mask_2d = torch.bitwise_or(attention_mask_2d, tmp_mask)
1109
+ # attention_mask = attention_mask_2d
1110
+ # attention_mask = ~attention_mask
1111
+ # 7. Denoising loop
1112
+ latents_all = []
1113
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1114
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1115
+ for i, t in enumerate(timesteps):
1116
+ # expand the latents if we are doing classifier free guidance
1117
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1118
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1119
+
1120
+ # predict the noise residual
1121
+ noise_pred = self.unet(
1122
+ latent_model_input,
1123
+ t,
1124
+ encoder_hidden_states=prompt_embeds,
1125
+ cross_attention_kwargs=cross_attention_kwargs,
1126
+ return_dict=False,
1127
+ encoder_attention_mask=custom_attention_mask if custom_attention_mask is not None and i < frozen_steps else None,
1128
+ attention_mask=attention_mask if frozen_steps is not None and i < frozen_steps else None,
1129
+ make_2d_attention_mask=make_attention_mask_2d,
1130
+ block_diagonal_attention=attention_mask_block_diagonal,
1131
+ )[0]
1132
+
1133
+ # perform guidance
1134
+ if do_classifier_free_guidance:
1135
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1136
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1137
+
1138
+ # reshape latents
1139
+ bsz, channel, frames, width, height = latents.shape
1140
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
1141
+ noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
1142
+
1143
+ # compute the previous noisy sample x_t -> x_t-1
1144
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1145
+
1146
+ # reshape latents back
1147
+ latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4)
1148
+ latents_all.append(latents)
1149
+
1150
+ # update the prompt_embeds after the frozen_steps to consider the whole prompt, including fg_prompt
1151
+ if frozen_steps is not None and i == frozen_steps:
1152
+ prompt_embeds, _ = self._encode_prompt(
1153
+ prompt,
1154
+ device,
1155
+ num_images_per_prompt,
1156
+ do_classifier_free_guidance,
1157
+ negative_prompt,
1158
+ prompt_embeds=None,
1159
+ negative_prompt_embeds=None,
1160
+ lora_scale=text_encoder_lora_scale,
1161
+ fg_prompt=None,
1162
+ )
1163
+ # call the callback, if provided
1164
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1165
+ progress_bar.update()
1166
+ if callback is not None and i % callback_steps == 0:
1167
+ callback(i, t, latents)
1168
+
1169
+ if output_type == "latent":
1170
+ # return TextToVideoSDPipelineOutput(frames=latents)
1171
+ latents_all = torch.cat(latents_all, dim=0) # (num_inference_steps, num_channels_latents, num_frames, height, width) batch size is 1
1172
+ print(latents_all.shape)
1173
+ return TextToVideoSDPipelineOutput(frames=latents_all)
1174
+
1175
+ video_tensor = self.decode_latents(latents)
1176
+
1177
+ if output_type == "pt":
1178
+ video = video_tensor
1179
+ else:
1180
+ video = tensor2vid(video_tensor)
1181
+
1182
+ # Offload last model to CPU
1183
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1184
+ self.final_offload_hook.offload()
1185
+
1186
+ if not return_dict:
1187
+ return (video,)
1188
+
1189
+ return TextToVideoSDPipelineOutput(frames=video)
1190
+
1191
+ @torch.no_grad()
1192
+ def __call__latestNotCalledForNow(
1193
+ self,
1194
+ prompt: Union[str, List[str]] = None,
1195
+ height: Optional[int] = None,
1196
+ width: Optional[int] = None,
1197
+ num_frames: int = 16,
1198
+ num_inference_steps: int = 50,
1199
+ guidance_scale: float = 9.0,
1200
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1201
+ eta: float = 0.0,
1202
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1203
+ latents: Optional[torch.FloatTensor] = None,
1204
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1205
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1206
+ output_type: Optional[str] = "np",
1207
+ return_dict: bool = True,
1208
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1209
+ callback_steps: int = 1,
1210
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1211
+ clip_skip: Optional[int] = None,
1212
+ frozen_mask: Optional[torch.FloatTensor] = None,
1213
+ frozen_steps: Optional[int] = None,
1214
+ latents_all_input: Optional[torch.FloatTensor] = None,
1215
+ ):
1216
+ r"""
1217
+ The call function to the pipeline for generation.
1218
+
1219
+ Args:
1220
+ prompt (`str` or `List[str]`, *optional*):
1221
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
1222
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1223
+ The height in pixels of the generated video.
1224
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1225
+ The width in pixels of the generated video.
1226
+ num_frames (`int`, *optional*, defaults to 16):
1227
+ The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
1228
+ amounts to 2 seconds of video.
1229
+ num_inference_steps (`int`, *optional*, defaults to 50):
1230
+ The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
1231
+ expense of slower inference.
1232
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1233
+ A higher guidance scale value encourages the model to generate images closely linked to the text
1234
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
1235
+ negative_prompt (`str` or `List[str]`, *optional*):
1236
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
1237
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
1238
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1239
+ The number of images to generate per prompt.
1240
+ eta (`float`, *optional*, defaults to 0.0):
1241
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
1242
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
1243
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1244
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
1245
+ generation deterministic.
1246
+ latents (`torch.FloatTensor`, *optional*):
1247
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
1248
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1249
+ tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
1250
+ `(batch_size, num_channel, num_frames, height, width)`.
1251
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1252
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
1253
+ provided, text embeddings are generated from the `prompt` input argument.
1254
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1255
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
1256
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
1257
+ output_type (`str`, *optional*, defaults to `"np"`):
1258
+ The output format of the generated video. Choose between `torch.FloatTensor` or `np.array`.
1259
+ return_dict (`bool`, *optional*, defaults to `True`):
1260
+ Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
1261
+ of a plain tuple.
1262
+ callback (`Callable`, *optional*):
1263
+ A function that calls every `callback_steps` steps during inference. The function is called with the
1264
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1265
+ callback_steps (`int`, *optional*, defaults to 1):
1266
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
1267
+ every step.
1268
+ cross_attention_kwargs (`dict`, *optional*):
1269
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
1270
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1271
+ clip_skip (`int`, *optional*):
1272
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1273
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1274
+ Examples:
1275
+
1276
+ Returns:
1277
+ [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`:
1278
+ If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
1279
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
1280
+ """
1281
+ # 0. Default height and width to unet
1282
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
1283
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
1284
+
1285
+ num_images_per_prompt = 1
1286
+
1287
+ # 1. Check inputs. Raise error if not correct
1288
+ self.check_inputs(
1289
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
1290
+ )
1291
+
1292
+ # 2. Define call parameters
1293
+ if prompt is not None and isinstance(prompt, str):
1294
+ batch_size = 1
1295
+ elif prompt is not None and isinstance(prompt, list):
1296
+ batch_size = len(prompt)
1297
+ else:
1298
+ batch_size = prompt_embeds.shape[0]
1299
+
1300
+ device = self._execution_device
1301
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1302
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1303
+ # corresponds to doing no classifier free guidance.
1304
+ do_classifier_free_guidance = guidance_scale > 1.0
1305
+
1306
+ # 3. Encode input prompt
1307
+ text_encoder_lora_scale = (
1308
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
1309
+ )
1310
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
1311
+ prompt,
1312
+ device,
1313
+ num_images_per_prompt,
1314
+ do_classifier_free_guidance,
1315
+ negative_prompt,
1316
+ prompt_embeds=prompt_embeds,
1317
+ negative_prompt_embeds=negative_prompt_embeds,
1318
+ lora_scale=text_encoder_lora_scale,
1319
+ clip_skip=clip_skip,
1320
+ )
1321
+ # For classifier free guidance, we need to do two forward passes.
1322
+ # Here we concatenate the unconditional and text embeddings into a single batch
1323
+ # to avoid doing two forward passes
1324
+ if do_classifier_free_guidance:
1325
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1326
+
1327
+ # 4. Prepare timesteps
1328
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1329
+ timesteps = self.scheduler.timesteps
1330
+
1331
+ # 5. Prepare latent variables
1332
+ num_channels_latents = self.unet.config.in_channels
1333
+ latents = self.prepare_latents(
1334
+ batch_size * num_images_per_prompt,
1335
+ num_channels_latents,
1336
+ num_frames,
1337
+ height,
1338
+ width,
1339
+ prompt_embeds.dtype,
1340
+ device,
1341
+ generator,
1342
+ latents,
1343
+ )
1344
+
1345
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1346
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1347
+
1348
+ # 7. Denoising loop
1349
+ latents_all = []
1350
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1351
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1352
+ for i, t in enumerate(timesteps):
1353
+ # expand the latents if we are doing classifier free guidance
1354
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1355
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1356
+
1357
+ # predict the noise residual
1358
+ noise_pred = self.unet(
1359
+ latent_model_input,
1360
+ t,
1361
+ encoder_hidden_states=prompt_embeds,
1362
+ cross_attention_kwargs=cross_attention_kwargs,
1363
+ return_dict=False,
1364
+ )[0]
1365
+
1366
+ # perform guidance
1367
+ if do_classifier_free_guidance:
1368
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1369
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1370
+
1371
+ # reshape latents
1372
+ bsz, channel, frames, width, height = latents.shape
1373
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
1374
+ noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
1375
+
1376
+ # compute the previous noisy sample x_t -> x_t-1
1377
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1378
+
1379
+
1380
+ # reshape latents back
1381
+ latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4)
1382
+ latents_all.append(latents)
1383
+
1384
+ # put frozen latents back
1385
+ if frozen_mask is not None and i < frozen_steps:
1386
+ latents = latents_all_input[i+1:i+2,...] * frozen_mask + latents * (1. - frozen_mask)
1387
+ print(t, latents.shape, frozen_mask.shape)
1388
+
1389
+ # call the callback, if provided
1390
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1391
+ progress_bar.update()
1392
+ if callback is not None and i % callback_steps == 0:
1393
+ callback(i, t, latents)
1394
+
1395
+ if output_type == "latent":
1396
+ # return TextToVideoSDPipelineOutput(frames=latents)
1397
+ latents_all = torch.cat(latents_all, dim=0) # (num_inference_steps, num_channels_latents, num_frames, height, width) batch size is 1
1398
+ print(latents_all.shape)
1399
+ return TextToVideoSDPipelineOutput(frames=latents_all)
1400
+
1401
+ video_tensor = self.decode_latents(latents)
1402
+
1403
+ if output_type == "pt":
1404
+ video = video_tensor
1405
+ else:
1406
+ video = tensor2vid(video_tensor)
1407
+
1408
+ # Offload all models
1409
+ self.maybe_free_model_hooks()
1410
+
1411
+ if not return_dict:
1412
+ return (video,)
1413
+
1414
+ return TextToVideoSDPipelineOutput(frames=video)
src/models/sd_pipeline.py ADDED
@@ -0,0 +1,719 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+ import PIL
4
+ import numpy as np
5
+ import torch
6
+ from packaging import version
7
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
8
+
9
+ from diffusers.configuration_utils import FrozenDict
10
+ from diffusers.image_processor import VaeImageProcessor
11
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
12
+ from diffusers.models import AutoencoderKL
13
+ from .unet_2d_condition import UNet2DConditionModel
14
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
15
+ from diffusers.schedulers import KarrasDiffusionSchedulers
16
+ from diffusers.utils import (
17
+ deprecate,
18
+ logging,
19
+ replace_example_docstring,
20
+ )
21
+ from diffusers.utils import (
22
+ BaseOutput,
23
+ )
24
+ from diffusers.utils.torch_utils import randn_tensor
25
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
26
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
27
+
28
+ from dataclasses import dataclass
29
+
30
+
31
+ @dataclass
32
+ class StableDiffusionPipelineOutput(BaseOutput):
33
+ """
34
+ Output class for Stable Diffusion pipelines.
35
+
36
+ Args:
37
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
38
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
39
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
40
+ nsfw_content_detected (`List[bool]`)
41
+ List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
42
+ (nsfw) content, or `None` if safety checking could not be performed.
43
+ """
44
+
45
+ images: Union[List[PIL.Image.Image], np.ndarray]
46
+ nsfw_content_detected: Optional[List[bool]]
47
+
48
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
49
+ """
50
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
51
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
52
+ """
53
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
54
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
55
+ # rescale the results from guidance (fixes overexposure)
56
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
57
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
58
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
59
+ return noise_cfg
60
+
61
+
62
+
63
+ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin):
64
+ r"""
65
+ Pipeline for text-to-image generation using Stable Diffusion.
66
+
67
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
68
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
69
+
70
+ The pipeline also inherits the following loading methods:
71
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
72
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
73
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
74
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
75
+
76
+ Args:
77
+ vae ([`AutoencoderKL`]):
78
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
79
+ text_encoder ([`~transformers.CLIPTextModel`]):
80
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
81
+ tokenizer ([`~transformers.CLIPTokenizer`]):
82
+ A `CLIPTokenizer` to tokenize text.
83
+ unet ([`UNet2DConditionModel`]):
84
+ A `UNet2DConditionModel` to denoise the encoded image latents.
85
+ scheduler ([`SchedulerMixin`]):
86
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
87
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
88
+ safety_checker ([`StableDiffusionSafetyChecker`]):
89
+ Classification module that estimates whether generated images could be considered offensive or harmful.
90
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
91
+ about a model's potential harms.
92
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
93
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
94
+ """
95
+ model_cpu_offload_seq = "text_encoder->unet->vae"
96
+ _optional_components = ["safety_checker", "feature_extractor"]
97
+ _exclude_from_cpu_offload = ["safety_checker"]
98
+
99
+ def __init__(
100
+ self,
101
+ vae: AutoencoderKL,
102
+ text_encoder: CLIPTextModel,
103
+ tokenizer: CLIPTokenizer,
104
+ unet: UNet2DConditionModel,
105
+ scheduler: KarrasDiffusionSchedulers,
106
+ safety_checker: StableDiffusionSafetyChecker,
107
+ feature_extractor: CLIPImageProcessor,
108
+ requires_safety_checker: bool = True,
109
+ ):
110
+ super().__init__()
111
+
112
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
113
+ deprecation_message = (
114
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
115
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
116
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
117
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
118
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
119
+ " file"
120
+ )
121
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
122
+ new_config = dict(scheduler.config)
123
+ new_config["steps_offset"] = 1
124
+ scheduler._internal_dict = FrozenDict(new_config)
125
+
126
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
127
+ deprecation_message = (
128
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
129
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
130
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
131
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
132
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
133
+ )
134
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
135
+ new_config = dict(scheduler.config)
136
+ new_config["clip_sample"] = False
137
+ scheduler._internal_dict = FrozenDict(new_config)
138
+
139
+ if safety_checker is None and requires_safety_checker:
140
+ logger.warning(
141
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
142
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
143
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
144
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
145
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
146
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
147
+ )
148
+
149
+ if safety_checker is not None and feature_extractor is None:
150
+ raise ValueError(
151
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
152
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
153
+ )
154
+
155
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
156
+ version.parse(unet.config._diffusers_version).base_version
157
+ ) < version.parse("0.9.0.dev0")
158
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
159
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
160
+ deprecation_message = (
161
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
162
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
163
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
164
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
165
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
166
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
167
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
168
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
169
+ " the `unet/config.json` file"
170
+ )
171
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
172
+ new_config = dict(unet.config)
173
+ new_config["sample_size"] = 64
174
+ unet._internal_dict = FrozenDict(new_config)
175
+
176
+ self.register_modules(
177
+ vae=vae,
178
+ text_encoder=text_encoder,
179
+ tokenizer=tokenizer,
180
+ unet=unet,
181
+ scheduler=scheduler,
182
+ safety_checker=safety_checker,
183
+ feature_extractor=feature_extractor,
184
+ )
185
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
186
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
187
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
188
+
189
+ def enable_vae_slicing(self):
190
+ r"""
191
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
192
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
193
+ """
194
+ self.vae.enable_slicing()
195
+
196
+ def disable_vae_slicing(self):
197
+ r"""
198
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
199
+ computing decoding in one step.
200
+ """
201
+ self.vae.disable_slicing()
202
+
203
+ def enable_vae_tiling(self):
204
+ r"""
205
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
206
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
207
+ processing larger images.
208
+ """
209
+ self.vae.enable_tiling()
210
+
211
+ def disable_vae_tiling(self):
212
+ r"""
213
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
214
+ computing decoding in one step.
215
+ """
216
+ self.vae.disable_tiling()
217
+
218
+ def _encode_prompt(
219
+ self,
220
+ prompt,
221
+ device,
222
+ num_images_per_prompt,
223
+ do_classifier_free_guidance,
224
+ negative_prompt=None,
225
+ prompt_embeds: Optional[torch.FloatTensor] = None,
226
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
227
+ lora_scale: Optional[float] = None,
228
+ ):
229
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
230
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
231
+
232
+ prompt_embeds_tuple = self.encode_prompt(
233
+ prompt=prompt,
234
+ device=device,
235
+ num_images_per_prompt=num_images_per_prompt,
236
+ do_classifier_free_guidance=do_classifier_free_guidance,
237
+ negative_prompt=negative_prompt,
238
+ prompt_embeds=prompt_embeds,
239
+ negative_prompt_embeds=negative_prompt_embeds,
240
+ lora_scale=lora_scale,
241
+ )
242
+
243
+ # concatenate for backwards comp
244
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
245
+
246
+ return prompt_embeds
247
+
248
+ def encode_prompt(
249
+ self,
250
+ prompt,
251
+ device,
252
+ num_images_per_prompt,
253
+ do_classifier_free_guidance,
254
+ negative_prompt=None,
255
+ prompt_embeds: Optional[torch.FloatTensor] = None,
256
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
257
+ lora_scale: Optional[float] = None,
258
+ ):
259
+ r"""
260
+ Encodes the prompt into text encoder hidden states.
261
+
262
+ Args:
263
+ prompt (`str` or `List[str]`, *optional*):
264
+ prompt to be encoded
265
+ device: (`torch.device`):
266
+ torch device
267
+ num_images_per_prompt (`int`):
268
+ number of images that should be generated per prompt
269
+ do_classifier_free_guidance (`bool`):
270
+ whether to use classifier free guidance or not
271
+ negative_prompt (`str` or `List[str]`, *optional*):
272
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
273
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
274
+ less than `1`).
275
+ prompt_embeds (`torch.FloatTensor`, *optional*):
276
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
277
+ provided, text embeddings will be generated from `prompt` input argument.
278
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
279
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
280
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
281
+ argument.
282
+ lora_scale (`float`, *optional*):
283
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
284
+ """
285
+ # set lora scale so that monkey patched LoRA
286
+ # function of text encoder can correctly access it
287
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
288
+ self._lora_scale = lora_scale
289
+
290
+ # dynamically adjust the LoRA scale
291
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
292
+
293
+ if prompt is not None and isinstance(prompt, str):
294
+ batch_size = 1
295
+ elif prompt is not None and isinstance(prompt, list):
296
+ batch_size = len(prompt)
297
+ else:
298
+ batch_size = prompt_embeds.shape[0]
299
+
300
+ if prompt_embeds is None:
301
+ # textual inversion: procecss multi-vector tokens if necessary
302
+ if isinstance(self, TextualInversionLoaderMixin):
303
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
304
+
305
+ text_inputs = self.tokenizer(
306
+ prompt,
307
+ padding="max_length",
308
+ max_length=self.tokenizer.model_max_length,
309
+ truncation=True,
310
+ return_tensors="pt",
311
+ )
312
+ text_input_ids = text_inputs.input_ids
313
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
314
+
315
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
316
+ text_input_ids, untruncated_ids
317
+ ):
318
+ removed_text = self.tokenizer.batch_decode(
319
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
320
+ )
321
+ logger.warning(
322
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
323
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
324
+ )
325
+
326
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
327
+ attention_mask = text_inputs.attention_mask.to(device)
328
+ else:
329
+ attention_mask = None
330
+
331
+ prompt_embeds = self.text_encoder(
332
+ text_input_ids.to(device),
333
+ attention_mask=attention_mask,
334
+ )
335
+ prompt_embeds = prompt_embeds[0]
336
+
337
+ if self.text_encoder is not None:
338
+ prompt_embeds_dtype = self.text_encoder.dtype
339
+ elif self.unet is not None:
340
+ prompt_embeds_dtype = self.unet.dtype
341
+ else:
342
+ prompt_embeds_dtype = prompt_embeds.dtype
343
+
344
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
345
+
346
+ bs_embed, seq_len, _ = prompt_embeds.shape
347
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
348
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
349
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
350
+
351
+ # get unconditional embeddings for classifier free guidance
352
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
353
+ uncond_tokens: List[str]
354
+ if negative_prompt is None:
355
+ uncond_tokens = [""] * batch_size
356
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
357
+ raise TypeError(
358
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
359
+ f" {type(prompt)}."
360
+ )
361
+ elif isinstance(negative_prompt, str):
362
+ uncond_tokens = [negative_prompt]
363
+ elif batch_size != len(negative_prompt):
364
+ raise ValueError(
365
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
366
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
367
+ " the batch size of `prompt`."
368
+ )
369
+ else:
370
+ uncond_tokens = negative_prompt
371
+
372
+ # textual inversion: procecss multi-vector tokens if necessary
373
+ if isinstance(self, TextualInversionLoaderMixin):
374
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
375
+
376
+ max_length = prompt_embeds.shape[1]
377
+ uncond_input = self.tokenizer(
378
+ uncond_tokens,
379
+ padding="max_length",
380
+ max_length=max_length,
381
+ truncation=True,
382
+ return_tensors="pt",
383
+ )
384
+
385
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
386
+ attention_mask = uncond_input.attention_mask.to(device)
387
+ else:
388
+ attention_mask = None
389
+
390
+ negative_prompt_embeds = self.text_encoder(
391
+ uncond_input.input_ids.to(device),
392
+ attention_mask=attention_mask,
393
+ )
394
+ negative_prompt_embeds = negative_prompt_embeds[0]
395
+
396
+ if do_classifier_free_guidance:
397
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
398
+ seq_len = negative_prompt_embeds.shape[1]
399
+
400
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
401
+
402
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
403
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
404
+
405
+ return prompt_embeds, negative_prompt_embeds
406
+
407
+ def run_safety_checker(self, image, device, dtype):
408
+ if self.safety_checker is None:
409
+ has_nsfw_concept = None
410
+ else:
411
+ if torch.is_tensor(image):
412
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
413
+ else:
414
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
415
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
416
+ image, has_nsfw_concept = self.safety_checker(
417
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
418
+ )
419
+ return image, has_nsfw_concept
420
+
421
+ def decode_latents(self, latents):
422
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
423
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
424
+
425
+ latents = 1 / self.vae.config.scaling_factor * latents
426
+ image = self.vae.decode(latents, return_dict=False)[0]
427
+ image = (image / 2 + 0.5).clamp(0, 1)
428
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
429
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
430
+ return image
431
+
432
+ def prepare_extra_step_kwargs(self, generator, eta):
433
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
434
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
435
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
436
+ # and should be between [0, 1]
437
+
438
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
439
+ extra_step_kwargs = {}
440
+ if accepts_eta:
441
+ extra_step_kwargs["eta"] = eta
442
+
443
+ # check if the scheduler accepts generator
444
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
445
+ if accepts_generator:
446
+ extra_step_kwargs["generator"] = generator
447
+ return extra_step_kwargs
448
+
449
+ def check_inputs(
450
+ self,
451
+ prompt,
452
+ height,
453
+ width,
454
+ callback_steps,
455
+ negative_prompt=None,
456
+ prompt_embeds=None,
457
+ negative_prompt_embeds=None,
458
+ ):
459
+ if height % 8 != 0 or width % 8 != 0:
460
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
461
+
462
+ if (callback_steps is None) or (
463
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
464
+ ):
465
+ raise ValueError(
466
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
467
+ f" {type(callback_steps)}."
468
+ )
469
+
470
+ if prompt is not None and prompt_embeds is not None:
471
+ raise ValueError(
472
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
473
+ " only forward one of the two."
474
+ )
475
+ elif prompt is None and prompt_embeds is None:
476
+ raise ValueError(
477
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
478
+ )
479
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
480
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
481
+
482
+ if negative_prompt is not None and negative_prompt_embeds is not None:
483
+ raise ValueError(
484
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
485
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
486
+ )
487
+
488
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
489
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
490
+ raise ValueError(
491
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
492
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
493
+ f" {negative_prompt_embeds.shape}."
494
+ )
495
+
496
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
497
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
498
+ if isinstance(generator, list) and len(generator) != batch_size:
499
+ raise ValueError(
500
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
501
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
502
+ )
503
+
504
+ if latents is None:
505
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
506
+ else:
507
+ latents = latents.to(device)
508
+
509
+ # scale the initial noise by the standard deviation required by the scheduler
510
+ latents = latents * self.scheduler.init_noise_sigma
511
+ return latents
512
+
513
+ @torch.no_grad()
514
+ # @replace_example_docstring(EXAMPLE_DOC_STRING)
515
+ def __call__(
516
+ self,
517
+ prompt: Union[str, List[str]] = None,
518
+ height: Optional[int] = None,
519
+ width: Optional[int] = None,
520
+ num_inference_steps: int = 50,
521
+ guidance_scale: float = 7.5,
522
+ negative_prompt: Optional[Union[str, List[str]]] = None,
523
+ num_images_per_prompt: Optional[int] = 1,
524
+ eta: float = 0.0,
525
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
526
+ latents: Optional[torch.FloatTensor] = None,
527
+ prompt_embeds: Optional[torch.FloatTensor] = None,
528
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
529
+ output_type: Optional[str] = "pil",
530
+ return_dict: bool = True,
531
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
532
+ callback_steps: int = 1,
533
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
534
+ guidance_rescale: float = 0.0,
535
+ ):
536
+ r"""
537
+ The call function to the pipeline for generation.
538
+
539
+ Args:
540
+ prompt (`str` or `List[str]`, *optional*):
541
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
542
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
543
+ The height in pixels of the generated image.
544
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
545
+ The width in pixels of the generated image.
546
+ num_inference_steps (`int`, *optional*, defaults to 50):
547
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
548
+ expense of slower inference.
549
+ guidance_scale (`float`, *optional*, defaults to 7.5):
550
+ A higher guidance scale value encourages the model to generate images closely linked to the text
551
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
552
+ negative_prompt (`str` or `List[str]`, *optional*):
553
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
554
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
555
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
556
+ The number of images to generate per prompt.
557
+ eta (`float`, *optional*, defaults to 0.0):
558
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
559
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
560
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
561
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
562
+ generation deterministic.
563
+ latents (`torch.FloatTensor`, *optional*):
564
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
565
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
566
+ tensor is generated by sampling using the supplied random `generator`.
567
+ prompt_embeds (`torch.FloatTensor`, *optional*):
568
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
569
+ provided, text embeddings are generated from the `prompt` input argument.
570
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
571
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
572
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
573
+ output_type (`str`, *optional*, defaults to `"pil"`):
574
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
575
+ return_dict (`bool`, *optional*, defaults to `True`):
576
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
577
+ plain tuple.
578
+ callback (`Callable`, *optional*):
579
+ A function that calls every `callback_steps` steps during inference. The function is called with the
580
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
581
+ callback_steps (`int`, *optional*, defaults to 1):
582
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
583
+ every step.
584
+ cross_attention_kwargs (`dict`, *optional*):
585
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
586
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
587
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
588
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
589
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
590
+ using zero terminal SNR.
591
+
592
+ Examples:
593
+
594
+ Returns:
595
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
596
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
597
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
598
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
599
+ "not-safe-for-work" (nsfw) content.
600
+ """
601
+ # 0. Default height and width to unet
602
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
603
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
604
+
605
+ # 1. Check inputs. Raise error if not correct
606
+ self.check_inputs(
607
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
608
+ )
609
+
610
+ # 2. Define call parameters
611
+ if prompt is not None and isinstance(prompt, str):
612
+ batch_size = 1
613
+ elif prompt is not None and isinstance(prompt, list):
614
+ batch_size = len(prompt)
615
+ else:
616
+ batch_size = prompt_embeds.shape[0]
617
+
618
+ device = self._execution_device
619
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
620
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
621
+ # corresponds to doing no classifier free guidance.
622
+ do_classifier_free_guidance = guidance_scale > 1.0
623
+
624
+ # 3. Encode input prompt
625
+ text_encoder_lora_scale = (
626
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
627
+ )
628
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
629
+ prompt,
630
+ device,
631
+ num_images_per_prompt,
632
+ do_classifier_free_guidance,
633
+ negative_prompt,
634
+ prompt_embeds=prompt_embeds,
635
+ negative_prompt_embeds=negative_prompt_embeds,
636
+ lora_scale=text_encoder_lora_scale,
637
+ )
638
+ # For classifier free guidance, we need to do two forward passes.
639
+ # Here we concatenate the unconditional and text embeddings into a single batch
640
+ # to avoid doing two forward passes
641
+ if do_classifier_free_guidance:
642
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
643
+
644
+ # 4. Prepare timesteps
645
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
646
+ timesteps = self.scheduler.timesteps
647
+
648
+ # 5. Prepare latent variables
649
+ num_channels_latents = self.unet.config.in_channels
650
+ latents = self.prepare_latents(
651
+ batch_size * num_images_per_prompt,
652
+ num_channels_latents,
653
+ height,
654
+ width,
655
+ prompt_embeds.dtype,
656
+ device,
657
+ generator,
658
+ latents,
659
+ )
660
+
661
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
662
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
663
+
664
+ # 7. Denoising loop
665
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
666
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
667
+ for i, t in enumerate(timesteps):
668
+ # expand the latents if we are doing classifier free guidance
669
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
670
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
671
+
672
+ # predict the noise residual
673
+ noise_pred = self.unet(
674
+ latent_model_input,
675
+ t,
676
+ encoder_hidden_states=prompt_embeds,
677
+ cross_attention_kwargs=cross_attention_kwargs,
678
+ return_dict=False,
679
+ )[0]
680
+
681
+ # perform guidance
682
+ if do_classifier_free_guidance:
683
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
684
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
685
+
686
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
687
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
688
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
689
+
690
+ # compute the previous noisy sample x_t -> x_t-1
691
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
692
+
693
+ # call the callback, if provided
694
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
695
+ progress_bar.update()
696
+ if callback is not None and i % callback_steps == 0:
697
+ callback(i, t, latents)
698
+
699
+ if not output_type == "latent":
700
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
701
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
702
+ else:
703
+ image = latents
704
+ has_nsfw_concept = None
705
+
706
+ if has_nsfw_concept is None:
707
+ do_denormalize = [True] * image.shape[0]
708
+ else:
709
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
710
+
711
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
712
+
713
+ # Offload all models
714
+ self.maybe_free_model_hooks()
715
+
716
+ if not return_dict:
717
+ return (image, has_nsfw_concept)
718
+
719
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
src/models/t2i_pipeline.py ADDED
@@ -0,0 +1,770 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .sd_pipeline import StableDiffusionPipeline
2
+ from dataclasses import dataclass
3
+ from typing import List, Union
4
+ from typing import Optional, Callable, Dict, Any
5
+ import PIL
6
+ from .unet_2d_condition import UNet2DConditionModel
7
+
8
+ import numpy as np
9
+ import torch
10
+ from diffusers.utils import (
11
+ BaseOutput,
12
+ )
13
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
14
+
15
+
16
+
17
+
18
+
19
+
20
+ @dataclass
21
+ class StableDiffusionPipelineOutput(BaseOutput):
22
+ """
23
+ Output class for Stable Diffusion pipelines.
24
+
25
+ Args:
26
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
27
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
28
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
29
+ nsfw_content_detected (`List[bool]`)
30
+ List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
31
+ (nsfw) content, or `None` if safety checking could not be performed.
32
+ """
33
+
34
+ images: Union[List[PIL.Image.Image], np.ndarray]
35
+ nsfw_content_detected: Optional[List[bool]]
36
+
37
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
38
+ """
39
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
40
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
41
+ """
42
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
43
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
44
+ # rescale the results from guidance (fixes overexposure)
45
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
46
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
47
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
48
+ return noise_cfg
49
+
50
+ class StableDiffusionPipelineSpatialAware(StableDiffusionPipeline):
51
+ def __init__(
52
+ self,
53
+ vae,
54
+ text_encoder,
55
+ tokenizer,
56
+ unet,
57
+ scheduler,
58
+ safety_checker,
59
+ feature_extractor,
60
+ requires_safety_checker: bool = True,
61
+ ):
62
+ unet_new = UNet2DConditionModel(**unet.config)
63
+ unet_new.load_state_dict(unet.state_dict())
64
+
65
+ super().__init__(vae, text_encoder, tokenizer, unet_new, scheduler, safety_checker, feature_extractor, requires_safety_checker)
66
+
67
+ def _encode_prompt(
68
+ self,
69
+ prompt,
70
+ device,
71
+ num_images_per_prompt,
72
+ do_classifier_free_guidance,
73
+ negative_prompt=None,
74
+ prompt_embeds: Optional[torch.FloatTensor] = None,
75
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
76
+ lora_scale: Optional[float] = None,
77
+ fg_prompt: Optional[str] = None,
78
+ ):
79
+ r"""
80
+ Encodes the prompt into text encoder hidden states.
81
+
82
+ Args:
83
+ prompt (`str` or `List[str]`, *optional*):
84
+ prompt to be encoded
85
+ device: (`torch.device`):
86
+ torch device
87
+ num_images_per_prompt (`int`):
88
+ number of images that should be generated per prompt
89
+ do_classifier_free_guidance (`bool`):
90
+ whether to use classifier free guidance or not
91
+ negative_prompt (`str` or `List[str]`, *optional*):
92
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
93
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
94
+ less than `1`).
95
+ prompt_embeds (`torch.FloatTensor`, *optional*):
96
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
97
+ provided, text embeddings will be generated from `prompt` input argument.
98
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
99
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
100
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
101
+ argument.
102
+ lora_scale (`float`, *optional*):
103
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
104
+ """
105
+ # set lora scale so that monkey patched LoRA
106
+ # function of text encoder can correctly access it
107
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
108
+ self._lora_scale = lora_scale
109
+
110
+ if prompt is not None and isinstance(prompt, str):
111
+ batch_size = 1
112
+ elif prompt is not None and isinstance(prompt, list):
113
+ batch_size = len(prompt)
114
+ else:
115
+ batch_size = prompt_embeds.shape[0]
116
+
117
+ if prompt_embeds is None:
118
+ # textual inversion: procecss multi-vector tokens if necessary
119
+ if isinstance(self, TextualInversionLoaderMixin):
120
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
121
+
122
+ text_inputs = self.tokenizer(
123
+ prompt,
124
+ padding="max_length",
125
+ max_length=self.tokenizer.model_max_length,
126
+ truncation=True,
127
+ return_tensors="pt",
128
+ )
129
+ text_input_ids = text_inputs.input_ids
130
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
131
+
132
+
133
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
134
+ text_input_ids, untruncated_ids
135
+ ):
136
+ removed_text = self.tokenizer.batch_decode(
137
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
138
+ )
139
+ logger.warning(
140
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
141
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
142
+ )
143
+
144
+ if fg_prompt is not None:
145
+ fg_text_inputs = self.tokenizer(
146
+ fg_prompt,
147
+ # padding="max_length",
148
+ # max_length=self.tokenizer.model_max_length,
149
+ # truncation=True,
150
+ return_tensors="pt",
151
+ )
152
+ # breakpoint()
153
+ fg_text_input_ids = fg_text_inputs.input_ids
154
+
155
+ # remove first and last token
156
+ fg_text_input_ids = fg_text_input_ids[:,:-1]
157
+
158
+ # remove common tokens in fg_text_input_ids from text_input_ids
159
+ batch_size = text_input_ids.shape[0]
160
+ # Create a mask that is True wherever a token in text_input_ids matches a token in fg_text_input_ids
161
+ mask = (text_input_ids.unsqueeze(-1) == fg_text_input_ids.unsqueeze(1)).any(dim=-1)
162
+ # Get the values from text_input_ids that are not in fg_text_input_ids
163
+ encoder_attention_mask = ~mask
164
+ encoder_attention_mask = encoder_attention_mask.repeat((2,1))
165
+ encoder_attention_mask = encoder_attention_mask.to(device)
166
+
167
+ # text_input_ids_filtered = text_input_ids[~mask].view(1, -1)
168
+ # text_input_ids_filtered will now contain the values from text_input_ids that aren't in fg_text_input_ids
169
+
170
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
171
+ attention_mask = text_inputs.attention_mask.to(device)
172
+ else:
173
+ attention_mask = None
174
+
175
+ prompt_embeds = self.text_encoder(
176
+ text_input_ids.to(device),
177
+ attention_mask=attention_mask,
178
+ # attention_mask=encoder_attention_mask[:batch_size] if fg_prompt is not None else None,
179
+ )
180
+ prompt_embeds = prompt_embeds[0]
181
+
182
+ if self.text_encoder is not None:
183
+ prompt_embeds_dtype = self.text_encoder.dtype
184
+ elif self.unet is not None:
185
+ prompt_embeds_dtype = self.unet.dtype
186
+ else:
187
+ prompt_embeds_dtype = prompt_embeds.dtype
188
+
189
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
190
+
191
+ bs_embed, seq_len, _ = prompt_embeds.shape
192
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
193
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
194
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
195
+
196
+ # get unconditional embeddings for classifier free guidance
197
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
198
+ uncond_tokens: List[str]
199
+ if negative_prompt is None:
200
+ uncond_tokens = [""] * batch_size
201
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
202
+ raise TypeError(
203
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
204
+ f" {type(prompt)}."
205
+ )
206
+ elif isinstance(negative_prompt, str):
207
+ uncond_tokens = [negative_prompt]
208
+ elif batch_size != len(negative_prompt):
209
+ raise ValueError(
210
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
211
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
212
+ " the batch size of `prompt`."
213
+ )
214
+ else:
215
+ uncond_tokens = negative_prompt
216
+
217
+ # textual inversion: procecss multi-vector tokens if necessary
218
+ if isinstance(self, TextualInversionLoaderMixin):
219
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
220
+
221
+ max_length = prompt_embeds.shape[1]
222
+ uncond_input = self.tokenizer(
223
+ uncond_tokens,
224
+ padding="max_length",
225
+ max_length=max_length,
226
+ truncation=True,
227
+ return_tensors="pt",
228
+ )
229
+
230
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
231
+ attention_mask = uncond_input.attention_mask.to(device)
232
+ else:
233
+ attention_mask = None
234
+
235
+ negative_prompt_embeds = self.text_encoder(
236
+ uncond_input.input_ids.to(device),
237
+ attention_mask=attention_mask,
238
+ )
239
+ negative_prompt_embeds = negative_prompt_embeds[0]
240
+
241
+ if do_classifier_free_guidance:
242
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
243
+ seq_len = negative_prompt_embeds.shape[1]
244
+
245
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
246
+
247
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
248
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
249
+
250
+ # For classifier free guidance, we need to do two forward passes.
251
+ # Here we concatenate the unconditional and text embeddings into a single batch
252
+ # to avoid doing two forward passes
253
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
254
+
255
+ if fg_prompt is not None:
256
+ return prompt_embeds, encoder_attention_mask
257
+ return prompt_embeds, None
258
+
259
+ @torch.no_grad()
260
+ def __call__(
261
+ self,
262
+ prompt: Union[str, List[str]] = None,
263
+ height: Optional[int] = None,
264
+ width: Optional[int] = None,
265
+ num_inference_steps: int = 50,
266
+ guidance_scale: float = 7.5,
267
+ negative_prompt: Optional[Union[str, List[str]]] = None,
268
+ num_images_per_prompt: Optional[int] = 1,
269
+ eta: float = 0.0,
270
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
271
+ latents: Optional[torch.FloatTensor] = None,
272
+ prompt_embeds: Optional[torch.FloatTensor] = None,
273
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
274
+ output_type: Optional[str] = "pil",
275
+ return_dict: bool = True,
276
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
277
+ callback_steps: int = 1,
278
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
279
+ guidance_rescale: float = 0.0,
280
+ frozen_mask: Optional[torch.FloatTensor] = None,
281
+ frozen_steps: Optional[int] = None,
282
+ frozen_text_mask: Optional[torch.FloatTensor] = None,
283
+ frozen_prompt: Optional[Union[str, List[str]]] = None,
284
+ custom_attention_mask: Optional[torch.FloatTensor] = None,
285
+ latents_all_input: Optional[torch.FloatTensor] = None,
286
+ fg_prompt: Optional[str] = None,
287
+ make_attention_mask_2d=False,
288
+ attention_mask_block_diagonal=False,
289
+ ):
290
+ r"""
291
+ The call function to the pipeline for generation.
292
+
293
+ Args:
294
+ prompt (`str` or `List[str]`, *optional*):
295
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
296
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
297
+ The height in pixels of the generated image.
298
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
299
+ The width in pixels of the generated image.
300
+ num_inference_steps (`int`, *optional*, defaults to 50):
301
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
302
+ expense of slower inference.
303
+ guidance_scale (`float`, *optional*, defaults to 7.5):
304
+ A higher guidance scale value encourages the model to generate images closely linked to the text
305
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
306
+ negative_prompt (`str` or `List[str]`, *optional*):
307
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
308
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
309
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
310
+ The number of images to generate per prompt.
311
+ eta (`float`, *optional*, defaults to 0.0):
312
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
313
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
314
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
315
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
316
+ generation deterministic.
317
+ latents (`torch.FloatTensor`, *optional*):
318
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
319
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
320
+ tensor is generated by sampling using the supplied random `generator`.
321
+ prompt_embeds (`torch.FloatTensor`, *optional*):
322
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
323
+ provided, text embeddings are generated from the `prompt` input argument.
324
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
325
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
326
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
327
+ output_type (`str`, *optional*, defaults to `"pil"`):
328
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
329
+ return_dict (`bool`, *optional*, defaults to `True`):
330
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
331
+ plain tuple.
332
+ callback (`Callable`, *optional*):
333
+ A function that calls every `callback_steps` steps during inference. The function is called with the
334
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
335
+ callback_steps (`int`, *optional*, defaults to 1):
336
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
337
+ every step.
338
+ cross_attention_kwargs (`dict`, *optional*):
339
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
340
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
341
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
342
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
343
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
344
+ using zero terminal SNR.
345
+
346
+ Examples:
347
+
348
+ Returns:
349
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
350
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
351
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
352
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
353
+ "not-safe-for-work" (nsfw) content.
354
+ """
355
+ # 0. Default height and width to unet
356
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
357
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
358
+
359
+ # 1. Check inputs. Raise error if not correct
360
+ self.check_inputs(
361
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
362
+ )
363
+
364
+ # 2. Define call parameters
365
+ if prompt is not None and isinstance(prompt, str):
366
+ batch_size = 1
367
+ elif prompt is not None and isinstance(prompt, list):
368
+ batch_size = len(prompt)
369
+ else:
370
+ batch_size = prompt_embeds.shape[0]
371
+
372
+ device = self._execution_device
373
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
374
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
375
+ # corresponds to doing no classifier free guidance.
376
+ do_classifier_free_guidance = guidance_scale > 1.0
377
+
378
+ # 3. Encode input prompt
379
+ text_encoder_lora_scale = (
380
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
381
+ )
382
+
383
+ prompt_embeds, custom_attention_mask = self._encode_prompt(
384
+ prompt,
385
+ device,
386
+ num_images_per_prompt,
387
+ do_classifier_free_guidance,
388
+ negative_prompt,
389
+ prompt_embeds=prompt_embeds,
390
+ negative_prompt_embeds=negative_prompt_embeds,
391
+ lora_scale=text_encoder_lora_scale,
392
+ fg_prompt=fg_prompt,
393
+ )
394
+ if frozen_prompt is not None: # freeze the prompt
395
+ prompt_embeds, _ = self._encode_prompt(
396
+ frozen_prompt,
397
+ device,
398
+ num_images_per_prompt,
399
+ do_classifier_free_guidance,
400
+ negative_prompt,
401
+ prompt_embeds=None,
402
+ negative_prompt_embeds=None,
403
+ lora_scale=text_encoder_lora_scale,
404
+ )
405
+
406
+ # 4. Prepare timesteps
407
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
408
+ timesteps = self.scheduler.timesteps
409
+
410
+ # 5. Prepare latent variables
411
+ num_channels_latents = self.unet.config.in_channels
412
+ latents = self.prepare_latents(
413
+ batch_size * num_images_per_prompt,
414
+ num_channels_latents,
415
+ height,
416
+ width,
417
+ prompt_embeds.dtype,
418
+ device,
419
+ generator,
420
+ latents,
421
+ )
422
+
423
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
424
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
425
+ if frozen_mask is not None:
426
+ attention_mask = frozen_mask.clone() # (1, 1, 96, 96)
427
+ attention_mask = attention_mask.view(attention_mask.shape[0], -1).repeat((2,1)).to(frozen_mask.device) # torch.Size([2, 9216])
428
+ attention_mask = attention_mask.bool()
429
+ # attention_mask = ~attention_mask
430
+ # if custom_attention_mask is not None:
431
+ # custom_attention_mask = ~custom_attention_mask
432
+ # 7. Denoising loop
433
+ latents_all = []
434
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
435
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
436
+ iter_t = iter(timesteps)
437
+ for i, t in enumerate(iter_t):
438
+ # expand the latents if we are doing classifier free guidance
439
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
440
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
441
+ # torch.save(i, 'i.pt')
442
+
443
+ # if attention_mask is not None:
444
+ # attention_mask=~attention_mask #fg deactivated at cnt%2==0
445
+ # if custom_attention_mask is not None:
446
+ # custom_attention_mask=~custom_attention_mask #fg deactivated at cnt%2==0
447
+
448
+ # predict the noise residual
449
+ noise_pred = self.unet(
450
+ latent_model_input,
451
+ t,
452
+ encoder_hidden_states=prompt_embeds,
453
+ cross_attention_kwargs=cross_attention_kwargs,
454
+ return_dict=False,
455
+ encoder_attention_mask=custom_attention_mask if custom_attention_mask is not None and i < frozen_steps else None,
456
+ attention_mask=attention_mask if frozen_steps is not None and i < frozen_steps else None,
457
+ make_2d_attention_mask=make_attention_mask_2d,
458
+ block_diagonal_attention=attention_mask_block_diagonal,
459
+ )[0]
460
+
461
+ # perform guidance
462
+ if do_classifier_free_guidance:
463
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
464
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
465
+
466
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
467
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
468
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
469
+
470
+ # compute the previous noisy sample x_t -> x_t-1
471
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
472
+
473
+ latents_all.append(latents)
474
+
475
+ # put frozen latents back
476
+ if frozen_mask is not None and i < frozen_steps:
477
+ # breakpoint()
478
+ # latents = latents_all_input[i+1:i+2,...] * frozen_mask + latents * (1. - frozen_mask)
479
+ pass
480
+
481
+ # update the prompt_embeds after the frozen_steps to consider the whole prompt, including fg_prompt
482
+ if frozen_steps is not None and i == frozen_steps:
483
+ prompt_embeds, _ = self._encode_prompt(
484
+ prompt,
485
+ device,
486
+ num_images_per_prompt,
487
+ do_classifier_free_guidance,
488
+ negative_prompt,
489
+ prompt_embeds=None,
490
+ negative_prompt_embeds=None,
491
+ lora_scale=text_encoder_lora_scale,
492
+ fg_prompt=None,
493
+ )
494
+ # call the callback, if provided
495
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
496
+ progress_bar.update()
497
+ if callback is not None and i % callback_steps == 0:
498
+ callback(i, t, latents)
499
+
500
+ # try:
501
+ # if i in [29,30,40,49]:
502
+ # tmp = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
503
+ # do_denormalize = [True] * tmp.shape[0]
504
+ # tmp = self.image_processor.postprocess(tmp, output_type=output_type, do_denormalize=do_denormalize)
505
+ # tmp_prompt = torch.load('prompt.pt')
506
+ # tmp[0].save(f'./demo15/im-{tmp_prompt}-{i}.png')
507
+ # except:
508
+ # pass
509
+
510
+ if not output_type == "latent":
511
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
512
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
513
+ else:
514
+ # image = latents
515
+ latents_all = torch.cat(latents_all, dim=0) # (num_inference_steps, num_channels_latents, height, width) assume batch_size=1
516
+ image = latents_all
517
+ has_nsfw_concept = None
518
+
519
+ if has_nsfw_concept is None:
520
+ do_denormalize = [True] * image.shape[0]
521
+ else:
522
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
523
+
524
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
525
+
526
+ # Offload last model to CPU
527
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
528
+ self.final_offload_hook.offload()
529
+
530
+ if not return_dict:
531
+ return (image, has_nsfw_concept)
532
+
533
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
534
+
535
+ '''
536
+ @torch.no_grad()
537
+ def __call__latestNotCalledForNow(
538
+ self,
539
+ prompt: Union[str, List[str]] = None,
540
+ height: Optional[int] = None,
541
+ width: Optional[int] = None,
542
+ num_inference_steps: int = 50,
543
+ guidance_scale: float = 7.5,
544
+ negative_prompt: Optional[Union[str, List[str]]] = None,
545
+ num_images_per_prompt: Optional[int] = 1,
546
+ eta: float = 0.0,
547
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
548
+ latents: Optional[torch.FloatTensor] = None,
549
+ prompt_embeds: Optional[torch.FloatTensor] = None,
550
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
551
+ output_type: Optional[str] = "pil",
552
+ return_dict: bool = True,
553
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
554
+ callback_steps: int = 1,
555
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
556
+ guidance_rescale: float = 0.0,
557
+ frozen_mask: Optional[torch.FloatTensor] = None,
558
+ frozen_steps: Optional[int] = None,
559
+ frozen_text_mask: Optional[torch.FloatTensor] = None,
560
+ frozen_prompt: Optional[Union[str, List[str]]] = None,
561
+ custom_attention_mask: Optional[torch.FloatTensor] = None,
562
+ latents_all_input: Optional[torch.FloatTensor] = None,
563
+ ):
564
+ r"""
565
+ The call function to the pipeline for generation.
566
+
567
+ Args:
568
+ prompt (`str` or `List[str]`, *optional*):
569
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
570
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
571
+ The height in pixels of the generated image.
572
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
573
+ The width in pixels of the generated image.
574
+ num_inference_steps (`int`, *optional*, defaults to 50):
575
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
576
+ expense of slower inference.
577
+ guidance_scale (`float`, *optional*, defaults to 7.5):
578
+ A higher guidance scale value encourages the model to generate images closely linked to the text
579
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
580
+ negative_prompt (`str` or `List[str]`, *optional*):
581
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
582
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
583
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
584
+ The number of images to generate per prompt.
585
+ eta (`float`, *optional*, defaults to 0.0):
586
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
587
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
588
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
589
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
590
+ generation deterministic.
591
+ latents (`torch.FloatTensor`, *optional*):
592
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
593
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
594
+ tensor is generated by sampling using the supplied random `generator`.
595
+ prompt_embeds (`torch.FloatTensor`, *optional*):
596
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
597
+ provided, text embeddings are generated from the `prompt` input argument.
598
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
599
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
600
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
601
+ output_type (`str`, *optional*, defaults to `"pil"`):
602
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
603
+ return_dict (`bool`, *optional*, defaults to `True`):
604
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
605
+ plain tuple.
606
+ callback (`Callable`, *optional*):
607
+ A function that calls every `callback_steps` steps during inference. The function is called with the
608
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
609
+ callback_steps (`int`, *optional*, defaults to 1):
610
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
611
+ every step.
612
+ cross_attention_kwargs (`dict`, *optional*):
613
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
614
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
615
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
616
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
617
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
618
+ using zero terminal SNR.
619
+
620
+ Examples:
621
+
622
+ Returns:
623
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
624
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
625
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
626
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
627
+ "not-safe-for-work" (nsfw) content.
628
+ """
629
+ # 0. Default height and width to unet
630
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
631
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
632
+
633
+ # 1. Check inputs. Raise error if not correct
634
+ self.check_inputs(
635
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
636
+ )
637
+
638
+ # 2. Define call parameters
639
+ if prompt is not None and isinstance(prompt, str):
640
+ batch_size = 1
641
+ elif prompt is not None and isinstance(prompt, list):
642
+ batch_size = len(prompt)
643
+ else:
644
+ batch_size = prompt_embeds.shape[0]
645
+
646
+ device = self._execution_device
647
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
648
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
649
+ # corresponds to doing no classifier free guidance.
650
+ do_classifier_free_guidance = guidance_scale > 1.0
651
+
652
+ # 3. Encode input prompt
653
+ text_encoder_lora_scale = (
654
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
655
+ )
656
+ prompt_embeds = self._encode_prompt(
657
+ prompt,
658
+ device,
659
+ num_images_per_prompt,
660
+ do_classifier_free_guidance,
661
+ negative_prompt,
662
+ prompt_embeds=prompt_embeds,
663
+ negative_prompt_embeds=negative_prompt_embeds,
664
+ lora_scale=text_encoder_lora_scale,
665
+ )
666
+
667
+ if frozen_prompt is not None: # freeze the prompt
668
+ frozen_prompt_embeds = self._encode_prompt(
669
+ frozen_prompt,
670
+ device,
671
+ num_images_per_prompt,
672
+ do_classifier_free_guidance,
673
+ negative_prompt,
674
+ prompt_embeds=prompt_embeds,
675
+ negative_prompt_embeds=negative_prompt_embeds,
676
+ lora_scale=text_encoder_lora_scale,
677
+ )
678
+ else:
679
+ frozen_prompt_embeds = None
680
+ # For classifier free guidance, we need to do two forward passes.
681
+ # Here we concatenate the unconditional and text embeddings into a single batch
682
+ # to avoid doing two forward passes
683
+ if do_classifier_free_guidance:
684
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
685
+
686
+ # 4. Prepare timesteps
687
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
688
+ timesteps = self.scheduler.timesteps
689
+
690
+ # 5. Prepare latent variables
691
+ num_channels_latents = self.unet.config.in_channels
692
+ latents = self.prepare_latents(
693
+ batch_size * num_images_per_prompt,
694
+ num_channels_latents,
695
+ height,
696
+ width,
697
+ prompt_embeds.dtype,
698
+ device,
699
+ generator,
700
+ latents,
701
+ )
702
+ print(latents.shape)
703
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
704
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
705
+
706
+ # 7. Denoising loop
707
+ latents_all = []
708
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
709
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
710
+ for i, t in enumerate(timesteps):
711
+ # expand the latents if we are doing classifier free guidance
712
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
713
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
714
+
715
+ # predict the noise residual
716
+ noise_pred = self.unet(
717
+ latent_model_input,
718
+ t,
719
+ encoder_hidden_states=prompt_embeds,
720
+ cross_attention_kwargs=cross_attention_kwargs,
721
+ return_dict=False,
722
+ # attention_mask=custom_attention_mask if custom_attention_mask is not None and i < frozen_steps else None,
723
+ )[0]
724
+
725
+ # perform guidance
726
+ if do_classifier_free_guidance:
727
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
728
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
729
+
730
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
731
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
732
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
733
+
734
+ # compute the previous noisy sample x_t -> x_t-1
735
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
736
+ latents_all.append(latents)
737
+
738
+ # put frozen latents back
739
+ if frozen_mask is not None and i < frozen_steps:
740
+ latents = latents_all_input[i+1:i+2,...] * frozen_mask + latents * (1. - frozen_mask)
741
+
742
+ # call the callback, if provided
743
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
744
+ progress_bar.update()
745
+ if callback is not None and i % callback_steps == 0:
746
+ callback(i, t, latents)
747
+
748
+ if not output_type == "latent":
749
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
750
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
751
+ else:
752
+ # image = latents
753
+ latents_all = torch.cat(latents_all, dim=0) # (num_inference_steps, num_channels_latents, height, width) assume batch_size=1
754
+ has_nsfw_concept = None
755
+
756
+ if has_nsfw_concept is None:
757
+ do_denormalize = [True] * image.shape[0]
758
+ else:
759
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
760
+
761
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
762
+
763
+ # Offload all models
764
+ self.maybe_free_model_hooks()
765
+
766
+ if not return_dict:
767
+ return (image, has_nsfw_concept)
768
+
769
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
770
+ '''
src/models/transformer_2d.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
23
+ from diffusers.utils import BaseOutput, deprecate
24
+ from diffusers.models.embeddings import PatchEmbed
25
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
26
+ from diffusers.models.modeling_utils import ModelMixin
27
+
28
+ from .attention import BasicTransformerBlock
29
+
30
+ @dataclass
31
+ class Transformer2DModelOutput(BaseOutput):
32
+ """
33
+ The output of [`Transformer2DModel`].
34
+
35
+ Args:
36
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
37
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
38
+ distributions for the unnoised latent pixels.
39
+ """
40
+
41
+ sample: torch.FloatTensor
42
+
43
+
44
+ class Transformer2DModel(ModelMixin, ConfigMixin):
45
+ """
46
+ A 2D Transformer model for image-like data.
47
+
48
+ Parameters:
49
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
50
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
51
+ in_channels (`int`, *optional*):
52
+ The number of channels in the input and output (specify if the input is **continuous**).
53
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
54
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
55
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
56
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
57
+ This is fixed during training since it is used to learn a number of position embeddings.
58
+ num_vector_embeds (`int`, *optional*):
59
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
60
+ Includes the class for the masked latent pixel.
61
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
62
+ num_embeds_ada_norm ( `int`, *optional*):
63
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
64
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
65
+ added to the hidden states.
66
+
67
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
68
+ attention_bias (`bool`, *optional*):
69
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
70
+ """
71
+
72
+ @register_to_config
73
+ def __init__(
74
+ self,
75
+ num_attention_heads: int = 16,
76
+ attention_head_dim: int = 88,
77
+ in_channels: Optional[int] = None,
78
+ out_channels: Optional[int] = None,
79
+ num_layers: int = 1,
80
+ dropout: float = 0.0,
81
+ norm_num_groups: int = 32,
82
+ cross_attention_dim: Optional[int] = None,
83
+ attention_bias: bool = False,
84
+ sample_size: Optional[int] = None,
85
+ num_vector_embeds: Optional[int] = None,
86
+ patch_size: Optional[int] = None,
87
+ activation_fn: str = "geglu",
88
+ num_embeds_ada_norm: Optional[int] = None,
89
+ use_linear_projection: bool = False,
90
+ only_cross_attention: bool = False,
91
+ double_self_attention: bool = False,
92
+ upcast_attention: bool = False,
93
+ norm_type: str = "layer_norm",
94
+ norm_elementwise_affine: bool = True,
95
+ attention_type: str = "default",
96
+ ):
97
+ super().__init__()
98
+ self.use_linear_projection = use_linear_projection
99
+ self.num_attention_heads = num_attention_heads
100
+ self.attention_head_dim = attention_head_dim
101
+ inner_dim = num_attention_heads * attention_head_dim
102
+
103
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
104
+ # Define whether input is continuous or discrete depending on configuration
105
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
106
+ self.is_input_vectorized = num_vector_embeds is not None
107
+ self.is_input_patches = in_channels is not None and patch_size is not None
108
+
109
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
110
+ deprecation_message = (
111
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
112
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
113
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
114
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
115
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
116
+ )
117
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
118
+ norm_type = "ada_norm"
119
+
120
+ if self.is_input_continuous and self.is_input_vectorized:
121
+ raise ValueError(
122
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
123
+ " sure that either `in_channels` or `num_vector_embeds` is None."
124
+ )
125
+ elif self.is_input_vectorized and self.is_input_patches:
126
+ raise ValueError(
127
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
128
+ " sure that either `num_vector_embeds` or `num_patches` is None."
129
+ )
130
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
131
+ raise ValueError(
132
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
133
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
134
+ )
135
+
136
+ # 2. Define input layers
137
+ if self.is_input_continuous:
138
+ self.in_channels = in_channels
139
+
140
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
141
+ if use_linear_projection:
142
+ self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
143
+ else:
144
+ self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
145
+ elif self.is_input_vectorized:
146
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
147
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
148
+
149
+ self.height = sample_size
150
+ self.width = sample_size
151
+ self.num_vector_embeds = num_vector_embeds
152
+ self.num_latent_pixels = self.height * self.width
153
+
154
+ self.latent_image_embedding = ImagePositionalEmbeddings(
155
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
156
+ )
157
+ elif self.is_input_patches:
158
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
159
+
160
+ self.height = sample_size
161
+ self.width = sample_size
162
+
163
+ self.patch_size = patch_size
164
+ self.pos_embed = PatchEmbed(
165
+ height=sample_size,
166
+ width=sample_size,
167
+ patch_size=patch_size,
168
+ in_channels=in_channels,
169
+ embed_dim=inner_dim,
170
+ )
171
+
172
+ # 3. Define transformers blocks
173
+ self.transformer_blocks = nn.ModuleList(
174
+ [
175
+ BasicTransformerBlock(
176
+ inner_dim,
177
+ num_attention_heads,
178
+ attention_head_dim,
179
+ dropout=dropout,
180
+ cross_attention_dim=cross_attention_dim,
181
+ activation_fn=activation_fn,
182
+ num_embeds_ada_norm=num_embeds_ada_norm,
183
+ attention_bias=attention_bias,
184
+ only_cross_attention=only_cross_attention,
185
+ double_self_attention=double_self_attention,
186
+ upcast_attention=upcast_attention,
187
+ norm_type=norm_type,
188
+ norm_elementwise_affine=norm_elementwise_affine,
189
+ attention_type=attention_type,
190
+ )
191
+ for d in range(num_layers)
192
+ ]
193
+ )
194
+
195
+ # 4. Define output layers
196
+ self.out_channels = in_channels if out_channels is None else out_channels
197
+ if self.is_input_continuous:
198
+ # TODO: should use out_channels for continuous projections
199
+ if use_linear_projection:
200
+ self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
201
+ else:
202
+ self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
203
+ elif self.is_input_vectorized:
204
+ self.norm_out = nn.LayerNorm(inner_dim)
205
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
206
+ elif self.is_input_patches:
207
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
208
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
209
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
210
+
211
+ self.gradient_checkpointing = False
212
+
213
+ def forward(
214
+ self,
215
+ hidden_states: torch.Tensor,
216
+ encoder_hidden_states: Optional[torch.Tensor] = None,
217
+ timestep: Optional[torch.LongTensor] = None,
218
+ class_labels: Optional[torch.LongTensor] = None,
219
+ cross_attention_kwargs: Dict[str, Any] = None,
220
+ attention_mask: Optional[torch.Tensor] = None,
221
+ encoder_attention_mask: Optional[torch.Tensor] = None,
222
+ return_dict: bool = True,
223
+ **kwargs,
224
+
225
+ ):
226
+ """
227
+ The [`Transformer2DModel`] forward method.
228
+
229
+ Args:
230
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
231
+ Input `hidden_states`.
232
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
233
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
234
+ self-attention.
235
+ timestep ( `torch.LongTensor`, *optional*):
236
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
237
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
238
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
239
+ `AdaLayerZeroNorm`.
240
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
241
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
242
+
243
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
244
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
245
+
246
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
247
+ above. This bias will be added to the cross-attention scores.
248
+ return_dict (`bool`, *optional*, defaults to `True`):
249
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
250
+ tuple.
251
+
252
+ Returns:
253
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
254
+ `tuple` where the first element is the sample tensor.
255
+ """
256
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
257
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
258
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
259
+ # expects mask of shape:
260
+ # [batch, key_tokens]
261
+ # adds singleton query_tokens dimension:
262
+ # [batch, 1, key_tokens]
263
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
264
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
265
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
266
+ if attention_mask is not None and not isinstance(attention_mask, list):
267
+ if attention_mask is not None and attention_mask.ndim == 2:
268
+ # assume that mask is expressed as:
269
+ # (1 = keep, 0 = discard)
270
+ # convert mask into a bias that can be added to attention scores:
271
+ # (keep = +0, discard = -10000.0)
272
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
273
+ attention_mask = attention_mask.unsqueeze(1)
274
+
275
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
276
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
277
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
278
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
279
+ elif attention_mask is not None:
280
+ if attention_mask[0].ndim == 2:
281
+ attention_mask = [(1 - mask.to(hidden_states.dtype)) * -10000.0 for mask in attention_mask]
282
+ attention_mask = [mask.unsqueeze(1) for mask in attention_mask]
283
+ if encoder_attention_mask is not None and encoder_attention_mask[0].ndim == 2:
284
+ encoder_attention_mask = [(1 - mask.to(hidden_states.dtype)) * -10000.0 for mask in encoder_attention_mask]
285
+ encoder_attention_mask = [mask.unsqueeze(1) for mask in encoder_attention_mask]
286
+
287
+ # Retrieve lora scale.
288
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
289
+
290
+ # 1. Input
291
+ if self.is_input_continuous:
292
+ batch, _, height, width = hidden_states.shape
293
+ residual = hidden_states
294
+
295
+ hidden_states = self.norm(hidden_states)
296
+ if not self.use_linear_projection:
297
+ hidden_states = self.proj_in(hidden_states, lora_scale)
298
+ inner_dim = hidden_states.shape[1]
299
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
300
+ else:
301
+ inner_dim = hidden_states.shape[1]
302
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
303
+ hidden_states = self.proj_in(hidden_states, scale=lora_scale)
304
+
305
+ elif self.is_input_vectorized:
306
+ hidden_states = self.latent_image_embedding(hidden_states)
307
+ elif self.is_input_patches:
308
+ hidden_states = self.pos_embed(hidden_states)
309
+
310
+ # 2. Blocks
311
+ for block in self.transformer_blocks:
312
+ if self.training and self.gradient_checkpointing:
313
+ hidden_states = torch.utils.checkpoint.checkpoint(
314
+ block,
315
+ hidden_states,
316
+ attention_mask,
317
+ encoder_hidden_states,
318
+ encoder_attention_mask,
319
+ timestep,
320
+ cross_attention_kwargs,
321
+ class_labels,
322
+ use_reentrant=False,
323
+ )
324
+ else:
325
+
326
+ hidden_states = block(
327
+ hidden_states,
328
+ attention_mask=attention_mask,
329
+ encoder_hidden_states=encoder_hidden_states,
330
+ encoder_attention_mask=encoder_attention_mask,
331
+ timestep=timestep,
332
+ cross_attention_kwargs=cross_attention_kwargs,
333
+ class_labels=class_labels,
334
+ **kwargs,
335
+
336
+ )
337
+
338
+ # 3. Output
339
+ if self.is_input_continuous:
340
+ if not self.use_linear_projection:
341
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
342
+ hidden_states = self.proj_out(hidden_states, scale=lora_scale)
343
+ else:
344
+ hidden_states = self.proj_out(hidden_states, scale=lora_scale)
345
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
346
+
347
+ output = hidden_states + residual
348
+ elif self.is_input_vectorized:
349
+ hidden_states = self.norm_out(hidden_states)
350
+ logits = self.out(hidden_states)
351
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
352
+ logits = logits.permute(0, 2, 1)
353
+
354
+ # log(p(x_0))
355
+ output = F.log_softmax(logits.double(), dim=1).float()
356
+ elif self.is_input_patches:
357
+ # TODO: cleanup!
358
+ conditioning = self.transformer_blocks[0].norm1.emb(
359
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
360
+ )
361
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
362
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
363
+ hidden_states = self.proj_out_2(hidden_states)
364
+
365
+ # unpatchify
366
+ height = width = int(hidden_states.shape[1] ** 0.5)
367
+ hidden_states = hidden_states.reshape(
368
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
369
+ )
370
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
371
+ output = hidden_states.reshape(
372
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
373
+ )
374
+
375
+ if not return_dict:
376
+ return (output,)
377
+
378
+ return Transformer2DModelOutput(sample=output)
src/models/transformer_temporal.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from torch import nn
19
+ import math
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.utils import BaseOutput
23
+ from diffusers.models.modeling_utils import ModelMixin
24
+
25
+ from .attention import BasicTransformerBlock
26
+
27
+ @dataclass
28
+ class TransformerTemporalModelOutput(BaseOutput):
29
+ """
30
+ The output of [`TransformerTemporalModel`].
31
+
32
+ Args:
33
+ sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
34
+ The hidden states output conditioned on `encoder_hidden_states` input.
35
+ """
36
+
37
+ sample: torch.FloatTensor
38
+
39
+
40
+ class TransformerTemporalModel(ModelMixin, ConfigMixin):
41
+ """
42
+ A Transformer model for video-like data.
43
+
44
+ Parameters:
45
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
46
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
47
+ in_channels (`int`, *optional*):
48
+ The number of channels in the input and output (specify if the input is **continuous**).
49
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
50
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
51
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
52
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
53
+ This is fixed during training since it is used to learn a number of position embeddings.
54
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
55
+ attention_bias (`bool`, *optional*):
56
+ Configure if the `TransformerBlock` attention should contain a bias parameter.
57
+ double_self_attention (`bool`, *optional*):
58
+ Configure if each `TransformerBlock` should contain two self-attention layers.
59
+ """
60
+
61
+ @register_to_config
62
+ def __init__(
63
+ self,
64
+ num_attention_heads: int = 16,
65
+ attention_head_dim: int = 88,
66
+ in_channels: Optional[int] = None,
67
+ out_channels: Optional[int] = None,
68
+ num_layers: int = 1,
69
+ dropout: float = 0.0,
70
+ norm_num_groups: int = 32,
71
+ cross_attention_dim: Optional[int] = None,
72
+ attention_bias: bool = False,
73
+ sample_size: Optional[int] = None,
74
+ activation_fn: str = "geglu",
75
+ norm_elementwise_affine: bool = True,
76
+ double_self_attention: bool = True,
77
+ ):
78
+ super().__init__()
79
+ self.num_attention_heads = num_attention_heads
80
+ self.attention_head_dim = attention_head_dim
81
+ inner_dim = num_attention_heads * attention_head_dim
82
+
83
+ self.in_channels = in_channels
84
+
85
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
86
+ self.proj_in = nn.Linear(in_channels, inner_dim)
87
+
88
+ # 3. Define transformers blocks
89
+ self.transformer_blocks = nn.ModuleList(
90
+ [
91
+ BasicTransformerBlock(
92
+ inner_dim,
93
+ num_attention_heads,
94
+ attention_head_dim,
95
+ dropout=dropout,
96
+ cross_attention_dim=cross_attention_dim,
97
+ activation_fn=activation_fn,
98
+ attention_bias=attention_bias,
99
+ double_self_attention=double_self_attention,
100
+ norm_elementwise_affine=norm_elementwise_affine,
101
+ )
102
+ for d in range(num_layers)
103
+ ]
104
+ )
105
+
106
+ self.proj_out = nn.Linear(inner_dim, in_channels)
107
+
108
+ def forward(
109
+ self,
110
+ hidden_states,
111
+ encoder_hidden_states=None,
112
+ timestep=None,
113
+ class_labels=None,
114
+ num_frames=1,
115
+ cross_attention_kwargs=None,
116
+ return_dict: bool = True,
117
+ attention_mask=None,
118
+ encoder_attention_mask=None,
119
+ **kwargs,
120
+ ):
121
+ """
122
+ The [`TransformerTemporal`] forward method.
123
+
124
+ Args:
125
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
126
+ Input hidden_states.
127
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
128
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
129
+ self-attention.
130
+ timestep ( `torch.long`, *optional*):
131
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
132
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
133
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
134
+ `AdaLayerZeroNorm`.
135
+ return_dict (`bool`, *optional*, defaults to `True`):
136
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
137
+ tuple.
138
+
139
+ Returns:
140
+ [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
141
+ If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
142
+ returned, otherwise a `tuple` where the first element is the sample tensor.
143
+ """
144
+ # 1. Input
145
+ batch_frames, channel, height, width = hidden_states.shape
146
+ batch_size = batch_frames // num_frames
147
+ if attention_mask is not None:
148
+
149
+ if not isinstance(attention_mask, list):
150
+ # Attn mask - (32, 1, 1024
151
+ new_attn_mask = attention_mask.clone()
152
+ # Convert to (2,16,1024)
153
+ new_attn_mask = new_attn_mask.permute(1,0,2).reshape(-1,num_frames, new_attn_mask.shape[2])
154
+ # spatial_dim_attn_mask = int(math.sqrt(new_attn_mask.shape[-1]))
155
+ scaling_factor = int(math.sqrt(new_attn_mask.shape[2] / (height*width)))
156
+
157
+ mask_x = int(height * scaling_factor)
158
+ mask_y = int(width * scaling_factor)
159
+
160
+
161
+ # Scale the attention mask possibly
162
+ new_attn_mask = new_attn_mask.reshape(-1, num_frames, mask_x, mask_y)[:,:,::scaling_factor, ::scaling_factor]
163
+ # Convert to (2,16,64)
164
+ new_attn_mask = new_attn_mask.reshape(-1, num_frames, height*width).permute(0,2,1)
165
+ # Convert to (128, 1, 16) when hidden states are (128, 16, 1280)
166
+ new_attn_mask = new_attn_mask.reshape(-1,1,num_frames)
167
+
168
+ # Trying to invert this mask, so that background is the only thing active -
169
+ new_attn_mask = torch.where(new_attn_mask < 0., 0., -10000.).type(new_attn_mask.dtype).to(new_attn_mask.device)
170
+ else:
171
+ new_attn_mask_list = []
172
+ for attn_mask in attention_mask:
173
+ new_attn_mask = attn_mask.clone()
174
+ new_attn_mask = new_attn_mask.permute(1,0,2).reshape(-1,num_frames, new_attn_mask.shape[2])
175
+ scaling_factor = int(math.sqrt(new_attn_mask.shape[2] / (height*width)))
176
+
177
+ mask_x = int(height * scaling_factor)
178
+ mask_y = int(width * scaling_factor)
179
+
180
+
181
+ # Scale the attention mask possibly
182
+ new_attn_mask = new_attn_mask.reshape(-1, num_frames, mask_x, mask_y)[:,:,::scaling_factor, ::scaling_factor]
183
+ new_attn_mask = new_attn_mask.reshape(-1, num_frames, height*width).permute(0,2,1)
184
+ new_attn_mask = new_attn_mask.reshape(-1,1,num_frames)
185
+ new_attn_mask = torch.where(new_attn_mask < 0., 0., -10000.).type(new_attn_mask.dtype).to(new_attn_mask.device)
186
+ new_attn_mask_list.append(new_attn_mask)
187
+
188
+ new_attn_mask = new_attn_mask_list
189
+ else:
190
+ new_attn_mask = None
191
+
192
+ residual = hidden_states
193
+
194
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
195
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
196
+
197
+ hidden_states = self.norm(hidden_states)
198
+ hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
199
+
200
+ hidden_states = self.proj_in(hidden_states)
201
+
202
+
203
+ # 2. Blocks
204
+ for block in self.transformer_blocks:
205
+ hidden_states = block(
206
+ hidden_states,
207
+ encoder_hidden_states=encoder_hidden_states,
208
+ timestep=timestep,
209
+ cross_attention_kwargs=cross_attention_kwargs,
210
+ class_labels=class_labels,
211
+ attention_mask=new_attn_mask,
212
+ encoder_attention_mask=encoder_attention_mask,
213
+ # make_2d_attention_mask=True, # Check this
214
+ # block_diagonal_attention=True, # TODO - Check this
215
+ **kwargs,
216
+ )
217
+
218
+ # 3. Output
219
+ hidden_states = self.proj_out(hidden_states)
220
+ hidden_states = (
221
+ hidden_states[None, None, :]
222
+ .reshape(batch_size, height, width, channel, num_frames)
223
+ .permute(0, 3, 4, 1, 2)
224
+ .contiguous()
225
+ )
226
+ hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
227
+
228
+ output = hidden_states + residual
229
+
230
+ if not return_dict:
231
+ return (output,)
232
+
233
+ return TransformerTemporalModelOutput(sample=output)
src/models/unet_2d_blocks.py ADDED
The diff for this file is too large to render. See raw diff
 
src/models/unet_2d_condition.py ADDED
@@ -0,0 +1,1052 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import UNet2DConditionLoadersMixin
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.models.activations import get_activation
25
+ from .attention_processor import (
26
+ ADDED_KV_ATTENTION_PROCESSORS,
27
+ CROSS_ATTENTION_PROCESSORS,
28
+ AttentionProcessor,
29
+ AttnAddedKVProcessor,
30
+ AttnProcessor,
31
+ )
32
+ from diffusers.models.embeddings import (
33
+ GaussianFourierProjection,
34
+ ImageHintTimeEmbedding,
35
+ ImageProjection,
36
+ ImageTimeEmbedding,
37
+ PositionNet,
38
+ TextImageProjection,
39
+ TextImageTimeEmbedding,
40
+ TextTimeEmbedding,
41
+ TimestepEmbedding,
42
+ Timesteps,
43
+ )
44
+ from diffusers.models.modeling_utils import ModelMixin
45
+ from .unet_2d_blocks import (
46
+ UNetMidBlock2DCrossAttn,
47
+ UNetMidBlock2DSimpleCrossAttn,
48
+ get_down_block,
49
+ get_up_block,
50
+ )
51
+
52
+
53
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
54
+
55
+
56
+ @dataclass
57
+ class UNet2DConditionOutput(BaseOutput):
58
+ """
59
+ The output of [`UNet2DConditionModel`].
60
+
61
+ Args:
62
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
63
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
64
+ """
65
+
66
+ sample: torch.FloatTensor = None
67
+
68
+
69
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
70
+ r"""
71
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
72
+ shaped output.
73
+
74
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
75
+ for all models (such as downloading or saving).
76
+
77
+ Parameters:
78
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
79
+ Height and width of input/output sample.
80
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
81
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
82
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
83
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
84
+ Whether to flip the sin to cos in the time embedding.
85
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
86
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
87
+ The tuple of downsample blocks to use.
88
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
89
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
90
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
91
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
92
+ The tuple of upsample blocks to use.
93
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
94
+ Whether to include self-attention in the basic transformer blocks, see
95
+ [`~models.attention.BasicTransformerBlock`].
96
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
97
+ The tuple of output channels for each block.
98
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
99
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
100
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
101
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
102
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
103
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
104
+ If `None`, normalization and activation layers is skipped in post-processing.
105
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
106
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
107
+ The dimension of the cross attention features.
108
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
109
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
110
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
111
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
112
+ encoder_hid_dim (`int`, *optional*, defaults to None):
113
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
114
+ dimension to `cross_attention_dim`.
115
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
116
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
117
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
118
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
119
+ num_attention_heads (`int`, *optional*):
120
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
121
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
122
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
123
+ class_embed_type (`str`, *optional*, defaults to `None`):
124
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
125
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
126
+ addition_embed_type (`str`, *optional*, defaults to `None`):
127
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
128
+ "text". "text" will use the `TextTimeEmbedding` layer.
129
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
130
+ Dimension for the timestep embeddings.
131
+ num_class_embeds (`int`, *optional*, defaults to `None`):
132
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
133
+ class conditioning with `class_embed_type` equal to `None`.
134
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
135
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
136
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
137
+ An optional override for the dimension of the projected time embedding.
138
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
139
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
140
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
141
+ timestep_post_act (`str`, *optional*, defaults to `None`):
142
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
143
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
144
+ The dimension of `cond_proj` layer in the timestep embedding.
145
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
146
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
147
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
148
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
149
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
150
+ embeddings with the class embeddings.
151
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
152
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
153
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
154
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
155
+ otherwise.
156
+ """
157
+
158
+ _supports_gradient_checkpointing = True
159
+
160
+ @register_to_config
161
+ def __init__(
162
+ self,
163
+ sample_size: Optional[int] = None,
164
+ in_channels: int = 4,
165
+ out_channels: int = 4,
166
+ center_input_sample: bool = False,
167
+ flip_sin_to_cos: bool = True,
168
+ freq_shift: int = 0,
169
+ down_block_types: Tuple[str] = (
170
+ "CrossAttnDownBlock2D",
171
+ "CrossAttnDownBlock2D",
172
+ "CrossAttnDownBlock2D",
173
+ "DownBlock2D",
174
+ ),
175
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
176
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
177
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
178
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
179
+ layers_per_block: Union[int, Tuple[int]] = 2,
180
+ downsample_padding: int = 1,
181
+ mid_block_scale_factor: float = 1,
182
+ dropout: float = 0.0,
183
+ act_fn: str = "silu",
184
+ norm_num_groups: Optional[int] = 32,
185
+ norm_eps: float = 1e-5,
186
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
187
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
188
+ encoder_hid_dim: Optional[int] = None,
189
+ encoder_hid_dim_type: Optional[str] = None,
190
+ attention_head_dim: Union[int, Tuple[int]] = 8,
191
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
192
+ dual_cross_attention: bool = False,
193
+ use_linear_projection: bool = False,
194
+ class_embed_type: Optional[str] = None,
195
+ addition_embed_type: Optional[str] = None,
196
+ addition_time_embed_dim: Optional[int] = None,
197
+ num_class_embeds: Optional[int] = None,
198
+ upcast_attention: bool = False,
199
+ resnet_time_scale_shift: str = "default",
200
+ resnet_skip_time_act: bool = False,
201
+ resnet_out_scale_factor: int = 1.0,
202
+ time_embedding_type: str = "positional",
203
+ time_embedding_dim: Optional[int] = None,
204
+ time_embedding_act_fn: Optional[str] = None,
205
+ timestep_post_act: Optional[str] = None,
206
+ time_cond_proj_dim: Optional[int] = None,
207
+ conv_in_kernel: int = 3,
208
+ conv_out_kernel: int = 3,
209
+ projection_class_embeddings_input_dim: Optional[int] = None,
210
+ attention_type: str = "default",
211
+ class_embeddings_concat: bool = False,
212
+ mid_block_only_cross_attention: Optional[bool] = None,
213
+ cross_attention_norm: Optional[str] = None,
214
+ addition_embed_type_num_heads=64,
215
+ ):
216
+ super().__init__()
217
+
218
+ self.sample_size = sample_size
219
+
220
+ if num_attention_heads is not None:
221
+ raise ValueError(
222
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
223
+ )
224
+
225
+ # If `num_attention_heads` is not defined (which is the case for most models)
226
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
227
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
228
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
229
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
230
+ # which is why we correct for the naming here.
231
+ num_attention_heads = num_attention_heads or attention_head_dim
232
+
233
+ # Check inputs
234
+ if len(down_block_types) != len(up_block_types):
235
+ raise ValueError(
236
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
237
+ )
238
+
239
+ if len(block_out_channels) != len(down_block_types):
240
+ raise ValueError(
241
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
242
+ )
243
+
244
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
245
+ raise ValueError(
246
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
247
+ )
248
+
249
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
250
+ raise ValueError(
251
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
252
+ )
253
+
254
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
255
+ raise ValueError(
256
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
257
+ )
258
+
259
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
260
+ raise ValueError(
261
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
262
+ )
263
+
264
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
265
+ raise ValueError(
266
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
267
+ )
268
+
269
+ # input
270
+ conv_in_padding = (conv_in_kernel - 1) // 2
271
+ self.conv_in = nn.Conv2d(
272
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
273
+ )
274
+
275
+ # time
276
+ if time_embedding_type == "fourier":
277
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
278
+ if time_embed_dim % 2 != 0:
279
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
280
+ self.time_proj = GaussianFourierProjection(
281
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
282
+ )
283
+ timestep_input_dim = time_embed_dim
284
+ elif time_embedding_type == "positional":
285
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
286
+
287
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
288
+ timestep_input_dim = block_out_channels[0]
289
+ else:
290
+ raise ValueError(
291
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
292
+ )
293
+
294
+ self.time_embedding = TimestepEmbedding(
295
+ timestep_input_dim,
296
+ time_embed_dim,
297
+ act_fn=act_fn,
298
+ post_act_fn=timestep_post_act,
299
+ cond_proj_dim=time_cond_proj_dim,
300
+ )
301
+
302
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
303
+ encoder_hid_dim_type = "text_proj"
304
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
305
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
306
+
307
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
308
+ raise ValueError(
309
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
310
+ )
311
+
312
+ if encoder_hid_dim_type == "text_proj":
313
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
314
+ elif encoder_hid_dim_type == "text_image_proj":
315
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
316
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
317
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
318
+ self.encoder_hid_proj = TextImageProjection(
319
+ text_embed_dim=encoder_hid_dim,
320
+ image_embed_dim=cross_attention_dim,
321
+ cross_attention_dim=cross_attention_dim,
322
+ )
323
+ elif encoder_hid_dim_type == "image_proj":
324
+ # Kandinsky 2.2
325
+ self.encoder_hid_proj = ImageProjection(
326
+ image_embed_dim=encoder_hid_dim,
327
+ cross_attention_dim=cross_attention_dim,
328
+ )
329
+ elif encoder_hid_dim_type is not None:
330
+ raise ValueError(
331
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
332
+ )
333
+ else:
334
+ self.encoder_hid_proj = None
335
+
336
+ # class embedding
337
+ if class_embed_type is None and num_class_embeds is not None:
338
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
339
+ elif class_embed_type == "timestep":
340
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
341
+ elif class_embed_type == "identity":
342
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
343
+ elif class_embed_type == "projection":
344
+ if projection_class_embeddings_input_dim is None:
345
+ raise ValueError(
346
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
347
+ )
348
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
349
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
350
+ # 2. it projects from an arbitrary input dimension.
351
+ #
352
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
353
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
354
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
355
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
356
+ elif class_embed_type == "simple_projection":
357
+ if projection_class_embeddings_input_dim is None:
358
+ raise ValueError(
359
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
360
+ )
361
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
362
+ else:
363
+ self.class_embedding = None
364
+
365
+ if addition_embed_type == "text":
366
+ if encoder_hid_dim is not None:
367
+ text_time_embedding_from_dim = encoder_hid_dim
368
+ else:
369
+ text_time_embedding_from_dim = cross_attention_dim
370
+
371
+ self.add_embedding = TextTimeEmbedding(
372
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
373
+ )
374
+ elif addition_embed_type == "text_image":
375
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
376
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
377
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
378
+ self.add_embedding = TextImageTimeEmbedding(
379
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
380
+ )
381
+ elif addition_embed_type == "text_time":
382
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
383
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
384
+ elif addition_embed_type == "image":
385
+ # Kandinsky 2.2
386
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
387
+ elif addition_embed_type == "image_hint":
388
+ # Kandinsky 2.2 ControlNet
389
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
390
+ elif addition_embed_type is not None:
391
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
392
+
393
+ if time_embedding_act_fn is None:
394
+ self.time_embed_act = None
395
+ else:
396
+ self.time_embed_act = get_activation(time_embedding_act_fn)
397
+
398
+ self.down_blocks = nn.ModuleList([])
399
+ self.up_blocks = nn.ModuleList([])
400
+
401
+ if isinstance(only_cross_attention, bool):
402
+ if mid_block_only_cross_attention is None:
403
+ mid_block_only_cross_attention = only_cross_attention
404
+
405
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
406
+
407
+ if mid_block_only_cross_attention is None:
408
+ mid_block_only_cross_attention = False
409
+
410
+ if isinstance(num_attention_heads, int):
411
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
412
+
413
+ if isinstance(attention_head_dim, int):
414
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
415
+
416
+ if isinstance(cross_attention_dim, int):
417
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
418
+
419
+ if isinstance(layers_per_block, int):
420
+ layers_per_block = [layers_per_block] * len(down_block_types)
421
+
422
+ if isinstance(transformer_layers_per_block, int):
423
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
424
+
425
+ if class_embeddings_concat:
426
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
427
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
428
+ # regular time embeddings
429
+ blocks_time_embed_dim = time_embed_dim * 2
430
+ else:
431
+ blocks_time_embed_dim = time_embed_dim
432
+
433
+ # down
434
+ output_channel = block_out_channels[0]
435
+ for i, down_block_type in enumerate(down_block_types):
436
+ input_channel = output_channel
437
+ output_channel = block_out_channels[i]
438
+ is_final_block = i == len(block_out_channels) - 1
439
+
440
+ down_block = get_down_block(
441
+ down_block_type,
442
+ num_layers=layers_per_block[i],
443
+ transformer_layers_per_block=transformer_layers_per_block[i],
444
+ in_channels=input_channel,
445
+ out_channels=output_channel,
446
+ temb_channels=blocks_time_embed_dim,
447
+ add_downsample=not is_final_block,
448
+ resnet_eps=norm_eps,
449
+ resnet_act_fn=act_fn,
450
+ resnet_groups=norm_num_groups,
451
+ cross_attention_dim=cross_attention_dim[i],
452
+ num_attention_heads=num_attention_heads[i],
453
+ downsample_padding=downsample_padding,
454
+ dual_cross_attention=dual_cross_attention,
455
+ use_linear_projection=use_linear_projection,
456
+ only_cross_attention=only_cross_attention[i],
457
+ upcast_attention=upcast_attention,
458
+ resnet_time_scale_shift=resnet_time_scale_shift,
459
+ attention_type=attention_type,
460
+ resnet_skip_time_act=resnet_skip_time_act,
461
+ resnet_out_scale_factor=resnet_out_scale_factor,
462
+ cross_attention_norm=cross_attention_norm,
463
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
464
+ dropout=dropout,
465
+ )
466
+ self.down_blocks.append(down_block)
467
+
468
+ # mid
469
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
470
+ self.mid_block = UNetMidBlock2DCrossAttn(
471
+ transformer_layers_per_block=transformer_layers_per_block[-1],
472
+ in_channels=block_out_channels[-1],
473
+ temb_channels=blocks_time_embed_dim,
474
+ dropout=dropout,
475
+ resnet_eps=norm_eps,
476
+ resnet_act_fn=act_fn,
477
+ output_scale_factor=mid_block_scale_factor,
478
+ resnet_time_scale_shift=resnet_time_scale_shift,
479
+ cross_attention_dim=cross_attention_dim[-1],
480
+ num_attention_heads=num_attention_heads[-1],
481
+ resnet_groups=norm_num_groups,
482
+ dual_cross_attention=dual_cross_attention,
483
+ use_linear_projection=use_linear_projection,
484
+ upcast_attention=upcast_attention,
485
+ attention_type=attention_type,
486
+ )
487
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
488
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
489
+ in_channels=block_out_channels[-1],
490
+ temb_channels=blocks_time_embed_dim,
491
+ dropout=dropout,
492
+ resnet_eps=norm_eps,
493
+ resnet_act_fn=act_fn,
494
+ output_scale_factor=mid_block_scale_factor,
495
+ cross_attention_dim=cross_attention_dim[-1],
496
+ attention_head_dim=attention_head_dim[-1],
497
+ resnet_groups=norm_num_groups,
498
+ resnet_time_scale_shift=resnet_time_scale_shift,
499
+ skip_time_act=resnet_skip_time_act,
500
+ only_cross_attention=mid_block_only_cross_attention,
501
+ cross_attention_norm=cross_attention_norm,
502
+ )
503
+ elif mid_block_type is None:
504
+ self.mid_block = None
505
+ else:
506
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
507
+
508
+ # count how many layers upsample the images
509
+ self.num_upsamplers = 0
510
+
511
+ # up
512
+ reversed_block_out_channels = list(reversed(block_out_channels))
513
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
514
+ reversed_layers_per_block = list(reversed(layers_per_block))
515
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
516
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
517
+ only_cross_attention = list(reversed(only_cross_attention))
518
+
519
+ output_channel = reversed_block_out_channels[0]
520
+ for i, up_block_type in enumerate(up_block_types):
521
+ is_final_block = i == len(block_out_channels) - 1
522
+
523
+ prev_output_channel = output_channel
524
+ output_channel = reversed_block_out_channels[i]
525
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
526
+
527
+ # add upsample block for all BUT final layer
528
+ if not is_final_block:
529
+ add_upsample = True
530
+ self.num_upsamplers += 1
531
+ else:
532
+ add_upsample = False
533
+
534
+ up_block = get_up_block(
535
+ up_block_type,
536
+ num_layers=reversed_layers_per_block[i] + 1,
537
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
538
+ in_channels=input_channel,
539
+ out_channels=output_channel,
540
+ prev_output_channel=prev_output_channel,
541
+ temb_channels=blocks_time_embed_dim,
542
+ add_upsample=add_upsample,
543
+ resnet_eps=norm_eps,
544
+ resnet_act_fn=act_fn,
545
+ resnet_groups=norm_num_groups,
546
+ cross_attention_dim=reversed_cross_attention_dim[i],
547
+ num_attention_heads=reversed_num_attention_heads[i],
548
+ dual_cross_attention=dual_cross_attention,
549
+ use_linear_projection=use_linear_projection,
550
+ only_cross_attention=only_cross_attention[i],
551
+ upcast_attention=upcast_attention,
552
+ resnet_time_scale_shift=resnet_time_scale_shift,
553
+ attention_type=attention_type,
554
+ resnet_skip_time_act=resnet_skip_time_act,
555
+ resnet_out_scale_factor=resnet_out_scale_factor,
556
+ cross_attention_norm=cross_attention_norm,
557
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
558
+ dropout=dropout,
559
+ )
560
+ self.up_blocks.append(up_block)
561
+ prev_output_channel = output_channel
562
+
563
+ # out
564
+ if norm_num_groups is not None:
565
+ self.conv_norm_out = nn.GroupNorm(
566
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
567
+ )
568
+
569
+ self.conv_act = get_activation(act_fn)
570
+
571
+ else:
572
+ self.conv_norm_out = None
573
+ self.conv_act = None
574
+
575
+ conv_out_padding = (conv_out_kernel - 1) // 2
576
+ self.conv_out = nn.Conv2d(
577
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
578
+ )
579
+
580
+ if attention_type in ["gated", "gated-text-image"]:
581
+ positive_len = 768
582
+ if isinstance(cross_attention_dim, int):
583
+ positive_len = cross_attention_dim
584
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
585
+ positive_len = cross_attention_dim[0]
586
+
587
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
588
+ self.position_net = PositionNet(
589
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
590
+ )
591
+
592
+ @property
593
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
594
+ r"""
595
+ Returns:
596
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
597
+ indexed by its weight name.
598
+ """
599
+ # set recursively
600
+ processors = {}
601
+
602
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
603
+ if hasattr(module, "get_processor"):
604
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
605
+
606
+ for sub_name, child in module.named_children():
607
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
608
+
609
+ return processors
610
+
611
+ for name, module in self.named_children():
612
+ fn_recursive_add_processors(name, module, processors)
613
+
614
+ return processors
615
+
616
+ def set_attn_processor(
617
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
618
+ ):
619
+ r"""
620
+ Sets the attention processor to use to compute attention.
621
+
622
+ Parameters:
623
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
624
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
625
+ for **all** `Attention` layers.
626
+
627
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
628
+ processor. This is strongly recommended when setting trainable attention processors.
629
+
630
+ """
631
+ count = len(self.attn_processors.keys())
632
+
633
+ if isinstance(processor, dict) and len(processor) != count:
634
+ raise ValueError(
635
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
636
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
637
+ )
638
+
639
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
640
+ if hasattr(module, "set_processor"):
641
+ if not isinstance(processor, dict):
642
+ module.set_processor(processor, _remove_lora=_remove_lora)
643
+ else:
644
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
645
+
646
+ for sub_name, child in module.named_children():
647
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
648
+
649
+ for name, module in self.named_children():
650
+ fn_recursive_attn_processor(name, module, processor)
651
+
652
+ def set_default_attn_processor(self):
653
+ """
654
+ Disables custom attention processors and sets the default attention implementation.
655
+ """
656
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
657
+ processor = AttnAddedKVProcessor()
658
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
659
+ processor = AttnProcessor()
660
+ else:
661
+ raise ValueError(
662
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
663
+ )
664
+
665
+ self.set_attn_processor(processor, _remove_lora=True)
666
+
667
+ def set_attention_slice(self, slice_size):
668
+ r"""
669
+ Enable sliced attention computation.
670
+
671
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
672
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
673
+
674
+ Args:
675
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
676
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
677
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
678
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
679
+ must be a multiple of `slice_size`.
680
+ """
681
+ sliceable_head_dims = []
682
+
683
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
684
+ if hasattr(module, "set_attention_slice"):
685
+ sliceable_head_dims.append(module.sliceable_head_dim)
686
+
687
+ for child in module.children():
688
+ fn_recursive_retrieve_sliceable_dims(child)
689
+
690
+ # retrieve number of attention layers
691
+ for module in self.children():
692
+ fn_recursive_retrieve_sliceable_dims(module)
693
+
694
+ num_sliceable_layers = len(sliceable_head_dims)
695
+
696
+ if slice_size == "auto":
697
+ # half the attention head size is usually a good trade-off between
698
+ # speed and memory
699
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
700
+ elif slice_size == "max":
701
+ # make smallest slice possible
702
+ slice_size = num_sliceable_layers * [1]
703
+
704
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
705
+
706
+ if len(slice_size) != len(sliceable_head_dims):
707
+ raise ValueError(
708
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
709
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
710
+ )
711
+
712
+ for i in range(len(slice_size)):
713
+ size = slice_size[i]
714
+ dim = sliceable_head_dims[i]
715
+ if size is not None and size > dim:
716
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
717
+
718
+ # Recursively walk through all the children.
719
+ # Any children which exposes the set_attention_slice method
720
+ # gets the message
721
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
722
+ if hasattr(module, "set_attention_slice"):
723
+ module.set_attention_slice(slice_size.pop())
724
+
725
+ for child in module.children():
726
+ fn_recursive_set_attention_slice(child, slice_size)
727
+
728
+ reversed_slice_size = list(reversed(slice_size))
729
+ for module in self.children():
730
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
731
+
732
+ def _set_gradient_checkpointing(self, module, value=False):
733
+ if hasattr(module, "gradient_checkpointing"):
734
+ module.gradient_checkpointing = value
735
+
736
+ def forward(
737
+ self,
738
+ sample: torch.FloatTensor,
739
+ timestep: Union[torch.Tensor, float, int],
740
+ encoder_hidden_states: torch.Tensor,
741
+ class_labels: Optional[torch.Tensor] = None,
742
+ timestep_cond: Optional[torch.Tensor] = None,
743
+ attention_mask: Optional[torch.Tensor] = None,
744
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
745
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
746
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
747
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
748
+ encoder_attention_mask: Optional[torch.Tensor] = None,
749
+ return_dict: bool = True,
750
+ **kwargs,
751
+ ) -> Union[UNet2DConditionOutput, Tuple]:
752
+ r"""
753
+ The [`UNet2DConditionModel`] forward method.
754
+
755
+ Args:
756
+ sample (`torch.FloatTensor`):
757
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
758
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
759
+ encoder_hidden_states (`torch.FloatTensor`):
760
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
761
+ encoder_attention_mask (`torch.Tensor`):
762
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
763
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
764
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
765
+ return_dict (`bool`, *optional*, defaults to `True`):
766
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
767
+ tuple.
768
+ cross_attention_kwargs (`dict`, *optional*):
769
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
770
+ added_cond_kwargs: (`dict`, *optional*):
771
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
772
+ are passed along to the UNet blocks.
773
+
774
+ Returns:
775
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
776
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
777
+ a `tuple` is returned where the first element is the sample tensor.
778
+ """
779
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
780
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
781
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
782
+ # on the fly if necessary.
783
+ default_overall_up_factor = 2**self.num_upsamplers
784
+
785
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
786
+ forward_upsample_size = False
787
+ upsample_size = None
788
+
789
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
790
+ logger.info("Forward upsample size to force interpolation output size.")
791
+ forward_upsample_size = True
792
+
793
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
794
+ # expects mask of shape:
795
+ # [batch, key_tokens]
796
+ # adds singleton query_tokens dimension:
797
+ # [batch, 1, key_tokens]
798
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
799
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
800
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
801
+ if attention_mask is not None:
802
+ # assume that mask is expressed as:
803
+ # (1 = keep, 0 = discard)
804
+ # convert mask into a bias that can be added to attention scores:
805
+ # (keep = +0, discard = -10000.0)
806
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
807
+ attention_mask = attention_mask.unsqueeze(1)
808
+
809
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
810
+ if encoder_attention_mask is not None:
811
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
812
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
813
+
814
+ # 0. center input if necessary
815
+ if self.config.center_input_sample:
816
+ sample = 2 * sample - 1.0
817
+
818
+ # 1. time
819
+ timesteps = timestep
820
+ if not torch.is_tensor(timesteps):
821
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
822
+ # This would be a good case for the `match` statement (Python 3.10+)
823
+ is_mps = sample.device.type == "mps"
824
+ if isinstance(timestep, float):
825
+ dtype = torch.float32 if is_mps else torch.float64
826
+ else:
827
+ dtype = torch.int32 if is_mps else torch.int64
828
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
829
+ elif len(timesteps.shape) == 0:
830
+ timesteps = timesteps[None].to(sample.device)
831
+
832
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
833
+ timesteps = timesteps.expand(sample.shape[0])
834
+
835
+ t_emb = self.time_proj(timesteps)
836
+
837
+ # `Timesteps` does not contain any weights and will always return f32 tensors
838
+ # but time_embedding might actually be running in fp16. so we need to cast here.
839
+ # there might be better ways to encapsulate this.
840
+ t_emb = t_emb.to(dtype=sample.dtype)
841
+
842
+ emb = self.time_embedding(t_emb, timestep_cond)
843
+ aug_emb = None
844
+
845
+ if self.class_embedding is not None:
846
+ if class_labels is None:
847
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
848
+
849
+ if self.config.class_embed_type == "timestep":
850
+ class_labels = self.time_proj(class_labels)
851
+
852
+ # `Timesteps` does not contain any weights and will always return f32 tensors
853
+ # there might be better ways to encapsulate this.
854
+ class_labels = class_labels.to(dtype=sample.dtype)
855
+
856
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
857
+
858
+ if self.config.class_embeddings_concat:
859
+ emb = torch.cat([emb, class_emb], dim=-1)
860
+ else:
861
+ emb = emb + class_emb
862
+
863
+ if self.config.addition_embed_type == "text":
864
+ aug_emb = self.add_embedding(encoder_hidden_states)
865
+ elif self.config.addition_embed_type == "text_image":
866
+ # Kandinsky 2.1 - style
867
+ if "image_embeds" not in added_cond_kwargs:
868
+ raise ValueError(
869
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
870
+ )
871
+
872
+ image_embs = added_cond_kwargs.get("image_embeds")
873
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
874
+ aug_emb = self.add_embedding(text_embs, image_embs)
875
+ elif self.config.addition_embed_type == "text_time":
876
+ # SDXL - style
877
+ if "text_embeds" not in added_cond_kwargs:
878
+ raise ValueError(
879
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
880
+ )
881
+ text_embeds = added_cond_kwargs.get("text_embeds")
882
+ if "time_ids" not in added_cond_kwargs:
883
+ raise ValueError(
884
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
885
+ )
886
+ time_ids = added_cond_kwargs.get("time_ids")
887
+ time_embeds = self.add_time_proj(time_ids.flatten())
888
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
889
+
890
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
891
+ add_embeds = add_embeds.to(emb.dtype)
892
+ aug_emb = self.add_embedding(add_embeds)
893
+ elif self.config.addition_embed_type == "image":
894
+ # Kandinsky 2.2 - style
895
+ if "image_embeds" not in added_cond_kwargs:
896
+ raise ValueError(
897
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
898
+ )
899
+ image_embs = added_cond_kwargs.get("image_embeds")
900
+ aug_emb = self.add_embedding(image_embs)
901
+ elif self.config.addition_embed_type == "image_hint":
902
+ # Kandinsky 2.2 - style
903
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
904
+ raise ValueError(
905
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
906
+ )
907
+ image_embs = added_cond_kwargs.get("image_embeds")
908
+ hint = added_cond_kwargs.get("hint")
909
+ aug_emb, hint = self.add_embedding(image_embs, hint)
910
+ sample = torch.cat([sample, hint], dim=1)
911
+
912
+ emb = emb + aug_emb if aug_emb is not None else emb
913
+
914
+ if self.time_embed_act is not None:
915
+ emb = self.time_embed_act(emb)
916
+
917
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
918
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
919
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
920
+ # Kadinsky 2.1 - style
921
+ if "image_embeds" not in added_cond_kwargs:
922
+ raise ValueError(
923
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
924
+ )
925
+
926
+ image_embeds = added_cond_kwargs.get("image_embeds")
927
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
928
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
929
+ # Kandinsky 2.2 - style
930
+ if "image_embeds" not in added_cond_kwargs:
931
+ raise ValueError(
932
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
933
+ )
934
+ image_embeds = added_cond_kwargs.get("image_embeds")
935
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
936
+ # 2. pre-process
937
+ sample = self.conv_in(sample)
938
+
939
+ # 2.5 GLIGEN position net
940
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
941
+ cross_attention_kwargs = cross_attention_kwargs.copy()
942
+ gligen_args = cross_attention_kwargs.pop("gligen")
943
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
944
+
945
+ # 3. down
946
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
947
+
948
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
949
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
950
+
951
+ down_block_res_samples = (sample,)
952
+ for downsample_block in self.down_blocks:
953
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
954
+ # For t2i-adapter CrossAttnDownBlock2D
955
+ additional_residuals = {}
956
+ if is_adapter and len(down_block_additional_residuals) > 0:
957
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
958
+
959
+ sample, res_samples = downsample_block(
960
+ hidden_states=sample,
961
+ temb=emb,
962
+ encoder_hidden_states=encoder_hidden_states,
963
+ attention_mask=attention_mask,
964
+ cross_attention_kwargs=cross_attention_kwargs,
965
+ encoder_attention_mask=encoder_attention_mask,
966
+ **additional_residuals,
967
+ **kwargs,
968
+ )
969
+ else:
970
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
971
+
972
+ if is_adapter and len(down_block_additional_residuals) > 0:
973
+ sample += down_block_additional_residuals.pop(0)
974
+
975
+ down_block_res_samples += res_samples
976
+
977
+ if is_controlnet:
978
+ new_down_block_res_samples = ()
979
+
980
+ for down_block_res_sample, down_block_additional_residual in zip(
981
+ down_block_res_samples, down_block_additional_residuals
982
+ ):
983
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
984
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
985
+
986
+ down_block_res_samples = new_down_block_res_samples
987
+
988
+ # 4. mid
989
+ if self.mid_block is not None:
990
+ sample = self.mid_block(
991
+ sample,
992
+ emb,
993
+ encoder_hidden_states=encoder_hidden_states,
994
+ attention_mask=attention_mask,
995
+ cross_attention_kwargs=cross_attention_kwargs,
996
+ encoder_attention_mask=encoder_attention_mask,
997
+ **kwargs,
998
+ )
999
+ # To support T2I-Adapter-XL
1000
+ if (
1001
+ is_adapter
1002
+ and len(down_block_additional_residuals) > 0
1003
+ and sample.shape == down_block_additional_residuals[0].shape
1004
+ ):
1005
+ sample += down_block_additional_residuals.pop(0)
1006
+
1007
+ if is_controlnet:
1008
+ sample = sample + mid_block_additional_residual
1009
+
1010
+ # 5. up
1011
+ for i, upsample_block in enumerate(self.up_blocks):
1012
+ is_final_block = i == len(self.up_blocks) - 1
1013
+
1014
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1015
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1016
+
1017
+ # if we have not reached the final block and need to forward the
1018
+ # upsample size, we do it here
1019
+ if not is_final_block and forward_upsample_size:
1020
+ upsample_size = down_block_res_samples[-1].shape[2:]
1021
+
1022
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1023
+ sample = upsample_block(
1024
+ hidden_states=sample,
1025
+ temb=emb,
1026
+ res_hidden_states_tuple=res_samples,
1027
+ encoder_hidden_states=encoder_hidden_states,
1028
+ cross_attention_kwargs=cross_attention_kwargs,
1029
+ upsample_size=upsample_size,
1030
+ attention_mask=attention_mask,
1031
+ encoder_attention_mask=encoder_attention_mask,
1032
+ **kwargs,
1033
+ )
1034
+ else:
1035
+ sample = upsample_block(
1036
+ hidden_states=sample,
1037
+ temb=emb,
1038
+ res_hidden_states_tuple=res_samples,
1039
+ upsample_size=upsample_size,
1040
+ scale=lora_scale,
1041
+ )
1042
+
1043
+ # 6. post-process
1044
+ if self.conv_norm_out:
1045
+ sample = self.conv_norm_out(sample)
1046
+ sample = self.conv_act(sample)
1047
+ sample = self.conv_out(sample)
1048
+
1049
+ if not return_dict:
1050
+ return (sample,)
1051
+
1052
+ return UNet2DConditionOutput(sample=sample)
src/models/unet_3d_blocks.py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import math
17
+ from torch import nn
18
+
19
+ from diffusers.models.resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
20
+ from .transformer_2d import Transformer2DModel
21
+ from .transformer_temporal import TransformerTemporalModel
22
+
23
+
24
+ def get_down_block(
25
+ down_block_type,
26
+ num_layers,
27
+ in_channels,
28
+ out_channels,
29
+ temb_channels,
30
+ add_downsample,
31
+ resnet_eps,
32
+ resnet_act_fn,
33
+ num_attention_heads,
34
+ resnet_groups=None,
35
+ cross_attention_dim=None,
36
+ downsample_padding=None,
37
+ dual_cross_attention=False,
38
+ use_linear_projection=True,
39
+ only_cross_attention=False,
40
+ upcast_attention=False,
41
+ resnet_time_scale_shift="default",
42
+ ):
43
+ if down_block_type == "DownBlock3D":
44
+ return DownBlock3D(
45
+ num_layers=num_layers,
46
+ in_channels=in_channels,
47
+ out_channels=out_channels,
48
+ temb_channels=temb_channels,
49
+ add_downsample=add_downsample,
50
+ resnet_eps=resnet_eps,
51
+ resnet_act_fn=resnet_act_fn,
52
+ resnet_groups=resnet_groups,
53
+ downsample_padding=downsample_padding,
54
+ resnet_time_scale_shift=resnet_time_scale_shift,
55
+ )
56
+ elif down_block_type == "CrossAttnDownBlock3D":
57
+ if cross_attention_dim is None:
58
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
59
+ return CrossAttnDownBlock3D(
60
+ num_layers=num_layers,
61
+ in_channels=in_channels,
62
+ out_channels=out_channels,
63
+ temb_channels=temb_channels,
64
+ add_downsample=add_downsample,
65
+ resnet_eps=resnet_eps,
66
+ resnet_act_fn=resnet_act_fn,
67
+ resnet_groups=resnet_groups,
68
+ downsample_padding=downsample_padding,
69
+ cross_attention_dim=cross_attention_dim,
70
+ num_attention_heads=num_attention_heads,
71
+ dual_cross_attention=dual_cross_attention,
72
+ use_linear_projection=use_linear_projection,
73
+ only_cross_attention=only_cross_attention,
74
+ upcast_attention=upcast_attention,
75
+ resnet_time_scale_shift=resnet_time_scale_shift,
76
+ )
77
+ raise ValueError(f"{down_block_type} does not exist.")
78
+
79
+
80
+ def get_up_block(
81
+ up_block_type,
82
+ num_layers,
83
+ in_channels,
84
+ out_channels,
85
+ prev_output_channel,
86
+ temb_channels,
87
+ add_upsample,
88
+ resnet_eps,
89
+ resnet_act_fn,
90
+ num_attention_heads,
91
+ resnet_groups=None,
92
+ cross_attention_dim=None,
93
+ dual_cross_attention=False,
94
+ use_linear_projection=True,
95
+ only_cross_attention=False,
96
+ upcast_attention=False,
97
+ resnet_time_scale_shift="default",
98
+ ):
99
+ if up_block_type == "UpBlock3D":
100
+ return UpBlock3D(
101
+ num_layers=num_layers,
102
+ in_channels=in_channels,
103
+ out_channels=out_channels,
104
+ prev_output_channel=prev_output_channel,
105
+ temb_channels=temb_channels,
106
+ add_upsample=add_upsample,
107
+ resnet_eps=resnet_eps,
108
+ resnet_act_fn=resnet_act_fn,
109
+ resnet_groups=resnet_groups,
110
+ resnet_time_scale_shift=resnet_time_scale_shift,
111
+ )
112
+ elif up_block_type == "CrossAttnUpBlock3D":
113
+ if cross_attention_dim is None:
114
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
115
+ return CrossAttnUpBlock3D(
116
+ num_layers=num_layers,
117
+ in_channels=in_channels,
118
+ out_channels=out_channels,
119
+ prev_output_channel=prev_output_channel,
120
+ temb_channels=temb_channels,
121
+ add_upsample=add_upsample,
122
+ resnet_eps=resnet_eps,
123
+ resnet_act_fn=resnet_act_fn,
124
+ resnet_groups=resnet_groups,
125
+ cross_attention_dim=cross_attention_dim,
126
+ num_attention_heads=num_attention_heads,
127
+ dual_cross_attention=dual_cross_attention,
128
+ use_linear_projection=use_linear_projection,
129
+ only_cross_attention=only_cross_attention,
130
+ upcast_attention=upcast_attention,
131
+ resnet_time_scale_shift=resnet_time_scale_shift,
132
+ )
133
+ raise ValueError(f"{up_block_type} does not exist.")
134
+
135
+
136
+ class UNetMidBlock3DCrossAttn(nn.Module):
137
+ def __init__(
138
+ self,
139
+ in_channels: int,
140
+ temb_channels: int,
141
+ dropout: float = 0.0,
142
+ num_layers: int = 1,
143
+ resnet_eps: float = 1e-6,
144
+ resnet_time_scale_shift: str = "default",
145
+ resnet_act_fn: str = "swish",
146
+ resnet_groups: int = 32,
147
+ resnet_pre_norm: bool = True,
148
+ num_attention_heads=1,
149
+ output_scale_factor=1.0,
150
+ cross_attention_dim=1280,
151
+ dual_cross_attention=False,
152
+ use_linear_projection=True,
153
+ upcast_attention=False,
154
+ ):
155
+ super().__init__()
156
+
157
+ self.has_cross_attention = True
158
+ self.num_attention_heads = num_attention_heads
159
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
160
+
161
+ # there is always at least one resnet
162
+ resnets = [
163
+ ResnetBlock2D(
164
+ in_channels=in_channels,
165
+ out_channels=in_channels,
166
+ temb_channels=temb_channels,
167
+ eps=resnet_eps,
168
+ groups=resnet_groups,
169
+ dropout=dropout,
170
+ time_embedding_norm=resnet_time_scale_shift,
171
+ non_linearity=resnet_act_fn,
172
+ output_scale_factor=output_scale_factor,
173
+ pre_norm=resnet_pre_norm,
174
+ )
175
+ ]
176
+ temp_convs = [
177
+ TemporalConvLayer(
178
+ in_channels,
179
+ in_channels,
180
+ dropout=0.1,
181
+ )
182
+ ]
183
+ attentions = []
184
+ temp_attentions = []
185
+
186
+ for _ in range(num_layers):
187
+ attentions.append(
188
+ Transformer2DModel(
189
+ in_channels // num_attention_heads,
190
+ num_attention_heads,
191
+ in_channels=in_channels,
192
+ num_layers=1,
193
+ cross_attention_dim=cross_attention_dim,
194
+ norm_num_groups=resnet_groups,
195
+ use_linear_projection=use_linear_projection,
196
+ upcast_attention=upcast_attention,
197
+ )
198
+ )
199
+ temp_attentions.append(
200
+ TransformerTemporalModel(
201
+ in_channels // num_attention_heads,
202
+ num_attention_heads,
203
+ in_channels=in_channels,
204
+ num_layers=1,
205
+ cross_attention_dim=cross_attention_dim,
206
+ norm_num_groups=resnet_groups,
207
+ )
208
+ )
209
+ resnets.append(
210
+ ResnetBlock2D(
211
+ in_channels=in_channels,
212
+ out_channels=in_channels,
213
+ temb_channels=temb_channels,
214
+ eps=resnet_eps,
215
+ groups=resnet_groups,
216
+ dropout=dropout,
217
+ time_embedding_norm=resnet_time_scale_shift,
218
+ non_linearity=resnet_act_fn,
219
+ output_scale_factor=output_scale_factor,
220
+ pre_norm=resnet_pre_norm,
221
+ )
222
+ )
223
+ temp_convs.append(
224
+ TemporalConvLayer(
225
+ in_channels,
226
+ in_channels,
227
+ dropout=0.1,
228
+ )
229
+ )
230
+
231
+ self.resnets = nn.ModuleList(resnets)
232
+ self.temp_convs = nn.ModuleList(temp_convs)
233
+ self.attentions = nn.ModuleList(attentions)
234
+ self.temp_attentions = nn.ModuleList(temp_attentions)
235
+
236
+ def forward(
237
+ self,
238
+ hidden_states,
239
+ temb=None,
240
+ encoder_hidden_states=None,
241
+ encoder_attention_mask=None,
242
+ attention_mask=None,
243
+ num_frames=1,
244
+ cross_attention_kwargs=None,
245
+ **kwargs,
246
+
247
+ ):
248
+ hidden_states = self.resnets[0](hidden_states, temb)
249
+ hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
250
+ for attn, temp_attn, resnet, temp_conv in zip(
251
+ self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
252
+ ):
253
+ hidden_states = attn(
254
+ hidden_states,
255
+ encoder_hidden_states=encoder_hidden_states,
256
+ encoder_attention_mask=encoder_attention_mask,
257
+ attention_mask=attention_mask,
258
+ cross_attention_kwargs=cross_attention_kwargs,
259
+ return_dict=False,
260
+ **kwargs,
261
+
262
+ )[0]
263
+ hidden_states = temp_attn(
264
+ hidden_states, num_frames=num_frames, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, **kwargs
265
+ )[0]
266
+ hidden_states = resnet(hidden_states, temb)
267
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
268
+
269
+ return hidden_states
270
+
271
+
272
+ class CrossAttnDownBlock3D(nn.Module):
273
+ def __init__(
274
+ self,
275
+ in_channels: int,
276
+ out_channels: int,
277
+ temb_channels: int,
278
+ dropout: float = 0.0,
279
+ num_layers: int = 1,
280
+ resnet_eps: float = 1e-6,
281
+ resnet_time_scale_shift: str = "default",
282
+ resnet_act_fn: str = "swish",
283
+ resnet_groups: int = 32,
284
+ resnet_pre_norm: bool = True,
285
+ num_attention_heads=1,
286
+ cross_attention_dim=1280,
287
+ output_scale_factor=1.0,
288
+ downsample_padding=1,
289
+ add_downsample=True,
290
+ dual_cross_attention=False,
291
+ use_linear_projection=False,
292
+ only_cross_attention=False,
293
+ upcast_attention=False,
294
+ ):
295
+ super().__init__()
296
+ resnets = []
297
+ attentions = []
298
+ temp_attentions = []
299
+ temp_convs = []
300
+
301
+ self.has_cross_attention = True
302
+ self.num_attention_heads = num_attention_heads
303
+
304
+ for i in range(num_layers):
305
+ in_channels = in_channels if i == 0 else out_channels
306
+ resnets.append(
307
+ ResnetBlock2D(
308
+ in_channels=in_channels,
309
+ out_channels=out_channels,
310
+ temb_channels=temb_channels,
311
+ eps=resnet_eps,
312
+ groups=resnet_groups,
313
+ dropout=dropout,
314
+ time_embedding_norm=resnet_time_scale_shift,
315
+ non_linearity=resnet_act_fn,
316
+ output_scale_factor=output_scale_factor,
317
+ pre_norm=resnet_pre_norm,
318
+ )
319
+ )
320
+ temp_convs.append(
321
+ TemporalConvLayer(
322
+ out_channels,
323
+ out_channels,
324
+ dropout=0.1,
325
+ )
326
+ )
327
+ attentions.append(
328
+ Transformer2DModel(
329
+ out_channels // num_attention_heads,
330
+ num_attention_heads,
331
+ in_channels=out_channels,
332
+ num_layers=1,
333
+ cross_attention_dim=cross_attention_dim,
334
+ norm_num_groups=resnet_groups,
335
+ use_linear_projection=use_linear_projection,
336
+ only_cross_attention=only_cross_attention,
337
+ upcast_attention=upcast_attention,
338
+ )
339
+ )
340
+ temp_attentions.append(
341
+ TransformerTemporalModel(
342
+ out_channels // num_attention_heads,
343
+ num_attention_heads,
344
+ in_channels=out_channels,
345
+ num_layers=1,
346
+ cross_attention_dim=cross_attention_dim,
347
+ norm_num_groups=resnet_groups,
348
+ )
349
+ )
350
+ self.resnets = nn.ModuleList(resnets)
351
+ self.temp_convs = nn.ModuleList(temp_convs)
352
+ self.attentions = nn.ModuleList(attentions)
353
+ self.temp_attentions = nn.ModuleList(temp_attentions)
354
+
355
+ if add_downsample:
356
+ self.downsamplers = nn.ModuleList(
357
+ [
358
+ Downsample2D(
359
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
360
+ )
361
+ ]
362
+ )
363
+ else:
364
+ self.downsamplers = None
365
+
366
+ self.gradient_checkpointing = False
367
+
368
+ def forward(
369
+ self,
370
+ hidden_states,
371
+ temb=None,
372
+ encoder_hidden_states=None,
373
+ encoder_attention_mask=None,
374
+ attention_mask=None,
375
+ num_frames=1,
376
+ cross_attention_kwargs=None,
377
+ **kwargs,
378
+
379
+ ):
380
+ # TODO(Patrick, William) - attention mask is not used
381
+ output_states = ()
382
+ for resnet, temp_conv, attn, temp_attn in zip(
383
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
384
+ ):
385
+ hidden_states = resnet(hidden_states, temb)
386
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
387
+ hidden_states = attn(
388
+ hidden_states,
389
+ attention_mask=attention_mask,
390
+ encoder_hidden_states=encoder_hidden_states,
391
+ encoder_attention_mask=encoder_attention_mask,
392
+ cross_attention_kwargs=cross_attention_kwargs,
393
+ return_dict=False,
394
+ **kwargs,
395
+
396
+ )[0]
397
+ hidden_states = temp_attn(
398
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, **kwargs
399
+ )[0]
400
+
401
+ output_states += (hidden_states,)
402
+
403
+ if self.downsamplers is not None:
404
+ for downsampler in self.downsamplers:
405
+ hidden_states = downsampler(hidden_states)
406
+
407
+ output_states += (hidden_states,)
408
+
409
+ return hidden_states, output_states
410
+
411
+
412
+ class DownBlock3D(nn.Module):
413
+ def __init__(
414
+ self,
415
+ in_channels: int,
416
+ out_channels: int,
417
+ temb_channels: int,
418
+ dropout: float = 0.0,
419
+ num_layers: int = 1,
420
+ resnet_eps: float = 1e-6,
421
+ resnet_time_scale_shift: str = "default",
422
+ resnet_act_fn: str = "swish",
423
+ resnet_groups: int = 32,
424
+ resnet_pre_norm: bool = True,
425
+ output_scale_factor=1.0,
426
+ add_downsample=True,
427
+ downsample_padding=1,
428
+ ):
429
+ super().__init__()
430
+ resnets = []
431
+ temp_convs = []
432
+
433
+ for i in range(num_layers):
434
+ in_channels = in_channels if i == 0 else out_channels
435
+ resnets.append(
436
+ ResnetBlock2D(
437
+ in_channels=in_channels,
438
+ out_channels=out_channels,
439
+ temb_channels=temb_channels,
440
+ eps=resnet_eps,
441
+ groups=resnet_groups,
442
+ dropout=dropout,
443
+ time_embedding_norm=resnet_time_scale_shift,
444
+ non_linearity=resnet_act_fn,
445
+ output_scale_factor=output_scale_factor,
446
+ pre_norm=resnet_pre_norm,
447
+ )
448
+ )
449
+ temp_convs.append(
450
+ TemporalConvLayer(
451
+ out_channels,
452
+ out_channels,
453
+ dropout=0.1,
454
+ )
455
+ )
456
+
457
+ self.resnets = nn.ModuleList(resnets)
458
+ self.temp_convs = nn.ModuleList(temp_convs)
459
+
460
+ if add_downsample:
461
+ self.downsamplers = nn.ModuleList(
462
+ [
463
+ Downsample2D(
464
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
465
+ )
466
+ ]
467
+ )
468
+ else:
469
+ self.downsamplers = None
470
+
471
+ self.gradient_checkpointing = False
472
+
473
+ def forward(self, hidden_states, temb=None, num_frames=1):
474
+ output_states = ()
475
+
476
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
477
+ hidden_states = resnet(hidden_states, temb)
478
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
479
+
480
+ output_states += (hidden_states,)
481
+
482
+ if self.downsamplers is not None:
483
+ for downsampler in self.downsamplers:
484
+ hidden_states = downsampler(hidden_states)
485
+
486
+ output_states += (hidden_states,)
487
+
488
+ return hidden_states, output_states
489
+
490
+
491
+ class CrossAttnUpBlock3D(nn.Module):
492
+ def __init__(
493
+ self,
494
+ in_channels: int,
495
+ out_channels: int,
496
+ prev_output_channel: int,
497
+ temb_channels: int,
498
+ dropout: float = 0.0,
499
+ num_layers: int = 1,
500
+ resnet_eps: float = 1e-6,
501
+ resnet_time_scale_shift: str = "default",
502
+ resnet_act_fn: str = "swish",
503
+ resnet_groups: int = 32,
504
+ resnet_pre_norm: bool = True,
505
+ num_attention_heads=1,
506
+ cross_attention_dim=1280,
507
+ output_scale_factor=1.0,
508
+ add_upsample=True,
509
+ dual_cross_attention=False,
510
+ use_linear_projection=False,
511
+ only_cross_attention=False,
512
+ upcast_attention=False,
513
+ ):
514
+ super().__init__()
515
+ resnets = []
516
+ temp_convs = []
517
+ attentions = []
518
+ temp_attentions = []
519
+
520
+ self.has_cross_attention = True
521
+ self.num_attention_heads = num_attention_heads
522
+
523
+ for i in range(num_layers):
524
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
525
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
526
+
527
+ resnets.append(
528
+ ResnetBlock2D(
529
+ in_channels=resnet_in_channels + res_skip_channels,
530
+ out_channels=out_channels,
531
+ temb_channels=temb_channels,
532
+ eps=resnet_eps,
533
+ groups=resnet_groups,
534
+ dropout=dropout,
535
+ time_embedding_norm=resnet_time_scale_shift,
536
+ non_linearity=resnet_act_fn,
537
+ output_scale_factor=output_scale_factor,
538
+ pre_norm=resnet_pre_norm,
539
+ )
540
+ )
541
+ temp_convs.append(
542
+ TemporalConvLayer(
543
+ out_channels,
544
+ out_channels,
545
+ dropout=0.1,
546
+ )
547
+ )
548
+ attentions.append(
549
+ Transformer2DModel(
550
+ out_channels // num_attention_heads,
551
+ num_attention_heads,
552
+ in_channels=out_channels,
553
+ num_layers=1,
554
+ cross_attention_dim=cross_attention_dim,
555
+ norm_num_groups=resnet_groups,
556
+ use_linear_projection=use_linear_projection,
557
+ only_cross_attention=only_cross_attention,
558
+ upcast_attention=upcast_attention,
559
+ )
560
+ )
561
+ temp_attentions.append(
562
+ TransformerTemporalModel(
563
+ out_channels // num_attention_heads,
564
+ num_attention_heads,
565
+ in_channels=out_channels,
566
+ num_layers=1,
567
+ cross_attention_dim=cross_attention_dim,
568
+ norm_num_groups=resnet_groups,
569
+ )
570
+ )
571
+ self.resnets = nn.ModuleList(resnets)
572
+ self.temp_convs = nn.ModuleList(temp_convs)
573
+ self.attentions = nn.ModuleList(attentions)
574
+ self.temp_attentions = nn.ModuleList(temp_attentions)
575
+
576
+ if add_upsample:
577
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
578
+ else:
579
+ self.upsamplers = None
580
+
581
+ self.gradient_checkpointing = False
582
+
583
+ def forward(
584
+ self,
585
+ hidden_states,
586
+ res_hidden_states_tuple,
587
+ temb=None,
588
+ encoder_hidden_states=None,
589
+ encoder_attention_mask=None,
590
+ upsample_size=None,
591
+ attention_mask=None,
592
+ num_frames=1,
593
+ cross_attention_kwargs=None,
594
+ **kwargs,
595
+ ):
596
+ # TODO(Patrick, William) - attention mask is not used
597
+ for resnet, temp_conv, attn, temp_attn in zip(
598
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
599
+ ):
600
+ # pop res hidden states
601
+ res_hidden_states = res_hidden_states_tuple[-1]
602
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
603
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
604
+
605
+ hidden_states = resnet(hidden_states, temb)
606
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames) # This gives the 1280 dim
607
+ hidden_states = attn(
608
+ hidden_states,
609
+ encoder_hidden_states=encoder_hidden_states,
610
+ attention_mask=attention_mask, # TODO: check if this is correct
611
+ encoder_attention_mask=encoder_attention_mask,
612
+ cross_attention_kwargs=cross_attention_kwargs,
613
+ return_dict=False,
614
+ **kwargs,
615
+ )[0]
616
+ hidden_states = temp_attn(
617
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, return_dict=False, **kwargs
618
+ )[0]
619
+
620
+ if self.upsamplers is not None:
621
+ for upsampler in self.upsamplers:
622
+ hidden_states = upsampler(hidden_states, upsample_size)
623
+
624
+ return hidden_states
625
+
626
+
627
+ class UpBlock3D(nn.Module):
628
+ def __init__(
629
+ self,
630
+ in_channels: int,
631
+ prev_output_channel: int,
632
+ out_channels: int,
633
+ temb_channels: int,
634
+ dropout: float = 0.0,
635
+ num_layers: int = 1,
636
+ resnet_eps: float = 1e-6,
637
+ resnet_time_scale_shift: str = "default",
638
+ resnet_act_fn: str = "swish",
639
+ resnet_groups: int = 32,
640
+ resnet_pre_norm: bool = True,
641
+ output_scale_factor=1.0,
642
+ add_upsample=True,
643
+ ):
644
+ super().__init__()
645
+ resnets = []
646
+ temp_convs = []
647
+
648
+ for i in range(num_layers):
649
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
650
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
651
+
652
+ resnets.append(
653
+ ResnetBlock2D(
654
+ in_channels=resnet_in_channels + res_skip_channels,
655
+ out_channels=out_channels,
656
+ temb_channels=temb_channels,
657
+ eps=resnet_eps,
658
+ groups=resnet_groups,
659
+ dropout=dropout,
660
+ time_embedding_norm=resnet_time_scale_shift,
661
+ non_linearity=resnet_act_fn,
662
+ output_scale_factor=output_scale_factor,
663
+ pre_norm=resnet_pre_norm,
664
+ )
665
+ )
666
+ temp_convs.append(
667
+ TemporalConvLayer(
668
+ out_channels,
669
+ out_channels,
670
+ dropout=0.1,
671
+ )
672
+ )
673
+
674
+ self.resnets = nn.ModuleList(resnets)
675
+ self.temp_convs = nn.ModuleList(temp_convs)
676
+
677
+ if add_upsample:
678
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
679
+ else:
680
+ self.upsamplers = None
681
+
682
+ self.gradient_checkpointing = False
683
+
684
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
685
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
686
+ # pop res hidden states
687
+ res_hidden_states = res_hidden_states_tuple[-1]
688
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
689
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
690
+
691
+ hidden_states = resnet(hidden_states, temb)
692
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
693
+
694
+ if self.upsamplers is not None:
695
+ for upsampler in self.upsamplers:
696
+ hidden_states = upsampler(hidden_states, upsample_size)
697
+
698
+ return hidden_states
src/models/unet_3d_condition.py ADDED
@@ -0,0 +1,673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
2
+ # Copyright 2023 The ModelScope Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.loaders import UNet2DConditionLoadersMixin
24
+ from diffusers.utils import BaseOutput, logging
25
+ from .attention_processor import (
26
+ ADDED_KV_ATTENTION_PROCESSORS,
27
+ CROSS_ATTENTION_PROCESSORS,
28
+ AttentionProcessor,
29
+ AttnAddedKVProcessor,
30
+ AttnProcessor,
31
+ )
32
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
33
+ from diffusers.models.modeling_utils import ModelMixin
34
+ from .transformer_temporal import TransformerTemporalModel
35
+ from .unet_3d_blocks import (
36
+ CrossAttnDownBlock3D,
37
+ CrossAttnUpBlock3D,
38
+ DownBlock3D,
39
+ UNetMidBlock3DCrossAttn,
40
+ UpBlock3D,
41
+ get_down_block,
42
+ get_up_block,
43
+ )
44
+
45
+
46
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
+
48
+
49
+ @dataclass
50
+ class UNet3DConditionOutput(BaseOutput):
51
+ """
52
+ The output of [`UNet3DConditionModel`].
53
+
54
+ Args:
55
+ sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
56
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
57
+ """
58
+
59
+ sample: torch.FloatTensor
60
+
61
+
62
+ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
63
+ r"""
64
+ A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
65
+ shaped output.
66
+
67
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
68
+ for all models (such as downloading or saving).
69
+
70
+ Parameters:
71
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
72
+ Height and width of input/output sample.
73
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
74
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
75
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
76
+ The tuple of downsample blocks to use.
77
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
78
+ The tuple of upsample blocks to use.
79
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
80
+ The tuple of output channels for each block.
81
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
82
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
83
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
84
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
85
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
86
+ If `None`, normalization and activation layers is skipped in post-processing.
87
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
88
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
89
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
90
+ num_attention_heads (`int`, *optional*): The number of attention heads.
91
+ """
92
+
93
+ _supports_gradient_checkpointing = False
94
+
95
+ @register_to_config
96
+ def __init__(
97
+ self,
98
+ sample_size: Optional[int] = None,
99
+ in_channels: int = 4,
100
+ out_channels: int = 4,
101
+ down_block_types: Tuple[str] = (
102
+ "CrossAttnDownBlock3D",
103
+ "CrossAttnDownBlock3D",
104
+ "CrossAttnDownBlock3D",
105
+ "DownBlock3D",
106
+ ),
107
+ up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
108
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
109
+ layers_per_block: int = 2,
110
+ downsample_padding: int = 1,
111
+ mid_block_scale_factor: float = 1,
112
+ act_fn: str = "silu",
113
+ norm_num_groups: Optional[int] = 32,
114
+ norm_eps: float = 1e-5,
115
+ cross_attention_dim: int = 1024,
116
+ attention_head_dim: Union[int, Tuple[int]] = 64,
117
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
118
+ ):
119
+ super().__init__()
120
+
121
+ self.sample_size = sample_size
122
+
123
+ if num_attention_heads is not None:
124
+ raise NotImplementedError(
125
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
126
+ )
127
+
128
+ # If `num_attention_heads` is not defined (which is the case for most models)
129
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
130
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
131
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
132
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
133
+ # which is why we correct for the naming here.
134
+ num_attention_heads = num_attention_heads or attention_head_dim
135
+
136
+ # Check inputs
137
+ if len(down_block_types) != len(up_block_types):
138
+ raise ValueError(
139
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
140
+ )
141
+
142
+ if len(block_out_channels) != len(down_block_types):
143
+ raise ValueError(
144
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
145
+ )
146
+
147
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
148
+ raise ValueError(
149
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
150
+ )
151
+
152
+ # input
153
+ conv_in_kernel = 3
154
+ conv_out_kernel = 3
155
+ conv_in_padding = (conv_in_kernel - 1) // 2
156
+ self.conv_in = nn.Conv2d(
157
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
158
+ )
159
+
160
+ # time
161
+ time_embed_dim = block_out_channels[0] * 4
162
+ self.time_proj = Timesteps(block_out_channels[0], True, 0)
163
+ timestep_input_dim = block_out_channels[0]
164
+
165
+ self.time_embedding = TimestepEmbedding(
166
+ timestep_input_dim,
167
+ time_embed_dim,
168
+ act_fn=act_fn,
169
+ )
170
+
171
+ self.transformer_in = TransformerTemporalModel(
172
+ num_attention_heads=8,
173
+ attention_head_dim=attention_head_dim,
174
+ in_channels=block_out_channels[0],
175
+ num_layers=1,
176
+ )
177
+
178
+ # class embedding
179
+ self.down_blocks = nn.ModuleList([])
180
+ self.up_blocks = nn.ModuleList([])
181
+
182
+ if isinstance(num_attention_heads, int):
183
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
184
+
185
+ # down
186
+ output_channel = block_out_channels[0]
187
+ for i, down_block_type in enumerate(down_block_types):
188
+ input_channel = output_channel
189
+ output_channel = block_out_channels[i]
190
+ is_final_block = i == len(block_out_channels) - 1
191
+
192
+ down_block = get_down_block(
193
+ down_block_type,
194
+ num_layers=layers_per_block,
195
+ in_channels=input_channel,
196
+ out_channels=output_channel,
197
+ temb_channels=time_embed_dim,
198
+ add_downsample=not is_final_block,
199
+ resnet_eps=norm_eps,
200
+ resnet_act_fn=act_fn,
201
+ resnet_groups=norm_num_groups,
202
+ cross_attention_dim=cross_attention_dim,
203
+ num_attention_heads=num_attention_heads[i],
204
+ downsample_padding=downsample_padding,
205
+ dual_cross_attention=False,
206
+ )
207
+ self.down_blocks.append(down_block)
208
+
209
+ # mid
210
+ self.mid_block = UNetMidBlock3DCrossAttn(
211
+ in_channels=block_out_channels[-1],
212
+ temb_channels=time_embed_dim,
213
+ resnet_eps=norm_eps,
214
+ resnet_act_fn=act_fn,
215
+ output_scale_factor=mid_block_scale_factor,
216
+ cross_attention_dim=cross_attention_dim,
217
+ num_attention_heads=num_attention_heads[-1],
218
+ resnet_groups=norm_num_groups,
219
+ dual_cross_attention=False,
220
+ )
221
+
222
+ # count how many layers upsample the images
223
+ self.num_upsamplers = 0
224
+
225
+ # up
226
+ reversed_block_out_channels = list(reversed(block_out_channels))
227
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
228
+
229
+ output_channel = reversed_block_out_channels[0]
230
+ for i, up_block_type in enumerate(up_block_types):
231
+ is_final_block = i == len(block_out_channels) - 1
232
+
233
+ prev_output_channel = output_channel
234
+ output_channel = reversed_block_out_channels[i]
235
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
236
+
237
+ # add upsample block for all BUT final layer
238
+ if not is_final_block:
239
+ add_upsample = True
240
+ self.num_upsamplers += 1
241
+ else:
242
+ add_upsample = False
243
+
244
+ up_block = get_up_block(
245
+ up_block_type,
246
+ num_layers=layers_per_block + 1,
247
+ in_channels=input_channel,
248
+ out_channels=output_channel,
249
+ prev_output_channel=prev_output_channel,
250
+ temb_channels=time_embed_dim,
251
+ add_upsample=add_upsample,
252
+ resnet_eps=norm_eps,
253
+ resnet_act_fn=act_fn,
254
+ resnet_groups=norm_num_groups,
255
+ cross_attention_dim=cross_attention_dim,
256
+ num_attention_heads=reversed_num_attention_heads[i],
257
+ dual_cross_attention=False,
258
+ )
259
+ self.up_blocks.append(up_block)
260
+ prev_output_channel = output_channel
261
+
262
+ # out
263
+ if norm_num_groups is not None:
264
+ self.conv_norm_out = nn.GroupNorm(
265
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
266
+ )
267
+ self.conv_act = nn.SiLU()
268
+ else:
269
+ self.conv_norm_out = None
270
+ self.conv_act = None
271
+
272
+ conv_out_padding = (conv_out_kernel - 1) // 2
273
+ self.conv_out = nn.Conv2d(
274
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
275
+ )
276
+
277
+ @property
278
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
279
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
280
+ r"""
281
+ Returns:
282
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
283
+ indexed by its weight name.
284
+ """
285
+ # set recursively
286
+ processors = {}
287
+
288
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
289
+ if hasattr(module, "get_processor"):
290
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
291
+
292
+ for sub_name, child in module.named_children():
293
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
294
+
295
+ return processors
296
+
297
+ for name, module in self.named_children():
298
+ fn_recursive_add_processors(name, module, processors)
299
+
300
+ return processors
301
+
302
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
303
+ def set_attention_slice(self, slice_size):
304
+ r"""
305
+ Enable sliced attention computation.
306
+
307
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
308
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
309
+
310
+ Args:
311
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
312
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
313
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
314
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
315
+ must be a multiple of `slice_size`.
316
+ """
317
+ sliceable_head_dims = []
318
+
319
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
320
+ if hasattr(module, "set_attention_slice"):
321
+ sliceable_head_dims.append(module.sliceable_head_dim)
322
+
323
+ for child in module.children():
324
+ fn_recursive_retrieve_sliceable_dims(child)
325
+
326
+ # retrieve number of attention layers
327
+ for module in self.children():
328
+ fn_recursive_retrieve_sliceable_dims(module)
329
+
330
+ num_sliceable_layers = len(sliceable_head_dims)
331
+
332
+ if slice_size == "auto":
333
+ # half the attention head size is usually a good trade-off between
334
+ # speed and memory
335
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
336
+ elif slice_size == "max":
337
+ # make smallest slice possible
338
+ slice_size = num_sliceable_layers * [1]
339
+
340
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
341
+
342
+ if len(slice_size) != len(sliceable_head_dims):
343
+ raise ValueError(
344
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
345
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
346
+ )
347
+
348
+ for i in range(len(slice_size)):
349
+ size = slice_size[i]
350
+ dim = sliceable_head_dims[i]
351
+ if size is not None and size > dim:
352
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
353
+
354
+ # Recursively walk through all the children.
355
+ # Any children which exposes the set_attention_slice method
356
+ # gets the message
357
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
358
+ if hasattr(module, "set_attention_slice"):
359
+ module.set_attention_slice(slice_size.pop())
360
+
361
+ for child in module.children():
362
+ fn_recursive_set_attention_slice(child, slice_size)
363
+
364
+ reversed_slice_size = list(reversed(slice_size))
365
+ for module in self.children():
366
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
367
+
368
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
369
+ def set_attn_processor(
370
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
371
+ ):
372
+ r"""
373
+ Sets the attention processor to use to compute attention.
374
+
375
+ Parameters:
376
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
377
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
378
+ for **all** `Attention` layers.
379
+
380
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
381
+ processor. This is strongly recommended when setting trainable attention processors.
382
+
383
+ """
384
+ count = len(self.attn_processors.keys())
385
+
386
+ if isinstance(processor, dict) and len(processor) != count:
387
+ raise ValueError(
388
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
389
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
390
+ )
391
+
392
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
393
+ if hasattr(module, "set_processor"):
394
+ if not isinstance(processor, dict):
395
+ module.set_processor(processor, _remove_lora=_remove_lora)
396
+ else:
397
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
398
+
399
+ for sub_name, child in module.named_children():
400
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
401
+
402
+ for name, module in self.named_children():
403
+ fn_recursive_attn_processor(name, module, processor)
404
+
405
+ def enable_forward_chunking(self, chunk_size=None, dim=0):
406
+ """
407
+ Sets the attention processor to use [feed forward
408
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
409
+
410
+ Parameters:
411
+ chunk_size (`int`, *optional*):
412
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
413
+ over each tensor of dim=`dim`.
414
+ dim (`int`, *optional*, defaults to `0`):
415
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
416
+ or dim=1 (sequence length).
417
+ """
418
+ if dim not in [0, 1]:
419
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
420
+
421
+ # By default chunk size is 1
422
+ chunk_size = chunk_size or 1
423
+
424
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
425
+ if hasattr(module, "set_chunk_feed_forward"):
426
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
427
+
428
+ for child in module.children():
429
+ fn_recursive_feed_forward(child, chunk_size, dim)
430
+
431
+ for module in self.children():
432
+ fn_recursive_feed_forward(module, chunk_size, dim)
433
+
434
+ def disable_forward_chunking(self):
435
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
436
+ if hasattr(module, "set_chunk_feed_forward"):
437
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
438
+
439
+ for child in module.children():
440
+ fn_recursive_feed_forward(child, chunk_size, dim)
441
+
442
+ for module in self.children():
443
+ fn_recursive_feed_forward(module, None, 0)
444
+
445
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
446
+ def set_default_attn_processor(self):
447
+ """
448
+ Disables custom attention processors and sets the default attention implementation.
449
+ """
450
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
451
+ processor = AttnAddedKVProcessor()
452
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
453
+ processor = AttnProcessor()
454
+ else:
455
+ raise ValueError(
456
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
457
+ )
458
+
459
+ self.set_attn_processor(processor, _remove_lora=True)
460
+
461
+ def _set_gradient_checkpointing(self, module, value=False):
462
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
463
+ module.gradient_checkpointing = value
464
+
465
+ def forward(
466
+ self,
467
+ sample: torch.FloatTensor,
468
+ timestep: Union[torch.Tensor, float, int],
469
+ encoder_hidden_states: torch.Tensor,
470
+ class_labels: Optional[torch.Tensor] = None,
471
+ timestep_cond: Optional[torch.Tensor] = None,
472
+ attention_mask: Optional[torch.Tensor] = None,
473
+ encoder_attention_mask: Optional[torch.Tensor] = None,
474
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
475
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
476
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
477
+ return_dict: bool = True,
478
+ **kwargs,
479
+ ) -> Union[UNet3DConditionOutput, Tuple]:
480
+ r"""
481
+ The [`UNet3DConditionModel`] forward method.
482
+
483
+ Args:
484
+ sample (`torch.FloatTensor`):
485
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`.
486
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
487
+ encoder_hidden_states (`torch.FloatTensor`):
488
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
489
+ encoder_attention_masl (`torch.FloatTensor`, *optional*): Masks out the encoder hidden states for cross
490
+ attention.
491
+ return_dict (`bool`, *optional*, defaults to `True`):
492
+ Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain
493
+ tuple.
494
+ cross_attention_kwargs (`dict`, *optional*):
495
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
496
+
497
+ Returns:
498
+ [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`:
499
+ If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise
500
+ a `tuple` is returned where the first element is the sample tensor.
501
+ """
502
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
503
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
504
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
505
+ # on the fly if necessary.
506
+ default_overall_up_factor = 2**self.num_upsamplers
507
+
508
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
509
+ forward_upsample_size = False
510
+ upsample_size = None
511
+
512
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
513
+ logger.info("Forward upsample size to force interpolation output size.")
514
+ forward_upsample_size = True
515
+
516
+ # prepare attention_mask
517
+ if attention_mask is not None:
518
+ if not isinstance(attention_mask, list):
519
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
520
+ if len(attention_mask.shape) == 2: # Else we are already passing a 2d mask
521
+ attention_mask = attention_mask.unsqueeze(1)
522
+ else:
523
+ attention_mask = [(1 - mask.to(sample.dtype)) * -10000.0 for mask in attention_mask]
524
+ if len(attention_mask[0].shape) == 2:
525
+ attention_mask = [mask.unsqueeze(1) for mask in attention_mask]
526
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
527
+ if encoder_attention_mask is not None:
528
+ if not isinstance(encoder_attention_mask, list):
529
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
530
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
531
+ else:
532
+ encoder_attention_mask = [(1 - mask.to(sample.dtype)) * -10000.0 for mask in encoder_attention_mask]
533
+ if len(encoder_attention_mask[0].shape) == 2:
534
+ encoder_attention_mask = [mask.unsqueeze(1) for mask in encoder_attention_mask]
535
+
536
+ # 1. time
537
+ timesteps = timestep
538
+ if not torch.is_tensor(timesteps):
539
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
540
+ # This would be a good case for the `match` statement (Python 3.10+)
541
+ is_mps = sample.device.type == "mps"
542
+ if isinstance(timestep, float):
543
+ dtype = torch.float32 if is_mps else torch.float64
544
+ else:
545
+ dtype = torch.int32 if is_mps else torch.int64
546
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
547
+ elif len(timesteps.shape) == 0:
548
+ timesteps = timesteps[None].to(sample.device)
549
+
550
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
551
+ num_frames = sample.shape[2]
552
+ timesteps = timesteps.expand(sample.shape[0])
553
+
554
+ t_emb = self.time_proj(timesteps)
555
+
556
+ # timesteps does not contain any weights and will always return f32 tensors
557
+ # but time_embedding might actually be running in fp16. so we need to cast here.
558
+ # there might be better ways to encapsulate this.
559
+ t_emb = t_emb.to(dtype=self.dtype)
560
+
561
+ emb = self.time_embedding(t_emb, timestep_cond)
562
+ emb = emb.repeat_interleave(repeats=num_frames, dim=0)
563
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
564
+
565
+ # 2. pre-process
566
+ sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
567
+ sample = self.conv_in(sample)
568
+
569
+ sample = self.transformer_in(
570
+ sample,
571
+ num_frames=num_frames,
572
+ cross_attention_kwargs=cross_attention_kwargs,
573
+ return_dict=False,
574
+ attention_mask = None,
575
+ encoder_attention_mask = None,
576
+ )[0]
577
+
578
+ # 3. down
579
+ down_block_res_samples = (sample,)
580
+ for downsample_block in self.down_blocks:
581
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
582
+ sample, res_samples = downsample_block(
583
+ hidden_states=sample,
584
+ temb=emb,
585
+ encoder_hidden_states=encoder_hidden_states,
586
+ encoder_attention_mask=None,
587
+ attention_mask=None,
588
+ num_frames=num_frames,
589
+ cross_attention_kwargs=cross_attention_kwargs,
590
+ **kwargs,
591
+
592
+ )
593
+ else:
594
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
595
+
596
+ down_block_res_samples += res_samples
597
+
598
+ if down_block_additional_residuals is not None:
599
+ new_down_block_res_samples = ()
600
+
601
+ for down_block_res_sample, down_block_additional_residual in zip(
602
+ down_block_res_samples, down_block_additional_residuals
603
+ ):
604
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
605
+ new_down_block_res_samples += (down_block_res_sample,)
606
+
607
+ down_block_res_samples = new_down_block_res_samples
608
+
609
+ # 4. mid
610
+ if self.mid_block is not None:
611
+ sample = self.mid_block(
612
+ sample,
613
+ emb,
614
+ encoder_hidden_states=encoder_hidden_states,
615
+ encoder_attention_mask=encoder_attention_mask,
616
+ attention_mask=attention_mask,
617
+ num_frames=num_frames,
618
+ cross_attention_kwargs=cross_attention_kwargs,
619
+ **kwargs,
620
+
621
+ )
622
+
623
+ if mid_block_additional_residual is not None:
624
+ sample = sample + mid_block_additional_residual
625
+
626
+ # 5. up
627
+ for i, upsample_block in enumerate(self.up_blocks):
628
+ is_final_block = i == len(self.up_blocks) - 1
629
+
630
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
631
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
632
+
633
+ # if we have not reached the final block and need to forward the
634
+ # upsample size, we do it here
635
+ if not is_final_block and forward_upsample_size:
636
+ upsample_size = down_block_res_samples[-1].shape[2:]
637
+
638
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
639
+ sample = upsample_block(
640
+ hidden_states=sample,
641
+ temb=emb,
642
+ res_hidden_states_tuple=res_samples,
643
+ encoder_hidden_states=encoder_hidden_states,
644
+ encoder_attention_mask=encoder_attention_mask,
645
+ upsample_size=upsample_size,
646
+ attention_mask=attention_mask,
647
+ num_frames=num_frames,
648
+ cross_attention_kwargs=cross_attention_kwargs,
649
+ **kwargs,
650
+ )
651
+ else:
652
+ sample = upsample_block(
653
+ hidden_states=sample,
654
+ temb=emb,
655
+ res_hidden_states_tuple=res_samples,
656
+ upsample_size=upsample_size,
657
+ num_frames=num_frames,
658
+ )
659
+
660
+ # 6. post-process
661
+ if self.conv_norm_out:
662
+ sample = self.conv_norm_out(sample)
663
+ sample = self.conv_act(sample)
664
+
665
+ sample = self.conv_out(sample)
666
+
667
+ # reshape to (batch, channel, framerate, width, height)
668
+ sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4)
669
+
670
+ if not return_dict:
671
+ return (sample,)
672
+
673
+ return UNet3DConditionOutput(sample=sample)