Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		kxhit
		
	commited on
		
		
					Commit 
							
							·
						
						5f093a6
	
1
								Parent(s):
							
							6d86936
								
update
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- 6DoF/CN_encoder.py +36 -0
 - 6DoF/dataset.py +176 -0
 - 6DoF/diffusers/__init__.py +281 -0
 - 6DoF/diffusers/commands/__init__.py +27 -0
 - 6DoF/diffusers/commands/diffusers_cli.py +41 -0
 - 6DoF/diffusers/commands/env.py +84 -0
 - 6DoF/diffusers/configuration_utils.py +664 -0
 - 6DoF/diffusers/dependency_versions_check.py +47 -0
 - 6DoF/diffusers/dependency_versions_table.py +44 -0
 - 6DoF/diffusers/experimental/__init__.py +1 -0
 - 6DoF/diffusers/experimental/rl/__init__.py +1 -0
 - 6DoF/diffusers/experimental/rl/value_guided_sampling.py +152 -0
 - 6DoF/diffusers/image_processor.py +366 -0
 - 6DoF/diffusers/loaders.py +1492 -0
 - 6DoF/diffusers/models/__init__.py +35 -0
 - 6DoF/diffusers/models/activations.py +12 -0
 - 6DoF/diffusers/models/attention.py +392 -0
 - 6DoF/diffusers/models/attention_flax.py +446 -0
 - 6DoF/diffusers/models/attention_processor.py +1684 -0
 - 6DoF/diffusers/models/autoencoder_kl.py +411 -0
 - 6DoF/diffusers/models/controlnet.py +705 -0
 - 6DoF/diffusers/models/controlnet_flax.py +394 -0
 - 6DoF/diffusers/models/cross_attention.py +94 -0
 - 6DoF/diffusers/models/dual_transformer_2d.py +151 -0
 - 6DoF/diffusers/models/embeddings.py +546 -0
 - 6DoF/diffusers/models/embeddings_flax.py +95 -0
 - 6DoF/diffusers/models/modeling_flax_pytorch_utils.py +118 -0
 - 6DoF/diffusers/models/modeling_flax_utils.py +534 -0
 - 6DoF/diffusers/models/modeling_pytorch_flax_utils.py +161 -0
 - 6DoF/diffusers/models/modeling_utils.py +980 -0
 - 6DoF/diffusers/models/prior_transformer.py +364 -0
 - 6DoF/diffusers/models/resnet.py +877 -0
 - 6DoF/diffusers/models/resnet_flax.py +124 -0
 - 6DoF/diffusers/models/t5_film_transformer.py +321 -0
 - 6DoF/diffusers/models/transformer_2d.py +343 -0
 - 6DoF/diffusers/models/transformer_temporal.py +179 -0
 - 6DoF/diffusers/models/unet_1d.py +255 -0
 - 6DoF/diffusers/models/unet_1d_blocks.py +656 -0
 - 6DoF/diffusers/models/unet_2d.py +329 -0
 - 6DoF/diffusers/models/unet_2d_blocks.py +0 -0
 - 6DoF/diffusers/models/unet_2d_blocks_flax.py +377 -0
 - 6DoF/diffusers/models/unet_2d_condition.py +980 -0
 - 6DoF/diffusers/models/unet_2d_condition_flax.py +357 -0
 - 6DoF/diffusers/models/unet_3d_blocks.py +679 -0
 - 6DoF/diffusers/models/unet_3d_condition.py +627 -0
 - 6DoF/diffusers/models/vae.py +441 -0
 - 6DoF/diffusers/models/vae_flax.py +869 -0
 - 6DoF/diffusers/models/vq_model.py +167 -0
 - 6DoF/diffusers/optimization.py +354 -0
 - 6DoF/diffusers/pipeline_utils.py +29 -0
 
    	
        6DoF/CN_encoder.py
    ADDED
    
    | 
         @@ -0,0 +1,36 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from transformers import ConvNextV2Model
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            from typing import Optional
         
     | 
| 4 | 
         
            +
            import einops
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            class CN_encoder(ConvNextV2Model):
         
     | 
| 7 | 
         
            +
                def __init__(self, config):
         
     | 
| 8 | 
         
            +
                    super().__init__(config)
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
                def forward(
         
     | 
| 11 | 
         
            +
                        self,
         
     | 
| 12 | 
         
            +
                        pixel_values: torch.FloatTensor = None,
         
     | 
| 13 | 
         
            +
                        output_hidden_states: Optional[bool] = None,
         
     | 
| 14 | 
         
            +
                        return_dict: Optional[bool] = None,
         
     | 
| 15 | 
         
            +
                ):
         
     | 
| 16 | 
         
            +
                    output_hidden_states = (
         
     | 
| 17 | 
         
            +
                        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
         
     | 
| 18 | 
         
            +
                    )
         
     | 
| 19 | 
         
            +
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                    if pixel_values is None:
         
     | 
| 22 | 
         
            +
                        raise ValueError("You have to specify pixel_values")
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                    embedding_output = self.embeddings(pixel_values)
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                    encoder_outputs = self.encoder(
         
     | 
| 27 | 
         
            +
                        embedding_output,
         
     | 
| 28 | 
         
            +
                        output_hidden_states=output_hidden_states,
         
     | 
| 29 | 
         
            +
                        return_dict=return_dict,
         
     | 
| 30 | 
         
            +
                    )
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                    last_hidden_state = encoder_outputs[0]
         
     | 
| 33 | 
         
            +
                    image_embeddings = einops.rearrange(last_hidden_state, 'b c h w -> b (h w) c')
         
     | 
| 34 | 
         
            +
                    image_embeddings = self.layernorm(image_embeddings)
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                    return image_embeddings
         
     | 
    	
        6DoF/dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,176 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import math
         
     | 
| 3 | 
         
            +
            from pathlib import Path
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torchvision
         
     | 
| 6 | 
         
            +
            from torch.utils.data import Dataset, DataLoader
         
     | 
| 7 | 
         
            +
            from torchvision import transforms
         
     | 
| 8 | 
         
            +
            from PIL import Image
         
     | 
| 9 | 
         
            +
            import numpy as np
         
     | 
| 10 | 
         
            +
            import webdataset as wds
         
     | 
| 11 | 
         
            +
            from torch.utils.data.distributed import DistributedSampler
         
     | 
| 12 | 
         
            +
            import matplotlib.pyplot as plt
         
     | 
| 13 | 
         
            +
            import sys
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            class ObjaverseDataLoader():
         
     | 
| 16 | 
         
            +
                def __init__(self, root_dir, batch_size, total_view=12, num_workers=4):
         
     | 
| 17 | 
         
            +
                    self.root_dir = root_dir
         
     | 
| 18 | 
         
            +
                    self.batch_size = batch_size
         
     | 
| 19 | 
         
            +
                    self.num_workers = num_workers
         
     | 
| 20 | 
         
            +
                    self.total_view = total_view
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                    image_transforms = [torchvision.transforms.Resize((256, 256)),
         
     | 
| 23 | 
         
            +
                                        transforms.ToTensor(),
         
     | 
| 24 | 
         
            +
                                        transforms.Normalize([0.5], [0.5])]
         
     | 
| 25 | 
         
            +
                    self.image_transforms = torchvision.transforms.Compose(image_transforms)
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                def train_dataloader(self):
         
     | 
| 28 | 
         
            +
                    dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=False,
         
     | 
| 29 | 
         
            +
                                            image_transforms=self.image_transforms)
         
     | 
| 30 | 
         
            +
                    # sampler = DistributedSampler(dataset)
         
     | 
| 31 | 
         
            +
                    return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
         
     | 
| 32 | 
         
            +
                                         # sampler=sampler)
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                def val_dataloader(self):
         
     | 
| 35 | 
         
            +
                    dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=True,
         
     | 
| 36 | 
         
            +
                                            image_transforms=self.image_transforms)
         
     | 
| 37 | 
         
            +
                    sampler = DistributedSampler(dataset)
         
     | 
| 38 | 
         
            +
                    return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            def get_pose(transformation):
         
     | 
| 41 | 
         
            +
                # transformation: 4x4
         
     | 
| 42 | 
         
            +
                return transformation
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            class ObjaverseData(Dataset):
         
     | 
| 45 | 
         
            +
                def __init__(self,
         
     | 
| 46 | 
         
            +
                             root_dir='.objaverse/hf-objaverse-v1/views',
         
     | 
| 47 | 
         
            +
                             image_transforms=None,
         
     | 
| 48 | 
         
            +
                             total_view=12,
         
     | 
| 49 | 
         
            +
                             validation=False,
         
     | 
| 50 | 
         
            +
                             T_in=1,
         
     | 
| 51 | 
         
            +
                             T_out=1,
         
     | 
| 52 | 
         
            +
                             fix_sample=False,
         
     | 
| 53 | 
         
            +
                             ) -> None:
         
     | 
| 54 | 
         
            +
                    """Create a dataset from a folder of images.
         
     | 
| 55 | 
         
            +
                    If you pass in a root directory it will be searched for images
         
     | 
| 56 | 
         
            +
                    ending in ext (ext can be a list)
         
     | 
| 57 | 
         
            +
                    """
         
     | 
| 58 | 
         
            +
                    self.root_dir = Path(root_dir)
         
     | 
| 59 | 
         
            +
                    self.total_view = total_view
         
     | 
| 60 | 
         
            +
                    self.T_in = T_in
         
     | 
| 61 | 
         
            +
                    self.T_out = T_out
         
     | 
| 62 | 
         
            +
                    self.fix_sample = fix_sample
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    self.paths = []
         
     | 
| 65 | 
         
            +
                    # # include all folders
         
     | 
| 66 | 
         
            +
                    # for folder in os.listdir(self.root_dir):
         
     | 
| 67 | 
         
            +
                    #     if os.path.isdir(os.path.join(self.root_dir, folder)):
         
     | 
| 68 | 
         
            +
                    #         self.paths.append(folder)
         
     | 
| 69 | 
         
            +
                    # load ids from .npy so we have exactly the same ids/order
         
     | 
| 70 | 
         
            +
                    self.paths = np.load("../scripts/obj_ids.npy")
         
     | 
| 71 | 
         
            +
                    # # only use 100K objects for ablation study
         
     | 
| 72 | 
         
            +
                    # self.paths = self.paths[:100000]
         
     | 
| 73 | 
         
            +
                    total_objects = len(self.paths)
         
     | 
| 74 | 
         
            +
                    assert total_objects == 790152, 'total objects %d' % total_objects
         
     | 
| 75 | 
         
            +
                    if validation:
         
     | 
| 76 | 
         
            +
                        self.paths = self.paths[math.floor(total_objects / 100. * 99.):]  # used last 1% as validation
         
     | 
| 77 | 
         
            +
                    else:
         
     | 
| 78 | 
         
            +
                        self.paths = self.paths[:math.floor(total_objects / 100. * 99.)]  # used first 99% as training
         
     | 
| 79 | 
         
            +
                    print('============= length of dataset %d =============' % len(self.paths))
         
     | 
| 80 | 
         
            +
                    self.tform = image_transforms
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                    downscale = 512 / 256.
         
     | 
| 83 | 
         
            +
                    self.fx = 560. / downscale
         
     | 
| 84 | 
         
            +
                    self.fy = 560. / downscale
         
     | 
| 85 | 
         
            +
                    self.intrinsic = torch.tensor([[self.fx, 0, 128., 0, self.fy, 128., 0, 0, 1.]], dtype=torch.float64).view(3, 3)
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                def __len__(self):
         
     | 
| 88 | 
         
            +
                    return len(self.paths)
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                def get_pose(self, transformation):
         
     | 
| 91 | 
         
            +
                    # transformation: 4x4
         
     | 
| 92 | 
         
            +
                    return transformation
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                def load_im(self, path, color):
         
     | 
| 96 | 
         
            +
                    '''
         
     | 
| 97 | 
         
            +
                    replace background pixel with random color in rendering
         
     | 
| 98 | 
         
            +
                    '''
         
     | 
| 99 | 
         
            +
                    try:
         
     | 
| 100 | 
         
            +
                        img = plt.imread(path)
         
     | 
| 101 | 
         
            +
                    except:
         
     | 
| 102 | 
         
            +
                        print(path)
         
     | 
| 103 | 
         
            +
                        sys.exit()
         
     | 
| 104 | 
         
            +
                    img[img[:, :, -1] == 0.] = color
         
     | 
| 105 | 
         
            +
                    img = Image.fromarray(np.uint8(img[:, :, :3] * 255.))
         
     | 
| 106 | 
         
            +
                    return img
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 109 | 
         
            +
                    data = {}
         
     | 
| 110 | 
         
            +
                    total_view = 12
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    if self.fix_sample:
         
     | 
| 113 | 
         
            +
                        if self.T_out > 1:
         
     | 
| 114 | 
         
            +
                            indexes = range(total_view)
         
     | 
| 115 | 
         
            +
                            index_targets = list(indexes[:2]) + list(indexes[-(self.T_out-2):])
         
     | 
| 116 | 
         
            +
                            index_inputs = indexes[1:self.T_in+1]   # one overlap identity
         
     | 
| 117 | 
         
            +
                        else:
         
     | 
| 118 | 
         
            +
                            indexes = range(total_view)
         
     | 
| 119 | 
         
            +
                            index_targets = indexes[:self.T_out]
         
     | 
| 120 | 
         
            +
                            index_inputs = indexes[self.T_out-1:self.T_in+self.T_out-1] # one overlap identity
         
     | 
| 121 | 
         
            +
                    else:
         
     | 
| 122 | 
         
            +
                        assert self.T_in + self.T_out <= total_view
         
     | 
| 123 | 
         
            +
                        # training with replace, including identity
         
     | 
| 124 | 
         
            +
                        indexes = np.random.choice(range(total_view), self.T_in+self.T_out, replace=True)
         
     | 
| 125 | 
         
            +
                        index_inputs = indexes[:self.T_in]
         
     | 
| 126 | 
         
            +
                        index_targets = indexes[self.T_in:]
         
     | 
| 127 | 
         
            +
                    filename = os.path.join(self.root_dir, self.paths[index])
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    color = [1., 1., 1., 1.]
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                    try:
         
     | 
| 132 | 
         
            +
                        input_ims = []
         
     | 
| 133 | 
         
            +
                        target_ims = []
         
     | 
| 134 | 
         
            +
                        target_Ts = []
         
     | 
| 135 | 
         
            +
                        cond_Ts = []
         
     | 
| 136 | 
         
            +
                        for i, index_input in enumerate(index_inputs):
         
     | 
| 137 | 
         
            +
                            input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color))
         
     | 
| 138 | 
         
            +
                            input_ims.append(input_im)
         
     | 
| 139 | 
         
            +
                            input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input))
         
     | 
| 140 | 
         
            +
                            cond_Ts.append(self.get_pose(np.concatenate([input_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
         
     | 
| 141 | 
         
            +
                        for i, index_target in enumerate(index_targets):
         
     | 
| 142 | 
         
            +
                            target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color))
         
     | 
| 143 | 
         
            +
                            target_ims.append(target_im)
         
     | 
| 144 | 
         
            +
                            target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
         
     | 
| 145 | 
         
            +
                            target_Ts.append(self.get_pose(np.concatenate([target_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
         
     | 
| 146 | 
         
            +
                    except:
         
     | 
| 147 | 
         
            +
                        print('error loading data ', filename)
         
     | 
| 148 | 
         
            +
                        filename = os.path.join(self.root_dir, '0a01f314e2864711aa7e33bace4bd8c8')  # this one we know is valid
         
     | 
| 149 | 
         
            +
                        input_ims = []
         
     | 
| 150 | 
         
            +
                        target_ims = []
         
     | 
| 151 | 
         
            +
                        target_Ts = []
         
     | 
| 152 | 
         
            +
                        cond_Ts = []
         
     | 
| 153 | 
         
            +
                        # very hacky solution, sorry about this
         
     | 
| 154 | 
         
            +
                        for i, index_input in enumerate(index_inputs):
         
     | 
| 155 | 
         
            +
                            input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color))
         
     | 
| 156 | 
         
            +
                            input_ims.append(input_im)
         
     | 
| 157 | 
         
            +
                            input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input))
         
     | 
| 158 | 
         
            +
                            cond_Ts.append(self.get_pose(np.concatenate([input_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
         
     | 
| 159 | 
         
            +
                        for i, index_target in enumerate(index_targets):
         
     | 
| 160 | 
         
            +
                            target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color))
         
     | 
| 161 | 
         
            +
                            target_ims.append(target_im)
         
     | 
| 162 | 
         
            +
                            target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
         
     | 
| 163 | 
         
            +
                            target_Ts.append(self.get_pose(np.concatenate([target_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                    # stack to batch
         
     | 
| 166 | 
         
            +
                    data['image_input'] = torch.stack(input_ims, dim=0)
         
     | 
| 167 | 
         
            +
                    data['image_target'] = torch.stack(target_ims, dim=0)
         
     | 
| 168 | 
         
            +
                    data['pose_out'] = np.stack(target_Ts)
         
     | 
| 169 | 
         
            +
                    data['pose_out_inv'] = np.linalg.inv(np.stack(target_Ts)).transpose([0, 2, 1])
         
     | 
| 170 | 
         
            +
                    data['pose_in'] = np.stack(cond_Ts)
         
     | 
| 171 | 
         
            +
                    data['pose_in_inv'] = np.linalg.inv(np.stack(cond_Ts)).transpose([0, 2, 1])
         
     | 
| 172 | 
         
            +
                    return data
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                def process_im(self, im):
         
     | 
| 175 | 
         
            +
                    im = im.convert("RGB")
         
     | 
| 176 | 
         
            +
                    return self.tform(im)
         
     | 
    	
        6DoF/diffusers/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,281 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            __version__ = "0.18.2"
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from .configuration_utils import ConfigMixin
         
     | 
| 4 | 
         
            +
            from .utils import (
         
     | 
| 5 | 
         
            +
                OptionalDependencyNotAvailable,
         
     | 
| 6 | 
         
            +
                is_flax_available,
         
     | 
| 7 | 
         
            +
                is_inflect_available,
         
     | 
| 8 | 
         
            +
                is_invisible_watermark_available,
         
     | 
| 9 | 
         
            +
                is_k_diffusion_available,
         
     | 
| 10 | 
         
            +
                is_k_diffusion_version,
         
     | 
| 11 | 
         
            +
                is_librosa_available,
         
     | 
| 12 | 
         
            +
                is_note_seq_available,
         
     | 
| 13 | 
         
            +
                is_onnx_available,
         
     | 
| 14 | 
         
            +
                is_scipy_available,
         
     | 
| 15 | 
         
            +
                is_torch_available,
         
     | 
| 16 | 
         
            +
                is_torchsde_available,
         
     | 
| 17 | 
         
            +
                is_transformers_available,
         
     | 
| 18 | 
         
            +
                is_transformers_version,
         
     | 
| 19 | 
         
            +
                is_unidecode_available,
         
     | 
| 20 | 
         
            +
                logging,
         
     | 
| 21 | 
         
            +
            )
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            try:
         
     | 
| 25 | 
         
            +
                if not is_onnx_available():
         
     | 
| 26 | 
         
            +
                    raise OptionalDependencyNotAvailable()
         
     | 
| 27 | 
         
            +
            except OptionalDependencyNotAvailable:
         
     | 
| 28 | 
         
            +
                from .utils.dummy_onnx_objects import *  # noqa F403
         
     | 
| 29 | 
         
            +
            else:
         
     | 
| 30 | 
         
            +
                from .pipelines import OnnxRuntimeModel
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            try:
         
     | 
| 33 | 
         
            +
                if not is_torch_available():
         
     | 
| 34 | 
         
            +
                    raise OptionalDependencyNotAvailable()
         
     | 
| 35 | 
         
            +
            except OptionalDependencyNotAvailable:
         
     | 
| 36 | 
         
            +
                from .utils.dummy_pt_objects import *  # noqa F403
         
     | 
| 37 | 
         
            +
            else:
         
     | 
| 38 | 
         
            +
                from .models import (
         
     | 
| 39 | 
         
            +
                    AutoencoderKL,
         
     | 
| 40 | 
         
            +
                    ControlNetModel,
         
     | 
| 41 | 
         
            +
                    ModelMixin,
         
     | 
| 42 | 
         
            +
                    PriorTransformer,
         
     | 
| 43 | 
         
            +
                    T5FilmDecoder,
         
     | 
| 44 | 
         
            +
                    Transformer2DModel,
         
     | 
| 45 | 
         
            +
                    UNet1DModel,
         
     | 
| 46 | 
         
            +
                    UNet2DConditionModel,
         
     | 
| 47 | 
         
            +
                    UNet2DModel,
         
     | 
| 48 | 
         
            +
                    UNet3DConditionModel,
         
     | 
| 49 | 
         
            +
                    VQModel,
         
     | 
| 50 | 
         
            +
                )
         
     | 
| 51 | 
         
            +
                from .optimization import (
         
     | 
| 52 | 
         
            +
                    get_constant_schedule,
         
     | 
| 53 | 
         
            +
                    get_constant_schedule_with_warmup,
         
     | 
| 54 | 
         
            +
                    get_cosine_schedule_with_warmup,
         
     | 
| 55 | 
         
            +
                    get_cosine_with_hard_restarts_schedule_with_warmup,
         
     | 
| 56 | 
         
            +
                    get_linear_schedule_with_warmup,
         
     | 
| 57 | 
         
            +
                    get_polynomial_decay_schedule_with_warmup,
         
     | 
| 58 | 
         
            +
                    get_scheduler,
         
     | 
| 59 | 
         
            +
                )
         
     | 
| 60 | 
         
            +
                from .pipelines import (
         
     | 
| 61 | 
         
            +
                    AudioPipelineOutput,
         
     | 
| 62 | 
         
            +
                    ConsistencyModelPipeline,
         
     | 
| 63 | 
         
            +
                    DanceDiffusionPipeline,
         
     | 
| 64 | 
         
            +
                    DDIMPipeline,
         
     | 
| 65 | 
         
            +
                    DDPMPipeline,
         
     | 
| 66 | 
         
            +
                    DiffusionPipeline,
         
     | 
| 67 | 
         
            +
                    DiTPipeline,
         
     | 
| 68 | 
         
            +
                    ImagePipelineOutput,
         
     | 
| 69 | 
         
            +
                    KarrasVePipeline,
         
     | 
| 70 | 
         
            +
                    LDMPipeline,
         
     | 
| 71 | 
         
            +
                    LDMSuperResolutionPipeline,
         
     | 
| 72 | 
         
            +
                    PNDMPipeline,
         
     | 
| 73 | 
         
            +
                    RePaintPipeline,
         
     | 
| 74 | 
         
            +
                    ScoreSdeVePipeline,
         
     | 
| 75 | 
         
            +
                )
         
     | 
| 76 | 
         
            +
                from .schedulers import (
         
     | 
| 77 | 
         
            +
                    CMStochasticIterativeScheduler,
         
     | 
| 78 | 
         
            +
                    DDIMInverseScheduler,
         
     | 
| 79 | 
         
            +
                    DDIMParallelScheduler,
         
     | 
| 80 | 
         
            +
                    DDIMScheduler,
         
     | 
| 81 | 
         
            +
                    DDPMParallelScheduler,
         
     | 
| 82 | 
         
            +
                    DDPMScheduler,
         
     | 
| 83 | 
         
            +
                    DEISMultistepScheduler,
         
     | 
| 84 | 
         
            +
                    DPMSolverMultistepInverseScheduler,
         
     | 
| 85 | 
         
            +
                    DPMSolverMultistepScheduler,
         
     | 
| 86 | 
         
            +
                    DPMSolverSinglestepScheduler,
         
     | 
| 87 | 
         
            +
                    EulerAncestralDiscreteScheduler,
         
     | 
| 88 | 
         
            +
                    EulerDiscreteScheduler,
         
     | 
| 89 | 
         
            +
                    HeunDiscreteScheduler,
         
     | 
| 90 | 
         
            +
                    IPNDMScheduler,
         
     | 
| 91 | 
         
            +
                    KarrasVeScheduler,
         
     | 
| 92 | 
         
            +
                    KDPM2AncestralDiscreteScheduler,
         
     | 
| 93 | 
         
            +
                    KDPM2DiscreteScheduler,
         
     | 
| 94 | 
         
            +
                    PNDMScheduler,
         
     | 
| 95 | 
         
            +
                    RePaintScheduler,
         
     | 
| 96 | 
         
            +
                    SchedulerMixin,
         
     | 
| 97 | 
         
            +
                    ScoreSdeVeScheduler,
         
     | 
| 98 | 
         
            +
                    UnCLIPScheduler,
         
     | 
| 99 | 
         
            +
                    UniPCMultistepScheduler,
         
     | 
| 100 | 
         
            +
                    VQDiffusionScheduler,
         
     | 
| 101 | 
         
            +
                )
         
     | 
| 102 | 
         
            +
                from .training_utils import EMAModel
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
            try:
         
     | 
| 105 | 
         
            +
                if not (is_torch_available() and is_scipy_available()):
         
     | 
| 106 | 
         
            +
                    raise OptionalDependencyNotAvailable()
         
     | 
| 107 | 
         
            +
            except OptionalDependencyNotAvailable:
         
     | 
| 108 | 
         
            +
                from .utils.dummy_torch_and_scipy_objects import *  # noqa F403
         
     | 
| 109 | 
         
            +
            else:
         
     | 
| 110 | 
         
            +
                from .schedulers import LMSDiscreteScheduler
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
            try:
         
     | 
| 113 | 
         
            +
                if not (is_torch_available() and is_torchsde_available()):
         
     | 
| 114 | 
         
            +
                    raise OptionalDependencyNotAvailable()
         
     | 
| 115 | 
         
            +
            except OptionalDependencyNotAvailable:
         
     | 
| 116 | 
         
            +
                from .utils.dummy_torch_and_torchsde_objects import *  # noqa F403
         
     | 
| 117 | 
         
            +
            else:
         
     | 
| 118 | 
         
            +
                from .schedulers import DPMSolverSDEScheduler
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
            try:
         
     | 
| 121 | 
         
            +
                if not (is_torch_available() and is_transformers_available()):
         
     | 
| 122 | 
         
            +
                    raise OptionalDependencyNotAvailable()
         
     | 
| 123 | 
         
            +
            except OptionalDependencyNotAvailable:
         
     | 
| 124 | 
         
            +
                from .utils.dummy_torch_and_transformers_objects import *  # noqa F403
         
     | 
| 125 | 
         
            +
            else:
         
     | 
| 126 | 
         
            +
                from .pipelines import (
         
     | 
| 127 | 
         
            +
                    AltDiffusionImg2ImgPipeline,
         
     | 
| 128 | 
         
            +
                    AltDiffusionPipeline,
         
     | 
| 129 | 
         
            +
                    AudioLDMPipeline,
         
     | 
| 130 | 
         
            +
                    CycleDiffusionPipeline,
         
     | 
| 131 | 
         
            +
                    IFImg2ImgPipeline,
         
     | 
| 132 | 
         
            +
                    IFImg2ImgSuperResolutionPipeline,
         
     | 
| 133 | 
         
            +
                    IFInpaintingPipeline,
         
     | 
| 134 | 
         
            +
                    IFInpaintingSuperResolutionPipeline,
         
     | 
| 135 | 
         
            +
                    IFPipeline,
         
     | 
| 136 | 
         
            +
                    IFSuperResolutionPipeline,
         
     | 
| 137 | 
         
            +
                    ImageTextPipelineOutput,
         
     | 
| 138 | 
         
            +
                    KandinskyImg2ImgPipeline,
         
     | 
| 139 | 
         
            +
                    KandinskyInpaintPipeline,
         
     | 
| 140 | 
         
            +
                    KandinskyPipeline,
         
     | 
| 141 | 
         
            +
                    KandinskyPriorPipeline,
         
     | 
| 142 | 
         
            +
                    KandinskyV22ControlnetImg2ImgPipeline,
         
     | 
| 143 | 
         
            +
                    KandinskyV22ControlnetPipeline,
         
     | 
| 144 | 
         
            +
                    KandinskyV22Img2ImgPipeline,
         
     | 
| 145 | 
         
            +
                    KandinskyV22InpaintPipeline,
         
     | 
| 146 | 
         
            +
                    KandinskyV22Pipeline,
         
     | 
| 147 | 
         
            +
                    KandinskyV22PriorEmb2EmbPipeline,
         
     | 
| 148 | 
         
            +
                    KandinskyV22PriorPipeline,
         
     | 
| 149 | 
         
            +
                    LDMTextToImagePipeline,
         
     | 
| 150 | 
         
            +
                    PaintByExamplePipeline,
         
     | 
| 151 | 
         
            +
                    SemanticStableDiffusionPipeline,
         
     | 
| 152 | 
         
            +
                    ShapEImg2ImgPipeline,
         
     | 
| 153 | 
         
            +
                    ShapEPipeline,
         
     | 
| 154 | 
         
            +
                    StableDiffusionAttendAndExcitePipeline,
         
     | 
| 155 | 
         
            +
                    StableDiffusionControlNetImg2ImgPipeline,
         
     | 
| 156 | 
         
            +
                    StableDiffusionControlNetInpaintPipeline,
         
     | 
| 157 | 
         
            +
                    StableDiffusionControlNetPipeline,
         
     | 
| 158 | 
         
            +
                    StableDiffusionDepth2ImgPipeline,
         
     | 
| 159 | 
         
            +
                    StableDiffusionDiffEditPipeline,
         
     | 
| 160 | 
         
            +
                    StableDiffusionImageVariationPipeline,
         
     | 
| 161 | 
         
            +
                    StableDiffusionImg2ImgPipeline,
         
     | 
| 162 | 
         
            +
                    StableDiffusionInpaintPipeline,
         
     | 
| 163 | 
         
            +
                    StableDiffusionInpaintPipelineLegacy,
         
     | 
| 164 | 
         
            +
                    StableDiffusionInstructPix2PixPipeline,
         
     | 
| 165 | 
         
            +
                    StableDiffusionLatentUpscalePipeline,
         
     | 
| 166 | 
         
            +
                    StableDiffusionLDM3DPipeline,
         
     | 
| 167 | 
         
            +
                    StableDiffusionModelEditingPipeline,
         
     | 
| 168 | 
         
            +
                    StableDiffusionPanoramaPipeline,
         
     | 
| 169 | 
         
            +
                    StableDiffusionParadigmsPipeline,
         
     | 
| 170 | 
         
            +
                    StableDiffusionPipeline,
         
     | 
| 171 | 
         
            +
                    StableDiffusionPipelineSafe,
         
     | 
| 172 | 
         
            +
                    StableDiffusionPix2PixZeroPipeline,
         
     | 
| 173 | 
         
            +
                    StableDiffusionSAGPipeline,
         
     | 
| 174 | 
         
            +
                    StableDiffusionUpscalePipeline,
         
     | 
| 175 | 
         
            +
                    StableUnCLIPImg2ImgPipeline,
         
     | 
| 176 | 
         
            +
                    StableUnCLIPPipeline,
         
     | 
| 177 | 
         
            +
                    TextToVideoSDPipeline,
         
     | 
| 178 | 
         
            +
                    TextToVideoZeroPipeline,
         
     | 
| 179 | 
         
            +
                    UnCLIPImageVariationPipeline,
         
     | 
| 180 | 
         
            +
                    UnCLIPPipeline,
         
     | 
| 181 | 
         
            +
                    UniDiffuserModel,
         
     | 
| 182 | 
         
            +
                    UniDiffuserPipeline,
         
     | 
| 183 | 
         
            +
                    UniDiffuserTextDecoder,
         
     | 
| 184 | 
         
            +
                    VersatileDiffusionDualGuidedPipeline,
         
     | 
| 185 | 
         
            +
                    VersatileDiffusionImageVariationPipeline,
         
     | 
| 186 | 
         
            +
                    VersatileDiffusionPipeline,
         
     | 
| 187 | 
         
            +
                    VersatileDiffusionTextToImagePipeline,
         
     | 
| 188 | 
         
            +
                    VideoToVideoSDPipeline,
         
     | 
| 189 | 
         
            +
                    VQDiffusionPipeline,
         
     | 
| 190 | 
         
            +
                )
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
            try:
         
     | 
| 193 | 
         
            +
                if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()):
         
     | 
| 194 | 
         
            +
                    raise OptionalDependencyNotAvailable()
         
     | 
| 195 | 
         
            +
            except OptionalDependencyNotAvailable:
         
     | 
| 196 | 
         
            +
                from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import *  # noqa F403
         
     | 
| 197 | 
         
            +
            else:
         
     | 
| 198 | 
         
            +
                from .pipelines import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
            try:
         
     | 
| 201 | 
         
            +
                if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
         
     | 
| 202 | 
         
            +
                    raise OptionalDependencyNotAvailable()
         
     | 
| 203 | 
         
            +
            except OptionalDependencyNotAvailable:
         
     | 
| 204 | 
         
            +
                from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import *  # noqa F403
         
     | 
| 205 | 
         
            +
            else:
         
     | 
| 206 | 
         
            +
                from .pipelines import StableDiffusionKDiffusionPipeline
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
            try:
         
     | 
| 209 | 
         
            +
                if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
         
     | 
| 210 | 
         
            +
                    raise OptionalDependencyNotAvailable()
         
     | 
| 211 | 
         
            +
            except OptionalDependencyNotAvailable:
         
     | 
| 212 | 
         
            +
                from .utils.dummy_torch_and_transformers_and_onnx_objects import *  # noqa F403
         
     | 
| 213 | 
         
            +
            else:
         
     | 
| 214 | 
         
            +
                from .pipelines import (
         
     | 
| 215 | 
         
            +
                    OnnxStableDiffusionImg2ImgPipeline,
         
     | 
| 216 | 
         
            +
                    OnnxStableDiffusionInpaintPipeline,
         
     | 
| 217 | 
         
            +
                    OnnxStableDiffusionInpaintPipelineLegacy,
         
     | 
| 218 | 
         
            +
                    OnnxStableDiffusionPipeline,
         
     | 
| 219 | 
         
            +
                    OnnxStableDiffusionUpscalePipeline,
         
     | 
| 220 | 
         
            +
                    StableDiffusionOnnxPipeline,
         
     | 
| 221 | 
         
            +
                )
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
            try:
         
     | 
| 224 | 
         
            +
                if not (is_torch_available() and is_librosa_available()):
         
     | 
| 225 | 
         
            +
                    raise OptionalDependencyNotAvailable()
         
     | 
| 226 | 
         
            +
            except OptionalDependencyNotAvailable:
         
     | 
| 227 | 
         
            +
                from .utils.dummy_torch_and_librosa_objects import *  # noqa F403
         
     | 
| 228 | 
         
            +
            else:
         
     | 
| 229 | 
         
            +
                from .pipelines import AudioDiffusionPipeline, Mel
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
            try:
         
     | 
| 232 | 
         
            +
                if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
         
     | 
| 233 | 
         
            +
                    raise OptionalDependencyNotAvailable()
         
     | 
| 234 | 
         
            +
            except OptionalDependencyNotAvailable:
         
     | 
| 235 | 
         
            +
                from .utils.dummy_transformers_and_torch_and_note_seq_objects import *  # noqa F403
         
     | 
| 236 | 
         
            +
            else:
         
     | 
| 237 | 
         
            +
                from .pipelines import SpectrogramDiffusionPipeline
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
            try:
         
     | 
| 240 | 
         
            +
                if not is_flax_available():
         
     | 
| 241 | 
         
            +
                    raise OptionalDependencyNotAvailable()
         
     | 
| 242 | 
         
            +
            except OptionalDependencyNotAvailable:
         
     | 
| 243 | 
         
            +
                from .utils.dummy_flax_objects import *  # noqa F403
         
     | 
| 244 | 
         
            +
            else:
         
     | 
| 245 | 
         
            +
                from .models.controlnet_flax import FlaxControlNetModel
         
     | 
| 246 | 
         
            +
                from .models.modeling_flax_utils import FlaxModelMixin
         
     | 
| 247 | 
         
            +
                from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
         
     | 
| 248 | 
         
            +
                from .models.vae_flax import FlaxAutoencoderKL
         
     | 
| 249 | 
         
            +
                from .pipelines import FlaxDiffusionPipeline
         
     | 
| 250 | 
         
            +
                from .schedulers import (
         
     | 
| 251 | 
         
            +
                    FlaxDDIMScheduler,
         
     | 
| 252 | 
         
            +
                    FlaxDDPMScheduler,
         
     | 
| 253 | 
         
            +
                    FlaxDPMSolverMultistepScheduler,
         
     | 
| 254 | 
         
            +
                    FlaxKarrasVeScheduler,
         
     | 
| 255 | 
         
            +
                    FlaxLMSDiscreteScheduler,
         
     | 
| 256 | 
         
            +
                    FlaxPNDMScheduler,
         
     | 
| 257 | 
         
            +
                    FlaxSchedulerMixin,
         
     | 
| 258 | 
         
            +
                    FlaxScoreSdeVeScheduler,
         
     | 
| 259 | 
         
            +
                )
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
            try:
         
     | 
| 263 | 
         
            +
                if not (is_flax_available() and is_transformers_available()):
         
     | 
| 264 | 
         
            +
                    raise OptionalDependencyNotAvailable()
         
     | 
| 265 | 
         
            +
            except OptionalDependencyNotAvailable:
         
     | 
| 266 | 
         
            +
                from .utils.dummy_flax_and_transformers_objects import *  # noqa F403
         
     | 
| 267 | 
         
            +
            else:
         
     | 
| 268 | 
         
            +
                from .pipelines import (
         
     | 
| 269 | 
         
            +
                    FlaxStableDiffusionControlNetPipeline,
         
     | 
| 270 | 
         
            +
                    FlaxStableDiffusionImg2ImgPipeline,
         
     | 
| 271 | 
         
            +
                    FlaxStableDiffusionInpaintPipeline,
         
     | 
| 272 | 
         
            +
                    FlaxStableDiffusionPipeline,
         
     | 
| 273 | 
         
            +
                )
         
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
            try:
         
     | 
| 276 | 
         
            +
                if not (is_note_seq_available()):
         
     | 
| 277 | 
         
            +
                    raise OptionalDependencyNotAvailable()
         
     | 
| 278 | 
         
            +
            except OptionalDependencyNotAvailable:
         
     | 
| 279 | 
         
            +
                from .utils.dummy_note_seq_objects import *  # noqa F403
         
     | 
| 280 | 
         
            +
            else:
         
     | 
| 281 | 
         
            +
                from .pipelines import MidiProcessor
         
     | 
    	
        6DoF/diffusers/commands/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,27 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            from abc import ABC, abstractmethod
         
     | 
| 16 | 
         
            +
            from argparse import ArgumentParser
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            class BaseDiffusersCLICommand(ABC):
         
     | 
| 20 | 
         
            +
                @staticmethod
         
     | 
| 21 | 
         
            +
                @abstractmethod
         
     | 
| 22 | 
         
            +
                def register_subcommand(parser: ArgumentParser):
         
     | 
| 23 | 
         
            +
                    raise NotImplementedError()
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                @abstractmethod
         
     | 
| 26 | 
         
            +
                def run(self):
         
     | 
| 27 | 
         
            +
                    raise NotImplementedError()
         
     | 
    	
        6DoF/diffusers/commands/diffusers_cli.py
    ADDED
    
    | 
         @@ -0,0 +1,41 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            #!/usr/bin/env python
         
     | 
| 2 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 3 | 
         
            +
            #
         
     | 
| 4 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 5 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 6 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 7 | 
         
            +
            #
         
     | 
| 8 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 9 | 
         
            +
            #
         
     | 
| 10 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 11 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 12 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 13 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 14 | 
         
            +
            # limitations under the License.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            from argparse import ArgumentParser
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            from .env import EnvironmentCommand
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            def main():
         
     | 
| 22 | 
         
            +
                parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
         
     | 
| 23 | 
         
            +
                commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                # Register commands
         
     | 
| 26 | 
         
            +
                EnvironmentCommand.register_subcommand(commands_parser)
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                # Let's go
         
     | 
| 29 | 
         
            +
                args = parser.parse_args()
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                if not hasattr(args, "func"):
         
     | 
| 32 | 
         
            +
                    parser.print_help()
         
     | 
| 33 | 
         
            +
                    exit(1)
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                # Run
         
     | 
| 36 | 
         
            +
                service = args.func(args)
         
     | 
| 37 | 
         
            +
                service.run()
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 41 | 
         
            +
                main()
         
     | 
    	
        6DoF/diffusers/commands/env.py
    ADDED
    
    | 
         @@ -0,0 +1,84 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            import platform
         
     | 
| 16 | 
         
            +
            from argparse import ArgumentParser
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            import huggingface_hub
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            from .. import __version__ as version
         
     | 
| 21 | 
         
            +
            from ..utils import is_accelerate_available, is_torch_available, is_transformers_available, is_xformers_available
         
     | 
| 22 | 
         
            +
            from . import BaseDiffusersCLICommand
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            def info_command_factory(_):
         
     | 
| 26 | 
         
            +
                return EnvironmentCommand()
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            class EnvironmentCommand(BaseDiffusersCLICommand):
         
     | 
| 30 | 
         
            +
                @staticmethod
         
     | 
| 31 | 
         
            +
                def register_subcommand(parser: ArgumentParser):
         
     | 
| 32 | 
         
            +
                    download_parser = parser.add_parser("env")
         
     | 
| 33 | 
         
            +
                    download_parser.set_defaults(func=info_command_factory)
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                def run(self):
         
     | 
| 36 | 
         
            +
                    hub_version = huggingface_hub.__version__
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                    pt_version = "not installed"
         
     | 
| 39 | 
         
            +
                    pt_cuda_available = "NA"
         
     | 
| 40 | 
         
            +
                    if is_torch_available():
         
     | 
| 41 | 
         
            +
                        import torch
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                        pt_version = torch.__version__
         
     | 
| 44 | 
         
            +
                        pt_cuda_available = torch.cuda.is_available()
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                    transformers_version = "not installed"
         
     | 
| 47 | 
         
            +
                    if is_transformers_available():
         
     | 
| 48 | 
         
            +
                        import transformers
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                        transformers_version = transformers.__version__
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                    accelerate_version = "not installed"
         
     | 
| 53 | 
         
            +
                    if is_accelerate_available():
         
     | 
| 54 | 
         
            +
                        import accelerate
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                        accelerate_version = accelerate.__version__
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                    xformers_version = "not installed"
         
     | 
| 59 | 
         
            +
                    if is_xformers_available():
         
     | 
| 60 | 
         
            +
                        import xformers
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                        xformers_version = xformers.__version__
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    info = {
         
     | 
| 65 | 
         
            +
                        "`diffusers` version": version,
         
     | 
| 66 | 
         
            +
                        "Platform": platform.platform(),
         
     | 
| 67 | 
         
            +
                        "Python version": platform.python_version(),
         
     | 
| 68 | 
         
            +
                        "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
         
     | 
| 69 | 
         
            +
                        "Huggingface_hub version": hub_version,
         
     | 
| 70 | 
         
            +
                        "Transformers version": transformers_version,
         
     | 
| 71 | 
         
            +
                        "Accelerate version": accelerate_version,
         
     | 
| 72 | 
         
            +
                        "xFormers version": xformers_version,
         
     | 
| 73 | 
         
            +
                        "Using GPU in script?": "<fill in>",
         
     | 
| 74 | 
         
            +
                        "Using distributed or parallel set-up in script?": "<fill in>",
         
     | 
| 75 | 
         
            +
                    }
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                    print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
         
     | 
| 78 | 
         
            +
                    print(self.format_dict(info))
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    return info
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                @staticmethod
         
     | 
| 83 | 
         
            +
                def format_dict(d):
         
     | 
| 84 | 
         
            +
                    return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
         
     | 
    	
        6DoF/diffusers/configuration_utils.py
    ADDED
    
    | 
         @@ -0,0 +1,664 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # coding=utf-8
         
     | 
| 2 | 
         
            +
            # Copyright 2023 The HuggingFace Inc. team.
         
     | 
| 3 | 
         
            +
            # Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
         
     | 
| 4 | 
         
            +
            #
         
     | 
| 5 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 6 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 7 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 10 | 
         
            +
            #
         
     | 
| 11 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 12 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 13 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 14 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 15 | 
         
            +
            # limitations under the License.
         
     | 
| 16 | 
         
            +
            """ ConfigMixin base class and utilities."""
         
     | 
| 17 | 
         
            +
            import dataclasses
         
     | 
| 18 | 
         
            +
            import functools
         
     | 
| 19 | 
         
            +
            import importlib
         
     | 
| 20 | 
         
            +
            import inspect
         
     | 
| 21 | 
         
            +
            import json
         
     | 
| 22 | 
         
            +
            import os
         
     | 
| 23 | 
         
            +
            import re
         
     | 
| 24 | 
         
            +
            from collections import OrderedDict
         
     | 
| 25 | 
         
            +
            from pathlib import PosixPath
         
     | 
| 26 | 
         
            +
            from typing import Any, Dict, Tuple, Union
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            import numpy as np
         
     | 
| 29 | 
         
            +
            from huggingface_hub import hf_hub_download
         
     | 
| 30 | 
         
            +
            from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
         
     | 
| 31 | 
         
            +
            from requests import HTTPError
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            from . import __version__
         
     | 
| 34 | 
         
            +
            from .utils import (
         
     | 
| 35 | 
         
            +
                DIFFUSERS_CACHE,
         
     | 
| 36 | 
         
            +
                HUGGINGFACE_CO_RESOLVE_ENDPOINT,
         
     | 
| 37 | 
         
            +
                DummyObject,
         
     | 
| 38 | 
         
            +
                deprecate,
         
     | 
| 39 | 
         
            +
                extract_commit_hash,
         
     | 
| 40 | 
         
            +
                http_user_agent,
         
     | 
| 41 | 
         
            +
                logging,
         
     | 
| 42 | 
         
            +
            )
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            logger = logging.get_logger(__name__)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            _re_configuration_file = re.compile(r"config\.(.*)\.json")
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            class FrozenDict(OrderedDict):
         
     | 
| 51 | 
         
            +
                def __init__(self, *args, **kwargs):
         
     | 
| 52 | 
         
            +
                    super().__init__(*args, **kwargs)
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    for key, value in self.items():
         
     | 
| 55 | 
         
            +
                        setattr(self, key, value)
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                    self.__frozen = True
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                def __delitem__(self, *args, **kwargs):
         
     | 
| 60 | 
         
            +
                    raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                def setdefault(self, *args, **kwargs):
         
     | 
| 63 | 
         
            +
                    raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                def pop(self, *args, **kwargs):
         
     | 
| 66 | 
         
            +
                    raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                def update(self, *args, **kwargs):
         
     | 
| 69 | 
         
            +
                    raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                def __setattr__(self, name, value):
         
     | 
| 72 | 
         
            +
                    if hasattr(self, "__frozen") and self.__frozen:
         
     | 
| 73 | 
         
            +
                        raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
         
     | 
| 74 | 
         
            +
                    super().__setattr__(name, value)
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                def __setitem__(self, name, value):
         
     | 
| 77 | 
         
            +
                    if hasattr(self, "__frozen") and self.__frozen:
         
     | 
| 78 | 
         
            +
                        raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
         
     | 
| 79 | 
         
            +
                    super().__setitem__(name, value)
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
            class ConfigMixin:
         
     | 
| 83 | 
         
            +
                r"""
         
     | 
| 84 | 
         
            +
                Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
         
     | 
| 85 | 
         
            +
                provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
         
     | 
| 86 | 
         
            +
                saving classes that inherit from [`ConfigMixin`].
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                Class attributes:
         
     | 
| 89 | 
         
            +
                    - **config_name** (`str`) -- A filename under which the config should stored when calling
         
     | 
| 90 | 
         
            +
                      [`~ConfigMixin.save_config`] (should be overridden by parent class).
         
     | 
| 91 | 
         
            +
                    - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
         
     | 
| 92 | 
         
            +
                      overridden by subclass).
         
     | 
| 93 | 
         
            +
                    - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
         
     | 
| 94 | 
         
            +
                    - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
         
     | 
| 95 | 
         
            +
                      should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
         
     | 
| 96 | 
         
            +
                      subclass).
         
     | 
| 97 | 
         
            +
                """
         
     | 
| 98 | 
         
            +
                config_name = None
         
     | 
| 99 | 
         
            +
                ignore_for_config = []
         
     | 
| 100 | 
         
            +
                has_compatibles = False
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                _deprecated_kwargs = []
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                def register_to_config(self, **kwargs):
         
     | 
| 105 | 
         
            +
                    if self.config_name is None:
         
     | 
| 106 | 
         
            +
                        raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
         
     | 
| 107 | 
         
            +
                    # Special case for `kwargs` used in deprecation warning added to schedulers
         
     | 
| 108 | 
         
            +
                    # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
         
     | 
| 109 | 
         
            +
                    # or solve in a more general way.
         
     | 
| 110 | 
         
            +
                    kwargs.pop("kwargs", None)
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    if not hasattr(self, "_internal_dict"):
         
     | 
| 113 | 
         
            +
                        internal_dict = kwargs
         
     | 
| 114 | 
         
            +
                    else:
         
     | 
| 115 | 
         
            +
                        previous_dict = dict(self._internal_dict)
         
     | 
| 116 | 
         
            +
                        internal_dict = {**self._internal_dict, **kwargs}
         
     | 
| 117 | 
         
            +
                        logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                    self._internal_dict = FrozenDict(internal_dict)
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                def __getattr__(self, name: str) -> Any:
         
     | 
| 122 | 
         
            +
                    """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
         
     | 
| 123 | 
         
            +
                    config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
         
     | 
| 126 | 
         
            +
                    https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
         
     | 
| 127 | 
         
            +
                    """
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
         
     | 
| 130 | 
         
            +
                    is_attribute = name in self.__dict__
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    if is_in_config and not is_attribute:
         
     | 
| 133 | 
         
            +
                        deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
         
     | 
| 134 | 
         
            +
                        deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
         
     | 
| 135 | 
         
            +
                        return self._internal_dict[name]
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
         
     | 
| 140 | 
         
            +
                    """
         
     | 
| 141 | 
         
            +
                    Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
         
     | 
| 142 | 
         
            +
                    [`~ConfigMixin.from_config`] class method.
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                    Args:
         
     | 
| 145 | 
         
            +
                        save_directory (`str` or `os.PathLike`):
         
     | 
| 146 | 
         
            +
                            Directory where the configuration JSON file is saved (will be created if it does not exist).
         
     | 
| 147 | 
         
            +
                    """
         
     | 
| 148 | 
         
            +
                    if os.path.isfile(save_directory):
         
     | 
| 149 | 
         
            +
                        raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                    os.makedirs(save_directory, exist_ok=True)
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                    # If we save using the predefined names, we can load using `from_config`
         
     | 
| 154 | 
         
            +
                    output_config_file = os.path.join(save_directory, self.config_name)
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                    self.to_json_file(output_config_file)
         
     | 
| 157 | 
         
            +
                    logger.info(f"Configuration saved in {output_config_file}")
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                @classmethod
         
     | 
| 160 | 
         
            +
                def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
         
     | 
| 161 | 
         
            +
                    r"""
         
     | 
| 162 | 
         
            +
                    Instantiate a Python class from a config dictionary.
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                    Parameters:
         
     | 
| 165 | 
         
            +
                        config (`Dict[str, Any]`):
         
     | 
| 166 | 
         
            +
                            A config dictionary from which the Python class is instantiated. Make sure to only load configuration
         
     | 
| 167 | 
         
            +
                            files of compatible classes.
         
     | 
| 168 | 
         
            +
                        return_unused_kwargs (`bool`, *optional*, defaults to `False`):
         
     | 
| 169 | 
         
            +
                            Whether kwargs that are not consumed by the Python class should be returned or not.
         
     | 
| 170 | 
         
            +
                        kwargs (remaining dictionary of keyword arguments, *optional*):
         
     | 
| 171 | 
         
            +
                            Can be used to update the configuration object (after it is loaded) and initiate the Python class.
         
     | 
| 172 | 
         
            +
                            `**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually
         
     | 
| 173 | 
         
            +
                            overwrite the same named arguments in `config`.
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                    Returns:
         
     | 
| 176 | 
         
            +
                        [`ModelMixin`] or [`SchedulerMixin`]:
         
     | 
| 177 | 
         
            +
                            A model or scheduler object instantiated from a config dictionary.
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                    Examples:
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                    ```python
         
     | 
| 182 | 
         
            +
                    >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                    >>> # Download scheduler from huggingface.co and cache.
         
     | 
| 185 | 
         
            +
                    >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                    >>> # Instantiate DDIM scheduler class with same config as DDPM
         
     | 
| 188 | 
         
            +
                    >>> scheduler = DDIMScheduler.from_config(scheduler.config)
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
                    >>> # Instantiate PNDM scheduler class with same config as DDPM
         
     | 
| 191 | 
         
            +
                    >>> scheduler = PNDMScheduler.from_config(scheduler.config)
         
     | 
| 192 | 
         
            +
                    ```
         
     | 
| 193 | 
         
            +
                    """
         
     | 
| 194 | 
         
            +
                    # <===== TO BE REMOVED WITH DEPRECATION
         
     | 
| 195 | 
         
            +
                    # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
         
     | 
| 196 | 
         
            +
                    if "pretrained_model_name_or_path" in kwargs:
         
     | 
| 197 | 
         
            +
                        config = kwargs.pop("pretrained_model_name_or_path")
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                    if config is None:
         
     | 
| 200 | 
         
            +
                        raise ValueError("Please make sure to provide a config as the first positional argument.")
         
     | 
| 201 | 
         
            +
                    # ======>
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
                    if not isinstance(config, dict):
         
     | 
| 204 | 
         
            +
                        deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
         
     | 
| 205 | 
         
            +
                        if "Scheduler" in cls.__name__:
         
     | 
| 206 | 
         
            +
                            deprecation_message += (
         
     | 
| 207 | 
         
            +
                                f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
         
     | 
| 208 | 
         
            +
                                " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
         
     | 
| 209 | 
         
            +
                                " be removed in v1.0.0."
         
     | 
| 210 | 
         
            +
                            )
         
     | 
| 211 | 
         
            +
                        elif "Model" in cls.__name__:
         
     | 
| 212 | 
         
            +
                            deprecation_message += (
         
     | 
| 213 | 
         
            +
                                f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
         
     | 
| 214 | 
         
            +
                                f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
         
     | 
| 215 | 
         
            +
                                " instead. This functionality will be removed in v1.0.0."
         
     | 
| 216 | 
         
            +
                            )
         
     | 
| 217 | 
         
            +
                        deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
         
     | 
| 218 | 
         
            +
                        config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
                    init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                    # Allow dtype to be specified on initialization
         
     | 
| 223 | 
         
            +
                    if "dtype" in unused_kwargs:
         
     | 
| 224 | 
         
            +
                        init_dict["dtype"] = unused_kwargs.pop("dtype")
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
                    # add possible deprecated kwargs
         
     | 
| 227 | 
         
            +
                    for deprecated_kwarg in cls._deprecated_kwargs:
         
     | 
| 228 | 
         
            +
                        if deprecated_kwarg in unused_kwargs:
         
     | 
| 229 | 
         
            +
                            init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                    # Return model and optionally state and/or unused_kwargs
         
     | 
| 232 | 
         
            +
                    model = cls(**init_dict)
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
                    # make sure to also save config parameters that might be used for compatible classes
         
     | 
| 235 | 
         
            +
                    model.register_to_config(**hidden_dict)
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                    # add hidden kwargs of compatible classes to unused_kwargs
         
     | 
| 238 | 
         
            +
                    unused_kwargs = {**unused_kwargs, **hidden_dict}
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                    if return_unused_kwargs:
         
     | 
| 241 | 
         
            +
                        return (model, unused_kwargs)
         
     | 
| 242 | 
         
            +
                    else:
         
     | 
| 243 | 
         
            +
                        return model
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                @classmethod
         
     | 
| 246 | 
         
            +
                def get_config_dict(cls, *args, **kwargs):
         
     | 
| 247 | 
         
            +
                    deprecation_message = (
         
     | 
| 248 | 
         
            +
                        f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
         
     | 
| 249 | 
         
            +
                        " removed in version v1.0.0"
         
     | 
| 250 | 
         
            +
                    )
         
     | 
| 251 | 
         
            +
                    deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
         
     | 
| 252 | 
         
            +
                    return cls.load_config(*args, **kwargs)
         
     | 
| 253 | 
         
            +
             
     | 
| 254 | 
         
            +
                @classmethod
         
     | 
| 255 | 
         
            +
                def load_config(
         
     | 
| 256 | 
         
            +
                    cls,
         
     | 
| 257 | 
         
            +
                    pretrained_model_name_or_path: Union[str, os.PathLike],
         
     | 
| 258 | 
         
            +
                    return_unused_kwargs=False,
         
     | 
| 259 | 
         
            +
                    return_commit_hash=False,
         
     | 
| 260 | 
         
            +
                    **kwargs,
         
     | 
| 261 | 
         
            +
                ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
         
     | 
| 262 | 
         
            +
                    r"""
         
     | 
| 263 | 
         
            +
                    Load a model or scheduler configuration.
         
     | 
| 264 | 
         
            +
             
     | 
| 265 | 
         
            +
                    Parameters:
         
     | 
| 266 | 
         
            +
                        pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
         
     | 
| 267 | 
         
            +
                            Can be either:
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                                - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
         
     | 
| 270 | 
         
            +
                                  the Hub.
         
     | 
| 271 | 
         
            +
                                - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
         
     | 
| 272 | 
         
            +
                                  [`~ConfigMixin.save_config`].
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
                        cache_dir (`Union[str, os.PathLike]`, *optional*):
         
     | 
| 275 | 
         
            +
                            Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
         
     | 
| 276 | 
         
            +
                            is not used.
         
     | 
| 277 | 
         
            +
                        force_download (`bool`, *optional*, defaults to `False`):
         
     | 
| 278 | 
         
            +
                            Whether or not to force the (re-)download of the model weights and configuration files, overriding the
         
     | 
| 279 | 
         
            +
                            cached versions if they exist.
         
     | 
| 280 | 
         
            +
                        resume_download (`bool`, *optional*, defaults to `False`):
         
     | 
| 281 | 
         
            +
                            Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
         
     | 
| 282 | 
         
            +
                            incompletely downloaded files are deleted.
         
     | 
| 283 | 
         
            +
                        proxies (`Dict[str, str]`, *optional*):
         
     | 
| 284 | 
         
            +
                            A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
         
     | 
| 285 | 
         
            +
                            'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
         
     | 
| 286 | 
         
            +
                        output_loading_info(`bool`, *optional*, defaults to `False`):
         
     | 
| 287 | 
         
            +
                            Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
         
     | 
| 288 | 
         
            +
                        local_files_only (`bool`, *optional*, defaults to `False`):
         
     | 
| 289 | 
         
            +
                            Whether to only load local model weights and configuration files or not. If set to `True`, the model
         
     | 
| 290 | 
         
            +
                            won't be downloaded from the Hub.
         
     | 
| 291 | 
         
            +
                        use_auth_token (`str` or *bool*, *optional*):
         
     | 
| 292 | 
         
            +
                            The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
         
     | 
| 293 | 
         
            +
                            `diffusers-cli login` (stored in `~/.huggingface`) is used.
         
     | 
| 294 | 
         
            +
                        revision (`str`, *optional*, defaults to `"main"`):
         
     | 
| 295 | 
         
            +
                            The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
         
     | 
| 296 | 
         
            +
                            allowed by Git.
         
     | 
| 297 | 
         
            +
                        subfolder (`str`, *optional*, defaults to `""`):
         
     | 
| 298 | 
         
            +
                            The subfolder location of a model file within a larger model repository on the Hub or locally.
         
     | 
| 299 | 
         
            +
                        return_unused_kwargs (`bool`, *optional*, defaults to `False):
         
     | 
| 300 | 
         
            +
                            Whether unused keyword arguments of the config are returned.
         
     | 
| 301 | 
         
            +
                        return_commit_hash (`bool`, *optional*, defaults to `False):
         
     | 
| 302 | 
         
            +
                            Whether the `commit_hash` of the loaded configuration are returned.
         
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
            +
                    Returns:
         
     | 
| 305 | 
         
            +
                        `dict`:
         
     | 
| 306 | 
         
            +
                            A dictionary of all the parameters stored in a JSON configuration file.
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
                    """
         
     | 
| 309 | 
         
            +
                    cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
         
     | 
| 310 | 
         
            +
                    force_download = kwargs.pop("force_download", False)
         
     | 
| 311 | 
         
            +
                    resume_download = kwargs.pop("resume_download", False)
         
     | 
| 312 | 
         
            +
                    proxies = kwargs.pop("proxies", None)
         
     | 
| 313 | 
         
            +
                    use_auth_token = kwargs.pop("use_auth_token", None)
         
     | 
| 314 | 
         
            +
                    local_files_only = kwargs.pop("local_files_only", False)
         
     | 
| 315 | 
         
            +
                    revision = kwargs.pop("revision", None)
         
     | 
| 316 | 
         
            +
                    _ = kwargs.pop("mirror", None)
         
     | 
| 317 | 
         
            +
                    subfolder = kwargs.pop("subfolder", None)
         
     | 
| 318 | 
         
            +
                    user_agent = kwargs.pop("user_agent", {})
         
     | 
| 319 | 
         
            +
             
     | 
| 320 | 
         
            +
                    user_agent = {**user_agent, "file_type": "config"}
         
     | 
| 321 | 
         
            +
                    user_agent = http_user_agent(user_agent)
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
                    pretrained_model_name_or_path = str(pretrained_model_name_or_path)
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
                    if cls.config_name is None:
         
     | 
| 326 | 
         
            +
                        raise ValueError(
         
     | 
| 327 | 
         
            +
                            "`self.config_name` is not defined. Note that one should not load a config from "
         
     | 
| 328 | 
         
            +
                            "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
         
     | 
| 329 | 
         
            +
                        )
         
     | 
| 330 | 
         
            +
             
     | 
| 331 | 
         
            +
                    if os.path.isfile(pretrained_model_name_or_path):
         
     | 
| 332 | 
         
            +
                        config_file = pretrained_model_name_or_path
         
     | 
| 333 | 
         
            +
                    elif os.path.isdir(pretrained_model_name_or_path):
         
     | 
| 334 | 
         
            +
                        if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
         
     | 
| 335 | 
         
            +
                            # Load from a PyTorch checkpoint
         
     | 
| 336 | 
         
            +
                            config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
         
     | 
| 337 | 
         
            +
                        elif subfolder is not None and os.path.isfile(
         
     | 
| 338 | 
         
            +
                            os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
         
     | 
| 339 | 
         
            +
                        ):
         
     | 
| 340 | 
         
            +
                            config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
         
     | 
| 341 | 
         
            +
                        else:
         
     | 
| 342 | 
         
            +
                            raise EnvironmentError(
         
     | 
| 343 | 
         
            +
                                f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
         
     | 
| 344 | 
         
            +
                            )
         
     | 
| 345 | 
         
            +
                    else:
         
     | 
| 346 | 
         
            +
                        try:
         
     | 
| 347 | 
         
            +
                            # Load from URL or cache if already cached
         
     | 
| 348 | 
         
            +
                            config_file = hf_hub_download(
         
     | 
| 349 | 
         
            +
                                pretrained_model_name_or_path,
         
     | 
| 350 | 
         
            +
                                filename=cls.config_name,
         
     | 
| 351 | 
         
            +
                                cache_dir=cache_dir,
         
     | 
| 352 | 
         
            +
                                force_download=force_download,
         
     | 
| 353 | 
         
            +
                                proxies=proxies,
         
     | 
| 354 | 
         
            +
                                resume_download=resume_download,
         
     | 
| 355 | 
         
            +
                                local_files_only=local_files_only,
         
     | 
| 356 | 
         
            +
                                use_auth_token=use_auth_token,
         
     | 
| 357 | 
         
            +
                                user_agent=user_agent,
         
     | 
| 358 | 
         
            +
                                subfolder=subfolder,
         
     | 
| 359 | 
         
            +
                                revision=revision,
         
     | 
| 360 | 
         
            +
                            )
         
     | 
| 361 | 
         
            +
                        except RepositoryNotFoundError:
         
     | 
| 362 | 
         
            +
                            raise EnvironmentError(
         
     | 
| 363 | 
         
            +
                                f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
         
     | 
| 364 | 
         
            +
                                " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
         
     | 
| 365 | 
         
            +
                                " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
         
     | 
| 366 | 
         
            +
                                " login`."
         
     | 
| 367 | 
         
            +
                            )
         
     | 
| 368 | 
         
            +
                        except RevisionNotFoundError:
         
     | 
| 369 | 
         
            +
                            raise EnvironmentError(
         
     | 
| 370 | 
         
            +
                                f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
         
     | 
| 371 | 
         
            +
                                " this model name. Check the model page at"
         
     | 
| 372 | 
         
            +
                                f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
         
     | 
| 373 | 
         
            +
                            )
         
     | 
| 374 | 
         
            +
                        except EntryNotFoundError:
         
     | 
| 375 | 
         
            +
                            raise EnvironmentError(
         
     | 
| 376 | 
         
            +
                                f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
         
     | 
| 377 | 
         
            +
                            )
         
     | 
| 378 | 
         
            +
                        except HTTPError as err:
         
     | 
| 379 | 
         
            +
                            raise EnvironmentError(
         
     | 
| 380 | 
         
            +
                                "There was a specific connection error when trying to load"
         
     | 
| 381 | 
         
            +
                                f" {pretrained_model_name_or_path}:\n{err}"
         
     | 
| 382 | 
         
            +
                            )
         
     | 
| 383 | 
         
            +
                        except ValueError:
         
     | 
| 384 | 
         
            +
                            raise EnvironmentError(
         
     | 
| 385 | 
         
            +
                                f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
         
     | 
| 386 | 
         
            +
                                f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
         
     | 
| 387 | 
         
            +
                                f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
         
     | 
| 388 | 
         
            +
                                " run the library in offline mode at"
         
     | 
| 389 | 
         
            +
                                " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
         
     | 
| 390 | 
         
            +
                            )
         
     | 
| 391 | 
         
            +
                        except EnvironmentError:
         
     | 
| 392 | 
         
            +
                            raise EnvironmentError(
         
     | 
| 393 | 
         
            +
                                f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
         
     | 
| 394 | 
         
            +
                                "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
         
     | 
| 395 | 
         
            +
                                f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
         
     | 
| 396 | 
         
            +
                                f"containing a {cls.config_name} file"
         
     | 
| 397 | 
         
            +
                            )
         
     | 
| 398 | 
         
            +
             
     | 
| 399 | 
         
            +
                    try:
         
     | 
| 400 | 
         
            +
                        # Load config dict
         
     | 
| 401 | 
         
            +
                        config_dict = cls._dict_from_json_file(config_file)
         
     | 
| 402 | 
         
            +
             
     | 
| 403 | 
         
            +
                        commit_hash = extract_commit_hash(config_file)
         
     | 
| 404 | 
         
            +
                    except (json.JSONDecodeError, UnicodeDecodeError):
         
     | 
| 405 | 
         
            +
                        raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
         
     | 
| 406 | 
         
            +
             
     | 
| 407 | 
         
            +
                    if not (return_unused_kwargs or return_commit_hash):
         
     | 
| 408 | 
         
            +
                        return config_dict
         
     | 
| 409 | 
         
            +
             
     | 
| 410 | 
         
            +
                    outputs = (config_dict,)
         
     | 
| 411 | 
         
            +
             
     | 
| 412 | 
         
            +
                    if return_unused_kwargs:
         
     | 
| 413 | 
         
            +
                        outputs += (kwargs,)
         
     | 
| 414 | 
         
            +
             
     | 
| 415 | 
         
            +
                    if return_commit_hash:
         
     | 
| 416 | 
         
            +
                        outputs += (commit_hash,)
         
     | 
| 417 | 
         
            +
             
     | 
| 418 | 
         
            +
                    return outputs
         
     | 
| 419 | 
         
            +
             
     | 
| 420 | 
         
            +
                @staticmethod
         
     | 
| 421 | 
         
            +
                def _get_init_keys(cls):
         
     | 
| 422 | 
         
            +
                    return set(dict(inspect.signature(cls.__init__).parameters).keys())
         
     | 
| 423 | 
         
            +
             
     | 
| 424 | 
         
            +
                @classmethod
         
     | 
| 425 | 
         
            +
                def extract_init_dict(cls, config_dict, **kwargs):
         
     | 
| 426 | 
         
            +
                    # Skip keys that were not present in the original config, so default __init__ values were used
         
     | 
| 427 | 
         
            +
                    used_defaults = config_dict.get("_use_default_values", [])
         
     | 
| 428 | 
         
            +
                    config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
         
     | 
| 429 | 
         
            +
             
     | 
| 430 | 
         
            +
                    # 0. Copy origin config dict
         
     | 
| 431 | 
         
            +
                    original_dict = dict(config_dict.items())
         
     | 
| 432 | 
         
            +
             
     | 
| 433 | 
         
            +
                    # 1. Retrieve expected config attributes from __init__ signature
         
     | 
| 434 | 
         
            +
                    expected_keys = cls._get_init_keys(cls)
         
     | 
| 435 | 
         
            +
                    expected_keys.remove("self")
         
     | 
| 436 | 
         
            +
                    # remove general kwargs if present in dict
         
     | 
| 437 | 
         
            +
                    if "kwargs" in expected_keys:
         
     | 
| 438 | 
         
            +
                        expected_keys.remove("kwargs")
         
     | 
| 439 | 
         
            +
                    # remove flax internal keys
         
     | 
| 440 | 
         
            +
                    if hasattr(cls, "_flax_internal_args"):
         
     | 
| 441 | 
         
            +
                        for arg in cls._flax_internal_args:
         
     | 
| 442 | 
         
            +
                            expected_keys.remove(arg)
         
     | 
| 443 | 
         
            +
             
     | 
| 444 | 
         
            +
                    # 2. Remove attributes that cannot be expected from expected config attributes
         
     | 
| 445 | 
         
            +
                    # remove keys to be ignored
         
     | 
| 446 | 
         
            +
                    if len(cls.ignore_for_config) > 0:
         
     | 
| 447 | 
         
            +
                        expected_keys = expected_keys - set(cls.ignore_for_config)
         
     | 
| 448 | 
         
            +
             
     | 
| 449 | 
         
            +
                    # load diffusers library to import compatible and original scheduler
         
     | 
| 450 | 
         
            +
                    diffusers_library = importlib.import_module(__name__.split(".")[0])
         
     | 
| 451 | 
         
            +
             
     | 
| 452 | 
         
            +
                    if cls.has_compatibles:
         
     | 
| 453 | 
         
            +
                        compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
         
     | 
| 454 | 
         
            +
                    else:
         
     | 
| 455 | 
         
            +
                        compatible_classes = []
         
     | 
| 456 | 
         
            +
             
     | 
| 457 | 
         
            +
                    expected_keys_comp_cls = set()
         
     | 
| 458 | 
         
            +
                    for c in compatible_classes:
         
     | 
| 459 | 
         
            +
                        expected_keys_c = cls._get_init_keys(c)
         
     | 
| 460 | 
         
            +
                        expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
         
     | 
| 461 | 
         
            +
                    expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
         
     | 
| 462 | 
         
            +
                    config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
         
     | 
| 463 | 
         
            +
             
     | 
| 464 | 
         
            +
                    # remove attributes from orig class that cannot be expected
         
     | 
| 465 | 
         
            +
                    orig_cls_name = config_dict.pop("_class_name", cls.__name__)
         
     | 
| 466 | 
         
            +
                    if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
         
     | 
| 467 | 
         
            +
                        orig_cls = getattr(diffusers_library, orig_cls_name)
         
     | 
| 468 | 
         
            +
                        unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
         
     | 
| 469 | 
         
            +
                        config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
         
     | 
| 470 | 
         
            +
             
     | 
| 471 | 
         
            +
                    # remove private attributes
         
     | 
| 472 | 
         
            +
                    config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
         
     | 
| 473 | 
         
            +
             
     | 
| 474 | 
         
            +
                    # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
         
     | 
| 475 | 
         
            +
                    init_dict = {}
         
     | 
| 476 | 
         
            +
                    for key in expected_keys:
         
     | 
| 477 | 
         
            +
                        # if config param is passed to kwarg and is present in config dict
         
     | 
| 478 | 
         
            +
                        # it should overwrite existing config dict key
         
     | 
| 479 | 
         
            +
                        if key in kwargs and key in config_dict:
         
     | 
| 480 | 
         
            +
                            config_dict[key] = kwargs.pop(key)
         
     | 
| 481 | 
         
            +
             
     | 
| 482 | 
         
            +
                        if key in kwargs:
         
     | 
| 483 | 
         
            +
                            # overwrite key
         
     | 
| 484 | 
         
            +
                            init_dict[key] = kwargs.pop(key)
         
     | 
| 485 | 
         
            +
                        elif key in config_dict:
         
     | 
| 486 | 
         
            +
                            # use value from config dict
         
     | 
| 487 | 
         
            +
                            init_dict[key] = config_dict.pop(key)
         
     | 
| 488 | 
         
            +
             
     | 
| 489 | 
         
            +
                    # 4. Give nice warning if unexpected values have been passed
         
     | 
| 490 | 
         
            +
                    if len(config_dict) > 0:
         
     | 
| 491 | 
         
            +
                        logger.warning(
         
     | 
| 492 | 
         
            +
                            f"The config attributes {config_dict} were passed to {cls.__name__}, "
         
     | 
| 493 | 
         
            +
                            "but are not expected and will be ignored. Please verify your "
         
     | 
| 494 | 
         
            +
                            f"{cls.config_name} configuration file."
         
     | 
| 495 | 
         
            +
                        )
         
     | 
| 496 | 
         
            +
             
     | 
| 497 | 
         
            +
                    # 5. Give nice info if config attributes are initiliazed to default because they have not been passed
         
     | 
| 498 | 
         
            +
                    passed_keys = set(init_dict.keys())
         
     | 
| 499 | 
         
            +
                    if len(expected_keys - passed_keys) > 0:
         
     | 
| 500 | 
         
            +
                        logger.info(
         
     | 
| 501 | 
         
            +
                            f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
         
     | 
| 502 | 
         
            +
                        )
         
     | 
| 503 | 
         
            +
             
     | 
| 504 | 
         
            +
                    # 6. Define unused keyword arguments
         
     | 
| 505 | 
         
            +
                    unused_kwargs = {**config_dict, **kwargs}
         
     | 
| 506 | 
         
            +
             
     | 
| 507 | 
         
            +
                    # 7. Define "hidden" config parameters that were saved for compatible classes
         
     | 
| 508 | 
         
            +
                    hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
         
     | 
| 509 | 
         
            +
             
     | 
| 510 | 
         
            +
                    return init_dict, unused_kwargs, hidden_config_dict
         
     | 
| 511 | 
         
            +
             
     | 
| 512 | 
         
            +
                @classmethod
         
     | 
| 513 | 
         
            +
                def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
         
     | 
| 514 | 
         
            +
                    with open(json_file, "r", encoding="utf-8") as reader:
         
     | 
| 515 | 
         
            +
                        text = reader.read()
         
     | 
| 516 | 
         
            +
                    return json.loads(text)
         
     | 
| 517 | 
         
            +
             
     | 
| 518 | 
         
            +
                def __repr__(self):
         
     | 
| 519 | 
         
            +
                    return f"{self.__class__.__name__} {self.to_json_string()}"
         
     | 
| 520 | 
         
            +
             
     | 
| 521 | 
         
            +
                @property
         
     | 
| 522 | 
         
            +
                def config(self) -> Dict[str, Any]:
         
     | 
| 523 | 
         
            +
                    """
         
     | 
| 524 | 
         
            +
                    Returns the config of the class as a frozen dictionary
         
     | 
| 525 | 
         
            +
             
     | 
| 526 | 
         
            +
                    Returns:
         
     | 
| 527 | 
         
            +
                        `Dict[str, Any]`: Config of the class.
         
     | 
| 528 | 
         
            +
                    """
         
     | 
| 529 | 
         
            +
                    return self._internal_dict
         
     | 
| 530 | 
         
            +
             
     | 
| 531 | 
         
            +
                def to_json_string(self) -> str:
         
     | 
| 532 | 
         
            +
                    """
         
     | 
| 533 | 
         
            +
                    Serializes the configuration instance to a JSON string.
         
     | 
| 534 | 
         
            +
             
     | 
| 535 | 
         
            +
                    Returns:
         
     | 
| 536 | 
         
            +
                        `str`:
         
     | 
| 537 | 
         
            +
                            String containing all the attributes that make up the configuration instance in JSON format.
         
     | 
| 538 | 
         
            +
                    """
         
     | 
| 539 | 
         
            +
                    config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
         
     | 
| 540 | 
         
            +
                    config_dict["_class_name"] = self.__class__.__name__
         
     | 
| 541 | 
         
            +
                    config_dict["_diffusers_version"] = __version__
         
     | 
| 542 | 
         
            +
             
     | 
| 543 | 
         
            +
                    def to_json_saveable(value):
         
     | 
| 544 | 
         
            +
                        if isinstance(value, np.ndarray):
         
     | 
| 545 | 
         
            +
                            value = value.tolist()
         
     | 
| 546 | 
         
            +
                        elif isinstance(value, PosixPath):
         
     | 
| 547 | 
         
            +
                            value = str(value)
         
     | 
| 548 | 
         
            +
                        return value
         
     | 
| 549 | 
         
            +
             
     | 
| 550 | 
         
            +
                    config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
         
     | 
| 551 | 
         
            +
                    # Don't save "_ignore_files" or "_use_default_values"
         
     | 
| 552 | 
         
            +
                    config_dict.pop("_ignore_files", None)
         
     | 
| 553 | 
         
            +
                    config_dict.pop("_use_default_values", None)
         
     | 
| 554 | 
         
            +
             
     | 
| 555 | 
         
            +
                    return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
         
     | 
| 556 | 
         
            +
             
     | 
| 557 | 
         
            +
                def to_json_file(self, json_file_path: Union[str, os.PathLike]):
         
     | 
| 558 | 
         
            +
                    """
         
     | 
| 559 | 
         
            +
                    Save the configuration instance's parameters to a JSON file.
         
     | 
| 560 | 
         
            +
             
     | 
| 561 | 
         
            +
                    Args:
         
     | 
| 562 | 
         
            +
                        json_file_path (`str` or `os.PathLike`):
         
     | 
| 563 | 
         
            +
                            Path to the JSON file to save a configuration instance's parameters.
         
     | 
| 564 | 
         
            +
                    """
         
     | 
| 565 | 
         
            +
                    with open(json_file_path, "w", encoding="utf-8") as writer:
         
     | 
| 566 | 
         
            +
                        writer.write(self.to_json_string())
         
     | 
| 567 | 
         
            +
             
     | 
| 568 | 
         
            +
             
     | 
| 569 | 
         
            +
            def register_to_config(init):
         
     | 
| 570 | 
         
            +
                r"""
         
     | 
| 571 | 
         
            +
                Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
         
     | 
| 572 | 
         
            +
                automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
         
     | 
| 573 | 
         
            +
                shouldn't be registered in the config, use the `ignore_for_config` class variable
         
     | 
| 574 | 
         
            +
             
     | 
| 575 | 
         
            +
                Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
         
     | 
| 576 | 
         
            +
                """
         
     | 
| 577 | 
         
            +
             
     | 
| 578 | 
         
            +
                @functools.wraps(init)
         
     | 
| 579 | 
         
            +
                def inner_init(self, *args, **kwargs):
         
     | 
| 580 | 
         
            +
                    # Ignore private kwargs in the init.
         
     | 
| 581 | 
         
            +
                    init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
         
     | 
| 582 | 
         
            +
                    config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
         
     | 
| 583 | 
         
            +
                    if not isinstance(self, ConfigMixin):
         
     | 
| 584 | 
         
            +
                        raise RuntimeError(
         
     | 
| 585 | 
         
            +
                            f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
         
     | 
| 586 | 
         
            +
                            "not inherit from `ConfigMixin`."
         
     | 
| 587 | 
         
            +
                        )
         
     | 
| 588 | 
         
            +
             
     | 
| 589 | 
         
            +
                    ignore = getattr(self, "ignore_for_config", [])
         
     | 
| 590 | 
         
            +
                    # Get positional arguments aligned with kwargs
         
     | 
| 591 | 
         
            +
                    new_kwargs = {}
         
     | 
| 592 | 
         
            +
                    signature = inspect.signature(init)
         
     | 
| 593 | 
         
            +
                    parameters = {
         
     | 
| 594 | 
         
            +
                        name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
         
     | 
| 595 | 
         
            +
                    }
         
     | 
| 596 | 
         
            +
                    for arg, name in zip(args, parameters.keys()):
         
     | 
| 597 | 
         
            +
                        new_kwargs[name] = arg
         
     | 
| 598 | 
         
            +
             
     | 
| 599 | 
         
            +
                    # Then add all kwargs
         
     | 
| 600 | 
         
            +
                    new_kwargs.update(
         
     | 
| 601 | 
         
            +
                        {
         
     | 
| 602 | 
         
            +
                            k: init_kwargs.get(k, default)
         
     | 
| 603 | 
         
            +
                            for k, default in parameters.items()
         
     | 
| 604 | 
         
            +
                            if k not in ignore and k not in new_kwargs
         
     | 
| 605 | 
         
            +
                        }
         
     | 
| 606 | 
         
            +
                    )
         
     | 
| 607 | 
         
            +
             
     | 
| 608 | 
         
            +
                    # Take note of the parameters that were not present in the loaded config
         
     | 
| 609 | 
         
            +
                    if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
         
     | 
| 610 | 
         
            +
                        new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
         
     | 
| 611 | 
         
            +
             
     | 
| 612 | 
         
            +
                    new_kwargs = {**config_init_kwargs, **new_kwargs}
         
     | 
| 613 | 
         
            +
                    getattr(self, "register_to_config")(**new_kwargs)
         
     | 
| 614 | 
         
            +
                    init(self, *args, **init_kwargs)
         
     | 
| 615 | 
         
            +
             
     | 
| 616 | 
         
            +
                return inner_init
         
     | 
| 617 | 
         
            +
             
     | 
| 618 | 
         
            +
             
     | 
| 619 | 
         
            +
            def flax_register_to_config(cls):
         
     | 
| 620 | 
         
            +
                original_init = cls.__init__
         
     | 
| 621 | 
         
            +
             
     | 
| 622 | 
         
            +
                @functools.wraps(original_init)
         
     | 
| 623 | 
         
            +
                def init(self, *args, **kwargs):
         
     | 
| 624 | 
         
            +
                    if not isinstance(self, ConfigMixin):
         
     | 
| 625 | 
         
            +
                        raise RuntimeError(
         
     | 
| 626 | 
         
            +
                            f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
         
     | 
| 627 | 
         
            +
                            "not inherit from `ConfigMixin`."
         
     | 
| 628 | 
         
            +
                        )
         
     | 
| 629 | 
         
            +
             
     | 
| 630 | 
         
            +
                    # Ignore private kwargs in the init. Retrieve all passed attributes
         
     | 
| 631 | 
         
            +
                    init_kwargs = dict(kwargs.items())
         
     | 
| 632 | 
         
            +
             
     | 
| 633 | 
         
            +
                    # Retrieve default values
         
     | 
| 634 | 
         
            +
                    fields = dataclasses.fields(self)
         
     | 
| 635 | 
         
            +
                    default_kwargs = {}
         
     | 
| 636 | 
         
            +
                    for field in fields:
         
     | 
| 637 | 
         
            +
                        # ignore flax specific attributes
         
     | 
| 638 | 
         
            +
                        if field.name in self._flax_internal_args:
         
     | 
| 639 | 
         
            +
                            continue
         
     | 
| 640 | 
         
            +
                        if type(field.default) == dataclasses._MISSING_TYPE:
         
     | 
| 641 | 
         
            +
                            default_kwargs[field.name] = None
         
     | 
| 642 | 
         
            +
                        else:
         
     | 
| 643 | 
         
            +
                            default_kwargs[field.name] = getattr(self, field.name)
         
     | 
| 644 | 
         
            +
             
     | 
| 645 | 
         
            +
                    # Make sure init_kwargs override default kwargs
         
     | 
| 646 | 
         
            +
                    new_kwargs = {**default_kwargs, **init_kwargs}
         
     | 
| 647 | 
         
            +
                    # dtype should be part of `init_kwargs`, but not `new_kwargs`
         
     | 
| 648 | 
         
            +
                    if "dtype" in new_kwargs:
         
     | 
| 649 | 
         
            +
                        new_kwargs.pop("dtype")
         
     | 
| 650 | 
         
            +
             
     | 
| 651 | 
         
            +
                    # Get positional arguments aligned with kwargs
         
     | 
| 652 | 
         
            +
                    for i, arg in enumerate(args):
         
     | 
| 653 | 
         
            +
                        name = fields[i].name
         
     | 
| 654 | 
         
            +
                        new_kwargs[name] = arg
         
     | 
| 655 | 
         
            +
             
     | 
| 656 | 
         
            +
                    # Take note of the parameters that were not present in the loaded config
         
     | 
| 657 | 
         
            +
                    if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
         
     | 
| 658 | 
         
            +
                        new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
         
     | 
| 659 | 
         
            +
             
     | 
| 660 | 
         
            +
                    getattr(self, "register_to_config")(**new_kwargs)
         
     | 
| 661 | 
         
            +
                    original_init(self, *args, **kwargs)
         
     | 
| 662 | 
         
            +
             
     | 
| 663 | 
         
            +
                cls.__init__ = init
         
     | 
| 664 | 
         
            +
                return cls
         
     | 
    	
        6DoF/diffusers/dependency_versions_check.py
    ADDED
    
    | 
         @@ -0,0 +1,47 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            import sys
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            from .dependency_versions_table import deps
         
     | 
| 17 | 
         
            +
            from .utils.versions import require_version, require_version_core
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            # define which module versions we always want to check at run time
         
     | 
| 21 | 
         
            +
            # (usually the ones defined in `install_requires` in setup.py)
         
     | 
| 22 | 
         
            +
            #
         
     | 
| 23 | 
         
            +
            # order specific notes:
         
     | 
| 24 | 
         
            +
            # - tqdm must be checked before tokenizers
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
         
     | 
| 27 | 
         
            +
            if sys.version_info < (3, 7):
         
     | 
| 28 | 
         
            +
                pkgs_to_check_at_runtime.append("dataclasses")
         
     | 
| 29 | 
         
            +
            if sys.version_info < (3, 8):
         
     | 
| 30 | 
         
            +
                pkgs_to_check_at_runtime.append("importlib_metadata")
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            for pkg in pkgs_to_check_at_runtime:
         
     | 
| 33 | 
         
            +
                if pkg in deps:
         
     | 
| 34 | 
         
            +
                    if pkg == "tokenizers":
         
     | 
| 35 | 
         
            +
                        # must be loaded here, or else tqdm check may fail
         
     | 
| 36 | 
         
            +
                        from .utils import is_tokenizers_available
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                        if not is_tokenizers_available():
         
     | 
| 39 | 
         
            +
                            continue  # not required, check version only if installed
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                    require_version_core(deps[pkg])
         
     | 
| 42 | 
         
            +
                else:
         
     | 
| 43 | 
         
            +
                    raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            def dep_version_check(pkg, hint=None):
         
     | 
| 47 | 
         
            +
                require_version(deps[pkg], hint)
         
     | 
    	
        6DoF/diffusers/dependency_versions_table.py
    ADDED
    
    | 
         @@ -0,0 +1,44 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # THIS FILE HAS BEEN AUTOGENERATED. To update:
         
     | 
| 2 | 
         
            +
            # 1. modify the `_deps` dict in setup.py
         
     | 
| 3 | 
         
            +
            # 2. run `make deps_table_update``
         
     | 
| 4 | 
         
            +
            deps = {
         
     | 
| 5 | 
         
            +
                "Pillow": "Pillow",
         
     | 
| 6 | 
         
            +
                "accelerate": "accelerate>=0.11.0",
         
     | 
| 7 | 
         
            +
                "compel": "compel==0.1.8",
         
     | 
| 8 | 
         
            +
                "black": "black~=23.1",
         
     | 
| 9 | 
         
            +
                "datasets": "datasets",
         
     | 
| 10 | 
         
            +
                "filelock": "filelock",
         
     | 
| 11 | 
         
            +
                "flax": "flax>=0.4.1",
         
     | 
| 12 | 
         
            +
                "hf-doc-builder": "hf-doc-builder>=0.3.0",
         
     | 
| 13 | 
         
            +
                "huggingface-hub": "huggingface-hub>=0.13.2",
         
     | 
| 14 | 
         
            +
                "requests-mock": "requests-mock==1.10.0",
         
     | 
| 15 | 
         
            +
                "importlib_metadata": "importlib_metadata",
         
     | 
| 16 | 
         
            +
                "invisible-watermark": "invisible-watermark",
         
     | 
| 17 | 
         
            +
                "isort": "isort>=5.5.4",
         
     | 
| 18 | 
         
            +
                "jax": "jax>=0.2.8,!=0.3.2",
         
     | 
| 19 | 
         
            +
                "jaxlib": "jaxlib>=0.1.65",
         
     | 
| 20 | 
         
            +
                "Jinja2": "Jinja2",
         
     | 
| 21 | 
         
            +
                "k-diffusion": "k-diffusion>=0.0.12",
         
     | 
| 22 | 
         
            +
                "torchsde": "torchsde",
         
     | 
| 23 | 
         
            +
                "note_seq": "note_seq",
         
     | 
| 24 | 
         
            +
                "librosa": "librosa",
         
     | 
| 25 | 
         
            +
                "numpy": "numpy",
         
     | 
| 26 | 
         
            +
                "omegaconf": "omegaconf",
         
     | 
| 27 | 
         
            +
                "parameterized": "parameterized",
         
     | 
| 28 | 
         
            +
                "protobuf": "protobuf>=3.20.3,<4",
         
     | 
| 29 | 
         
            +
                "pytest": "pytest",
         
     | 
| 30 | 
         
            +
                "pytest-timeout": "pytest-timeout",
         
     | 
| 31 | 
         
            +
                "pytest-xdist": "pytest-xdist",
         
     | 
| 32 | 
         
            +
                "ruff": "ruff>=0.0.241",
         
     | 
| 33 | 
         
            +
                "safetensors": "safetensors",
         
     | 
| 34 | 
         
            +
                "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
         
     | 
| 35 | 
         
            +
                "scipy": "scipy",
         
     | 
| 36 | 
         
            +
                "onnx": "onnx",
         
     | 
| 37 | 
         
            +
                "regex": "regex!=2019.12.17",
         
     | 
| 38 | 
         
            +
                "requests": "requests",
         
     | 
| 39 | 
         
            +
                "tensorboard": "tensorboard",
         
     | 
| 40 | 
         
            +
                "torch": "torch>=1.4",
         
     | 
| 41 | 
         
            +
                "torchvision": "torchvision",
         
     | 
| 42 | 
         
            +
                "transformers": "transformers>=4.25.1",
         
     | 
| 43 | 
         
            +
                "urllib3": "urllib3<=2.0.0",
         
     | 
| 44 | 
         
            +
            }
         
     | 
    	
        6DoF/diffusers/experimental/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .rl import ValueGuidedRLPipeline
         
     | 
    	
        6DoF/diffusers/experimental/rl/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .value_guided_sampling import ValueGuidedRLPipeline
         
     | 
    	
        6DoF/diffusers/experimental/rl/value_guided_sampling.py
    ADDED
    
    | 
         @@ -0,0 +1,152 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            import numpy as np
         
     | 
| 16 | 
         
            +
            import torch
         
     | 
| 17 | 
         
            +
            import tqdm
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            from ...models.unet_1d import UNet1DModel
         
     | 
| 20 | 
         
            +
            from ...pipelines import DiffusionPipeline
         
     | 
| 21 | 
         
            +
            from ...utils import randn_tensor
         
     | 
| 22 | 
         
            +
            from ...utils.dummy_pt_objects import DDPMScheduler
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            class ValueGuidedRLPipeline(DiffusionPipeline):
         
     | 
| 26 | 
         
            +
                r"""
         
     | 
| 27 | 
         
            +
                This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
         
     | 
| 28 | 
         
            +
                library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
         
     | 
| 29 | 
         
            +
                Pipeline for sampling actions from a diffusion model trained to predict sequences of states.
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                Original implementation inspired by this repository: https://github.com/jannerm/diffuser.
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                Parameters:
         
     | 
| 34 | 
         
            +
                    value_function ([`UNet1DModel`]): A specialized UNet for fine-tuning trajectories base on reward.
         
     | 
| 35 | 
         
            +
                    unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded trajectories.
         
     | 
| 36 | 
         
            +
                    scheduler ([`SchedulerMixin`]):
         
     | 
| 37 | 
         
            +
                        A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this
         
     | 
| 38 | 
         
            +
                        application is [`DDPMScheduler`].
         
     | 
| 39 | 
         
            +
                    env: An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models.
         
     | 
| 40 | 
         
            +
                """
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                def __init__(
         
     | 
| 43 | 
         
            +
                    self,
         
     | 
| 44 | 
         
            +
                    value_function: UNet1DModel,
         
     | 
| 45 | 
         
            +
                    unet: UNet1DModel,
         
     | 
| 46 | 
         
            +
                    scheduler: DDPMScheduler,
         
     | 
| 47 | 
         
            +
                    env,
         
     | 
| 48 | 
         
            +
                ):
         
     | 
| 49 | 
         
            +
                    super().__init__()
         
     | 
| 50 | 
         
            +
                    self.value_function = value_function
         
     | 
| 51 | 
         
            +
                    self.unet = unet
         
     | 
| 52 | 
         
            +
                    self.scheduler = scheduler
         
     | 
| 53 | 
         
            +
                    self.env = env
         
     | 
| 54 | 
         
            +
                    self.data = env.get_dataset()
         
     | 
| 55 | 
         
            +
                    self.means = {}
         
     | 
| 56 | 
         
            +
                    for key in self.data.keys():
         
     | 
| 57 | 
         
            +
                        try:
         
     | 
| 58 | 
         
            +
                            self.means[key] = self.data[key].mean()
         
     | 
| 59 | 
         
            +
                        except:  # noqa: E722
         
     | 
| 60 | 
         
            +
                            pass
         
     | 
| 61 | 
         
            +
                    self.stds = {}
         
     | 
| 62 | 
         
            +
                    for key in self.data.keys():
         
     | 
| 63 | 
         
            +
                        try:
         
     | 
| 64 | 
         
            +
                            self.stds[key] = self.data[key].std()
         
     | 
| 65 | 
         
            +
                        except:  # noqa: E722
         
     | 
| 66 | 
         
            +
                            pass
         
     | 
| 67 | 
         
            +
                    self.state_dim = env.observation_space.shape[0]
         
     | 
| 68 | 
         
            +
                    self.action_dim = env.action_space.shape[0]
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                def normalize(self, x_in, key):
         
     | 
| 71 | 
         
            +
                    return (x_in - self.means[key]) / self.stds[key]
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                def de_normalize(self, x_in, key):
         
     | 
| 74 | 
         
            +
                    return x_in * self.stds[key] + self.means[key]
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                def to_torch(self, x_in):
         
     | 
| 77 | 
         
            +
                    if type(x_in) is dict:
         
     | 
| 78 | 
         
            +
                        return {k: self.to_torch(v) for k, v in x_in.items()}
         
     | 
| 79 | 
         
            +
                    elif torch.is_tensor(x_in):
         
     | 
| 80 | 
         
            +
                        return x_in.to(self.unet.device)
         
     | 
| 81 | 
         
            +
                    return torch.tensor(x_in, device=self.unet.device)
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                def reset_x0(self, x_in, cond, act_dim):
         
     | 
| 84 | 
         
            +
                    for key, val in cond.items():
         
     | 
| 85 | 
         
            +
                        x_in[:, key, act_dim:] = val.clone()
         
     | 
| 86 | 
         
            +
                    return x_in
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                def run_diffusion(self, x, conditions, n_guide_steps, scale):
         
     | 
| 89 | 
         
            +
                    batch_size = x.shape[0]
         
     | 
| 90 | 
         
            +
                    y = None
         
     | 
| 91 | 
         
            +
                    for i in tqdm.tqdm(self.scheduler.timesteps):
         
     | 
| 92 | 
         
            +
                        # create batch of timesteps to pass into model
         
     | 
| 93 | 
         
            +
                        timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
         
     | 
| 94 | 
         
            +
                        for _ in range(n_guide_steps):
         
     | 
| 95 | 
         
            +
                            with torch.enable_grad():
         
     | 
| 96 | 
         
            +
                                x.requires_grad_()
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                                # permute to match dimension for pre-trained models
         
     | 
| 99 | 
         
            +
                                y = self.value_function(x.permute(0, 2, 1), timesteps).sample
         
     | 
| 100 | 
         
            +
                                grad = torch.autograd.grad([y.sum()], [x])[0]
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                                posterior_variance = self.scheduler._get_variance(i)
         
     | 
| 103 | 
         
            +
                                model_std = torch.exp(0.5 * posterior_variance)
         
     | 
| 104 | 
         
            +
                                grad = model_std * grad
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                            grad[timesteps < 2] = 0
         
     | 
| 107 | 
         
            +
                            x = x.detach()
         
     | 
| 108 | 
         
            +
                            x = x + scale * grad
         
     | 
| 109 | 
         
            +
                            x = self.reset_x0(x, conditions, self.action_dim)
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                        prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                        # TODO: verify deprecation of this kwarg
         
     | 
| 114 | 
         
            +
                        x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                        # apply conditions to the trajectory (set the initial state)
         
     | 
| 117 | 
         
            +
                        x = self.reset_x0(x, conditions, self.action_dim)
         
     | 
| 118 | 
         
            +
                        x = self.to_torch(x)
         
     | 
| 119 | 
         
            +
                    return x, y
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
         
     | 
| 122 | 
         
            +
                    # normalize the observations and create  batch dimension
         
     | 
| 123 | 
         
            +
                    obs = self.normalize(obs, "observations")
         
     | 
| 124 | 
         
            +
                    obs = obs[None].repeat(batch_size, axis=0)
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                    conditions = {0: self.to_torch(obs)}
         
     | 
| 127 | 
         
            +
                    shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    # generate initial noise and apply our conditions (to make the trajectories start at current state)
         
     | 
| 130 | 
         
            +
                    x1 = randn_tensor(shape, device=self.unet.device)
         
     | 
| 131 | 
         
            +
                    x = self.reset_x0(x1, conditions, self.action_dim)
         
     | 
| 132 | 
         
            +
                    x = self.to_torch(x)
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                    # run the diffusion process
         
     | 
| 135 | 
         
            +
                    x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                    # sort output trajectories by value
         
     | 
| 138 | 
         
            +
                    sorted_idx = y.argsort(0, descending=True).squeeze()
         
     | 
| 139 | 
         
            +
                    sorted_values = x[sorted_idx]
         
     | 
| 140 | 
         
            +
                    actions = sorted_values[:, :, : self.action_dim]
         
     | 
| 141 | 
         
            +
                    actions = actions.detach().cpu().numpy()
         
     | 
| 142 | 
         
            +
                    denorm_actions = self.de_normalize(actions, key="actions")
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                    # select the action with the highest value
         
     | 
| 145 | 
         
            +
                    if y is not None:
         
     | 
| 146 | 
         
            +
                        selected_index = 0
         
     | 
| 147 | 
         
            +
                    else:
         
     | 
| 148 | 
         
            +
                        # if we didn't run value guiding, select a random action
         
     | 
| 149 | 
         
            +
                        selected_index = np.random.randint(0, batch_size)
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                    denorm_actions = denorm_actions[selected_index, 0]
         
     | 
| 152 | 
         
            +
                    return denorm_actions
         
     | 
    	
        6DoF/diffusers/image_processor.py
    ADDED
    
    | 
         @@ -0,0 +1,366 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            import warnings
         
     | 
| 16 | 
         
            +
            from typing import List, Optional, Union
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            import numpy as np
         
     | 
| 19 | 
         
            +
            import PIL
         
     | 
| 20 | 
         
            +
            import torch
         
     | 
| 21 | 
         
            +
            from PIL import Image
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            from .configuration_utils import ConfigMixin, register_to_config
         
     | 
| 24 | 
         
            +
            from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            class VaeImageProcessor(ConfigMixin):
         
     | 
| 28 | 
         
            +
                """
         
     | 
| 29 | 
         
            +
                Image processor for VAE.
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                Args:
         
     | 
| 32 | 
         
            +
                    do_resize (`bool`, *optional*, defaults to `True`):
         
     | 
| 33 | 
         
            +
                        Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
         
     | 
| 34 | 
         
            +
                        `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
         
     | 
| 35 | 
         
            +
                    vae_scale_factor (`int`, *optional*, defaults to `8`):
         
     | 
| 36 | 
         
            +
                        VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
         
     | 
| 37 | 
         
            +
                    resample (`str`, *optional*, defaults to `lanczos`):
         
     | 
| 38 | 
         
            +
                        Resampling filter to use when resizing the image.
         
     | 
| 39 | 
         
            +
                    do_normalize (`bool`, *optional*, defaults to `True`):
         
     | 
| 40 | 
         
            +
                        Whether to normalize the image to [-1,1].
         
     | 
| 41 | 
         
            +
                    do_convert_rgb (`bool`, *optional*, defaults to be `False`):
         
     | 
| 42 | 
         
            +
                        Whether to convert the images to RGB format.
         
     | 
| 43 | 
         
            +
                """
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                config_name = CONFIG_NAME
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                @register_to_config
         
     | 
| 48 | 
         
            +
                def __init__(
         
     | 
| 49 | 
         
            +
                    self,
         
     | 
| 50 | 
         
            +
                    do_resize: bool = True,
         
     | 
| 51 | 
         
            +
                    vae_scale_factor: int = 8,
         
     | 
| 52 | 
         
            +
                    resample: str = "lanczos",
         
     | 
| 53 | 
         
            +
                    do_normalize: bool = True,
         
     | 
| 54 | 
         
            +
                    do_convert_rgb: bool = False,
         
     | 
| 55 | 
         
            +
                ):
         
     | 
| 56 | 
         
            +
                    super().__init__()
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                @staticmethod
         
     | 
| 59 | 
         
            +
                def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
         
     | 
| 60 | 
         
            +
                    """
         
     | 
| 61 | 
         
            +
                    Convert a numpy image or a batch of images to a PIL image.
         
     | 
| 62 | 
         
            +
                    """
         
     | 
| 63 | 
         
            +
                    if images.ndim == 3:
         
     | 
| 64 | 
         
            +
                        images = images[None, ...]
         
     | 
| 65 | 
         
            +
                    images = (images * 255).round().astype("uint8")
         
     | 
| 66 | 
         
            +
                    if images.shape[-1] == 1:
         
     | 
| 67 | 
         
            +
                        # special case for grayscale (single channel) images
         
     | 
| 68 | 
         
            +
                        pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
         
     | 
| 69 | 
         
            +
                    else:
         
     | 
| 70 | 
         
            +
                        pil_images = [Image.fromarray(image) for image in images]
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                    return pil_images
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                @staticmethod
         
     | 
| 75 | 
         
            +
                def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
         
     | 
| 76 | 
         
            +
                    """
         
     | 
| 77 | 
         
            +
                    Convert a PIL image or a list of PIL images to NumPy arrays.
         
     | 
| 78 | 
         
            +
                    """
         
     | 
| 79 | 
         
            +
                    if not isinstance(images, list):
         
     | 
| 80 | 
         
            +
                        images = [images]
         
     | 
| 81 | 
         
            +
                    images = [np.array(image).astype(np.float32) / 255.0 for image in images]
         
     | 
| 82 | 
         
            +
                    images = np.stack(images, axis=0)
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                    return images
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                @staticmethod
         
     | 
| 87 | 
         
            +
                def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
         
     | 
| 88 | 
         
            +
                    """
         
     | 
| 89 | 
         
            +
                    Convert a NumPy image to a PyTorch tensor.
         
     | 
| 90 | 
         
            +
                    """
         
     | 
| 91 | 
         
            +
                    if images.ndim == 3:
         
     | 
| 92 | 
         
            +
                        images = images[..., None]
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    images = torch.from_numpy(images.transpose(0, 3, 1, 2))
         
     | 
| 95 | 
         
            +
                    return images
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                @staticmethod
         
     | 
| 98 | 
         
            +
                def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
         
     | 
| 99 | 
         
            +
                    """
         
     | 
| 100 | 
         
            +
                    Convert a PyTorch tensor to a NumPy image.
         
     | 
| 101 | 
         
            +
                    """
         
     | 
| 102 | 
         
            +
                    images = images.cpu().permute(0, 2, 3, 1).float().numpy()
         
     | 
| 103 | 
         
            +
                    return images
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                @staticmethod
         
     | 
| 106 | 
         
            +
                def normalize(images):
         
     | 
| 107 | 
         
            +
                    """
         
     | 
| 108 | 
         
            +
                    Normalize an image array to [-1,1].
         
     | 
| 109 | 
         
            +
                    """
         
     | 
| 110 | 
         
            +
                    return 2.0 * images - 1.0
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                @staticmethod
         
     | 
| 113 | 
         
            +
                def denormalize(images):
         
     | 
| 114 | 
         
            +
                    """
         
     | 
| 115 | 
         
            +
                    Denormalize an image array to [0,1].
         
     | 
| 116 | 
         
            +
                    """
         
     | 
| 117 | 
         
            +
                    return (images / 2 + 0.5).clamp(0, 1)
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                @staticmethod
         
     | 
| 120 | 
         
            +
                def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
         
     | 
| 121 | 
         
            +
                    """
         
     | 
| 122 | 
         
            +
                    Converts an image to RGB format.
         
     | 
| 123 | 
         
            +
                    """
         
     | 
| 124 | 
         
            +
                    image = image.convert("RGB")
         
     | 
| 125 | 
         
            +
                    return image
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                def resize(
         
     | 
| 128 | 
         
            +
                    self,
         
     | 
| 129 | 
         
            +
                    image: PIL.Image.Image,
         
     | 
| 130 | 
         
            +
                    height: Optional[int] = None,
         
     | 
| 131 | 
         
            +
                    width: Optional[int] = None,
         
     | 
| 132 | 
         
            +
                ) -> PIL.Image.Image:
         
     | 
| 133 | 
         
            +
                    """
         
     | 
| 134 | 
         
            +
                    Resize a PIL image. Both height and width are downscaled to the next integer multiple of `vae_scale_factor`.
         
     | 
| 135 | 
         
            +
                    """
         
     | 
| 136 | 
         
            +
                    if height is None:
         
     | 
| 137 | 
         
            +
                        height = image.height
         
     | 
| 138 | 
         
            +
                    if width is None:
         
     | 
| 139 | 
         
            +
                        width = image.width
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                    width, height = (
         
     | 
| 142 | 
         
            +
                        x - x % self.config.vae_scale_factor for x in (width, height)
         
     | 
| 143 | 
         
            +
                    )  # resize to integer multiple of vae_scale_factor
         
     | 
| 144 | 
         
            +
                    image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
         
     | 
| 145 | 
         
            +
                    return image
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
                def preprocess(
         
     | 
| 148 | 
         
            +
                    self,
         
     | 
| 149 | 
         
            +
                    image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
         
     | 
| 150 | 
         
            +
                    height: Optional[int] = None,
         
     | 
| 151 | 
         
            +
                    width: Optional[int] = None,
         
     | 
| 152 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 153 | 
         
            +
                    """
         
     | 
| 154 | 
         
            +
                    Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
         
     | 
| 155 | 
         
            +
                    """
         
     | 
| 156 | 
         
            +
                    supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
         
     | 
| 157 | 
         
            +
                    if isinstance(image, supported_formats):
         
     | 
| 158 | 
         
            +
                        image = [image]
         
     | 
| 159 | 
         
            +
                    elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
         
     | 
| 160 | 
         
            +
                        raise ValueError(
         
     | 
| 161 | 
         
            +
                            f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
         
     | 
| 162 | 
         
            +
                        )
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                    if isinstance(image[0], PIL.Image.Image):
         
     | 
| 165 | 
         
            +
                        if self.config.do_convert_rgb:
         
     | 
| 166 | 
         
            +
                            image = [self.convert_to_rgb(i) for i in image]
         
     | 
| 167 | 
         
            +
                        if self.config.do_resize:
         
     | 
| 168 | 
         
            +
                            image = [self.resize(i, height, width) for i in image]
         
     | 
| 169 | 
         
            +
                        image = self.pil_to_numpy(image)  # to np
         
     | 
| 170 | 
         
            +
                        image = self.numpy_to_pt(image)  # to pt
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                    elif isinstance(image[0], np.ndarray):
         
     | 
| 173 | 
         
            +
                        image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
         
     | 
| 174 | 
         
            +
                        image = self.numpy_to_pt(image)
         
     | 
| 175 | 
         
            +
                        _, _, height, width = image.shape
         
     | 
| 176 | 
         
            +
                        if self.config.do_resize and (
         
     | 
| 177 | 
         
            +
                            height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
         
     | 
| 178 | 
         
            +
                        ):
         
     | 
| 179 | 
         
            +
                            raise ValueError(
         
     | 
| 180 | 
         
            +
                                f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.config.vae_scale_factor}"
         
     | 
| 181 | 
         
            +
                                f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
         
     | 
| 182 | 
         
            +
                            )
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                    elif isinstance(image[0], torch.Tensor):
         
     | 
| 185 | 
         
            +
                        image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
         
     | 
| 186 | 
         
            +
                        _, channel, height, width = image.shape
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                        # don't need any preprocess if the image is latents
         
     | 
| 189 | 
         
            +
                        if channel == 4:
         
     | 
| 190 | 
         
            +
                            return image
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                        if self.config.do_resize and (
         
     | 
| 193 | 
         
            +
                            height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
         
     | 
| 194 | 
         
            +
                        ):
         
     | 
| 195 | 
         
            +
                            raise ValueError(
         
     | 
| 196 | 
         
            +
                                f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.config.vae_scale_factor}"
         
     | 
| 197 | 
         
            +
                                f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
         
     | 
| 198 | 
         
            +
                            )
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                    # expected range [0,1], normalize to [-1,1]
         
     | 
| 201 | 
         
            +
                    do_normalize = self.config.do_normalize
         
     | 
| 202 | 
         
            +
                    if image.min() < 0:
         
     | 
| 203 | 
         
            +
                        warnings.warn(
         
     | 
| 204 | 
         
            +
                            "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
         
     | 
| 205 | 
         
            +
                            f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
         
     | 
| 206 | 
         
            +
                            FutureWarning,
         
     | 
| 207 | 
         
            +
                        )
         
     | 
| 208 | 
         
            +
                        do_normalize = False
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
                    if do_normalize:
         
     | 
| 211 | 
         
            +
                        image = self.normalize(image)
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
                    return image
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
                def postprocess(
         
     | 
| 216 | 
         
            +
                    self,
         
     | 
| 217 | 
         
            +
                    image: torch.FloatTensor,
         
     | 
| 218 | 
         
            +
                    output_type: str = "pil",
         
     | 
| 219 | 
         
            +
                    do_denormalize: Optional[List[bool]] = None,
         
     | 
| 220 | 
         
            +
                ):
         
     | 
| 221 | 
         
            +
                    if not isinstance(image, torch.Tensor):
         
     | 
| 222 | 
         
            +
                        raise ValueError(
         
     | 
| 223 | 
         
            +
                            f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
         
     | 
| 224 | 
         
            +
                        )
         
     | 
| 225 | 
         
            +
                    if output_type not in ["latent", "pt", "np", "pil"]:
         
     | 
| 226 | 
         
            +
                        deprecation_message = (
         
     | 
| 227 | 
         
            +
                            f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
         
     | 
| 228 | 
         
            +
                            "`pil`, `np`, `pt`, `latent`"
         
     | 
| 229 | 
         
            +
                        )
         
     | 
| 230 | 
         
            +
                        deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
         
     | 
| 231 | 
         
            +
                        output_type = "np"
         
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
                    if output_type == "latent":
         
     | 
| 234 | 
         
            +
                        return image
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
                    if do_denormalize is None:
         
     | 
| 237 | 
         
            +
                        do_denormalize = [self.config.do_normalize] * image.shape[0]
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
                    image = torch.stack(
         
     | 
| 240 | 
         
            +
                        [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
         
     | 
| 241 | 
         
            +
                    )
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                    if output_type == "pt":
         
     | 
| 244 | 
         
            +
                        return image
         
     | 
| 245 | 
         
            +
             
     | 
| 246 | 
         
            +
                    image = self.pt_to_numpy(image)
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                    if output_type == "np":
         
     | 
| 249 | 
         
            +
                        return image
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
                    if output_type == "pil":
         
     | 
| 252 | 
         
            +
                        return self.numpy_to_pil(image)
         
     | 
| 253 | 
         
            +
             
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
            class VaeImageProcessorLDM3D(VaeImageProcessor):
         
     | 
| 256 | 
         
            +
                """
         
     | 
| 257 | 
         
            +
                Image processor for VAE LDM3D.
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
                Args:
         
     | 
| 260 | 
         
            +
                    do_resize (`bool`, *optional*, defaults to `True`):
         
     | 
| 261 | 
         
            +
                        Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
         
     | 
| 262 | 
         
            +
                    vae_scale_factor (`int`, *optional*, defaults to `8`):
         
     | 
| 263 | 
         
            +
                        VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
         
     | 
| 264 | 
         
            +
                    resample (`str`, *optional*, defaults to `lanczos`):
         
     | 
| 265 | 
         
            +
                        Resampling filter to use when resizing the image.
         
     | 
| 266 | 
         
            +
                    do_normalize (`bool`, *optional*, defaults to `True`):
         
     | 
| 267 | 
         
            +
                        Whether to normalize the image to [-1,1].
         
     | 
| 268 | 
         
            +
                """
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
                config_name = CONFIG_NAME
         
     | 
| 271 | 
         
            +
             
     | 
| 272 | 
         
            +
                @register_to_config
         
     | 
| 273 | 
         
            +
                def __init__(
         
     | 
| 274 | 
         
            +
                    self,
         
     | 
| 275 | 
         
            +
                    do_resize: bool = True,
         
     | 
| 276 | 
         
            +
                    vae_scale_factor: int = 8,
         
     | 
| 277 | 
         
            +
                    resample: str = "lanczos",
         
     | 
| 278 | 
         
            +
                    do_normalize: bool = True,
         
     | 
| 279 | 
         
            +
                ):
         
     | 
| 280 | 
         
            +
                    super().__init__()
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
                @staticmethod
         
     | 
| 283 | 
         
            +
                def numpy_to_pil(images):
         
     | 
| 284 | 
         
            +
                    """
         
     | 
| 285 | 
         
            +
                    Convert a NumPy image or a batch of images to a PIL image.
         
     | 
| 286 | 
         
            +
                    """
         
     | 
| 287 | 
         
            +
                    if images.ndim == 3:
         
     | 
| 288 | 
         
            +
                        images = images[None, ...]
         
     | 
| 289 | 
         
            +
                    images = (images * 255).round().astype("uint8")
         
     | 
| 290 | 
         
            +
                    if images.shape[-1] == 1:
         
     | 
| 291 | 
         
            +
                        # special case for grayscale (single channel) images
         
     | 
| 292 | 
         
            +
                        pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
         
     | 
| 293 | 
         
            +
                    else:
         
     | 
| 294 | 
         
            +
                        pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
                    return pil_images
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
                @staticmethod
         
     | 
| 299 | 
         
            +
                def rgblike_to_depthmap(image):
         
     | 
| 300 | 
         
            +
                    """
         
     | 
| 301 | 
         
            +
                    Args:
         
     | 
| 302 | 
         
            +
                        image: RGB-like depth image
         
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
            +
                    Returns: depth map
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
                    """
         
     | 
| 307 | 
         
            +
                    return image[:, :, 1] * 2**8 + image[:, :, 2]
         
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
                def numpy_to_depth(self, images):
         
     | 
| 310 | 
         
            +
                    """
         
     | 
| 311 | 
         
            +
                    Convert a NumPy depth image or a batch of images to a PIL image.
         
     | 
| 312 | 
         
            +
                    """
         
     | 
| 313 | 
         
            +
                    if images.ndim == 3:
         
     | 
| 314 | 
         
            +
                        images = images[None, ...]
         
     | 
| 315 | 
         
            +
                    images_depth = images[:, :, :, 3:]
         
     | 
| 316 | 
         
            +
                    if images.shape[-1] == 6:
         
     | 
| 317 | 
         
            +
                        images_depth = (images_depth * 255).round().astype("uint8")
         
     | 
| 318 | 
         
            +
                        pil_images = [
         
     | 
| 319 | 
         
            +
                            Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
         
     | 
| 320 | 
         
            +
                        ]
         
     | 
| 321 | 
         
            +
                    elif images.shape[-1] == 4:
         
     | 
| 322 | 
         
            +
                        images_depth = (images_depth * 65535.0).astype(np.uint16)
         
     | 
| 323 | 
         
            +
                        pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
         
     | 
| 324 | 
         
            +
                    else:
         
     | 
| 325 | 
         
            +
                        raise Exception("Not supported")
         
     | 
| 326 | 
         
            +
             
     | 
| 327 | 
         
            +
                    return pil_images
         
     | 
| 328 | 
         
            +
             
     | 
| 329 | 
         
            +
                def postprocess(
         
     | 
| 330 | 
         
            +
                    self,
         
     | 
| 331 | 
         
            +
                    image: torch.FloatTensor,
         
     | 
| 332 | 
         
            +
                    output_type: str = "pil",
         
     | 
| 333 | 
         
            +
                    do_denormalize: Optional[List[bool]] = None,
         
     | 
| 334 | 
         
            +
                ):
         
     | 
| 335 | 
         
            +
                    if not isinstance(image, torch.Tensor):
         
     | 
| 336 | 
         
            +
                        raise ValueError(
         
     | 
| 337 | 
         
            +
                            f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
         
     | 
| 338 | 
         
            +
                        )
         
     | 
| 339 | 
         
            +
                    if output_type not in ["latent", "pt", "np", "pil"]:
         
     | 
| 340 | 
         
            +
                        deprecation_message = (
         
     | 
| 341 | 
         
            +
                            f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
         
     | 
| 342 | 
         
            +
                            "`pil`, `np`, `pt`, `latent`"
         
     | 
| 343 | 
         
            +
                        )
         
     | 
| 344 | 
         
            +
                        deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
         
     | 
| 345 | 
         
            +
                        output_type = "np"
         
     | 
| 346 | 
         
            +
             
     | 
| 347 | 
         
            +
                    if do_denormalize is None:
         
     | 
| 348 | 
         
            +
                        do_denormalize = [self.config.do_normalize] * image.shape[0]
         
     | 
| 349 | 
         
            +
             
     | 
| 350 | 
         
            +
                    image = torch.stack(
         
     | 
| 351 | 
         
            +
                        [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
         
     | 
| 352 | 
         
            +
                    )
         
     | 
| 353 | 
         
            +
             
     | 
| 354 | 
         
            +
                    image = self.pt_to_numpy(image)
         
     | 
| 355 | 
         
            +
             
     | 
| 356 | 
         
            +
                    if output_type == "np":
         
     | 
| 357 | 
         
            +
                        if image.shape[-1] == 6:
         
     | 
| 358 | 
         
            +
                            image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
         
     | 
| 359 | 
         
            +
                        else:
         
     | 
| 360 | 
         
            +
                            image_depth = image[:, :, :, 3:]
         
     | 
| 361 | 
         
            +
                        return image[:, :, :, :3], image_depth
         
     | 
| 362 | 
         
            +
             
     | 
| 363 | 
         
            +
                    if output_type == "pil":
         
     | 
| 364 | 
         
            +
                        return self.numpy_to_pil(image), self.numpy_to_depth(image)
         
     | 
| 365 | 
         
            +
                    else:
         
     | 
| 366 | 
         
            +
                        raise Exception(f"This type {output_type} is not supported")
         
     | 
    	
        6DoF/diffusers/loaders.py
    ADDED
    
    | 
         @@ -0,0 +1,1492 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            import os
         
     | 
| 15 | 
         
            +
            import warnings
         
     | 
| 16 | 
         
            +
            from collections import defaultdict
         
     | 
| 17 | 
         
            +
            from pathlib import Path
         
     | 
| 18 | 
         
            +
            from typing import Callable, Dict, List, Optional, Union
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            import torch
         
     | 
| 21 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 22 | 
         
            +
            from huggingface_hub import hf_hub_download
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            from .models.attention_processor import (
         
     | 
| 25 | 
         
            +
                AttnAddedKVProcessor,
         
     | 
| 26 | 
         
            +
                AttnAddedKVProcessor2_0,
         
     | 
| 27 | 
         
            +
                CustomDiffusionAttnProcessor,
         
     | 
| 28 | 
         
            +
                CustomDiffusionXFormersAttnProcessor,
         
     | 
| 29 | 
         
            +
                LoRAAttnAddedKVProcessor,
         
     | 
| 30 | 
         
            +
                LoRAAttnProcessor,
         
     | 
| 31 | 
         
            +
                LoRAAttnProcessor2_0,
         
     | 
| 32 | 
         
            +
                LoRAXFormersAttnProcessor,
         
     | 
| 33 | 
         
            +
                SlicedAttnAddedKVProcessor,
         
     | 
| 34 | 
         
            +
                XFormersAttnProcessor,
         
     | 
| 35 | 
         
            +
            )
         
     | 
| 36 | 
         
            +
            from .utils import (
         
     | 
| 37 | 
         
            +
                DIFFUSERS_CACHE,
         
     | 
| 38 | 
         
            +
                HF_HUB_OFFLINE,
         
     | 
| 39 | 
         
            +
                TEXT_ENCODER_ATTN_MODULE,
         
     | 
| 40 | 
         
            +
                _get_model_file,
         
     | 
| 41 | 
         
            +
                deprecate,
         
     | 
| 42 | 
         
            +
                is_safetensors_available,
         
     | 
| 43 | 
         
            +
                is_transformers_available,
         
     | 
| 44 | 
         
            +
                logging,
         
     | 
| 45 | 
         
            +
            )
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            if is_safetensors_available():
         
     | 
| 49 | 
         
            +
                import safetensors
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            if is_transformers_available():
         
     | 
| 52 | 
         
            +
                from transformers import PreTrainedModel, PreTrainedTokenizer
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
            logger = logging.get_logger(__name__)
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            TEXT_ENCODER_NAME = "text_encoder"
         
     | 
| 58 | 
         
            +
            UNET_NAME = "unet"
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
            LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
         
     | 
| 61 | 
         
            +
            LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            TEXT_INVERSION_NAME = "learned_embeds.bin"
         
     | 
| 64 | 
         
            +
            TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
            CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
         
     | 
| 67 | 
         
            +
            CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
            class AttnProcsLayers(torch.nn.Module):
         
     | 
| 71 | 
         
            +
                def __init__(self, state_dict: Dict[str, torch.Tensor]):
         
     | 
| 72 | 
         
            +
                    super().__init__()
         
     | 
| 73 | 
         
            +
                    self.layers = torch.nn.ModuleList(state_dict.values())
         
     | 
| 74 | 
         
            +
                    self.mapping = dict(enumerate(state_dict.keys()))
         
     | 
| 75 | 
         
            +
                    self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                    # .processor for unet, .self_attn for text encoder
         
     | 
| 78 | 
         
            +
                    self.split_keys = [".processor", ".self_attn"]
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    # we add a hook to state_dict() and load_state_dict() so that the
         
     | 
| 81 | 
         
            +
                    # naming fits with `unet.attn_processors`
         
     | 
| 82 | 
         
            +
                    def map_to(module, state_dict, *args, **kwargs):
         
     | 
| 83 | 
         
            +
                        new_state_dict = {}
         
     | 
| 84 | 
         
            +
                        for key, value in state_dict.items():
         
     | 
| 85 | 
         
            +
                            num = int(key.split(".")[1])  # 0 is always "layers"
         
     | 
| 86 | 
         
            +
                            new_key = key.replace(f"layers.{num}", module.mapping[num])
         
     | 
| 87 | 
         
            +
                            new_state_dict[new_key] = value
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                        return new_state_dict
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                    def remap_key(key, state_dict):
         
     | 
| 92 | 
         
            +
                        for k in self.split_keys:
         
     | 
| 93 | 
         
            +
                            if k in key:
         
     | 
| 94 | 
         
            +
                                return key.split(k)[0] + k
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                        raise ValueError(
         
     | 
| 97 | 
         
            +
                            f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}."
         
     | 
| 98 | 
         
            +
                        )
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                    def map_from(module, state_dict, *args, **kwargs):
         
     | 
| 101 | 
         
            +
                        all_keys = list(state_dict.keys())
         
     | 
| 102 | 
         
            +
                        for key in all_keys:
         
     | 
| 103 | 
         
            +
                            replace_key = remap_key(key, state_dict)
         
     | 
| 104 | 
         
            +
                            new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
         
     | 
| 105 | 
         
            +
                            state_dict[new_key] = state_dict[key]
         
     | 
| 106 | 
         
            +
                            del state_dict[key]
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                    self._register_state_dict_hook(map_to)
         
     | 
| 109 | 
         
            +
                    self._register_load_state_dict_pre_hook(map_from, with_module=True)
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
            class UNet2DConditionLoadersMixin:
         
     | 
| 113 | 
         
            +
                text_encoder_name = TEXT_ENCODER_NAME
         
     | 
| 114 | 
         
            +
                unet_name = UNET_NAME
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
         
     | 
| 117 | 
         
            +
                    r"""
         
     | 
| 118 | 
         
            +
                    Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
         
     | 
| 119 | 
         
            +
                    defined in
         
     | 
| 120 | 
         
            +
                    [`cross_attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py)
         
     | 
| 121 | 
         
            +
                    and be a `torch.nn.Module` class.
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                    Parameters:
         
     | 
| 124 | 
         
            +
                        pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
         
     | 
| 125 | 
         
            +
                            Can be either:
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                                - A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
         
     | 
| 128 | 
         
            +
                                  the Hub.
         
     | 
| 129 | 
         
            +
                                - A path to a directory (for example `./my_model_directory`) containing the model weights saved
         
     | 
| 130 | 
         
            +
                                  with [`ModelMixin.save_pretrained`].
         
     | 
| 131 | 
         
            +
                                - A [torch state
         
     | 
| 132 | 
         
            +
                                  dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                        cache_dir (`Union[str, os.PathLike]`, *optional*):
         
     | 
| 135 | 
         
            +
                            Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
         
     | 
| 136 | 
         
            +
                            is not used.
         
     | 
| 137 | 
         
            +
                        force_download (`bool`, *optional*, defaults to `False`):
         
     | 
| 138 | 
         
            +
                            Whether or not to force the (re-)download of the model weights and configuration files, overriding the
         
     | 
| 139 | 
         
            +
                            cached versions if they exist.
         
     | 
| 140 | 
         
            +
                        resume_download (`bool`, *optional*, defaults to `False`):
         
     | 
| 141 | 
         
            +
                            Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
         
     | 
| 142 | 
         
            +
                            incompletely downloaded files are deleted.
         
     | 
| 143 | 
         
            +
                        proxies (`Dict[str, str]`, *optional*):
         
     | 
| 144 | 
         
            +
                            A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
         
     | 
| 145 | 
         
            +
                            'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
         
     | 
| 146 | 
         
            +
                        local_files_only (`bool`, *optional*, defaults to `False`):
         
     | 
| 147 | 
         
            +
                            Whether to only load local model weights and configuration files or not. If set to `True`, the model
         
     | 
| 148 | 
         
            +
                            won't be downloaded from the Hub.
         
     | 
| 149 | 
         
            +
                        use_auth_token (`str` or *bool*, *optional*):
         
     | 
| 150 | 
         
            +
                            The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
         
     | 
| 151 | 
         
            +
                            `diffusers-cli login` (stored in `~/.huggingface`) is used.
         
     | 
| 152 | 
         
            +
                        revision (`str`, *optional*, defaults to `"main"`):
         
     | 
| 153 | 
         
            +
                            The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
         
     | 
| 154 | 
         
            +
                            allowed by Git.
         
     | 
| 155 | 
         
            +
                        subfolder (`str`, *optional*, defaults to `""`):
         
     | 
| 156 | 
         
            +
                            The subfolder location of a model file within a larger model repository on the Hub or locally.
         
     | 
| 157 | 
         
            +
                        mirror (`str`, *optional*):
         
     | 
| 158 | 
         
            +
                            Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
         
     | 
| 159 | 
         
            +
                            guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
         
     | 
| 160 | 
         
            +
                            information.
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                    """
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                    cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
         
     | 
| 165 | 
         
            +
                    force_download = kwargs.pop("force_download", False)
         
     | 
| 166 | 
         
            +
                    resume_download = kwargs.pop("resume_download", False)
         
     | 
| 167 | 
         
            +
                    proxies = kwargs.pop("proxies", None)
         
     | 
| 168 | 
         
            +
                    local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
         
     | 
| 169 | 
         
            +
                    use_auth_token = kwargs.pop("use_auth_token", None)
         
     | 
| 170 | 
         
            +
                    revision = kwargs.pop("revision", None)
         
     | 
| 171 | 
         
            +
                    subfolder = kwargs.pop("subfolder", None)
         
     | 
| 172 | 
         
            +
                    weight_name = kwargs.pop("weight_name", None)
         
     | 
| 173 | 
         
            +
                    use_safetensors = kwargs.pop("use_safetensors", None)
         
     | 
| 174 | 
         
            +
                    # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
         
     | 
| 175 | 
         
            +
                    # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
         
     | 
| 176 | 
         
            +
                    network_alpha = kwargs.pop("network_alpha", None)
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                    if use_safetensors and not is_safetensors_available():
         
     | 
| 179 | 
         
            +
                        raise ValueError(
         
     | 
| 180 | 
         
            +
                            "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
         
     | 
| 181 | 
         
            +
                        )
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                    allow_pickle = False
         
     | 
| 184 | 
         
            +
                    if use_safetensors is None:
         
     | 
| 185 | 
         
            +
                        use_safetensors = is_safetensors_available()
         
     | 
| 186 | 
         
            +
                        allow_pickle = True
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                    user_agent = {
         
     | 
| 189 | 
         
            +
                        "file_type": "attn_procs_weights",
         
     | 
| 190 | 
         
            +
                        "framework": "pytorch",
         
     | 
| 191 | 
         
            +
                    }
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                    model_file = None
         
     | 
| 194 | 
         
            +
                    if not isinstance(pretrained_model_name_or_path_or_dict, dict):
         
     | 
| 195 | 
         
            +
                        # Let's first try to load .safetensors weights
         
     | 
| 196 | 
         
            +
                        if (use_safetensors and weight_name is None) or (
         
     | 
| 197 | 
         
            +
                            weight_name is not None and weight_name.endswith(".safetensors")
         
     | 
| 198 | 
         
            +
                        ):
         
     | 
| 199 | 
         
            +
                            try:
         
     | 
| 200 | 
         
            +
                                model_file = _get_model_file(
         
     | 
| 201 | 
         
            +
                                    pretrained_model_name_or_path_or_dict,
         
     | 
| 202 | 
         
            +
                                    weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
         
     | 
| 203 | 
         
            +
                                    cache_dir=cache_dir,
         
     | 
| 204 | 
         
            +
                                    force_download=force_download,
         
     | 
| 205 | 
         
            +
                                    resume_download=resume_download,
         
     | 
| 206 | 
         
            +
                                    proxies=proxies,
         
     | 
| 207 | 
         
            +
                                    local_files_only=local_files_only,
         
     | 
| 208 | 
         
            +
                                    use_auth_token=use_auth_token,
         
     | 
| 209 | 
         
            +
                                    revision=revision,
         
     | 
| 210 | 
         
            +
                                    subfolder=subfolder,
         
     | 
| 211 | 
         
            +
                                    user_agent=user_agent,
         
     | 
| 212 | 
         
            +
                                )
         
     | 
| 213 | 
         
            +
                                state_dict = safetensors.torch.load_file(model_file, device="cpu")
         
     | 
| 214 | 
         
            +
                            except IOError as e:
         
     | 
| 215 | 
         
            +
                                if not allow_pickle:
         
     | 
| 216 | 
         
            +
                                    raise e
         
     | 
| 217 | 
         
            +
                                # try loading non-safetensors weights
         
     | 
| 218 | 
         
            +
                                pass
         
     | 
| 219 | 
         
            +
                        if model_file is None:
         
     | 
| 220 | 
         
            +
                            model_file = _get_model_file(
         
     | 
| 221 | 
         
            +
                                pretrained_model_name_or_path_or_dict,
         
     | 
| 222 | 
         
            +
                                weights_name=weight_name or LORA_WEIGHT_NAME,
         
     | 
| 223 | 
         
            +
                                cache_dir=cache_dir,
         
     | 
| 224 | 
         
            +
                                force_download=force_download,
         
     | 
| 225 | 
         
            +
                                resume_download=resume_download,
         
     | 
| 226 | 
         
            +
                                proxies=proxies,
         
     | 
| 227 | 
         
            +
                                local_files_only=local_files_only,
         
     | 
| 228 | 
         
            +
                                use_auth_token=use_auth_token,
         
     | 
| 229 | 
         
            +
                                revision=revision,
         
     | 
| 230 | 
         
            +
                                subfolder=subfolder,
         
     | 
| 231 | 
         
            +
                                user_agent=user_agent,
         
     | 
| 232 | 
         
            +
                            )
         
     | 
| 233 | 
         
            +
                            state_dict = torch.load(model_file, map_location="cpu")
         
     | 
| 234 | 
         
            +
                    else:
         
     | 
| 235 | 
         
            +
                        state_dict = pretrained_model_name_or_path_or_dict
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                    # fill attn processors
         
     | 
| 238 | 
         
            +
                    attn_processors = {}
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                    is_lora = all("lora" in k for k in state_dict.keys())
         
     | 
| 241 | 
         
            +
                    is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                    if is_lora:
         
     | 
| 244 | 
         
            +
                        is_new_lora_format = all(
         
     | 
| 245 | 
         
            +
                            key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
         
     | 
| 246 | 
         
            +
                        )
         
     | 
| 247 | 
         
            +
                        if is_new_lora_format:
         
     | 
| 248 | 
         
            +
                            # Strip the `"unet"` prefix.
         
     | 
| 249 | 
         
            +
                            is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
         
     | 
| 250 | 
         
            +
                            if is_text_encoder_present:
         
     | 
| 251 | 
         
            +
                                warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
         
     | 
| 252 | 
         
            +
                                warnings.warn(warn_message)
         
     | 
| 253 | 
         
            +
                            unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
         
     | 
| 254 | 
         
            +
                            state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                        lora_grouped_dict = defaultdict(dict)
         
     | 
| 257 | 
         
            +
                        for key, value in state_dict.items():
         
     | 
| 258 | 
         
            +
                            attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
         
     | 
| 259 | 
         
            +
                            lora_grouped_dict[attn_processor_key][sub_key] = value
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
                        for key, value_dict in lora_grouped_dict.items():
         
     | 
| 262 | 
         
            +
                            rank = value_dict["to_k_lora.down.weight"].shape[0]
         
     | 
| 263 | 
         
            +
                            hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
         
     | 
| 264 | 
         
            +
             
     | 
| 265 | 
         
            +
                            attn_processor = self
         
     | 
| 266 | 
         
            +
                            for sub_key in key.split("."):
         
     | 
| 267 | 
         
            +
                                attn_processor = getattr(attn_processor, sub_key)
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                            if isinstance(
         
     | 
| 270 | 
         
            +
                                attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
         
     | 
| 271 | 
         
            +
                            ):
         
     | 
| 272 | 
         
            +
                                cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1]
         
     | 
| 273 | 
         
            +
                                attn_processor_class = LoRAAttnAddedKVProcessor
         
     | 
| 274 | 
         
            +
                            else:
         
     | 
| 275 | 
         
            +
                                cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
         
     | 
| 276 | 
         
            +
                                if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
         
     | 
| 277 | 
         
            +
                                    attn_processor_class = LoRAXFormersAttnProcessor
         
     | 
| 278 | 
         
            +
                                else:
         
     | 
| 279 | 
         
            +
                                    attn_processor_class = (
         
     | 
| 280 | 
         
            +
                                        LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
         
     | 
| 281 | 
         
            +
                                    )
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
                            attn_processors[key] = attn_processor_class(
         
     | 
| 284 | 
         
            +
                                hidden_size=hidden_size,
         
     | 
| 285 | 
         
            +
                                cross_attention_dim=cross_attention_dim,
         
     | 
| 286 | 
         
            +
                                rank=rank,
         
     | 
| 287 | 
         
            +
                                network_alpha=network_alpha,
         
     | 
| 288 | 
         
            +
                            )
         
     | 
| 289 | 
         
            +
                            attn_processors[key].load_state_dict(value_dict)
         
     | 
| 290 | 
         
            +
                    elif is_custom_diffusion:
         
     | 
| 291 | 
         
            +
                        custom_diffusion_grouped_dict = defaultdict(dict)
         
     | 
| 292 | 
         
            +
                        for key, value in state_dict.items():
         
     | 
| 293 | 
         
            +
                            if len(value) == 0:
         
     | 
| 294 | 
         
            +
                                custom_diffusion_grouped_dict[key] = {}
         
     | 
| 295 | 
         
            +
                            else:
         
     | 
| 296 | 
         
            +
                                if "to_out" in key:
         
     | 
| 297 | 
         
            +
                                    attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
         
     | 
| 298 | 
         
            +
                                else:
         
     | 
| 299 | 
         
            +
                                    attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
         
     | 
| 300 | 
         
            +
                                custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
         
     | 
| 301 | 
         
            +
             
     | 
| 302 | 
         
            +
                        for key, value_dict in custom_diffusion_grouped_dict.items():
         
     | 
| 303 | 
         
            +
                            if len(value_dict) == 0:
         
     | 
| 304 | 
         
            +
                                attn_processors[key] = CustomDiffusionAttnProcessor(
         
     | 
| 305 | 
         
            +
                                    train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
         
     | 
| 306 | 
         
            +
                                )
         
     | 
| 307 | 
         
            +
                            else:
         
     | 
| 308 | 
         
            +
                                cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
         
     | 
| 309 | 
         
            +
                                hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
         
     | 
| 310 | 
         
            +
                                train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
         
     | 
| 311 | 
         
            +
                                attn_processors[key] = CustomDiffusionAttnProcessor(
         
     | 
| 312 | 
         
            +
                                    train_kv=True,
         
     | 
| 313 | 
         
            +
                                    train_q_out=train_q_out,
         
     | 
| 314 | 
         
            +
                                    hidden_size=hidden_size,
         
     | 
| 315 | 
         
            +
                                    cross_attention_dim=cross_attention_dim,
         
     | 
| 316 | 
         
            +
                                )
         
     | 
| 317 | 
         
            +
                                attn_processors[key].load_state_dict(value_dict)
         
     | 
| 318 | 
         
            +
                    else:
         
     | 
| 319 | 
         
            +
                        raise ValueError(
         
     | 
| 320 | 
         
            +
                            f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
         
     | 
| 321 | 
         
            +
                        )
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
                    # set correct dtype & device
         
     | 
| 324 | 
         
            +
                    attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
         
     | 
| 325 | 
         
            +
             
     | 
| 326 | 
         
            +
                    # set layers
         
     | 
| 327 | 
         
            +
                    self.set_attn_processor(attn_processors)
         
     | 
| 328 | 
         
            +
             
     | 
| 329 | 
         
            +
                def save_attn_procs(
         
     | 
| 330 | 
         
            +
                    self,
         
     | 
| 331 | 
         
            +
                    save_directory: Union[str, os.PathLike],
         
     | 
| 332 | 
         
            +
                    is_main_process: bool = True,
         
     | 
| 333 | 
         
            +
                    weight_name: str = None,
         
     | 
| 334 | 
         
            +
                    save_function: Callable = None,
         
     | 
| 335 | 
         
            +
                    safe_serialization: bool = False,
         
     | 
| 336 | 
         
            +
                    **kwargs,
         
     | 
| 337 | 
         
            +
                ):
         
     | 
| 338 | 
         
            +
                    r"""
         
     | 
| 339 | 
         
            +
                    Save an attention processor to a directory so that it can be reloaded using the
         
     | 
| 340 | 
         
            +
                    [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method.
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
                    Arguments:
         
     | 
| 343 | 
         
            +
                        save_directory (`str` or `os.PathLike`):
         
     | 
| 344 | 
         
            +
                            Directory to save an attention processor to. Will be created if it doesn't exist.
         
     | 
| 345 | 
         
            +
                        is_main_process (`bool`, *optional*, defaults to `True`):
         
     | 
| 346 | 
         
            +
                            Whether the process calling this is the main process or not. Useful during distributed training and you
         
     | 
| 347 | 
         
            +
                            need to call this function on all processes. In this case, set `is_main_process=True` only on the main
         
     | 
| 348 | 
         
            +
                            process to avoid race conditions.
         
     | 
| 349 | 
         
            +
                        save_function (`Callable`):
         
     | 
| 350 | 
         
            +
                            The function to use to save the state dictionary. Useful during distributed training when you need to
         
     | 
| 351 | 
         
            +
                            replace `torch.save` with another method. Can be configured with the environment variable
         
     | 
| 352 | 
         
            +
                            `DIFFUSERS_SAVE_MODE`.
         
     | 
| 353 | 
         
            +
             
     | 
| 354 | 
         
            +
                    """
         
     | 
| 355 | 
         
            +
                    weight_name = weight_name or deprecate(
         
     | 
| 356 | 
         
            +
                        "weights_name",
         
     | 
| 357 | 
         
            +
                        "0.20.0",
         
     | 
| 358 | 
         
            +
                        "`weights_name` is deprecated, please use `weight_name` instead.",
         
     | 
| 359 | 
         
            +
                        take_from=kwargs,
         
     | 
| 360 | 
         
            +
                    )
         
     | 
| 361 | 
         
            +
                    if os.path.isfile(save_directory):
         
     | 
| 362 | 
         
            +
                        logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
         
     | 
| 363 | 
         
            +
                        return
         
     | 
| 364 | 
         
            +
             
     | 
| 365 | 
         
            +
                    if save_function is None:
         
     | 
| 366 | 
         
            +
                        if safe_serialization:
         
     | 
| 367 | 
         
            +
             
     | 
| 368 | 
         
            +
                            def save_function(weights, filename):
         
     | 
| 369 | 
         
            +
                                return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
         
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
                        else:
         
     | 
| 372 | 
         
            +
                            save_function = torch.save
         
     | 
| 373 | 
         
            +
             
     | 
| 374 | 
         
            +
                    os.makedirs(save_directory, exist_ok=True)
         
     | 
| 375 | 
         
            +
             
     | 
| 376 | 
         
            +
                    is_custom_diffusion = any(
         
     | 
| 377 | 
         
            +
                        isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
         
     | 
| 378 | 
         
            +
                        for (_, x) in self.attn_processors.items()
         
     | 
| 379 | 
         
            +
                    )
         
     | 
| 380 | 
         
            +
                    if is_custom_diffusion:
         
     | 
| 381 | 
         
            +
                        model_to_save = AttnProcsLayers(
         
     | 
| 382 | 
         
            +
                            {
         
     | 
| 383 | 
         
            +
                                y: x
         
     | 
| 384 | 
         
            +
                                for (y, x) in self.attn_processors.items()
         
     | 
| 385 | 
         
            +
                                if isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
         
     | 
| 386 | 
         
            +
                            }
         
     | 
| 387 | 
         
            +
                        )
         
     | 
| 388 | 
         
            +
                        state_dict = model_to_save.state_dict()
         
     | 
| 389 | 
         
            +
                        for name, attn in self.attn_processors.items():
         
     | 
| 390 | 
         
            +
                            if len(attn.state_dict()) == 0:
         
     | 
| 391 | 
         
            +
                                state_dict[name] = {}
         
     | 
| 392 | 
         
            +
                    else:
         
     | 
| 393 | 
         
            +
                        model_to_save = AttnProcsLayers(self.attn_processors)
         
     | 
| 394 | 
         
            +
                        state_dict = model_to_save.state_dict()
         
     | 
| 395 | 
         
            +
             
     | 
| 396 | 
         
            +
                    if weight_name is None:
         
     | 
| 397 | 
         
            +
                        if safe_serialization:
         
     | 
| 398 | 
         
            +
                            weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
         
     | 
| 399 | 
         
            +
                        else:
         
     | 
| 400 | 
         
            +
                            weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
         
     | 
| 401 | 
         
            +
             
     | 
| 402 | 
         
            +
                    # Save the model
         
     | 
| 403 | 
         
            +
                    save_function(state_dict, os.path.join(save_directory, weight_name))
         
     | 
| 404 | 
         
            +
                    logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
         
     | 
| 405 | 
         
            +
             
     | 
| 406 | 
         
            +
             
     | 
| 407 | 
         
            +
            class TextualInversionLoaderMixin:
         
     | 
| 408 | 
         
            +
                r"""
         
     | 
| 409 | 
         
            +
                Load textual inversion tokens and embeddings to the tokenizer and text encoder.
         
     | 
| 410 | 
         
            +
                """
         
     | 
| 411 | 
         
            +
             
     | 
| 412 | 
         
            +
                def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"):
         
     | 
| 413 | 
         
            +
                    r"""
         
     | 
| 414 | 
         
            +
                    Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
         
     | 
| 415 | 
         
            +
                    be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
         
     | 
| 416 | 
         
            +
                    inversion token or if the textual inversion token is a single vector, the input prompt is returned.
         
     | 
| 417 | 
         
            +
             
     | 
| 418 | 
         
            +
                    Parameters:
         
     | 
| 419 | 
         
            +
                        prompt (`str` or list of `str`):
         
     | 
| 420 | 
         
            +
                            The prompt or prompts to guide the image generation.
         
     | 
| 421 | 
         
            +
                        tokenizer (`PreTrainedTokenizer`):
         
     | 
| 422 | 
         
            +
                            The tokenizer responsible for encoding the prompt into input tokens.
         
     | 
| 423 | 
         
            +
             
     | 
| 424 | 
         
            +
                    Returns:
         
     | 
| 425 | 
         
            +
                        `str` or list of `str`: The converted prompt
         
     | 
| 426 | 
         
            +
                    """
         
     | 
| 427 | 
         
            +
                    if not isinstance(prompt, List):
         
     | 
| 428 | 
         
            +
                        prompts = [prompt]
         
     | 
| 429 | 
         
            +
                    else:
         
     | 
| 430 | 
         
            +
                        prompts = prompt
         
     | 
| 431 | 
         
            +
             
     | 
| 432 | 
         
            +
                    prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
         
     | 
| 433 | 
         
            +
             
     | 
| 434 | 
         
            +
                    if not isinstance(prompt, List):
         
     | 
| 435 | 
         
            +
                        return prompts[0]
         
     | 
| 436 | 
         
            +
             
     | 
| 437 | 
         
            +
                    return prompts
         
     | 
| 438 | 
         
            +
             
     | 
| 439 | 
         
            +
                def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):
         
     | 
| 440 | 
         
            +
                    r"""
         
     | 
| 441 | 
         
            +
                    Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
         
     | 
| 442 | 
         
            +
                    to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
         
     | 
| 443 | 
         
            +
                    is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
         
     | 
| 444 | 
         
            +
                    inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
         
     | 
| 445 | 
         
            +
             
     | 
| 446 | 
         
            +
                    Parameters:
         
     | 
| 447 | 
         
            +
                        prompt (`str`):
         
     | 
| 448 | 
         
            +
                            The prompt to guide the image generation.
         
     | 
| 449 | 
         
            +
                        tokenizer (`PreTrainedTokenizer`):
         
     | 
| 450 | 
         
            +
                            The tokenizer responsible for encoding the prompt into input tokens.
         
     | 
| 451 | 
         
            +
             
     | 
| 452 | 
         
            +
                    Returns:
         
     | 
| 453 | 
         
            +
                        `str`: The converted prompt
         
     | 
| 454 | 
         
            +
                    """
         
     | 
| 455 | 
         
            +
                    tokens = tokenizer.tokenize(prompt)
         
     | 
| 456 | 
         
            +
                    unique_tokens = set(tokens)
         
     | 
| 457 | 
         
            +
                    for token in unique_tokens:
         
     | 
| 458 | 
         
            +
                        if token in tokenizer.added_tokens_encoder:
         
     | 
| 459 | 
         
            +
                            replacement = token
         
     | 
| 460 | 
         
            +
                            i = 1
         
     | 
| 461 | 
         
            +
                            while f"{token}_{i}" in tokenizer.added_tokens_encoder:
         
     | 
| 462 | 
         
            +
                                replacement += f" {token}_{i}"
         
     | 
| 463 | 
         
            +
                                i += 1
         
     | 
| 464 | 
         
            +
             
     | 
| 465 | 
         
            +
                            prompt = prompt.replace(token, replacement)
         
     | 
| 466 | 
         
            +
             
     | 
| 467 | 
         
            +
                    return prompt
         
     | 
| 468 | 
         
            +
             
     | 
| 469 | 
         
            +
                def load_textual_inversion(
         
     | 
| 470 | 
         
            +
                    self,
         
     | 
| 471 | 
         
            +
                    pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
         
     | 
| 472 | 
         
            +
                    token: Optional[Union[str, List[str]]] = None,
         
     | 
| 473 | 
         
            +
                    **kwargs,
         
     | 
| 474 | 
         
            +
                ):
         
     | 
| 475 | 
         
            +
                    r"""
         
     | 
| 476 | 
         
            +
                    Load textual inversion embeddings into the text encoder of [`StableDiffusionPipeline`] (both 🤗 Diffusers and
         
     | 
| 477 | 
         
            +
                    Automatic1111 formats are supported).
         
     | 
| 478 | 
         
            +
             
     | 
| 479 | 
         
            +
                    Parameters:
         
     | 
| 480 | 
         
            +
                        pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
         
     | 
| 481 | 
         
            +
                            Can be either one of the following or a list of them:
         
     | 
| 482 | 
         
            +
             
     | 
| 483 | 
         
            +
                                - A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a
         
     | 
| 484 | 
         
            +
                                  pretrained model hosted on the Hub.
         
     | 
| 485 | 
         
            +
                                - A path to a *directory* (for example `./my_text_inversion_directory/`) containing the textual
         
     | 
| 486 | 
         
            +
                                  inversion weights.
         
     | 
| 487 | 
         
            +
                                - A path to a *file* (for example `./my_text_inversions.pt`) containing textual inversion weights.
         
     | 
| 488 | 
         
            +
                                - A [torch state
         
     | 
| 489 | 
         
            +
                                  dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
         
     | 
| 490 | 
         
            +
             
     | 
| 491 | 
         
            +
                        token (`str` or `List[str]`, *optional*):
         
     | 
| 492 | 
         
            +
                            Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
         
     | 
| 493 | 
         
            +
                            list, then `token` must also be a list of equal length.
         
     | 
| 494 | 
         
            +
                        weight_name (`str`, *optional*):
         
     | 
| 495 | 
         
            +
                            Name of a custom weight file. This should be used when:
         
     | 
| 496 | 
         
            +
             
     | 
| 497 | 
         
            +
                                - The saved textual inversion file is in 🤗 Diffusers format, but was saved under a specific weight
         
     | 
| 498 | 
         
            +
                                  name such as `text_inv.bin`.
         
     | 
| 499 | 
         
            +
                                - The saved textual inversion file is in the Automatic1111 format.
         
     | 
| 500 | 
         
            +
                        cache_dir (`Union[str, os.PathLike]`, *optional*):
         
     | 
| 501 | 
         
            +
                            Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
         
     | 
| 502 | 
         
            +
                            is not used.
         
     | 
| 503 | 
         
            +
                        force_download (`bool`, *optional*, defaults to `False`):
         
     | 
| 504 | 
         
            +
                            Whether or not to force the (re-)download of the model weights and configuration files, overriding the
         
     | 
| 505 | 
         
            +
                            cached versions if they exist.
         
     | 
| 506 | 
         
            +
                        resume_download (`bool`, *optional*, defaults to `False`):
         
     | 
| 507 | 
         
            +
                            Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
         
     | 
| 508 | 
         
            +
                            incompletely downloaded files are deleted.
         
     | 
| 509 | 
         
            +
                        proxies (`Dict[str, str]`, *optional*):
         
     | 
| 510 | 
         
            +
                            A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
         
     | 
| 511 | 
         
            +
                            'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
         
     | 
| 512 | 
         
            +
                        local_files_only (`bool`, *optional*, defaults to `False`):
         
     | 
| 513 | 
         
            +
                            Whether to only load local model weights and configuration files or not. If set to `True`, the model
         
     | 
| 514 | 
         
            +
                            won't be downloaded from the Hub.
         
     | 
| 515 | 
         
            +
                        use_auth_token (`str` or *bool*, *optional*):
         
     | 
| 516 | 
         
            +
                            The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
         
     | 
| 517 | 
         
            +
                            `diffusers-cli login` (stored in `~/.huggingface`) is used.
         
     | 
| 518 | 
         
            +
                        revision (`str`, *optional*, defaults to `"main"`):
         
     | 
| 519 | 
         
            +
                            The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
         
     | 
| 520 | 
         
            +
                            allowed by Git.
         
     | 
| 521 | 
         
            +
                        subfolder (`str`, *optional*, defaults to `""`):
         
     | 
| 522 | 
         
            +
                            The subfolder location of a model file within a larger model repository on the Hub or locally.
         
     | 
| 523 | 
         
            +
                        mirror (`str`, *optional*):
         
     | 
| 524 | 
         
            +
                            Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
         
     | 
| 525 | 
         
            +
                            guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
         
     | 
| 526 | 
         
            +
                            information.
         
     | 
| 527 | 
         
            +
             
     | 
| 528 | 
         
            +
                    Example:
         
     | 
| 529 | 
         
            +
             
     | 
| 530 | 
         
            +
                    To load a textual inversion embedding vector in 🤗 Diffusers format:
         
     | 
| 531 | 
         
            +
             
     | 
| 532 | 
         
            +
                    ```py
         
     | 
| 533 | 
         
            +
                    from diffusers import StableDiffusionPipeline
         
     | 
| 534 | 
         
            +
                    import torch
         
     | 
| 535 | 
         
            +
             
     | 
| 536 | 
         
            +
                    model_id = "runwayml/stable-diffusion-v1-5"
         
     | 
| 537 | 
         
            +
                    pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
         
     | 
| 538 | 
         
            +
             
     | 
| 539 | 
         
            +
                    pipe.load_textual_inversion("sd-concepts-library/cat-toy")
         
     | 
| 540 | 
         
            +
             
     | 
| 541 | 
         
            +
                    prompt = "A <cat-toy> backpack"
         
     | 
| 542 | 
         
            +
             
     | 
| 543 | 
         
            +
                    image = pipe(prompt, num_inference_steps=50).images[0]
         
     | 
| 544 | 
         
            +
                    image.save("cat-backpack.png")
         
     | 
| 545 | 
         
            +
                    ```
         
     | 
| 546 | 
         
            +
             
     | 
| 547 | 
         
            +
                    To load a textual inversion embedding vector in Automatic1111 format, make sure to download the vector first
         
     | 
| 548 | 
         
            +
                    (for example from [civitAI](https://civitai.com/models/3036?modelVersionId=9857)) and then load the vector
         
     | 
| 549 | 
         
            +
                    locally:
         
     | 
| 550 | 
         
            +
             
     | 
| 551 | 
         
            +
                    ```py
         
     | 
| 552 | 
         
            +
                    from diffusers import StableDiffusionPipeline
         
     | 
| 553 | 
         
            +
                    import torch
         
     | 
| 554 | 
         
            +
             
     | 
| 555 | 
         
            +
                    model_id = "runwayml/stable-diffusion-v1-5"
         
     | 
| 556 | 
         
            +
                    pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
         
     | 
| 557 | 
         
            +
             
     | 
| 558 | 
         
            +
                    pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2")
         
     | 
| 559 | 
         
            +
             
     | 
| 560 | 
         
            +
                    prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details."
         
     | 
| 561 | 
         
            +
             
     | 
| 562 | 
         
            +
                    image = pipe(prompt, num_inference_steps=50).images[0]
         
     | 
| 563 | 
         
            +
                    image.save("character.png")
         
     | 
| 564 | 
         
            +
                    ```
         
     | 
| 565 | 
         
            +
             
     | 
| 566 | 
         
            +
                    """
         
     | 
| 567 | 
         
            +
                    if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
         
     | 
| 568 | 
         
            +
                        raise ValueError(
         
     | 
| 569 | 
         
            +
                            f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling"
         
     | 
| 570 | 
         
            +
                            f" `{self.load_textual_inversion.__name__}`"
         
     | 
| 571 | 
         
            +
                        )
         
     | 
| 572 | 
         
            +
             
     | 
| 573 | 
         
            +
                    if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel):
         
     | 
| 574 | 
         
            +
                        raise ValueError(
         
     | 
| 575 | 
         
            +
                            f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling"
         
     | 
| 576 | 
         
            +
                            f" `{self.load_textual_inversion.__name__}`"
         
     | 
| 577 | 
         
            +
                        )
         
     | 
| 578 | 
         
            +
             
     | 
| 579 | 
         
            +
                    cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
         
     | 
| 580 | 
         
            +
                    force_download = kwargs.pop("force_download", False)
         
     | 
| 581 | 
         
            +
                    resume_download = kwargs.pop("resume_download", False)
         
     | 
| 582 | 
         
            +
                    proxies = kwargs.pop("proxies", None)
         
     | 
| 583 | 
         
            +
                    local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
         
     | 
| 584 | 
         
            +
                    use_auth_token = kwargs.pop("use_auth_token", None)
         
     | 
| 585 | 
         
            +
                    revision = kwargs.pop("revision", None)
         
     | 
| 586 | 
         
            +
                    subfolder = kwargs.pop("subfolder", None)
         
     | 
| 587 | 
         
            +
                    weight_name = kwargs.pop("weight_name", None)
         
     | 
| 588 | 
         
            +
                    use_safetensors = kwargs.pop("use_safetensors", None)
         
     | 
| 589 | 
         
            +
             
     | 
| 590 | 
         
            +
                    if use_safetensors and not is_safetensors_available():
         
     | 
| 591 | 
         
            +
                        raise ValueError(
         
     | 
| 592 | 
         
            +
                            "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
         
     | 
| 593 | 
         
            +
                        )
         
     | 
| 594 | 
         
            +
             
     | 
| 595 | 
         
            +
                    allow_pickle = False
         
     | 
| 596 | 
         
            +
                    if use_safetensors is None:
         
     | 
| 597 | 
         
            +
                        use_safetensors = is_safetensors_available()
         
     | 
| 598 | 
         
            +
                        allow_pickle = True
         
     | 
| 599 | 
         
            +
             
     | 
| 600 | 
         
            +
                    user_agent = {
         
     | 
| 601 | 
         
            +
                        "file_type": "text_inversion",
         
     | 
| 602 | 
         
            +
                        "framework": "pytorch",
         
     | 
| 603 | 
         
            +
                    }
         
     | 
| 604 | 
         
            +
             
     | 
| 605 | 
         
            +
                    if not isinstance(pretrained_model_name_or_path, list):
         
     | 
| 606 | 
         
            +
                        pretrained_model_name_or_paths = [pretrained_model_name_or_path]
         
     | 
| 607 | 
         
            +
                    else:
         
     | 
| 608 | 
         
            +
                        pretrained_model_name_or_paths = pretrained_model_name_or_path
         
     | 
| 609 | 
         
            +
             
     | 
| 610 | 
         
            +
                    if isinstance(token, str):
         
     | 
| 611 | 
         
            +
                        tokens = [token]
         
     | 
| 612 | 
         
            +
                    elif token is None:
         
     | 
| 613 | 
         
            +
                        tokens = [None] * len(pretrained_model_name_or_paths)
         
     | 
| 614 | 
         
            +
                    else:
         
     | 
| 615 | 
         
            +
                        tokens = token
         
     | 
| 616 | 
         
            +
             
     | 
| 617 | 
         
            +
                    if len(pretrained_model_name_or_paths) != len(tokens):
         
     | 
| 618 | 
         
            +
                        raise ValueError(
         
     | 
| 619 | 
         
            +
                            f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}"
         
     | 
| 620 | 
         
            +
                            f"Make sure both lists have the same length."
         
     | 
| 621 | 
         
            +
                        )
         
     | 
| 622 | 
         
            +
             
     | 
| 623 | 
         
            +
                    valid_tokens = [t for t in tokens if t is not None]
         
     | 
| 624 | 
         
            +
                    if len(set(valid_tokens)) < len(valid_tokens):
         
     | 
| 625 | 
         
            +
                        raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
         
     | 
| 626 | 
         
            +
             
     | 
| 627 | 
         
            +
                    token_ids_and_embeddings = []
         
     | 
| 628 | 
         
            +
             
     | 
| 629 | 
         
            +
                    for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
         
     | 
| 630 | 
         
            +
                        if not isinstance(pretrained_model_name_or_path, dict):
         
     | 
| 631 | 
         
            +
                            # 1. Load textual inversion file
         
     | 
| 632 | 
         
            +
                            model_file = None
         
     | 
| 633 | 
         
            +
                            # Let's first try to load .safetensors weights
         
     | 
| 634 | 
         
            +
                            if (use_safetensors and weight_name is None) or (
         
     | 
| 635 | 
         
            +
                                weight_name is not None and weight_name.endswith(".safetensors")
         
     | 
| 636 | 
         
            +
                            ):
         
     | 
| 637 | 
         
            +
                                try:
         
     | 
| 638 | 
         
            +
                                    model_file = _get_model_file(
         
     | 
| 639 | 
         
            +
                                        pretrained_model_name_or_path,
         
     | 
| 640 | 
         
            +
                                        weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
         
     | 
| 641 | 
         
            +
                                        cache_dir=cache_dir,
         
     | 
| 642 | 
         
            +
                                        force_download=force_download,
         
     | 
| 643 | 
         
            +
                                        resume_download=resume_download,
         
     | 
| 644 | 
         
            +
                                        proxies=proxies,
         
     | 
| 645 | 
         
            +
                                        local_files_only=local_files_only,
         
     | 
| 646 | 
         
            +
                                        use_auth_token=use_auth_token,
         
     | 
| 647 | 
         
            +
                                        revision=revision,
         
     | 
| 648 | 
         
            +
                                        subfolder=subfolder,
         
     | 
| 649 | 
         
            +
                                        user_agent=user_agent,
         
     | 
| 650 | 
         
            +
                                    )
         
     | 
| 651 | 
         
            +
                                    state_dict = safetensors.torch.load_file(model_file, device="cpu")
         
     | 
| 652 | 
         
            +
                                except Exception as e:
         
     | 
| 653 | 
         
            +
                                    if not allow_pickle:
         
     | 
| 654 | 
         
            +
                                        raise e
         
     | 
| 655 | 
         
            +
             
     | 
| 656 | 
         
            +
                                    model_file = None
         
     | 
| 657 | 
         
            +
             
     | 
| 658 | 
         
            +
                            if model_file is None:
         
     | 
| 659 | 
         
            +
                                model_file = _get_model_file(
         
     | 
| 660 | 
         
            +
                                    pretrained_model_name_or_path,
         
     | 
| 661 | 
         
            +
                                    weights_name=weight_name or TEXT_INVERSION_NAME,
         
     | 
| 662 | 
         
            +
                                    cache_dir=cache_dir,
         
     | 
| 663 | 
         
            +
                                    force_download=force_download,
         
     | 
| 664 | 
         
            +
                                    resume_download=resume_download,
         
     | 
| 665 | 
         
            +
                                    proxies=proxies,
         
     | 
| 666 | 
         
            +
                                    local_files_only=local_files_only,
         
     | 
| 667 | 
         
            +
                                    use_auth_token=use_auth_token,
         
     | 
| 668 | 
         
            +
                                    revision=revision,
         
     | 
| 669 | 
         
            +
                                    subfolder=subfolder,
         
     | 
| 670 | 
         
            +
                                    user_agent=user_agent,
         
     | 
| 671 | 
         
            +
                                )
         
     | 
| 672 | 
         
            +
                                state_dict = torch.load(model_file, map_location="cpu")
         
     | 
| 673 | 
         
            +
                        else:
         
     | 
| 674 | 
         
            +
                            state_dict = pretrained_model_name_or_path
         
     | 
| 675 | 
         
            +
             
     | 
| 676 | 
         
            +
                        # 2. Load token and embedding correcly from file
         
     | 
| 677 | 
         
            +
                        loaded_token = None
         
     | 
| 678 | 
         
            +
                        if isinstance(state_dict, torch.Tensor):
         
     | 
| 679 | 
         
            +
                            if token is None:
         
     | 
| 680 | 
         
            +
                                raise ValueError(
         
     | 
| 681 | 
         
            +
                                    "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
         
     | 
| 682 | 
         
            +
                                )
         
     | 
| 683 | 
         
            +
                            embedding = state_dict
         
     | 
| 684 | 
         
            +
                        elif len(state_dict) == 1:
         
     | 
| 685 | 
         
            +
                            # diffusers
         
     | 
| 686 | 
         
            +
                            loaded_token, embedding = next(iter(state_dict.items()))
         
     | 
| 687 | 
         
            +
                        elif "string_to_param" in state_dict:
         
     | 
| 688 | 
         
            +
                            # A1111
         
     | 
| 689 | 
         
            +
                            loaded_token = state_dict["name"]
         
     | 
| 690 | 
         
            +
                            embedding = state_dict["string_to_param"]["*"]
         
     | 
| 691 | 
         
            +
             
     | 
| 692 | 
         
            +
                        if token is not None and loaded_token != token:
         
     | 
| 693 | 
         
            +
                            logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
         
     | 
| 694 | 
         
            +
                        else:
         
     | 
| 695 | 
         
            +
                            token = loaded_token
         
     | 
| 696 | 
         
            +
             
     | 
| 697 | 
         
            +
                        embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device)
         
     | 
| 698 | 
         
            +
             
     | 
| 699 | 
         
            +
                        # 3. Make sure we don't mess up the tokenizer or text encoder
         
     | 
| 700 | 
         
            +
                        vocab = self.tokenizer.get_vocab()
         
     | 
| 701 | 
         
            +
                        if token in vocab:
         
     | 
| 702 | 
         
            +
                            raise ValueError(
         
     | 
| 703 | 
         
            +
                                f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
         
     | 
| 704 | 
         
            +
                            )
         
     | 
| 705 | 
         
            +
                        elif f"{token}_1" in vocab:
         
     | 
| 706 | 
         
            +
                            multi_vector_tokens = [token]
         
     | 
| 707 | 
         
            +
                            i = 1
         
     | 
| 708 | 
         
            +
                            while f"{token}_{i}" in self.tokenizer.added_tokens_encoder:
         
     | 
| 709 | 
         
            +
                                multi_vector_tokens.append(f"{token}_{i}")
         
     | 
| 710 | 
         
            +
                                i += 1
         
     | 
| 711 | 
         
            +
             
     | 
| 712 | 
         
            +
                            raise ValueError(
         
     | 
| 713 | 
         
            +
                                f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
         
     | 
| 714 | 
         
            +
                            )
         
     | 
| 715 | 
         
            +
             
     | 
| 716 | 
         
            +
                        is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
         
     | 
| 717 | 
         
            +
             
     | 
| 718 | 
         
            +
                        if is_multi_vector:
         
     | 
| 719 | 
         
            +
                            tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
         
     | 
| 720 | 
         
            +
                            embeddings = [e for e in embedding]  # noqa: C416
         
     | 
| 721 | 
         
            +
                        else:
         
     | 
| 722 | 
         
            +
                            tokens = [token]
         
     | 
| 723 | 
         
            +
                            embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding]
         
     | 
| 724 | 
         
            +
             
     | 
| 725 | 
         
            +
                        # add tokens and get ids
         
     | 
| 726 | 
         
            +
                        self.tokenizer.add_tokens(tokens)
         
     | 
| 727 | 
         
            +
                        token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
         
     | 
| 728 | 
         
            +
                        token_ids_and_embeddings += zip(token_ids, embeddings)
         
     | 
| 729 | 
         
            +
             
     | 
| 730 | 
         
            +
                        logger.info(f"Loaded textual inversion embedding for {token}.")
         
     | 
| 731 | 
         
            +
             
     | 
| 732 | 
         
            +
                    # resize token embeddings and set all new embeddings
         
     | 
| 733 | 
         
            +
                    self.text_encoder.resize_token_embeddings(len(self.tokenizer))
         
     | 
| 734 | 
         
            +
                    for token_id, embedding in token_ids_and_embeddings:
         
     | 
| 735 | 
         
            +
                        self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
         
     | 
| 736 | 
         
            +
             
     | 
| 737 | 
         
            +
             
     | 
| 738 | 
         
            +
            class LoraLoaderMixin:
         
     | 
| 739 | 
         
            +
                r"""
         
     | 
| 740 | 
         
            +
                Load LoRA layers into [`UNet2DConditionModel`] and
         
     | 
| 741 | 
         
            +
                [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
         
     | 
| 742 | 
         
            +
                """
         
     | 
| 743 | 
         
            +
                text_encoder_name = TEXT_ENCODER_NAME
         
     | 
| 744 | 
         
            +
                unet_name = UNET_NAME
         
     | 
| 745 | 
         
            +
             
     | 
| 746 | 
         
            +
                def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
         
     | 
| 747 | 
         
            +
                    r"""
         
     | 
| 748 | 
         
            +
                    Load pretrained LoRA attention processor layers into [`UNet2DConditionModel`] and
         
     | 
| 749 | 
         
            +
                    [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
         
     | 
| 750 | 
         
            +
             
     | 
| 751 | 
         
            +
                    Parameters:
         
     | 
| 752 | 
         
            +
                        pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
         
     | 
| 753 | 
         
            +
                            Can be either:
         
     | 
| 754 | 
         
            +
             
     | 
| 755 | 
         
            +
                                - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
         
     | 
| 756 | 
         
            +
                                  the Hub.
         
     | 
| 757 | 
         
            +
                                - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
         
     | 
| 758 | 
         
            +
                                  with [`ModelMixin.save_pretrained`].
         
     | 
| 759 | 
         
            +
                                - A [torch state
         
     | 
| 760 | 
         
            +
                                  dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
         
     | 
| 761 | 
         
            +
             
     | 
| 762 | 
         
            +
                        cache_dir (`Union[str, os.PathLike]`, *optional*):
         
     | 
| 763 | 
         
            +
                            Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
         
     | 
| 764 | 
         
            +
                            is not used.
         
     | 
| 765 | 
         
            +
                        force_download (`bool`, *optional*, defaults to `False`):
         
     | 
| 766 | 
         
            +
                            Whether or not to force the (re-)download of the model weights and configuration files, overriding the
         
     | 
| 767 | 
         
            +
                            cached versions if they exist.
         
     | 
| 768 | 
         
            +
                        resume_download (`bool`, *optional*, defaults to `False`):
         
     | 
| 769 | 
         
            +
                            Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
         
     | 
| 770 | 
         
            +
                            incompletely downloaded files are deleted.
         
     | 
| 771 | 
         
            +
                        proxies (`Dict[str, str]`, *optional*):
         
     | 
| 772 | 
         
            +
                            A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
         
     | 
| 773 | 
         
            +
                            'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
         
     | 
| 774 | 
         
            +
                        local_files_only (`bool`, *optional*, defaults to `False`):
         
     | 
| 775 | 
         
            +
                            Whether to only load local model weights and configuration files or not. If set to `True`, the model
         
     | 
| 776 | 
         
            +
                            won't be downloaded from the Hub.
         
     | 
| 777 | 
         
            +
                        use_auth_token (`str` or *bool*, *optional*):
         
     | 
| 778 | 
         
            +
                            The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
         
     | 
| 779 | 
         
            +
                            `diffusers-cli login` (stored in `~/.huggingface`) is used.
         
     | 
| 780 | 
         
            +
                        revision (`str`, *optional*, defaults to `"main"`):
         
     | 
| 781 | 
         
            +
                            The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
         
     | 
| 782 | 
         
            +
                            allowed by Git.
         
     | 
| 783 | 
         
            +
                        subfolder (`str`, *optional*, defaults to `""`):
         
     | 
| 784 | 
         
            +
                            The subfolder location of a model file within a larger model repository on the Hub or locally.
         
     | 
| 785 | 
         
            +
                        mirror (`str`, *optional*):
         
     | 
| 786 | 
         
            +
                            Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
         
     | 
| 787 | 
         
            +
                            guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
         
     | 
| 788 | 
         
            +
                            information.
         
     | 
| 789 | 
         
            +
             
     | 
| 790 | 
         
            +
                    """
         
     | 
| 791 | 
         
            +
                    # Load the main state dict first which has the LoRA layers for either of
         
     | 
| 792 | 
         
            +
                    # UNet and text encoder or both.
         
     | 
| 793 | 
         
            +
                    cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
         
     | 
| 794 | 
         
            +
                    force_download = kwargs.pop("force_download", False)
         
     | 
| 795 | 
         
            +
                    resume_download = kwargs.pop("resume_download", False)
         
     | 
| 796 | 
         
            +
                    proxies = kwargs.pop("proxies", None)
         
     | 
| 797 | 
         
            +
                    local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
         
     | 
| 798 | 
         
            +
                    use_auth_token = kwargs.pop("use_auth_token", None)
         
     | 
| 799 | 
         
            +
                    revision = kwargs.pop("revision", None)
         
     | 
| 800 | 
         
            +
                    subfolder = kwargs.pop("subfolder", None)
         
     | 
| 801 | 
         
            +
                    weight_name = kwargs.pop("weight_name", None)
         
     | 
| 802 | 
         
            +
                    use_safetensors = kwargs.pop("use_safetensors", None)
         
     | 
| 803 | 
         
            +
             
     | 
| 804 | 
         
            +
                    # set lora scale to a reasonable default
         
     | 
| 805 | 
         
            +
                    self._lora_scale = 1.0
         
     | 
| 806 | 
         
            +
             
     | 
| 807 | 
         
            +
                    if use_safetensors and not is_safetensors_available():
         
     | 
| 808 | 
         
            +
                        raise ValueError(
         
     | 
| 809 | 
         
            +
                            "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
         
     | 
| 810 | 
         
            +
                        )
         
     | 
| 811 | 
         
            +
             
     | 
| 812 | 
         
            +
                    allow_pickle = False
         
     | 
| 813 | 
         
            +
                    if use_safetensors is None:
         
     | 
| 814 | 
         
            +
                        use_safetensors = is_safetensors_available()
         
     | 
| 815 | 
         
            +
                        allow_pickle = True
         
     | 
| 816 | 
         
            +
             
     | 
| 817 | 
         
            +
                    user_agent = {
         
     | 
| 818 | 
         
            +
                        "file_type": "attn_procs_weights",
         
     | 
| 819 | 
         
            +
                        "framework": "pytorch",
         
     | 
| 820 | 
         
            +
                    }
         
     | 
| 821 | 
         
            +
             
     | 
| 822 | 
         
            +
                    model_file = None
         
     | 
| 823 | 
         
            +
                    if not isinstance(pretrained_model_name_or_path_or_dict, dict):
         
     | 
| 824 | 
         
            +
                        # Let's first try to load .safetensors weights
         
     | 
| 825 | 
         
            +
                        if (use_safetensors and weight_name is None) or (
         
     | 
| 826 | 
         
            +
                            weight_name is not None and weight_name.endswith(".safetensors")
         
     | 
| 827 | 
         
            +
                        ):
         
     | 
| 828 | 
         
            +
                            try:
         
     | 
| 829 | 
         
            +
                                model_file = _get_model_file(
         
     | 
| 830 | 
         
            +
                                    pretrained_model_name_or_path_or_dict,
         
     | 
| 831 | 
         
            +
                                    weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
         
     | 
| 832 | 
         
            +
                                    cache_dir=cache_dir,
         
     | 
| 833 | 
         
            +
                                    force_download=force_download,
         
     | 
| 834 | 
         
            +
                                    resume_download=resume_download,
         
     | 
| 835 | 
         
            +
                                    proxies=proxies,
         
     | 
| 836 | 
         
            +
                                    local_files_only=local_files_only,
         
     | 
| 837 | 
         
            +
                                    use_auth_token=use_auth_token,
         
     | 
| 838 | 
         
            +
                                    revision=revision,
         
     | 
| 839 | 
         
            +
                                    subfolder=subfolder,
         
     | 
| 840 | 
         
            +
                                    user_agent=user_agent,
         
     | 
| 841 | 
         
            +
                                )
         
     | 
| 842 | 
         
            +
                                state_dict = safetensors.torch.load_file(model_file, device="cpu")
         
     | 
| 843 | 
         
            +
                            except IOError as e:
         
     | 
| 844 | 
         
            +
                                if not allow_pickle:
         
     | 
| 845 | 
         
            +
                                    raise e
         
     | 
| 846 | 
         
            +
                                # try loading non-safetensors weights
         
     | 
| 847 | 
         
            +
                                pass
         
     | 
| 848 | 
         
            +
                        if model_file is None:
         
     | 
| 849 | 
         
            +
                            model_file = _get_model_file(
         
     | 
| 850 | 
         
            +
                                pretrained_model_name_or_path_or_dict,
         
     | 
| 851 | 
         
            +
                                weights_name=weight_name or LORA_WEIGHT_NAME,
         
     | 
| 852 | 
         
            +
                                cache_dir=cache_dir,
         
     | 
| 853 | 
         
            +
                                force_download=force_download,
         
     | 
| 854 | 
         
            +
                                resume_download=resume_download,
         
     | 
| 855 | 
         
            +
                                proxies=proxies,
         
     | 
| 856 | 
         
            +
                                local_files_only=local_files_only,
         
     | 
| 857 | 
         
            +
                                use_auth_token=use_auth_token,
         
     | 
| 858 | 
         
            +
                                revision=revision,
         
     | 
| 859 | 
         
            +
                                subfolder=subfolder,
         
     | 
| 860 | 
         
            +
                                user_agent=user_agent,
         
     | 
| 861 | 
         
            +
                            )
         
     | 
| 862 | 
         
            +
                            state_dict = torch.load(model_file, map_location="cpu")
         
     | 
| 863 | 
         
            +
                    else:
         
     | 
| 864 | 
         
            +
                        state_dict = pretrained_model_name_or_path_or_dict
         
     | 
| 865 | 
         
            +
             
     | 
| 866 | 
         
            +
                    # Convert kohya-ss Style LoRA attn procs to diffusers attn procs
         
     | 
| 867 | 
         
            +
                    network_alpha = None
         
     | 
| 868 | 
         
            +
                    if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()):
         
     | 
| 869 | 
         
            +
                        state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict)
         
     | 
| 870 | 
         
            +
             
     | 
| 871 | 
         
            +
                    # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
         
     | 
| 872 | 
         
            +
                    # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
         
     | 
| 873 | 
         
            +
                    # their prefixes.
         
     | 
| 874 | 
         
            +
                    keys = list(state_dict.keys())
         
     | 
| 875 | 
         
            +
                    if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys):
         
     | 
| 876 | 
         
            +
                        # Load the layers corresponding to UNet.
         
     | 
| 877 | 
         
            +
                        unet_keys = [k for k in keys if k.startswith(self.unet_name)]
         
     | 
| 878 | 
         
            +
                        logger.info(f"Loading {self.unet_name}.")
         
     | 
| 879 | 
         
            +
                        unet_lora_state_dict = {
         
     | 
| 880 | 
         
            +
                            k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
         
     | 
| 881 | 
         
            +
                        }
         
     | 
| 882 | 
         
            +
                        self.unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha)
         
     | 
| 883 | 
         
            +
             
     | 
| 884 | 
         
            +
                        # Load the layers corresponding to text encoder and make necessary adjustments.
         
     | 
| 885 | 
         
            +
                        text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
         
     | 
| 886 | 
         
            +
                        text_encoder_lora_state_dict = {
         
     | 
| 887 | 
         
            +
                            k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
         
     | 
| 888 | 
         
            +
                        }
         
     | 
| 889 | 
         
            +
                        if len(text_encoder_lora_state_dict) > 0:
         
     | 
| 890 | 
         
            +
                            logger.info(f"Loading {self.text_encoder_name}.")
         
     | 
| 891 | 
         
            +
                            attn_procs_text_encoder = self._load_text_encoder_attn_procs(
         
     | 
| 892 | 
         
            +
                                text_encoder_lora_state_dict, network_alpha=network_alpha
         
     | 
| 893 | 
         
            +
                            )
         
     | 
| 894 | 
         
            +
                            self._modify_text_encoder(attn_procs_text_encoder)
         
     | 
| 895 | 
         
            +
             
     | 
| 896 | 
         
            +
                            # save lora attn procs of text encoder so that it can be easily retrieved
         
     | 
| 897 | 
         
            +
                            self._text_encoder_lora_attn_procs = attn_procs_text_encoder
         
     | 
| 898 | 
         
            +
             
     | 
| 899 | 
         
            +
                    # Otherwise, we're dealing with the old format. This means the `state_dict` should only
         
     | 
| 900 | 
         
            +
                    # contain the module names of the `unet` as its keys WITHOUT any prefix.
         
     | 
| 901 | 
         
            +
                    elif not all(
         
     | 
| 902 | 
         
            +
                        key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
         
     | 
| 903 | 
         
            +
                    ):
         
     | 
| 904 | 
         
            +
                        self.unet.load_attn_procs(state_dict)
         
     | 
| 905 | 
         
            +
                        warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
         
     | 
| 906 | 
         
            +
                        warnings.warn(warn_message)
         
     | 
| 907 | 
         
            +
             
     | 
| 908 | 
         
            +
                @property
         
     | 
| 909 | 
         
            +
                def lora_scale(self) -> float:
         
     | 
| 910 | 
         
            +
                    # property function that returns the lora scale which can be set at run time by the pipeline.
         
     | 
| 911 | 
         
            +
                    # if _lora_scale has not been set, return 1
         
     | 
| 912 | 
         
            +
                    return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
         
     | 
| 913 | 
         
            +
             
     | 
| 914 | 
         
            +
                @property
         
     | 
| 915 | 
         
            +
                def text_encoder_lora_attn_procs(self):
         
     | 
| 916 | 
         
            +
                    if hasattr(self, "_text_encoder_lora_attn_procs"):
         
     | 
| 917 | 
         
            +
                        return self._text_encoder_lora_attn_procs
         
     | 
| 918 | 
         
            +
                    return
         
     | 
| 919 | 
         
            +
             
     | 
| 920 | 
         
            +
                def _remove_text_encoder_monkey_patch(self):
         
     | 
| 921 | 
         
            +
                    # Loop over the CLIPAttention module of text_encoder
         
     | 
| 922 | 
         
            +
                    for name, attn_module in self.text_encoder.named_modules():
         
     | 
| 923 | 
         
            +
                        if name.endswith(TEXT_ENCODER_ATTN_MODULE):
         
     | 
| 924 | 
         
            +
                            # Loop over the LoRA layers
         
     | 
| 925 | 
         
            +
                            for _, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
         
     | 
| 926 | 
         
            +
                                # Retrieve the q/k/v/out projection of CLIPAttention
         
     | 
| 927 | 
         
            +
                                module = attn_module.get_submodule(text_encoder_attr)
         
     | 
| 928 | 
         
            +
                                if hasattr(module, "old_forward"):
         
     | 
| 929 | 
         
            +
                                    # restore original `forward` to remove monkey-patch
         
     | 
| 930 | 
         
            +
                                    module.forward = module.old_forward
         
     | 
| 931 | 
         
            +
                                    delattr(module, "old_forward")
         
     | 
| 932 | 
         
            +
             
     | 
| 933 | 
         
            +
                def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
         
     | 
| 934 | 
         
            +
                    r"""
         
     | 
| 935 | 
         
            +
                    Monkey-patches the forward passes of attention modules of the text encoder.
         
     | 
| 936 | 
         
            +
             
     | 
| 937 | 
         
            +
                    Parameters:
         
     | 
| 938 | 
         
            +
                        attn_processors: Dict[str, `LoRAAttnProcessor`]:
         
     | 
| 939 | 
         
            +
                            A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`].
         
     | 
| 940 | 
         
            +
                    """
         
     | 
| 941 | 
         
            +
             
     | 
| 942 | 
         
            +
                    # First, remove any monkey-patch that might have been applied before
         
     | 
| 943 | 
         
            +
                    self._remove_text_encoder_monkey_patch()
         
     | 
| 944 | 
         
            +
             
     | 
| 945 | 
         
            +
                    # Loop over the CLIPAttention module of text_encoder
         
     | 
| 946 | 
         
            +
                    for name, attn_module in self.text_encoder.named_modules():
         
     | 
| 947 | 
         
            +
                        if name.endswith(TEXT_ENCODER_ATTN_MODULE):
         
     | 
| 948 | 
         
            +
                            # Loop over the LoRA layers
         
     | 
| 949 | 
         
            +
                            for attn_proc_attr, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
         
     | 
| 950 | 
         
            +
                                # Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer.
         
     | 
| 951 | 
         
            +
                                module = attn_module.get_submodule(text_encoder_attr)
         
     | 
| 952 | 
         
            +
                                lora_layer = attn_processors[name].get_submodule(attn_proc_attr)
         
     | 
| 953 | 
         
            +
             
     | 
| 954 | 
         
            +
                                # save old_forward to module that can be used to remove monkey-patch
         
     | 
| 955 | 
         
            +
                                old_forward = module.old_forward = module.forward
         
     | 
| 956 | 
         
            +
             
     | 
| 957 | 
         
            +
                                # create a new scope that locks in the old_forward, lora_layer value for each new_forward function
         
     | 
| 958 | 
         
            +
                                # for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
         
     | 
| 959 | 
         
            +
                                def make_new_forward(old_forward, lora_layer):
         
     | 
| 960 | 
         
            +
                                    def new_forward(x):
         
     | 
| 961 | 
         
            +
                                        result = old_forward(x) + self.lora_scale * lora_layer(x)
         
     | 
| 962 | 
         
            +
                                        return result
         
     | 
| 963 | 
         
            +
             
     | 
| 964 | 
         
            +
                                    return new_forward
         
     | 
| 965 | 
         
            +
             
     | 
| 966 | 
         
            +
                                # Monkey-patch.
         
     | 
| 967 | 
         
            +
                                module.forward = make_new_forward(old_forward, lora_layer)
         
     | 
| 968 | 
         
            +
             
     | 
| 969 | 
         
            +
                @property
         
     | 
| 970 | 
         
            +
                def _lora_attn_processor_attr_to_text_encoder_attr(self):
         
     | 
| 971 | 
         
            +
                    return {
         
     | 
| 972 | 
         
            +
                        "to_q_lora": "q_proj",
         
     | 
| 973 | 
         
            +
                        "to_k_lora": "k_proj",
         
     | 
| 974 | 
         
            +
                        "to_v_lora": "v_proj",
         
     | 
| 975 | 
         
            +
                        "to_out_lora": "out_proj",
         
     | 
| 976 | 
         
            +
                    }
         
     | 
| 977 | 
         
            +
             
     | 
| 978 | 
         
            +
                def _load_text_encoder_attn_procs(
         
     | 
| 979 | 
         
            +
                    self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs
         
     | 
| 980 | 
         
            +
                ):
         
     | 
| 981 | 
         
            +
                    r"""
         
     | 
| 982 | 
         
            +
                    Load pretrained attention processor layers for
         
     | 
| 983 | 
         
            +
                    [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
         
     | 
| 984 | 
         
            +
             
     | 
| 985 | 
         
            +
                    <Tip warning={true}>
         
     | 
| 986 | 
         
            +
             
     | 
| 987 | 
         
            +
                    This function is experimental and might change in the future.
         
     | 
| 988 | 
         
            +
             
     | 
| 989 | 
         
            +
                    </Tip>
         
     | 
| 990 | 
         
            +
             
     | 
| 991 | 
         
            +
                    Parameters:
         
     | 
| 992 | 
         
            +
                        pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
         
     | 
| 993 | 
         
            +
                            Can be either:
         
     | 
| 994 | 
         
            +
             
     | 
| 995 | 
         
            +
                                - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
         
     | 
| 996 | 
         
            +
                                  Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
         
     | 
| 997 | 
         
            +
                                - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
         
     | 
| 998 | 
         
            +
                                  `./my_model_directory/`.
         
     | 
| 999 | 
         
            +
                                - A [torch state
         
     | 
| 1000 | 
         
            +
                                  dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
         
     | 
| 1001 | 
         
            +
             
     | 
| 1002 | 
         
            +
                        cache_dir (`Union[str, os.PathLike]`, *optional*):
         
     | 
| 1003 | 
         
            +
                            Path to a directory in which a downloaded pretrained model configuration should be cached if the
         
     | 
| 1004 | 
         
            +
                            standard cache should not be used.
         
     | 
| 1005 | 
         
            +
                        force_download (`bool`, *optional*, defaults to `False`):
         
     | 
| 1006 | 
         
            +
                            Whether or not to force the (re-)download of the model weights and configuration files, overriding the
         
     | 
| 1007 | 
         
            +
                            cached versions if they exist.
         
     | 
| 1008 | 
         
            +
                        resume_download (`bool`, *optional*, defaults to `False`):
         
     | 
| 1009 | 
         
            +
                            Whether or not to delete incompletely received files. Will attempt to resume the download if such a
         
     | 
| 1010 | 
         
            +
                            file exists.
         
     | 
| 1011 | 
         
            +
                        proxies (`Dict[str, str]`, *optional*):
         
     | 
| 1012 | 
         
            +
                            A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
         
     | 
| 1013 | 
         
            +
                            'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
         
     | 
| 1014 | 
         
            +
                        local_files_only (`bool`, *optional*, defaults to `False`):
         
     | 
| 1015 | 
         
            +
                            Whether or not to only look at local files (i.e., do not try to download the model).
         
     | 
| 1016 | 
         
            +
                        use_auth_token (`str` or *bool*, *optional*):
         
     | 
| 1017 | 
         
            +
                            The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
         
     | 
| 1018 | 
         
            +
                            when running `diffusers-cli login` (stored in `~/.huggingface`).
         
     | 
| 1019 | 
         
            +
                        revision (`str`, *optional*, defaults to `"main"`):
         
     | 
| 1020 | 
         
            +
                            The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
         
     | 
| 1021 | 
         
            +
                            git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
         
     | 
| 1022 | 
         
            +
                            identifier allowed by git.
         
     | 
| 1023 | 
         
            +
                        subfolder (`str`, *optional*, defaults to `""`):
         
     | 
| 1024 | 
         
            +
                            In case the relevant files are located inside a subfolder of the model repo (either remote in
         
     | 
| 1025 | 
         
            +
                            huggingface.co or downloaded locally), you can specify the folder name here.
         
     | 
| 1026 | 
         
            +
                        mirror (`str`, *optional*):
         
     | 
| 1027 | 
         
            +
                            Mirror source to accelerate downloads in China. If you are from China and have an accessibility
         
     | 
| 1028 | 
         
            +
                            problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
         
     | 
| 1029 | 
         
            +
                            Please refer to the mirror site for more information.
         
     | 
| 1030 | 
         
            +
             
     | 
| 1031 | 
         
            +
                    Returns:
         
     | 
| 1032 | 
         
            +
                        `Dict[name, LoRAAttnProcessor]`: Mapping between the module names and their corresponding
         
     | 
| 1033 | 
         
            +
                        [`LoRAAttnProcessor`].
         
     | 
| 1034 | 
         
            +
             
     | 
| 1035 | 
         
            +
                    <Tip>
         
     | 
| 1036 | 
         
            +
             
     | 
| 1037 | 
         
            +
                    It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
         
     | 
| 1038 | 
         
            +
                    models](https://huggingface.co/docs/hub/models-gated#gated-models).
         
     | 
| 1039 | 
         
            +
             
     | 
| 1040 | 
         
            +
                    </Tip>
         
     | 
| 1041 | 
         
            +
                    """
         
     | 
| 1042 | 
         
            +
             
     | 
| 1043 | 
         
            +
                    cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
         
     | 
| 1044 | 
         
            +
                    force_download = kwargs.pop("force_download", False)
         
     | 
| 1045 | 
         
            +
                    resume_download = kwargs.pop("resume_download", False)
         
     | 
| 1046 | 
         
            +
                    proxies = kwargs.pop("proxies", None)
         
     | 
| 1047 | 
         
            +
                    local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
         
     | 
| 1048 | 
         
            +
                    use_auth_token = kwargs.pop("use_auth_token", None)
         
     | 
| 1049 | 
         
            +
                    revision = kwargs.pop("revision", None)
         
     | 
| 1050 | 
         
            +
                    subfolder = kwargs.pop("subfolder", None)
         
     | 
| 1051 | 
         
            +
                    weight_name = kwargs.pop("weight_name", None)
         
     | 
| 1052 | 
         
            +
                    use_safetensors = kwargs.pop("use_safetensors", None)
         
     | 
| 1053 | 
         
            +
                    network_alpha = kwargs.pop("network_alpha", None)
         
     | 
| 1054 | 
         
            +
             
     | 
| 1055 | 
         
            +
                    if use_safetensors and not is_safetensors_available():
         
     | 
| 1056 | 
         
            +
                        raise ValueError(
         
     | 
| 1057 | 
         
            +
                            "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
         
     | 
| 1058 | 
         
            +
                        )
         
     | 
| 1059 | 
         
            +
             
     | 
| 1060 | 
         
            +
                    allow_pickle = False
         
     | 
| 1061 | 
         
            +
                    if use_safetensors is None:
         
     | 
| 1062 | 
         
            +
                        use_safetensors = is_safetensors_available()
         
     | 
| 1063 | 
         
            +
                        allow_pickle = True
         
     | 
| 1064 | 
         
            +
             
     | 
| 1065 | 
         
            +
                    user_agent = {
         
     | 
| 1066 | 
         
            +
                        "file_type": "attn_procs_weights",
         
     | 
| 1067 | 
         
            +
                        "framework": "pytorch",
         
     | 
| 1068 | 
         
            +
                    }
         
     | 
| 1069 | 
         
            +
             
     | 
| 1070 | 
         
            +
                    model_file = None
         
     | 
| 1071 | 
         
            +
                    if not isinstance(pretrained_model_name_or_path_or_dict, dict):
         
     | 
| 1072 | 
         
            +
                        # Let's first try to load .safetensors weights
         
     | 
| 1073 | 
         
            +
                        if (use_safetensors and weight_name is None) or (
         
     | 
| 1074 | 
         
            +
                            weight_name is not None and weight_name.endswith(".safetensors")
         
     | 
| 1075 | 
         
            +
                        ):
         
     | 
| 1076 | 
         
            +
                            try:
         
     | 
| 1077 | 
         
            +
                                model_file = _get_model_file(
         
     | 
| 1078 | 
         
            +
                                    pretrained_model_name_or_path_or_dict,
         
     | 
| 1079 | 
         
            +
                                    weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
         
     | 
| 1080 | 
         
            +
                                    cache_dir=cache_dir,
         
     | 
| 1081 | 
         
            +
                                    force_download=force_download,
         
     | 
| 1082 | 
         
            +
                                    resume_download=resume_download,
         
     | 
| 1083 | 
         
            +
                                    proxies=proxies,
         
     | 
| 1084 | 
         
            +
                                    local_files_only=local_files_only,
         
     | 
| 1085 | 
         
            +
                                    use_auth_token=use_auth_token,
         
     | 
| 1086 | 
         
            +
                                    revision=revision,
         
     | 
| 1087 | 
         
            +
                                    subfolder=subfolder,
         
     | 
| 1088 | 
         
            +
                                    user_agent=user_agent,
         
     | 
| 1089 | 
         
            +
                                )
         
     | 
| 1090 | 
         
            +
                                state_dict = safetensors.torch.load_file(model_file, device="cpu")
         
     | 
| 1091 | 
         
            +
                            except IOError as e:
         
     | 
| 1092 | 
         
            +
                                if not allow_pickle:
         
     | 
| 1093 | 
         
            +
                                    raise e
         
     | 
| 1094 | 
         
            +
                                # try loading non-safetensors weights
         
     | 
| 1095 | 
         
            +
                                pass
         
     | 
| 1096 | 
         
            +
                        if model_file is None:
         
     | 
| 1097 | 
         
            +
                            model_file = _get_model_file(
         
     | 
| 1098 | 
         
            +
                                pretrained_model_name_or_path_or_dict,
         
     | 
| 1099 | 
         
            +
                                weights_name=weight_name or LORA_WEIGHT_NAME,
         
     | 
| 1100 | 
         
            +
                                cache_dir=cache_dir,
         
     | 
| 1101 | 
         
            +
                                force_download=force_download,
         
     | 
| 1102 | 
         
            +
                                resume_download=resume_download,
         
     | 
| 1103 | 
         
            +
                                proxies=proxies,
         
     | 
| 1104 | 
         
            +
                                local_files_only=local_files_only,
         
     | 
| 1105 | 
         
            +
                                use_auth_token=use_auth_token,
         
     | 
| 1106 | 
         
            +
                                revision=revision,
         
     | 
| 1107 | 
         
            +
                                subfolder=subfolder,
         
     | 
| 1108 | 
         
            +
                                user_agent=user_agent,
         
     | 
| 1109 | 
         
            +
                            )
         
     | 
| 1110 | 
         
            +
                            state_dict = torch.load(model_file, map_location="cpu")
         
     | 
| 1111 | 
         
            +
                    else:
         
     | 
| 1112 | 
         
            +
                        state_dict = pretrained_model_name_or_path_or_dict
         
     | 
| 1113 | 
         
            +
             
     | 
| 1114 | 
         
            +
                    # fill attn processors
         
     | 
| 1115 | 
         
            +
                    attn_processors = {}
         
     | 
| 1116 | 
         
            +
             
     | 
| 1117 | 
         
            +
                    is_lora = all("lora" in k for k in state_dict.keys())
         
     | 
| 1118 | 
         
            +
             
     | 
| 1119 | 
         
            +
                    if is_lora:
         
     | 
| 1120 | 
         
            +
                        lora_grouped_dict = defaultdict(dict)
         
     | 
| 1121 | 
         
            +
                        for key, value in state_dict.items():
         
     | 
| 1122 | 
         
            +
                            attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
         
     | 
| 1123 | 
         
            +
                            lora_grouped_dict[attn_processor_key][sub_key] = value
         
     | 
| 1124 | 
         
            +
             
     | 
| 1125 | 
         
            +
                        for key, value_dict in lora_grouped_dict.items():
         
     | 
| 1126 | 
         
            +
                            rank = value_dict["to_k_lora.down.weight"].shape[0]
         
     | 
| 1127 | 
         
            +
                            cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
         
     | 
| 1128 | 
         
            +
                            hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
         
     | 
| 1129 | 
         
            +
             
     | 
| 1130 | 
         
            +
                            attn_processor_class = (
         
     | 
| 1131 | 
         
            +
                                LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
         
     | 
| 1132 | 
         
            +
                            )
         
     | 
| 1133 | 
         
            +
                            attn_processors[key] = attn_processor_class(
         
     | 
| 1134 | 
         
            +
                                hidden_size=hidden_size,
         
     | 
| 1135 | 
         
            +
                                cross_attention_dim=cross_attention_dim,
         
     | 
| 1136 | 
         
            +
                                rank=rank,
         
     | 
| 1137 | 
         
            +
                                network_alpha=network_alpha,
         
     | 
| 1138 | 
         
            +
                            )
         
     | 
| 1139 | 
         
            +
                            attn_processors[key].load_state_dict(value_dict)
         
     | 
| 1140 | 
         
            +
             
     | 
| 1141 | 
         
            +
                    else:
         
     | 
| 1142 | 
         
            +
                        raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
         
     | 
| 1143 | 
         
            +
             
     | 
| 1144 | 
         
            +
                    # set correct dtype & device
         
     | 
| 1145 | 
         
            +
                    attn_processors = {
         
     | 
| 1146 | 
         
            +
                        k: v.to(device=self.device, dtype=self.text_encoder.dtype) for k, v in attn_processors.items()
         
     | 
| 1147 | 
         
            +
                    }
         
     | 
| 1148 | 
         
            +
                    return attn_processors
         
     | 
| 1149 | 
         
            +
             
     | 
| 1150 | 
         
            +
                @classmethod
         
     | 
| 1151 | 
         
            +
                def save_lora_weights(
         
     | 
| 1152 | 
         
            +
                    self,
         
     | 
| 1153 | 
         
            +
                    save_directory: Union[str, os.PathLike],
         
     | 
| 1154 | 
         
            +
                    unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
         
     | 
| 1155 | 
         
            +
                    text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
         
     | 
| 1156 | 
         
            +
                    is_main_process: bool = True,
         
     | 
| 1157 | 
         
            +
                    weight_name: str = None,
         
     | 
| 1158 | 
         
            +
                    save_function: Callable = None,
         
     | 
| 1159 | 
         
            +
                    safe_serialization: bool = False,
         
     | 
| 1160 | 
         
            +
                ):
         
     | 
| 1161 | 
         
            +
                    r"""
         
     | 
| 1162 | 
         
            +
                    Save the LoRA parameters corresponding to the UNet and text encoder.
         
     | 
| 1163 | 
         
            +
             
     | 
| 1164 | 
         
            +
                    Arguments:
         
     | 
| 1165 | 
         
            +
                        save_directory (`str` or `os.PathLike`):
         
     | 
| 1166 | 
         
            +
                            Directory to save LoRA parameters to. Will be created if it doesn't exist.
         
     | 
| 1167 | 
         
            +
                        unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
         
     | 
| 1168 | 
         
            +
                            State dict of the LoRA layers corresponding to the UNet.
         
     | 
| 1169 | 
         
            +
                        text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`):
         
     | 
| 1170 | 
         
            +
                            State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
         
     | 
| 1171 | 
         
            +
                            encoder LoRA state dict because it comes 🤗 Transformers.
         
     | 
| 1172 | 
         
            +
                        is_main_process (`bool`, *optional*, defaults to `True`):
         
     | 
| 1173 | 
         
            +
                            Whether the process calling this is the main process or not. Useful during distributed training and you
         
     | 
| 1174 | 
         
            +
                            need to call this function on all processes. In this case, set `is_main_process=True` only on the main
         
     | 
| 1175 | 
         
            +
                            process to avoid race conditions.
         
     | 
| 1176 | 
         
            +
                        save_function (`Callable`):
         
     | 
| 1177 | 
         
            +
                            The function to use to save the state dictionary. Useful during distributed training when you need to
         
     | 
| 1178 | 
         
            +
                            replace `torch.save` with another method. Can be configured with the environment variable
         
     | 
| 1179 | 
         
            +
                            `DIFFUSERS_SAVE_MODE`.
         
     | 
| 1180 | 
         
            +
                    """
         
     | 
| 1181 | 
         
            +
                    if os.path.isfile(save_directory):
         
     | 
| 1182 | 
         
            +
                        logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
         
     | 
| 1183 | 
         
            +
                        return
         
     | 
| 1184 | 
         
            +
             
     | 
| 1185 | 
         
            +
                    if save_function is None:
         
     | 
| 1186 | 
         
            +
                        if safe_serialization:
         
     | 
| 1187 | 
         
            +
             
     | 
| 1188 | 
         
            +
                            def save_function(weights, filename):
         
     | 
| 1189 | 
         
            +
                                return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
         
     | 
| 1190 | 
         
            +
             
     | 
| 1191 | 
         
            +
                        else:
         
     | 
| 1192 | 
         
            +
                            save_function = torch.save
         
     | 
| 1193 | 
         
            +
             
     | 
| 1194 | 
         
            +
                    os.makedirs(save_directory, exist_ok=True)
         
     | 
| 1195 | 
         
            +
             
     | 
| 1196 | 
         
            +
                    # Create a flat dictionary.
         
     | 
| 1197 | 
         
            +
                    state_dict = {}
         
     | 
| 1198 | 
         
            +
                    if unet_lora_layers is not None:
         
     | 
| 1199 | 
         
            +
                        weights = (
         
     | 
| 1200 | 
         
            +
                            unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
         
     | 
| 1201 | 
         
            +
                        )
         
     | 
| 1202 | 
         
            +
             
     | 
| 1203 | 
         
            +
                        unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()}
         
     | 
| 1204 | 
         
            +
                        state_dict.update(unet_lora_state_dict)
         
     | 
| 1205 | 
         
            +
             
     | 
| 1206 | 
         
            +
                    if text_encoder_lora_layers is not None:
         
     | 
| 1207 | 
         
            +
                        weights = (
         
     | 
| 1208 | 
         
            +
                            text_encoder_lora_layers.state_dict()
         
     | 
| 1209 | 
         
            +
                            if isinstance(text_encoder_lora_layers, torch.nn.Module)
         
     | 
| 1210 | 
         
            +
                            else text_encoder_lora_layers
         
     | 
| 1211 | 
         
            +
                        )
         
     | 
| 1212 | 
         
            +
             
     | 
| 1213 | 
         
            +
                        text_encoder_lora_state_dict = {
         
     | 
| 1214 | 
         
            +
                            f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
         
     | 
| 1215 | 
         
            +
                        }
         
     | 
| 1216 | 
         
            +
                        state_dict.update(text_encoder_lora_state_dict)
         
     | 
| 1217 | 
         
            +
             
     | 
| 1218 | 
         
            +
                    # Save the model
         
     | 
| 1219 | 
         
            +
                    if weight_name is None:
         
     | 
| 1220 | 
         
            +
                        if safe_serialization:
         
     | 
| 1221 | 
         
            +
                            weight_name = LORA_WEIGHT_NAME_SAFE
         
     | 
| 1222 | 
         
            +
                        else:
         
     | 
| 1223 | 
         
            +
                            weight_name = LORA_WEIGHT_NAME
         
     | 
| 1224 | 
         
            +
             
     | 
| 1225 | 
         
            +
                    save_function(state_dict, os.path.join(save_directory, weight_name))
         
     | 
| 1226 | 
         
            +
                    logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
         
     | 
| 1227 | 
         
            +
             
     | 
| 1228 | 
         
            +
                def _convert_kohya_lora_to_diffusers(self, state_dict):
         
     | 
| 1229 | 
         
            +
                    unet_state_dict = {}
         
     | 
| 1230 | 
         
            +
                    te_state_dict = {}
         
     | 
| 1231 | 
         
            +
                    network_alpha = None
         
     | 
| 1232 | 
         
            +
             
     | 
| 1233 | 
         
            +
                    for key, value in state_dict.items():
         
     | 
| 1234 | 
         
            +
                        if "lora_down" in key:
         
     | 
| 1235 | 
         
            +
                            lora_name = key.split(".")[0]
         
     | 
| 1236 | 
         
            +
                            lora_name_up = lora_name + ".lora_up.weight"
         
     | 
| 1237 | 
         
            +
                            lora_name_alpha = lora_name + ".alpha"
         
     | 
| 1238 | 
         
            +
                            if lora_name_alpha in state_dict:
         
     | 
| 1239 | 
         
            +
                                alpha = state_dict[lora_name_alpha].item()
         
     | 
| 1240 | 
         
            +
                                if network_alpha is None:
         
     | 
| 1241 | 
         
            +
                                    network_alpha = alpha
         
     | 
| 1242 | 
         
            +
                                elif network_alpha != alpha:
         
     | 
| 1243 | 
         
            +
                                    raise ValueError("Network alpha is not consistent")
         
     | 
| 1244 | 
         
            +
             
     | 
| 1245 | 
         
            +
                            if lora_name.startswith("lora_unet_"):
         
     | 
| 1246 | 
         
            +
                                diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
         
     | 
| 1247 | 
         
            +
                                diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
         
     | 
| 1248 | 
         
            +
                                diffusers_name = diffusers_name.replace("mid.block", "mid_block")
         
     | 
| 1249 | 
         
            +
                                diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
         
     | 
| 1250 | 
         
            +
                                diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
         
     | 
| 1251 | 
         
            +
                                diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
         
     | 
| 1252 | 
         
            +
                                diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
         
     | 
| 1253 | 
         
            +
                                diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
         
     | 
| 1254 | 
         
            +
                                diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
         
     | 
| 1255 | 
         
            +
                                if "transformer_blocks" in diffusers_name:
         
     | 
| 1256 | 
         
            +
                                    if "attn1" in diffusers_name or "attn2" in diffusers_name:
         
     | 
| 1257 | 
         
            +
                                        diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
         
     | 
| 1258 | 
         
            +
                                        diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
         
     | 
| 1259 | 
         
            +
                                        unet_state_dict[diffusers_name] = value
         
     | 
| 1260 | 
         
            +
                                        unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
         
     | 
| 1261 | 
         
            +
                            elif lora_name.startswith("lora_te_"):
         
     | 
| 1262 | 
         
            +
                                diffusers_name = key.replace("lora_te_", "").replace("_", ".")
         
     | 
| 1263 | 
         
            +
                                diffusers_name = diffusers_name.replace("text.model", "text_model")
         
     | 
| 1264 | 
         
            +
                                diffusers_name = diffusers_name.replace("self.attn", "self_attn")
         
     | 
| 1265 | 
         
            +
                                diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
         
     | 
| 1266 | 
         
            +
                                diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
         
     | 
| 1267 | 
         
            +
                                diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
         
     | 
| 1268 | 
         
            +
                                diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
         
     | 
| 1269 | 
         
            +
                                if "self_attn" in diffusers_name:
         
     | 
| 1270 | 
         
            +
                                    te_state_dict[diffusers_name] = value
         
     | 
| 1271 | 
         
            +
                                    te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
         
     | 
| 1272 | 
         
            +
             
     | 
| 1273 | 
         
            +
                    unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
         
     | 
| 1274 | 
         
            +
                    te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
         
     | 
| 1275 | 
         
            +
                    new_state_dict = {**unet_state_dict, **te_state_dict}
         
     | 
| 1276 | 
         
            +
                    return new_state_dict, network_alpha
         
     | 
| 1277 | 
         
            +
             
     | 
| 1278 | 
         
            +
             
     | 
| 1279 | 
         
            +
            class FromSingleFileMixin:
         
     | 
| 1280 | 
         
            +
                """
         
     | 
| 1281 | 
         
            +
                Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
         
     | 
| 1282 | 
         
            +
                """
         
     | 
| 1283 | 
         
            +
             
     | 
| 1284 | 
         
            +
                @classmethod
         
     | 
| 1285 | 
         
            +
                def from_ckpt(cls, *args, **kwargs):
         
     | 
| 1286 | 
         
            +
                    deprecation_message = "The function `from_ckpt` is deprecated in favor of `from_single_file` and will be removed in diffusers v.0.21. Please make sure to use `StableDiffusionPipeline.from_single_file(...)` instead."
         
     | 
| 1287 | 
         
            +
                    deprecate("from_ckpt", "0.21.0", deprecation_message, standard_warn=False)
         
     | 
| 1288 | 
         
            +
                    return cls.from_single_file(*args, **kwargs)
         
     | 
| 1289 | 
         
            +
             
     | 
| 1290 | 
         
            +
                @classmethod
         
     | 
| 1291 | 
         
            +
                def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
         
     | 
| 1292 | 
         
            +
                    r"""
         
     | 
| 1293 | 
         
            +
                    Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` format. The pipeline
         
     | 
| 1294 | 
         
            +
                    is set in evaluation mode (`model.eval()`) by default.
         
     | 
| 1295 | 
         
            +
             
     | 
| 1296 | 
         
            +
                    Parameters:
         
     | 
| 1297 | 
         
            +
                        pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
         
     | 
| 1298 | 
         
            +
                            Can be either:
         
     | 
| 1299 | 
         
            +
                                - A link to the `.ckpt` file (for example
         
     | 
| 1300 | 
         
            +
                                  `"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
         
     | 
| 1301 | 
         
            +
                                - A path to a *file* containing all pipeline weights.
         
     | 
| 1302 | 
         
            +
                        torch_dtype (`str` or `torch.dtype`, *optional*):
         
     | 
| 1303 | 
         
            +
                            Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
         
     | 
| 1304 | 
         
            +
                            dtype is automatically derived from the model's weights.
         
     | 
| 1305 | 
         
            +
                        force_download (`bool`, *optional*, defaults to `False`):
         
     | 
| 1306 | 
         
            +
                            Whether or not to force the (re-)download of the model weights and configuration files, overriding the
         
     | 
| 1307 | 
         
            +
                            cached versions if they exist.
         
     | 
| 1308 | 
         
            +
                        cache_dir (`Union[str, os.PathLike]`, *optional*):
         
     | 
| 1309 | 
         
            +
                            Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
         
     | 
| 1310 | 
         
            +
                            is not used.
         
     | 
| 1311 | 
         
            +
                        resume_download (`bool`, *optional*, defaults to `False`):
         
     | 
| 1312 | 
         
            +
                            Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
         
     | 
| 1313 | 
         
            +
                            incompletely downloaded files are deleted.
         
     | 
| 1314 | 
         
            +
                        proxies (`Dict[str, str]`, *optional*):
         
     | 
| 1315 | 
         
            +
                            A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
         
     | 
| 1316 | 
         
            +
                            'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
         
     | 
| 1317 | 
         
            +
                        local_files_only (`bool`, *optional*, defaults to `False`):
         
     | 
| 1318 | 
         
            +
                            Whether to only load local model weights and configuration files or not. If set to True, the model
         
     | 
| 1319 | 
         
            +
                            won't be downloaded from the Hub.
         
     | 
| 1320 | 
         
            +
                        use_auth_token (`str` or *bool*, *optional*):
         
     | 
| 1321 | 
         
            +
                            The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
         
     | 
| 1322 | 
         
            +
                            `diffusers-cli login` (stored in `~/.huggingface`) is used.
         
     | 
| 1323 | 
         
            +
                        revision (`str`, *optional*, defaults to `"main"`):
         
     | 
| 1324 | 
         
            +
                            The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
         
     | 
| 1325 | 
         
            +
                            allowed by Git.
         
     | 
| 1326 | 
         
            +
                        use_safetensors (`bool`, *optional*, defaults to `None`):
         
     | 
| 1327 | 
         
            +
                            If set to `None`, the safetensors weights are downloaded if they're available **and** if the
         
     | 
| 1328 | 
         
            +
                            safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
         
     | 
| 1329 | 
         
            +
                            weights. If set to `False`, safetensors weights are not loaded.
         
     | 
| 1330 | 
         
            +
                        extract_ema (`bool`, *optional*, defaults to `False`):
         
     | 
| 1331 | 
         
            +
                            Whether to extract the EMA weights or not. Pass `True` to extract the EMA weights which usually yield
         
     | 
| 1332 | 
         
            +
                            higher quality images for inference. Non-EMA weights are usually better to continue finetuning.
         
     | 
| 1333 | 
         
            +
                        upcast_attention (`bool`, *optional*, defaults to `None`):
         
     | 
| 1334 | 
         
            +
                            Whether the attention computation should always be upcasted.
         
     | 
| 1335 | 
         
            +
                        image_size (`int`, *optional*, defaults to 512):
         
     | 
| 1336 | 
         
            +
                            The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
         
     | 
| 1337 | 
         
            +
                            Diffusion v2 base model. Use 768 for Stable Diffusion v2.
         
     | 
| 1338 | 
         
            +
                        prediction_type (`str`, *optional*):
         
     | 
| 1339 | 
         
            +
                            The prediction type the model was trained on. Use `'epsilon'` for all Stable Diffusion v1 models and
         
     | 
| 1340 | 
         
            +
                            the Stable Diffusion v2 base model. Use `'v_prediction'` for Stable Diffusion v2.
         
     | 
| 1341 | 
         
            +
                        num_in_channels (`int`, *optional*, defaults to `None`):
         
     | 
| 1342 | 
         
            +
                            The number of input channels. If `None`, it will be automatically inferred.
         
     | 
| 1343 | 
         
            +
                        scheduler_type (`str`, *optional*, defaults to `"pndm"`):
         
     | 
| 1344 | 
         
            +
                            Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
         
     | 
| 1345 | 
         
            +
                            "ddim"]`.
         
     | 
| 1346 | 
         
            +
                        load_safety_checker (`bool`, *optional*, defaults to `True`):
         
     | 
| 1347 | 
         
            +
                            Whether to load the safety checker or not.
         
     | 
| 1348 | 
         
            +
                        text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):
         
     | 
| 1349 | 
         
            +
                            An instance of
         
     | 
| 1350 | 
         
            +
                            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to use,
         
     | 
| 1351 | 
         
            +
                            specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
         
     | 
| 1352 | 
         
            +
                            variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if
         
     | 
| 1353 | 
         
            +
                            needed.
         
     | 
| 1354 | 
         
            +
                        tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`):
         
     | 
| 1355 | 
         
            +
                            An instance of
         
     | 
| 1356 | 
         
            +
                            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
         
     | 
| 1357 | 
         
            +
                            to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by
         
     | 
| 1358 | 
         
            +
                            itself, if needed.
         
     | 
| 1359 | 
         
            +
                        kwargs (remaining dictionary of keyword arguments, *optional*):
         
     | 
| 1360 | 
         
            +
                            Can be used to overwrite load and saveable variables (for example the pipeline components of the
         
     | 
| 1361 | 
         
            +
                            specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
         
     | 
| 1362 | 
         
            +
                            method. See example below for more information.
         
     | 
| 1363 | 
         
            +
             
     | 
| 1364 | 
         
            +
                    Examples:
         
     | 
| 1365 | 
         
            +
             
     | 
| 1366 | 
         
            +
                    ```py
         
     | 
| 1367 | 
         
            +
                    >>> from diffusers import StableDiffusionPipeline
         
     | 
| 1368 | 
         
            +
             
     | 
| 1369 | 
         
            +
                    >>> # Download pipeline from huggingface.co and cache.
         
     | 
| 1370 | 
         
            +
                    >>> pipeline = StableDiffusionPipeline.from_single_file(
         
     | 
| 1371 | 
         
            +
                    ...     "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
         
     | 
| 1372 | 
         
            +
                    ... )
         
     | 
| 1373 | 
         
            +
             
     | 
| 1374 | 
         
            +
                    >>> # Download pipeline from local file
         
     | 
| 1375 | 
         
            +
                    >>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
         
     | 
| 1376 | 
         
            +
                    >>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly")
         
     | 
| 1377 | 
         
            +
             
     | 
| 1378 | 
         
            +
                    >>> # Enable float16 and move to GPU
         
     | 
| 1379 | 
         
            +
                    >>> pipeline = StableDiffusionPipeline.from_single_file(
         
     | 
| 1380 | 
         
            +
                    ...     "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
         
     | 
| 1381 | 
         
            +
                    ...     torch_dtype=torch.float16,
         
     | 
| 1382 | 
         
            +
                    ... )
         
     | 
| 1383 | 
         
            +
                    >>> pipeline.to("cuda")
         
     | 
| 1384 | 
         
            +
                    ```
         
     | 
| 1385 | 
         
            +
                    """
         
     | 
| 1386 | 
         
            +
                    # import here to avoid circular dependency
         
     | 
| 1387 | 
         
            +
                    from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
         
     | 
| 1388 | 
         
            +
             
     | 
| 1389 | 
         
            +
                    cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
         
     | 
| 1390 | 
         
            +
                    resume_download = kwargs.pop("resume_download", False)
         
     | 
| 1391 | 
         
            +
                    force_download = kwargs.pop("force_download", False)
         
     | 
| 1392 | 
         
            +
                    proxies = kwargs.pop("proxies", None)
         
     | 
| 1393 | 
         
            +
                    local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
         
     | 
| 1394 | 
         
            +
                    use_auth_token = kwargs.pop("use_auth_token", None)
         
     | 
| 1395 | 
         
            +
                    revision = kwargs.pop("revision", None)
         
     | 
| 1396 | 
         
            +
                    extract_ema = kwargs.pop("extract_ema", False)
         
     | 
| 1397 | 
         
            +
                    image_size = kwargs.pop("image_size", None)
         
     | 
| 1398 | 
         
            +
                    scheduler_type = kwargs.pop("scheduler_type", "pndm")
         
     | 
| 1399 | 
         
            +
                    num_in_channels = kwargs.pop("num_in_channels", None)
         
     | 
| 1400 | 
         
            +
                    upcast_attention = kwargs.pop("upcast_attention", None)
         
     | 
| 1401 | 
         
            +
                    load_safety_checker = kwargs.pop("load_safety_checker", True)
         
     | 
| 1402 | 
         
            +
                    prediction_type = kwargs.pop("prediction_type", None)
         
     | 
| 1403 | 
         
            +
                    text_encoder = kwargs.pop("text_encoder", None)
         
     | 
| 1404 | 
         
            +
                    tokenizer = kwargs.pop("tokenizer", None)
         
     | 
| 1405 | 
         
            +
             
     | 
| 1406 | 
         
            +
                    torch_dtype = kwargs.pop("torch_dtype", None)
         
     | 
| 1407 | 
         
            +
             
     | 
| 1408 | 
         
            +
                    use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
         
     | 
| 1409 | 
         
            +
             
     | 
| 1410 | 
         
            +
                    pipeline_name = cls.__name__
         
     | 
| 1411 | 
         
            +
                    file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
         
     | 
| 1412 | 
         
            +
                    from_safetensors = file_extension == "safetensors"
         
     | 
| 1413 | 
         
            +
             
     | 
| 1414 | 
         
            +
                    if from_safetensors and use_safetensors is False:
         
     | 
| 1415 | 
         
            +
                        raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
         
     | 
| 1416 | 
         
            +
             
     | 
| 1417 | 
         
            +
                    # TODO: For now we only support stable diffusion
         
     | 
| 1418 | 
         
            +
                    stable_unclip = None
         
     | 
| 1419 | 
         
            +
                    model_type = None
         
     | 
| 1420 | 
         
            +
                    controlnet = False
         
     | 
| 1421 | 
         
            +
             
     | 
| 1422 | 
         
            +
                    if pipeline_name == "StableDiffusionControlNetPipeline":
         
     | 
| 1423 | 
         
            +
                        # Model type will be inferred from the checkpoint.
         
     | 
| 1424 | 
         
            +
                        controlnet = True
         
     | 
| 1425 | 
         
            +
                    elif "StableDiffusion" in pipeline_name:
         
     | 
| 1426 | 
         
            +
                        # Model type will be inferred from the checkpoint.
         
     | 
| 1427 | 
         
            +
                        pass
         
     | 
| 1428 | 
         
            +
                    elif pipeline_name == "StableUnCLIPPipeline":
         
     | 
| 1429 | 
         
            +
                        model_type = "FrozenOpenCLIPEmbedder"
         
     | 
| 1430 | 
         
            +
                        stable_unclip = "txt2img"
         
     | 
| 1431 | 
         
            +
                    elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
         
     | 
| 1432 | 
         
            +
                        model_type = "FrozenOpenCLIPEmbedder"
         
     | 
| 1433 | 
         
            +
                        stable_unclip = "img2img"
         
     | 
| 1434 | 
         
            +
                    elif pipeline_name == "PaintByExamplePipeline":
         
     | 
| 1435 | 
         
            +
                        model_type = "PaintByExample"
         
     | 
| 1436 | 
         
            +
                    elif pipeline_name == "LDMTextToImagePipeline":
         
     | 
| 1437 | 
         
            +
                        model_type = "LDMTextToImage"
         
     | 
| 1438 | 
         
            +
                    else:
         
     | 
| 1439 | 
         
            +
                        raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
         
     | 
| 1440 | 
         
            +
             
     | 
| 1441 | 
         
            +
                    # remove huggingface url
         
     | 
| 1442 | 
         
            +
                    for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
         
     | 
| 1443 | 
         
            +
                        if pretrained_model_link_or_path.startswith(prefix):
         
     | 
| 1444 | 
         
            +
                            pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
         
     | 
| 1445 | 
         
            +
             
     | 
| 1446 | 
         
            +
                    # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
         
     | 
| 1447 | 
         
            +
                    ckpt_path = Path(pretrained_model_link_or_path)
         
     | 
| 1448 | 
         
            +
                    if not ckpt_path.is_file():
         
     | 
| 1449 | 
         
            +
                        # get repo_id and (potentially nested) file path of ckpt in repo
         
     | 
| 1450 | 
         
            +
                        repo_id = "/".join(ckpt_path.parts[:2])
         
     | 
| 1451 | 
         
            +
                        file_path = "/".join(ckpt_path.parts[2:])
         
     | 
| 1452 | 
         
            +
             
     | 
| 1453 | 
         
            +
                        if file_path.startswith("blob/"):
         
     | 
| 1454 | 
         
            +
                            file_path = file_path[len("blob/") :]
         
     | 
| 1455 | 
         
            +
             
     | 
| 1456 | 
         
            +
                        if file_path.startswith("main/"):
         
     | 
| 1457 | 
         
            +
                            file_path = file_path[len("main/") :]
         
     | 
| 1458 | 
         
            +
             
     | 
| 1459 | 
         
            +
                        pretrained_model_link_or_path = hf_hub_download(
         
     | 
| 1460 | 
         
            +
                            repo_id,
         
     | 
| 1461 | 
         
            +
                            filename=file_path,
         
     | 
| 1462 | 
         
            +
                            cache_dir=cache_dir,
         
     | 
| 1463 | 
         
            +
                            resume_download=resume_download,
         
     | 
| 1464 | 
         
            +
                            proxies=proxies,
         
     | 
| 1465 | 
         
            +
                            local_files_only=local_files_only,
         
     | 
| 1466 | 
         
            +
                            use_auth_token=use_auth_token,
         
     | 
| 1467 | 
         
            +
                            revision=revision,
         
     | 
| 1468 | 
         
            +
                            force_download=force_download,
         
     | 
| 1469 | 
         
            +
                        )
         
     | 
| 1470 | 
         
            +
             
     | 
| 1471 | 
         
            +
                    pipe = download_from_original_stable_diffusion_ckpt(
         
     | 
| 1472 | 
         
            +
                        pretrained_model_link_or_path,
         
     | 
| 1473 | 
         
            +
                        pipeline_class=cls,
         
     | 
| 1474 | 
         
            +
                        model_type=model_type,
         
     | 
| 1475 | 
         
            +
                        stable_unclip=stable_unclip,
         
     | 
| 1476 | 
         
            +
                        controlnet=controlnet,
         
     | 
| 1477 | 
         
            +
                        from_safetensors=from_safetensors,
         
     | 
| 1478 | 
         
            +
                        extract_ema=extract_ema,
         
     | 
| 1479 | 
         
            +
                        image_size=image_size,
         
     | 
| 1480 | 
         
            +
                        scheduler_type=scheduler_type,
         
     | 
| 1481 | 
         
            +
                        num_in_channels=num_in_channels,
         
     | 
| 1482 | 
         
            +
                        upcast_attention=upcast_attention,
         
     | 
| 1483 | 
         
            +
                        load_safety_checker=load_safety_checker,
         
     | 
| 1484 | 
         
            +
                        prediction_type=prediction_type,
         
     | 
| 1485 | 
         
            +
                        text_encoder=text_encoder,
         
     | 
| 1486 | 
         
            +
                        tokenizer=tokenizer,
         
     | 
| 1487 | 
         
            +
                    )
         
     | 
| 1488 | 
         
            +
             
     | 
| 1489 | 
         
            +
                    if torch_dtype is not None:
         
     | 
| 1490 | 
         
            +
                        pipe.to(torch_dtype=torch_dtype)
         
     | 
| 1491 | 
         
            +
             
     | 
| 1492 | 
         
            +
                    return pipe
         
     | 
    	
        6DoF/diffusers/models/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,35 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            from ..utils import is_flax_available, is_torch_available
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            if is_torch_available():
         
     | 
| 19 | 
         
            +
                from .autoencoder_kl import AutoencoderKL
         
     | 
| 20 | 
         
            +
                from .controlnet import ControlNetModel
         
     | 
| 21 | 
         
            +
                from .dual_transformer_2d import DualTransformer2DModel
         
     | 
| 22 | 
         
            +
                from .modeling_utils import ModelMixin
         
     | 
| 23 | 
         
            +
                from .prior_transformer import PriorTransformer
         
     | 
| 24 | 
         
            +
                from .t5_film_transformer import T5FilmDecoder
         
     | 
| 25 | 
         
            +
                from .transformer_2d import Transformer2DModel
         
     | 
| 26 | 
         
            +
                from .unet_1d import UNet1DModel
         
     | 
| 27 | 
         
            +
                from .unet_2d import UNet2DModel
         
     | 
| 28 | 
         
            +
                from .unet_2d_condition import UNet2DConditionModel
         
     | 
| 29 | 
         
            +
                from .unet_3d_condition import UNet3DConditionModel
         
     | 
| 30 | 
         
            +
                from .vq_model import VQModel
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            if is_flax_available():
         
     | 
| 33 | 
         
            +
                from .controlnet_flax import FlaxControlNetModel
         
     | 
| 34 | 
         
            +
                from .unet_2d_condition_flax import FlaxUNet2DConditionModel
         
     | 
| 35 | 
         
            +
                from .vae_flax import FlaxAutoencoderKL
         
     | 
    	
        6DoF/diffusers/models/activations.py
    ADDED
    
    | 
         @@ -0,0 +1,12 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from torch import nn
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            def get_activation(act_fn):
         
     | 
| 5 | 
         
            +
                if act_fn in ["swish", "silu"]:
         
     | 
| 6 | 
         
            +
                    return nn.SiLU()
         
     | 
| 7 | 
         
            +
                elif act_fn == "mish":
         
     | 
| 8 | 
         
            +
                    return nn.Mish()
         
     | 
| 9 | 
         
            +
                elif act_fn == "gelu":
         
     | 
| 10 | 
         
            +
                    return nn.GELU()
         
     | 
| 11 | 
         
            +
                else:
         
     | 
| 12 | 
         
            +
                    raise ValueError(f"Unsupported activation function: {act_fn}")
         
     | 
    	
        6DoF/diffusers/models/attention.py
    ADDED
    
    | 
         @@ -0,0 +1,392 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            from typing import Any, Dict, Optional
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            import torch
         
     | 
| 17 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 18 | 
         
            +
            from torch import nn
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            from ..utils import maybe_allow_in_graph
         
     | 
| 21 | 
         
            +
            from .activations import get_activation
         
     | 
| 22 | 
         
            +
            from .attention_processor import Attention
         
     | 
| 23 | 
         
            +
            from .embeddings import CombinedTimestepLabelEmbeddings
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            @maybe_allow_in_graph
         
     | 
| 27 | 
         
            +
            class BasicTransformerBlock(nn.Module):
         
     | 
| 28 | 
         
            +
                r"""
         
     | 
| 29 | 
         
            +
                A basic Transformer block.
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                Parameters:
         
     | 
| 32 | 
         
            +
                    dim (`int`): The number of channels in the input and output.
         
     | 
| 33 | 
         
            +
                    num_attention_heads (`int`): The number of heads to use for multi-head attention.
         
     | 
| 34 | 
         
            +
                    attention_head_dim (`int`): The number of channels in each head.
         
     | 
| 35 | 
         
            +
                    dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
         
     | 
| 36 | 
         
            +
                    cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
         
     | 
| 37 | 
         
            +
                    only_cross_attention (`bool`, *optional*):
         
     | 
| 38 | 
         
            +
                        Whether to use only cross-attention layers. In this case two cross attention layers are used.
         
     | 
| 39 | 
         
            +
                    double_self_attention (`bool`, *optional*):
         
     | 
| 40 | 
         
            +
                        Whether to use two self-attention layers. In this case no cross attention layers are used.
         
     | 
| 41 | 
         
            +
                    activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
         
     | 
| 42 | 
         
            +
                    num_embeds_ada_norm (:
         
     | 
| 43 | 
         
            +
                        obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
         
     | 
| 44 | 
         
            +
                    attention_bias (:
         
     | 
| 45 | 
         
            +
                        obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
         
     | 
| 46 | 
         
            +
                """
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                def __init__(
         
     | 
| 49 | 
         
            +
                    self,
         
     | 
| 50 | 
         
            +
                    dim: int,
         
     | 
| 51 | 
         
            +
                    num_attention_heads: int,
         
     | 
| 52 | 
         
            +
                    attention_head_dim: int,
         
     | 
| 53 | 
         
            +
                    dropout=0.0,
         
     | 
| 54 | 
         
            +
                    cross_attention_dim: Optional[int] = None,
         
     | 
| 55 | 
         
            +
                    activation_fn: str = "geglu",
         
     | 
| 56 | 
         
            +
                    num_embeds_ada_norm: Optional[int] = None,
         
     | 
| 57 | 
         
            +
                    attention_bias: bool = False,
         
     | 
| 58 | 
         
            +
                    only_cross_attention: bool = False,
         
     | 
| 59 | 
         
            +
                    double_self_attention: bool = False,
         
     | 
| 60 | 
         
            +
                    upcast_attention: bool = False,
         
     | 
| 61 | 
         
            +
                    norm_elementwise_affine: bool = True,
         
     | 
| 62 | 
         
            +
                    norm_type: str = "layer_norm",
         
     | 
| 63 | 
         
            +
                    final_dropout: bool = False,
         
     | 
| 64 | 
         
            +
                ):
         
     | 
| 65 | 
         
            +
                    super().__init__()
         
     | 
| 66 | 
         
            +
                    self.only_cross_attention = only_cross_attention
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                    self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
         
     | 
| 69 | 
         
            +
                    self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
         
     | 
| 72 | 
         
            +
                        raise ValueError(
         
     | 
| 73 | 
         
            +
                            f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
         
     | 
| 74 | 
         
            +
                            f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
         
     | 
| 75 | 
         
            +
                        )
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                    # Define 3 blocks. Each block has its own normalization layer.
         
     | 
| 78 | 
         
            +
                    # 1. Self-Attn
         
     | 
| 79 | 
         
            +
                    if self.use_ada_layer_norm:
         
     | 
| 80 | 
         
            +
                        self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
         
     | 
| 81 | 
         
            +
                    elif self.use_ada_layer_norm_zero:
         
     | 
| 82 | 
         
            +
                        self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
         
     | 
| 83 | 
         
            +
                    else:
         
     | 
| 84 | 
         
            +
                        self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
         
     | 
| 85 | 
         
            +
                    self.attn1 = Attention(
         
     | 
| 86 | 
         
            +
                        query_dim=dim,
         
     | 
| 87 | 
         
            +
                        heads=num_attention_heads,
         
     | 
| 88 | 
         
            +
                        dim_head=attention_head_dim,
         
     | 
| 89 | 
         
            +
                        dropout=dropout,
         
     | 
| 90 | 
         
            +
                        bias=attention_bias,
         
     | 
| 91 | 
         
            +
                        cross_attention_dim=cross_attention_dim if only_cross_attention else None,
         
     | 
| 92 | 
         
            +
                        upcast_attention=upcast_attention,
         
     | 
| 93 | 
         
            +
                    )
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                    # 2. Cross-Attn
         
     | 
| 96 | 
         
            +
                    if cross_attention_dim is not None or double_self_attention:
         
     | 
| 97 | 
         
            +
                        # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
         
     | 
| 98 | 
         
            +
                        # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
         
     | 
| 99 | 
         
            +
                        # the second cross attention block.
         
     | 
| 100 | 
         
            +
                        self.norm2 = (
         
     | 
| 101 | 
         
            +
                            AdaLayerNorm(dim, num_embeds_ada_norm)
         
     | 
| 102 | 
         
            +
                            if self.use_ada_layer_norm
         
     | 
| 103 | 
         
            +
                            else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
         
     | 
| 104 | 
         
            +
                        )
         
     | 
| 105 | 
         
            +
                        self.attn2 = Attention(
         
     | 
| 106 | 
         
            +
                            query_dim=dim,
         
     | 
| 107 | 
         
            +
                            cross_attention_dim=cross_attention_dim if not double_self_attention else None,
         
     | 
| 108 | 
         
            +
                            heads=num_attention_heads,
         
     | 
| 109 | 
         
            +
                            dim_head=attention_head_dim,
         
     | 
| 110 | 
         
            +
                            dropout=dropout,
         
     | 
| 111 | 
         
            +
                            bias=attention_bias,
         
     | 
| 112 | 
         
            +
                            upcast_attention=upcast_attention,
         
     | 
| 113 | 
         
            +
                        )  # is self-attn if encoder_hidden_states is none
         
     | 
| 114 | 
         
            +
                    else:
         
     | 
| 115 | 
         
            +
                        self.norm2 = None
         
     | 
| 116 | 
         
            +
                        self.attn2 = None
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                    # 3. Feed-forward
         
     | 
| 119 | 
         
            +
                    self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
         
     | 
| 120 | 
         
            +
                    self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                    # let chunk size default to None
         
     | 
| 123 | 
         
            +
                    self._chunk_size = None
         
     | 
| 124 | 
         
            +
                    self._chunk_dim = 0
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
         
     | 
| 127 | 
         
            +
                    # Sets chunk feed-forward
         
     | 
| 128 | 
         
            +
                    self._chunk_size = chunk_size
         
     | 
| 129 | 
         
            +
                    self._chunk_dim = dim
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                def forward(
         
     | 
| 132 | 
         
            +
                    self,
         
     | 
| 133 | 
         
            +
                    hidden_states: torch.FloatTensor,
         
     | 
| 134 | 
         
            +
                    attention_mask: Optional[torch.FloatTensor] = None,
         
     | 
| 135 | 
         
            +
                    encoder_hidden_states: Optional[torch.FloatTensor] = None,
         
     | 
| 136 | 
         
            +
                    encoder_attention_mask: Optional[torch.FloatTensor] = None,
         
     | 
| 137 | 
         
            +
                    timestep: Optional[torch.LongTensor] = None,
         
     | 
| 138 | 
         
            +
                    posemb: Optional = None,
         
     | 
| 139 | 
         
            +
                    cross_attention_kwargs: Dict[str, Any] = None,
         
     | 
| 140 | 
         
            +
                    class_labels: Optional[torch.LongTensor] = None,
         
     | 
| 141 | 
         
            +
                ):
         
     | 
| 142 | 
         
            +
                    # Notice that normalization is always applied before the real computation in the following blocks.
         
     | 
| 143 | 
         
            +
                    # 1. Self-Attention
         
     | 
| 144 | 
         
            +
                    if self.use_ada_layer_norm:
         
     | 
| 145 | 
         
            +
                        norm_hidden_states = self.norm1(hidden_states, timestep)
         
     | 
| 146 | 
         
            +
                    elif self.use_ada_layer_norm_zero:
         
     | 
| 147 | 
         
            +
                        norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
         
     | 
| 148 | 
         
            +
                            hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
         
     | 
| 149 | 
         
            +
                        )
         
     | 
| 150 | 
         
            +
                    else:
         
     | 
| 151 | 
         
            +
                        norm_hidden_states = self.norm1(hidden_states)
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                    cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                    attn_output = self.attn1(
         
     | 
| 156 | 
         
            +
                        norm_hidden_states,
         
     | 
| 157 | 
         
            +
                        encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
         
     | 
| 158 | 
         
            +
                        attention_mask=attention_mask,
         
     | 
| 159 | 
         
            +
                        posemb=posemb,  # todo in self attn, posemb shoule be [pose_in, pose_in]?
         
     | 
| 160 | 
         
            +
                        **cross_attention_kwargs,
         
     | 
| 161 | 
         
            +
                    )
         
     | 
| 162 | 
         
            +
                    if self.use_ada_layer_norm_zero:
         
     | 
| 163 | 
         
            +
                        attn_output = gate_msa.unsqueeze(1) * attn_output
         
     | 
| 164 | 
         
            +
                    hidden_states = attn_output + hidden_states
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                    # 2. Cross-Attention
         
     | 
| 167 | 
         
            +
                    if self.attn2 is not None:
         
     | 
| 168 | 
         
            +
                        norm_hidden_states = (
         
     | 
| 169 | 
         
            +
                            self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
         
     | 
| 170 | 
         
            +
                        )
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                        attn_output = self.attn2(
         
     | 
| 173 | 
         
            +
                            norm_hidden_states,
         
     | 
| 174 | 
         
            +
                            encoder_hidden_states=encoder_hidden_states,
         
     | 
| 175 | 
         
            +
                            attention_mask=encoder_attention_mask,
         
     | 
| 176 | 
         
            +
                            posemb=posemb,
         
     | 
| 177 | 
         
            +
                            **cross_attention_kwargs,
         
     | 
| 178 | 
         
            +
                        )
         
     | 
| 179 | 
         
            +
                        hidden_states = attn_output + hidden_states
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                    # 3. Feed-forward
         
     | 
| 182 | 
         
            +
                    norm_hidden_states = self.norm3(hidden_states)
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                    if self.use_ada_layer_norm_zero:
         
     | 
| 185 | 
         
            +
                        norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                    if self._chunk_size is not None:
         
     | 
| 188 | 
         
            +
                        # "feed_forward_chunk_size" can be used to save memory
         
     | 
| 189 | 
         
            +
                        if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
         
     | 
| 190 | 
         
            +
                            raise ValueError(
         
     | 
| 191 | 
         
            +
                                f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
         
     | 
| 192 | 
         
            +
                            )
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                        num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
         
     | 
| 195 | 
         
            +
                        ff_output = torch.cat(
         
     | 
| 196 | 
         
            +
                            [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
         
     | 
| 197 | 
         
            +
                            dim=self._chunk_dim,
         
     | 
| 198 | 
         
            +
                        )
         
     | 
| 199 | 
         
            +
                    else:
         
     | 
| 200 | 
         
            +
                        ff_output = self.ff(norm_hidden_states)
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                    if self.use_ada_layer_norm_zero:
         
     | 
| 203 | 
         
            +
                        ff_output = gate_mlp.unsqueeze(1) * ff_output
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                    hidden_states = ff_output + hidden_states
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                    return hidden_states
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
            class FeedForward(nn.Module):
         
     | 
| 211 | 
         
            +
                r"""
         
     | 
| 212 | 
         
            +
                A feed-forward layer.
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                Parameters:
         
     | 
| 215 | 
         
            +
                    dim (`int`): The number of channels in the input.
         
     | 
| 216 | 
         
            +
                    dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
         
     | 
| 217 | 
         
            +
                    mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
         
     | 
| 218 | 
         
            +
                    dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
         
     | 
| 219 | 
         
            +
                    activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
         
     | 
| 220 | 
         
            +
                    final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
         
     | 
| 221 | 
         
            +
                """
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
                def __init__(
         
     | 
| 224 | 
         
            +
                    self,
         
     | 
| 225 | 
         
            +
                    dim: int,
         
     | 
| 226 | 
         
            +
                    dim_out: Optional[int] = None,
         
     | 
| 227 | 
         
            +
                    mult: int = 4,
         
     | 
| 228 | 
         
            +
                    dropout: float = 0.0,
         
     | 
| 229 | 
         
            +
                    activation_fn: str = "geglu",
         
     | 
| 230 | 
         
            +
                    final_dropout: bool = False,
         
     | 
| 231 | 
         
            +
                ):
         
     | 
| 232 | 
         
            +
                    super().__init__()
         
     | 
| 233 | 
         
            +
                    inner_dim = int(dim * mult)
         
     | 
| 234 | 
         
            +
                    dim_out = dim_out if dim_out is not None else dim
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
                    if activation_fn == "gelu":
         
     | 
| 237 | 
         
            +
                        act_fn = GELU(dim, inner_dim)
         
     | 
| 238 | 
         
            +
                    if activation_fn == "gelu-approximate":
         
     | 
| 239 | 
         
            +
                        act_fn = GELU(dim, inner_dim, approximate="tanh")
         
     | 
| 240 | 
         
            +
                    elif activation_fn == "geglu":
         
     | 
| 241 | 
         
            +
                        act_fn = GEGLU(dim, inner_dim)
         
     | 
| 242 | 
         
            +
                    elif activation_fn == "geglu-approximate":
         
     | 
| 243 | 
         
            +
                        act_fn = ApproximateGELU(dim, inner_dim)
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                    self.net = nn.ModuleList([])
         
     | 
| 246 | 
         
            +
                    # project in
         
     | 
| 247 | 
         
            +
                    self.net.append(act_fn)
         
     | 
| 248 | 
         
            +
                    # project dropout
         
     | 
| 249 | 
         
            +
                    self.net.append(nn.Dropout(dropout))
         
     | 
| 250 | 
         
            +
                    # project out
         
     | 
| 251 | 
         
            +
                    self.net.append(nn.Linear(inner_dim, dim_out))
         
     | 
| 252 | 
         
            +
                    # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
         
     | 
| 253 | 
         
            +
                    if final_dropout:
         
     | 
| 254 | 
         
            +
                        self.net.append(nn.Dropout(dropout))
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                def forward(self, hidden_states):
         
     | 
| 257 | 
         
            +
                    for module in self.net:
         
     | 
| 258 | 
         
            +
                        hidden_states = module(hidden_states)
         
     | 
| 259 | 
         
            +
                    return hidden_states
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
            class GELU(nn.Module):
         
     | 
| 263 | 
         
            +
                r"""
         
     | 
| 264 | 
         
            +
                GELU activation function with tanh approximation support with `approximate="tanh"`.
         
     | 
| 265 | 
         
            +
                """
         
     | 
| 266 | 
         
            +
             
     | 
| 267 | 
         
            +
                def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
         
     | 
| 268 | 
         
            +
                    super().__init__()
         
     | 
| 269 | 
         
            +
                    self.proj = nn.Linear(dim_in, dim_out)
         
     | 
| 270 | 
         
            +
                    self.approximate = approximate
         
     | 
| 271 | 
         
            +
             
     | 
| 272 | 
         
            +
                def gelu(self, gate):
         
     | 
| 273 | 
         
            +
                    if gate.device.type != "mps":
         
     | 
| 274 | 
         
            +
                        return F.gelu(gate, approximate=self.approximate)
         
     | 
| 275 | 
         
            +
                    # mps: gelu is not implemented for float16
         
     | 
| 276 | 
         
            +
                    return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
                def forward(self, hidden_states):
         
     | 
| 279 | 
         
            +
                    hidden_states = self.proj(hidden_states)
         
     | 
| 280 | 
         
            +
                    hidden_states = self.gelu(hidden_states)
         
     | 
| 281 | 
         
            +
                    return hidden_states
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
            class GEGLU(nn.Module):
         
     | 
| 285 | 
         
            +
                r"""
         
     | 
| 286 | 
         
            +
                A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
         
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
                Parameters:
         
     | 
| 289 | 
         
            +
                    dim_in (`int`): The number of channels in the input.
         
     | 
| 290 | 
         
            +
                    dim_out (`int`): The number of channels in the output.
         
     | 
| 291 | 
         
            +
                """
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
                def __init__(self, dim_in: int, dim_out: int):
         
     | 
| 294 | 
         
            +
                    super().__init__()
         
     | 
| 295 | 
         
            +
                    self.proj = nn.Linear(dim_in, dim_out * 2)
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
                def gelu(self, gate):
         
     | 
| 298 | 
         
            +
                    if gate.device.type != "mps":
         
     | 
| 299 | 
         
            +
                        return F.gelu(gate)
         
     | 
| 300 | 
         
            +
                    # mps: gelu is not implemented for float16
         
     | 
| 301 | 
         
            +
                    return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
         
     | 
| 302 | 
         
            +
             
     | 
| 303 | 
         
            +
                def forward(self, hidden_states):
         
     | 
| 304 | 
         
            +
                    hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
         
     | 
| 305 | 
         
            +
                    return hidden_states * self.gelu(gate)
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
            class ApproximateGELU(nn.Module):
         
     | 
| 309 | 
         
            +
                """
         
     | 
| 310 | 
         
            +
                The approximate form of Gaussian Error Linear Unit (GELU)
         
     | 
| 311 | 
         
            +
             
     | 
| 312 | 
         
            +
                For more details, see section 2: https://arxiv.org/abs/1606.08415
         
     | 
| 313 | 
         
            +
                """
         
     | 
| 314 | 
         
            +
             
     | 
| 315 | 
         
            +
                def __init__(self, dim_in: int, dim_out: int):
         
     | 
| 316 | 
         
            +
                    super().__init__()
         
     | 
| 317 | 
         
            +
                    self.proj = nn.Linear(dim_in, dim_out)
         
     | 
| 318 | 
         
            +
             
     | 
| 319 | 
         
            +
                def forward(self, x):
         
     | 
| 320 | 
         
            +
                    x = self.proj(x)
         
     | 
| 321 | 
         
            +
                    return x * torch.sigmoid(1.702 * x)
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
             
     | 
| 324 | 
         
            +
            class AdaLayerNorm(nn.Module):
         
     | 
| 325 | 
         
            +
                """
         
     | 
| 326 | 
         
            +
                Norm layer modified to incorporate timestep embeddings.
         
     | 
| 327 | 
         
            +
                """
         
     | 
| 328 | 
         
            +
             
     | 
| 329 | 
         
            +
                def __init__(self, embedding_dim, num_embeddings):
         
     | 
| 330 | 
         
            +
                    super().__init__()
         
     | 
| 331 | 
         
            +
                    self.emb = nn.Embedding(num_embeddings, embedding_dim)
         
     | 
| 332 | 
         
            +
                    self.silu = nn.SiLU()
         
     | 
| 333 | 
         
            +
                    self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
         
     | 
| 334 | 
         
            +
                    self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
         
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
                def forward(self, x, timestep):
         
     | 
| 337 | 
         
            +
                    emb = self.linear(self.silu(self.emb(timestep)))
         
     | 
| 338 | 
         
            +
                    scale, shift = torch.chunk(emb, 2)
         
     | 
| 339 | 
         
            +
                    x = self.norm(x) * (1 + scale) + shift
         
     | 
| 340 | 
         
            +
                    return x
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
            class AdaLayerNormZero(nn.Module):
         
     | 
| 344 | 
         
            +
                """
         
     | 
| 345 | 
         
            +
                Norm layer adaptive layer norm zero (adaLN-Zero).
         
     | 
| 346 | 
         
            +
                """
         
     | 
| 347 | 
         
            +
             
     | 
| 348 | 
         
            +
                def __init__(self, embedding_dim, num_embeddings):
         
     | 
| 349 | 
         
            +
                    super().__init__()
         
     | 
| 350 | 
         
            +
             
     | 
| 351 | 
         
            +
                    self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
         
     | 
| 352 | 
         
            +
             
     | 
| 353 | 
         
            +
                    self.silu = nn.SiLU()
         
     | 
| 354 | 
         
            +
                    self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
         
     | 
| 355 | 
         
            +
                    self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
         
     | 
| 356 | 
         
            +
             
     | 
| 357 | 
         
            +
                def forward(self, x, timestep, class_labels, hidden_dtype=None):
         
     | 
| 358 | 
         
            +
                    emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
         
     | 
| 359 | 
         
            +
                    shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
         
     | 
| 360 | 
         
            +
                    x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
         
     | 
| 361 | 
         
            +
                    return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
         
     | 
| 362 | 
         
            +
             
     | 
| 363 | 
         
            +
             
     | 
| 364 | 
         
            +
            class AdaGroupNorm(nn.Module):
         
     | 
| 365 | 
         
            +
                """
         
     | 
| 366 | 
         
            +
                GroupNorm layer modified to incorporate timestep embeddings.
         
     | 
| 367 | 
         
            +
                """
         
     | 
| 368 | 
         
            +
             
     | 
| 369 | 
         
            +
                def __init__(
         
     | 
| 370 | 
         
            +
                    self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
         
     | 
| 371 | 
         
            +
                ):
         
     | 
| 372 | 
         
            +
                    super().__init__()
         
     | 
| 373 | 
         
            +
                    self.num_groups = num_groups
         
     | 
| 374 | 
         
            +
                    self.eps = eps
         
     | 
| 375 | 
         
            +
             
     | 
| 376 | 
         
            +
                    if act_fn is None:
         
     | 
| 377 | 
         
            +
                        self.act = None
         
     | 
| 378 | 
         
            +
                    else:
         
     | 
| 379 | 
         
            +
                        self.act = get_activation(act_fn)
         
     | 
| 380 | 
         
            +
             
     | 
| 381 | 
         
            +
                    self.linear = nn.Linear(embedding_dim, out_dim * 2)
         
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
                def forward(self, x, emb):
         
     | 
| 384 | 
         
            +
                    if self.act:
         
     | 
| 385 | 
         
            +
                        emb = self.act(emb)
         
     | 
| 386 | 
         
            +
                    emb = self.linear(emb)
         
     | 
| 387 | 
         
            +
                    emb = emb[:, :, None, None]
         
     | 
| 388 | 
         
            +
                    scale, shift = emb.chunk(2, dim=1)
         
     | 
| 389 | 
         
            +
             
     | 
| 390 | 
         
            +
                    x = F.group_norm(x, self.num_groups, eps=self.eps)
         
     | 
| 391 | 
         
            +
                    x = x * (1 + scale) + shift
         
     | 
| 392 | 
         
            +
                    return x
         
     | 
    	
        6DoF/diffusers/models/attention_flax.py
    ADDED
    
    | 
         @@ -0,0 +1,446 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            import functools
         
     | 
| 16 | 
         
            +
            import math
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            import flax.linen as nn
         
     | 
| 19 | 
         
            +
            import jax
         
     | 
| 20 | 
         
            +
            import jax.numpy as jnp
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
         
     | 
| 24 | 
         
            +
                """Multi-head dot product attention with a limited number of queries."""
         
     | 
| 25 | 
         
            +
                num_kv, num_heads, k_features = key.shape[-3:]
         
     | 
| 26 | 
         
            +
                v_features = value.shape[-1]
         
     | 
| 27 | 
         
            +
                key_chunk_size = min(key_chunk_size, num_kv)
         
     | 
| 28 | 
         
            +
                query = query / jnp.sqrt(k_features)
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                @functools.partial(jax.checkpoint, prevent_cse=False)
         
     | 
| 31 | 
         
            +
                def summarize_chunk(query, key, value):
         
     | 
| 32 | 
         
            +
                    attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                    max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
         
     | 
| 35 | 
         
            +
                    max_score = jax.lax.stop_gradient(max_score)
         
     | 
| 36 | 
         
            +
                    exp_weights = jnp.exp(attn_weights - max_score)
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                    exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
         
     | 
| 39 | 
         
            +
                    max_score = jnp.einsum("...qhk->...qh", max_score)
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                    return (exp_values, exp_weights.sum(axis=-1), max_score)
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                def chunk_scanner(chunk_idx):
         
     | 
| 44 | 
         
            +
                    # julienne key array
         
     | 
| 45 | 
         
            +
                    key_chunk = jax.lax.dynamic_slice(
         
     | 
| 46 | 
         
            +
                        operand=key,
         
     | 
| 47 | 
         
            +
                        start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0],  # [...,k,h,d]
         
     | 
| 48 | 
         
            +
                        slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features],  # [...,k,h,d]
         
     | 
| 49 | 
         
            +
                    )
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                    # julienne value array
         
     | 
| 52 | 
         
            +
                    value_chunk = jax.lax.dynamic_slice(
         
     | 
| 53 | 
         
            +
                        operand=value,
         
     | 
| 54 | 
         
            +
                        start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0],  # [...,v,h,d]
         
     | 
| 55 | 
         
            +
                        slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features],  # [...,v,h,d]
         
     | 
| 56 | 
         
            +
                    )
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                    return summarize_chunk(query, key_chunk, value_chunk)
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                global_max = jnp.max(chunk_max, axis=0, keepdims=True)
         
     | 
| 63 | 
         
            +
                max_diffs = jnp.exp(chunk_max - global_max)
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
         
     | 
| 66 | 
         
            +
                chunk_weights *= max_diffs
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                all_values = chunk_values.sum(axis=0)
         
     | 
| 69 | 
         
            +
                all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                return all_values / all_weights
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
            def jax_memory_efficient_attention(
         
     | 
| 75 | 
         
            +
                query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
         
     | 
| 76 | 
         
            +
            ):
         
     | 
| 77 | 
         
            +
                r"""
         
     | 
| 78 | 
         
            +
                Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
         
     | 
| 79 | 
         
            +
                https://github.com/AminRezaei0x443/memory-efficient-attention
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                Args:
         
     | 
| 82 | 
         
            +
                    query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
         
     | 
| 83 | 
         
            +
                    key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
         
     | 
| 84 | 
         
            +
                    value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
         
     | 
| 85 | 
         
            +
                    precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
         
     | 
| 86 | 
         
            +
                        numerical precision for computation
         
     | 
| 87 | 
         
            +
                    query_chunk_size (`int`, *optional*, defaults to 1024):
         
     | 
| 88 | 
         
            +
                        chunk size to divide query array value must divide query_length equally without remainder
         
     | 
| 89 | 
         
            +
                    key_chunk_size (`int`, *optional*, defaults to 4096):
         
     | 
| 90 | 
         
            +
                        chunk size to divide key and value array value must divide key_value_length equally without remainder
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                Returns:
         
     | 
| 93 | 
         
            +
                    (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
         
     | 
| 94 | 
         
            +
                """
         
     | 
| 95 | 
         
            +
                num_q, num_heads, q_features = query.shape[-3:]
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                def chunk_scanner(chunk_idx, _):
         
     | 
| 98 | 
         
            +
                    # julienne query array
         
     | 
| 99 | 
         
            +
                    query_chunk = jax.lax.dynamic_slice(
         
     | 
| 100 | 
         
            +
                        operand=query,
         
     | 
| 101 | 
         
            +
                        start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0],  # [...,q,h,d]
         
     | 
| 102 | 
         
            +
                        slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features],  # [...,q,h,d]
         
     | 
| 103 | 
         
            +
                    )
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                    return (
         
     | 
| 106 | 
         
            +
                        chunk_idx + query_chunk_size,  # unused ignore it
         
     | 
| 107 | 
         
            +
                        _query_chunk_attention(
         
     | 
| 108 | 
         
            +
                            query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
         
     | 
| 109 | 
         
            +
                        ),
         
     | 
| 110 | 
         
            +
                    )
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                _, res = jax.lax.scan(
         
     | 
| 113 | 
         
            +
                    f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size)  # start counter  # stop counter
         
     | 
| 114 | 
         
            +
                )
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                return jnp.concatenate(res, axis=-3)  # fuse the chunked result back
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
            class FlaxAttention(nn.Module):
         
     | 
| 120 | 
         
            +
                r"""
         
     | 
| 121 | 
         
            +
                A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                Parameters:
         
     | 
| 124 | 
         
            +
                    query_dim (:obj:`int`):
         
     | 
| 125 | 
         
            +
                        Input hidden states dimension
         
     | 
| 126 | 
         
            +
                    heads (:obj:`int`, *optional*, defaults to 8):
         
     | 
| 127 | 
         
            +
                        Number of heads
         
     | 
| 128 | 
         
            +
                    dim_head (:obj:`int`, *optional*, defaults to 64):
         
     | 
| 129 | 
         
            +
                        Hidden states dimension inside each head
         
     | 
| 130 | 
         
            +
                    dropout (:obj:`float`, *optional*, defaults to 0.0):
         
     | 
| 131 | 
         
            +
                        Dropout rate
         
     | 
| 132 | 
         
            +
                    use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
         
     | 
| 133 | 
         
            +
                        enable memory efficient attention https://arxiv.org/abs/2112.05682
         
     | 
| 134 | 
         
            +
                    dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
         
     | 
| 135 | 
         
            +
                        Parameters `dtype`
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                """
         
     | 
| 138 | 
         
            +
                query_dim: int
         
     | 
| 139 | 
         
            +
                heads: int = 8
         
     | 
| 140 | 
         
            +
                dim_head: int = 64
         
     | 
| 141 | 
         
            +
                dropout: float = 0.0
         
     | 
| 142 | 
         
            +
                use_memory_efficient_attention: bool = False
         
     | 
| 143 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                def setup(self):
         
     | 
| 146 | 
         
            +
                    inner_dim = self.dim_head * self.heads
         
     | 
| 147 | 
         
            +
                    self.scale = self.dim_head**-0.5
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                    # Weights were exported with old names {to_q, to_k, to_v, to_out}
         
     | 
| 150 | 
         
            +
                    self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
         
     | 
| 151 | 
         
            +
                    self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
         
     | 
| 152 | 
         
            +
                    self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                    self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
         
     | 
| 155 | 
         
            +
                    self.dropout_layer = nn.Dropout(rate=self.dropout)
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                def reshape_heads_to_batch_dim(self, tensor):
         
     | 
| 158 | 
         
            +
                    batch_size, seq_len, dim = tensor.shape
         
     | 
| 159 | 
         
            +
                    head_size = self.heads
         
     | 
| 160 | 
         
            +
                    tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
         
     | 
| 161 | 
         
            +
                    tensor = jnp.transpose(tensor, (0, 2, 1, 3))
         
     | 
| 162 | 
         
            +
                    tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
         
     | 
| 163 | 
         
            +
                    return tensor
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                def reshape_batch_dim_to_heads(self, tensor):
         
     | 
| 166 | 
         
            +
                    batch_size, seq_len, dim = tensor.shape
         
     | 
| 167 | 
         
            +
                    head_size = self.heads
         
     | 
| 168 | 
         
            +
                    tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
         
     | 
| 169 | 
         
            +
                    tensor = jnp.transpose(tensor, (0, 2, 1, 3))
         
     | 
| 170 | 
         
            +
                    tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
         
     | 
| 171 | 
         
            +
                    return tensor
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                def __call__(self, hidden_states, context=None, deterministic=True):
         
     | 
| 174 | 
         
            +
                    context = hidden_states if context is None else context
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                    query_proj = self.query(hidden_states)
         
     | 
| 177 | 
         
            +
                    key_proj = self.key(context)
         
     | 
| 178 | 
         
            +
                    value_proj = self.value(context)
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                    query_states = self.reshape_heads_to_batch_dim(query_proj)
         
     | 
| 181 | 
         
            +
                    key_states = self.reshape_heads_to_batch_dim(key_proj)
         
     | 
| 182 | 
         
            +
                    value_states = self.reshape_heads_to_batch_dim(value_proj)
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                    if self.use_memory_efficient_attention:
         
     | 
| 185 | 
         
            +
                        query_states = query_states.transpose(1, 0, 2)
         
     | 
| 186 | 
         
            +
                        key_states = key_states.transpose(1, 0, 2)
         
     | 
| 187 | 
         
            +
                        value_states = value_states.transpose(1, 0, 2)
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                        # this if statement create a chunk size for each layer of the unet
         
     | 
| 190 | 
         
            +
                        # the chunk size is equal to the query_length dimension of the deepest layer of the unet
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                        flatten_latent_dim = query_states.shape[-3]
         
     | 
| 193 | 
         
            +
                        if flatten_latent_dim % 64 == 0:
         
     | 
| 194 | 
         
            +
                            query_chunk_size = int(flatten_latent_dim / 64)
         
     | 
| 195 | 
         
            +
                        elif flatten_latent_dim % 16 == 0:
         
     | 
| 196 | 
         
            +
                            query_chunk_size = int(flatten_latent_dim / 16)
         
     | 
| 197 | 
         
            +
                        elif flatten_latent_dim % 4 == 0:
         
     | 
| 198 | 
         
            +
                            query_chunk_size = int(flatten_latent_dim / 4)
         
     | 
| 199 | 
         
            +
                        else:
         
     | 
| 200 | 
         
            +
                            query_chunk_size = int(flatten_latent_dim)
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                        hidden_states = jax_memory_efficient_attention(
         
     | 
| 203 | 
         
            +
                            query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
         
     | 
| 204 | 
         
            +
                        )
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                        hidden_states = hidden_states.transpose(1, 0, 2)
         
     | 
| 207 | 
         
            +
                    else:
         
     | 
| 208 | 
         
            +
                        # compute attentions
         
     | 
| 209 | 
         
            +
                        attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
         
     | 
| 210 | 
         
            +
                        attention_scores = attention_scores * self.scale
         
     | 
| 211 | 
         
            +
                        attention_probs = nn.softmax(attention_scores, axis=2)
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
                        # attend to values
         
     | 
| 214 | 
         
            +
                        hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                    hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
         
     | 
| 217 | 
         
            +
                    hidden_states = self.proj_attn(hidden_states)
         
     | 
| 218 | 
         
            +
                    return self.dropout_layer(hidden_states, deterministic=deterministic)
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
            class FlaxBasicTransformerBlock(nn.Module):
         
     | 
| 222 | 
         
            +
                r"""
         
     | 
| 223 | 
         
            +
                A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in:
         
     | 
| 224 | 
         
            +
                https://arxiv.org/abs/1706.03762
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
                Parameters:
         
     | 
| 228 | 
         
            +
                    dim (:obj:`int`):
         
     | 
| 229 | 
         
            +
                        Inner hidden states dimension
         
     | 
| 230 | 
         
            +
                    n_heads (:obj:`int`):
         
     | 
| 231 | 
         
            +
                        Number of heads
         
     | 
| 232 | 
         
            +
                    d_head (:obj:`int`):
         
     | 
| 233 | 
         
            +
                        Hidden states dimension inside each head
         
     | 
| 234 | 
         
            +
                    dropout (:obj:`float`, *optional*, defaults to 0.0):
         
     | 
| 235 | 
         
            +
                        Dropout rate
         
     | 
| 236 | 
         
            +
                    only_cross_attention (`bool`, defaults to `False`):
         
     | 
| 237 | 
         
            +
                        Whether to only apply cross attention.
         
     | 
| 238 | 
         
            +
                    dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
         
     | 
| 239 | 
         
            +
                        Parameters `dtype`
         
     | 
| 240 | 
         
            +
                    use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
         
     | 
| 241 | 
         
            +
                        enable memory efficient attention https://arxiv.org/abs/2112.05682
         
     | 
| 242 | 
         
            +
                """
         
     | 
| 243 | 
         
            +
                dim: int
         
     | 
| 244 | 
         
            +
                n_heads: int
         
     | 
| 245 | 
         
            +
                d_head: int
         
     | 
| 246 | 
         
            +
                dropout: float = 0.0
         
     | 
| 247 | 
         
            +
                only_cross_attention: bool = False
         
     | 
| 248 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 249 | 
         
            +
                use_memory_efficient_attention: bool = False
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
                def setup(self):
         
     | 
| 252 | 
         
            +
                    # self attention (or cross_attention if only_cross_attention is True)
         
     | 
| 253 | 
         
            +
                    self.attn1 = FlaxAttention(
         
     | 
| 254 | 
         
            +
                        self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
         
     | 
| 255 | 
         
            +
                    )
         
     | 
| 256 | 
         
            +
                    # cross attention
         
     | 
| 257 | 
         
            +
                    self.attn2 = FlaxAttention(
         
     | 
| 258 | 
         
            +
                        self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
         
     | 
| 259 | 
         
            +
                    )
         
     | 
| 260 | 
         
            +
                    self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
         
     | 
| 261 | 
         
            +
                    self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
         
     | 
| 262 | 
         
            +
                    self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
         
     | 
| 263 | 
         
            +
                    self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
         
     | 
| 264 | 
         
            +
                    self.dropout_layer = nn.Dropout(rate=self.dropout)
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                def __call__(self, hidden_states, context, deterministic=True):
         
     | 
| 267 | 
         
            +
                    # self attention
         
     | 
| 268 | 
         
            +
                    residual = hidden_states
         
     | 
| 269 | 
         
            +
                    if self.only_cross_attention:
         
     | 
| 270 | 
         
            +
                        hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
         
     | 
| 271 | 
         
            +
                    else:
         
     | 
| 272 | 
         
            +
                        hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
         
     | 
| 273 | 
         
            +
                    hidden_states = hidden_states + residual
         
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
                    # cross attention
         
     | 
| 276 | 
         
            +
                    residual = hidden_states
         
     | 
| 277 | 
         
            +
                    hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
         
     | 
| 278 | 
         
            +
                    hidden_states = hidden_states + residual
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                    # feed forward
         
     | 
| 281 | 
         
            +
                    residual = hidden_states
         
     | 
| 282 | 
         
            +
                    hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
         
     | 
| 283 | 
         
            +
                    hidden_states = hidden_states + residual
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
                    return self.dropout_layer(hidden_states, deterministic=deterministic)
         
     | 
| 286 | 
         
            +
             
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
            class FlaxTransformer2DModel(nn.Module):
         
     | 
| 289 | 
         
            +
                r"""
         
     | 
| 290 | 
         
            +
                A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
         
     | 
| 291 | 
         
            +
                https://arxiv.org/pdf/1506.02025.pdf
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
                Parameters:
         
     | 
| 295 | 
         
            +
                    in_channels (:obj:`int`):
         
     | 
| 296 | 
         
            +
                        Input number of channels
         
     | 
| 297 | 
         
            +
                    n_heads (:obj:`int`):
         
     | 
| 298 | 
         
            +
                        Number of heads
         
     | 
| 299 | 
         
            +
                    d_head (:obj:`int`):
         
     | 
| 300 | 
         
            +
                        Hidden states dimension inside each head
         
     | 
| 301 | 
         
            +
                    depth (:obj:`int`, *optional*, defaults to 1):
         
     | 
| 302 | 
         
            +
                        Number of transformers block
         
     | 
| 303 | 
         
            +
                    dropout (:obj:`float`, *optional*, defaults to 0.0):
         
     | 
| 304 | 
         
            +
                        Dropout rate
         
     | 
| 305 | 
         
            +
                    use_linear_projection (`bool`, defaults to `False`): tbd
         
     | 
| 306 | 
         
            +
                    only_cross_attention (`bool`, defaults to `False`): tbd
         
     | 
| 307 | 
         
            +
                    dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
         
     | 
| 308 | 
         
            +
                        Parameters `dtype`
         
     | 
| 309 | 
         
            +
                    use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
         
     | 
| 310 | 
         
            +
                        enable memory efficient attention https://arxiv.org/abs/2112.05682
         
     | 
| 311 | 
         
            +
                """
         
     | 
| 312 | 
         
            +
                in_channels: int
         
     | 
| 313 | 
         
            +
                n_heads: int
         
     | 
| 314 | 
         
            +
                d_head: int
         
     | 
| 315 | 
         
            +
                depth: int = 1
         
     | 
| 316 | 
         
            +
                dropout: float = 0.0
         
     | 
| 317 | 
         
            +
                use_linear_projection: bool = False
         
     | 
| 318 | 
         
            +
                only_cross_attention: bool = False
         
     | 
| 319 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 320 | 
         
            +
                use_memory_efficient_attention: bool = False
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
                def setup(self):
         
     | 
| 323 | 
         
            +
                    self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
                    inner_dim = self.n_heads * self.d_head
         
     | 
| 326 | 
         
            +
                    if self.use_linear_projection:
         
     | 
| 327 | 
         
            +
                        self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
         
     | 
| 328 | 
         
            +
                    else:
         
     | 
| 329 | 
         
            +
                        self.proj_in = nn.Conv(
         
     | 
| 330 | 
         
            +
                            inner_dim,
         
     | 
| 331 | 
         
            +
                            kernel_size=(1, 1),
         
     | 
| 332 | 
         
            +
                            strides=(1, 1),
         
     | 
| 333 | 
         
            +
                            padding="VALID",
         
     | 
| 334 | 
         
            +
                            dtype=self.dtype,
         
     | 
| 335 | 
         
            +
                        )
         
     | 
| 336 | 
         
            +
             
     | 
| 337 | 
         
            +
                    self.transformer_blocks = [
         
     | 
| 338 | 
         
            +
                        FlaxBasicTransformerBlock(
         
     | 
| 339 | 
         
            +
                            inner_dim,
         
     | 
| 340 | 
         
            +
                            self.n_heads,
         
     | 
| 341 | 
         
            +
                            self.d_head,
         
     | 
| 342 | 
         
            +
                            dropout=self.dropout,
         
     | 
| 343 | 
         
            +
                            only_cross_attention=self.only_cross_attention,
         
     | 
| 344 | 
         
            +
                            dtype=self.dtype,
         
     | 
| 345 | 
         
            +
                            use_memory_efficient_attention=self.use_memory_efficient_attention,
         
     | 
| 346 | 
         
            +
                        )
         
     | 
| 347 | 
         
            +
                        for _ in range(self.depth)
         
     | 
| 348 | 
         
            +
                    ]
         
     | 
| 349 | 
         
            +
             
     | 
| 350 | 
         
            +
                    if self.use_linear_projection:
         
     | 
| 351 | 
         
            +
                        self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
         
     | 
| 352 | 
         
            +
                    else:
         
     | 
| 353 | 
         
            +
                        self.proj_out = nn.Conv(
         
     | 
| 354 | 
         
            +
                            inner_dim,
         
     | 
| 355 | 
         
            +
                            kernel_size=(1, 1),
         
     | 
| 356 | 
         
            +
                            strides=(1, 1),
         
     | 
| 357 | 
         
            +
                            padding="VALID",
         
     | 
| 358 | 
         
            +
                            dtype=self.dtype,
         
     | 
| 359 | 
         
            +
                        )
         
     | 
| 360 | 
         
            +
             
     | 
| 361 | 
         
            +
                    self.dropout_layer = nn.Dropout(rate=self.dropout)
         
     | 
| 362 | 
         
            +
             
     | 
| 363 | 
         
            +
                def __call__(self, hidden_states, context, deterministic=True):
         
     | 
| 364 | 
         
            +
                    batch, height, width, channels = hidden_states.shape
         
     | 
| 365 | 
         
            +
                    residual = hidden_states
         
     | 
| 366 | 
         
            +
                    hidden_states = self.norm(hidden_states)
         
     | 
| 367 | 
         
            +
                    if self.use_linear_projection:
         
     | 
| 368 | 
         
            +
                        hidden_states = hidden_states.reshape(batch, height * width, channels)
         
     | 
| 369 | 
         
            +
                        hidden_states = self.proj_in(hidden_states)
         
     | 
| 370 | 
         
            +
                    else:
         
     | 
| 371 | 
         
            +
                        hidden_states = self.proj_in(hidden_states)
         
     | 
| 372 | 
         
            +
                        hidden_states = hidden_states.reshape(batch, height * width, channels)
         
     | 
| 373 | 
         
            +
             
     | 
| 374 | 
         
            +
                    for transformer_block in self.transformer_blocks:
         
     | 
| 375 | 
         
            +
                        hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
         
     | 
| 376 | 
         
            +
             
     | 
| 377 | 
         
            +
                    if self.use_linear_projection:
         
     | 
| 378 | 
         
            +
                        hidden_states = self.proj_out(hidden_states)
         
     | 
| 379 | 
         
            +
                        hidden_states = hidden_states.reshape(batch, height, width, channels)
         
     | 
| 380 | 
         
            +
                    else:
         
     | 
| 381 | 
         
            +
                        hidden_states = hidden_states.reshape(batch, height, width, channels)
         
     | 
| 382 | 
         
            +
                        hidden_states = self.proj_out(hidden_states)
         
     | 
| 383 | 
         
            +
             
     | 
| 384 | 
         
            +
                    hidden_states = hidden_states + residual
         
     | 
| 385 | 
         
            +
                    return self.dropout_layer(hidden_states, deterministic=deterministic)
         
     | 
| 386 | 
         
            +
             
     | 
| 387 | 
         
            +
             
     | 
| 388 | 
         
            +
            class FlaxFeedForward(nn.Module):
         
     | 
| 389 | 
         
            +
                r"""
         
     | 
| 390 | 
         
            +
                Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
         
     | 
| 391 | 
         
            +
                [`FeedForward`] class, with the following simplifications:
         
     | 
| 392 | 
         
            +
                - The activation function is currently hardcoded to a gated linear unit from:
         
     | 
| 393 | 
         
            +
                https://arxiv.org/abs/2002.05202
         
     | 
| 394 | 
         
            +
                - `dim_out` is equal to `dim`.
         
     | 
| 395 | 
         
            +
                - The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].
         
     | 
| 396 | 
         
            +
             
     | 
| 397 | 
         
            +
                Parameters:
         
     | 
| 398 | 
         
            +
                    dim (:obj:`int`):
         
     | 
| 399 | 
         
            +
                        Inner hidden states dimension
         
     | 
| 400 | 
         
            +
                    dropout (:obj:`float`, *optional*, defaults to 0.0):
         
     | 
| 401 | 
         
            +
                        Dropout rate
         
     | 
| 402 | 
         
            +
                    dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
         
     | 
| 403 | 
         
            +
                        Parameters `dtype`
         
     | 
| 404 | 
         
            +
                """
         
     | 
| 405 | 
         
            +
                dim: int
         
     | 
| 406 | 
         
            +
                dropout: float = 0.0
         
     | 
| 407 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 408 | 
         
            +
             
     | 
| 409 | 
         
            +
                def setup(self):
         
     | 
| 410 | 
         
            +
                    # The second linear layer needs to be called
         
     | 
| 411 | 
         
            +
                    # net_2 for now to match the index of the Sequential layer
         
     | 
| 412 | 
         
            +
                    self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
         
     | 
| 413 | 
         
            +
                    self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
         
     | 
| 414 | 
         
            +
             
     | 
| 415 | 
         
            +
                def __call__(self, hidden_states, deterministic=True):
         
     | 
| 416 | 
         
            +
                    hidden_states = self.net_0(hidden_states, deterministic=deterministic)
         
     | 
| 417 | 
         
            +
                    hidden_states = self.net_2(hidden_states)
         
     | 
| 418 | 
         
            +
                    return hidden_states
         
     | 
| 419 | 
         
            +
             
     | 
| 420 | 
         
            +
             
     | 
| 421 | 
         
            +
            class FlaxGEGLU(nn.Module):
         
     | 
| 422 | 
         
            +
                r"""
         
     | 
| 423 | 
         
            +
                Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
         
     | 
| 424 | 
         
            +
                https://arxiv.org/abs/2002.05202.
         
     | 
| 425 | 
         
            +
             
     | 
| 426 | 
         
            +
                Parameters:
         
     | 
| 427 | 
         
            +
                    dim (:obj:`int`):
         
     | 
| 428 | 
         
            +
                        Input hidden states dimension
         
     | 
| 429 | 
         
            +
                    dropout (:obj:`float`, *optional*, defaults to 0.0):
         
     | 
| 430 | 
         
            +
                        Dropout rate
         
     | 
| 431 | 
         
            +
                    dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
         
     | 
| 432 | 
         
            +
                        Parameters `dtype`
         
     | 
| 433 | 
         
            +
                """
         
     | 
| 434 | 
         
            +
                dim: int
         
     | 
| 435 | 
         
            +
                dropout: float = 0.0
         
     | 
| 436 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 437 | 
         
            +
             
     | 
| 438 | 
         
            +
                def setup(self):
         
     | 
| 439 | 
         
            +
                    inner_dim = self.dim * 4
         
     | 
| 440 | 
         
            +
                    self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
         
     | 
| 441 | 
         
            +
                    self.dropout_layer = nn.Dropout(rate=self.dropout)
         
     | 
| 442 | 
         
            +
             
     | 
| 443 | 
         
            +
                def __call__(self, hidden_states, deterministic=True):
         
     | 
| 444 | 
         
            +
                    hidden_states = self.proj(hidden_states)
         
     | 
| 445 | 
         
            +
                    hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
         
     | 
| 446 | 
         
            +
                    return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
         
     | 
    	
        6DoF/diffusers/models/attention_processor.py
    ADDED
    
    | 
         @@ -0,0 +1,1684 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            from typing import Callable, Optional, Union
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            import torch
         
     | 
| 17 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 18 | 
         
            +
            from torch import nn
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            from ..utils import deprecate, logging, maybe_allow_in_graph
         
     | 
| 21 | 
         
            +
            from ..utils.import_utils import is_xformers_available
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            if is_xformers_available():
         
     | 
| 28 | 
         
            +
                import xformers
         
     | 
| 29 | 
         
            +
                import xformers.ops
         
     | 
| 30 | 
         
            +
            else:
         
     | 
| 31 | 
         
            +
                xformers = None
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            # 6DoF CaPE
         
     | 
| 35 | 
         
            +
            import einops
         
     | 
| 36 | 
         
            +
            def cape_embed(f, P):
         
     | 
| 37 | 
         
            +
                # f is feature vector of shape [..., d]
         
     | 
| 38 | 
         
            +
                # P is 4x4 transformation matrix
         
     | 
| 39 | 
         
            +
                f = einops.rearrange(f, '... (d k) -> ... d k', k=4)
         
     | 
| 40 | 
         
            +
                return einops.rearrange(f@P, '... d k -> ... (d k)', k=4)
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            @maybe_allow_in_graph
         
     | 
| 43 | 
         
            +
            class Attention(nn.Module):
         
     | 
| 44 | 
         
            +
                r"""
         
     | 
| 45 | 
         
            +
                A cross attention layer.
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                Parameters:
         
     | 
| 48 | 
         
            +
                    query_dim (`int`): The number of channels in the query.
         
     | 
| 49 | 
         
            +
                    cross_attention_dim (`int`, *optional*):
         
     | 
| 50 | 
         
            +
                        The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
         
     | 
| 51 | 
         
            +
                    heads (`int`,  *optional*, defaults to 8): The number of heads to use for multi-head attention.
         
     | 
| 52 | 
         
            +
                    dim_head (`int`,  *optional*, defaults to 64): The number of channels in each head.
         
     | 
| 53 | 
         
            +
                    dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
         
     | 
| 54 | 
         
            +
                    bias (`bool`, *optional*, defaults to False):
         
     | 
| 55 | 
         
            +
                        Set to `True` for the query, key, and value linear layers to contain a bias parameter.
         
     | 
| 56 | 
         
            +
                """
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                def __init__(
         
     | 
| 59 | 
         
            +
                    self,
         
     | 
| 60 | 
         
            +
                    query_dim: int,
         
     | 
| 61 | 
         
            +
                    cross_attention_dim: Optional[int] = None,
         
     | 
| 62 | 
         
            +
                    heads: int = 8,
         
     | 
| 63 | 
         
            +
                    dim_head: int = 64,
         
     | 
| 64 | 
         
            +
                    dropout: float = 0.0,
         
     | 
| 65 | 
         
            +
                    bias=False,
         
     | 
| 66 | 
         
            +
                    upcast_attention: bool = False,
         
     | 
| 67 | 
         
            +
                    upcast_softmax: bool = False,
         
     | 
| 68 | 
         
            +
                    cross_attention_norm: Optional[str] = None,
         
     | 
| 69 | 
         
            +
                    cross_attention_norm_num_groups: int = 32,
         
     | 
| 70 | 
         
            +
                    added_kv_proj_dim: Optional[int] = None,
         
     | 
| 71 | 
         
            +
                    norm_num_groups: Optional[int] = None,
         
     | 
| 72 | 
         
            +
                    spatial_norm_dim: Optional[int] = None,
         
     | 
| 73 | 
         
            +
                    out_bias: bool = True,
         
     | 
| 74 | 
         
            +
                    scale_qk: bool = True,
         
     | 
| 75 | 
         
            +
                    only_cross_attention: bool = False,
         
     | 
| 76 | 
         
            +
                    eps: float = 1e-5,
         
     | 
| 77 | 
         
            +
                    rescale_output_factor: float = 1.0,
         
     | 
| 78 | 
         
            +
                    residual_connection: bool = False,
         
     | 
| 79 | 
         
            +
                    _from_deprecated_attn_block=False,
         
     | 
| 80 | 
         
            +
                    processor: Optional["AttnProcessor"] = None,
         
     | 
| 81 | 
         
            +
                ):
         
     | 
| 82 | 
         
            +
                    super().__init__()
         
     | 
| 83 | 
         
            +
                    inner_dim = dim_head * heads
         
     | 
| 84 | 
         
            +
                    cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
         
     | 
| 85 | 
         
            +
                    self.upcast_attention = upcast_attention
         
     | 
| 86 | 
         
            +
                    self.upcast_softmax = upcast_softmax
         
     | 
| 87 | 
         
            +
                    self.rescale_output_factor = rescale_output_factor
         
     | 
| 88 | 
         
            +
                    self.residual_connection = residual_connection
         
     | 
| 89 | 
         
            +
                    self.dropout = dropout
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                    # we make use of this private variable to know whether this class is loaded
         
     | 
| 92 | 
         
            +
                    # with an deprecated state dict so that we can convert it on the fly
         
     | 
| 93 | 
         
            +
                    self._from_deprecated_attn_block = _from_deprecated_attn_block
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                    self.scale_qk = scale_qk
         
     | 
| 96 | 
         
            +
                    self.scale = dim_head**-0.5 if self.scale_qk else 1.0
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                    self.heads = heads
         
     | 
| 99 | 
         
            +
                    # for slice_size > 0 the attention score computation
         
     | 
| 100 | 
         
            +
                    # is split across the batch axis to save memory
         
     | 
| 101 | 
         
            +
                    # You can set slice_size with `set_attention_slice`
         
     | 
| 102 | 
         
            +
                    self.sliceable_head_dim = heads
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                    self.added_kv_proj_dim = added_kv_proj_dim
         
     | 
| 105 | 
         
            +
                    self.only_cross_attention = only_cross_attention
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                    if self.added_kv_proj_dim is None and self.only_cross_attention:
         
     | 
| 108 | 
         
            +
                        raise ValueError(
         
     | 
| 109 | 
         
            +
                            "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
         
     | 
| 110 | 
         
            +
                        )
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    if norm_num_groups is not None:
         
     | 
| 113 | 
         
            +
                        self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
         
     | 
| 114 | 
         
            +
                    else:
         
     | 
| 115 | 
         
            +
                        self.group_norm = None
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                    if spatial_norm_dim is not None:
         
     | 
| 118 | 
         
            +
                        self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
         
     | 
| 119 | 
         
            +
                    else:
         
     | 
| 120 | 
         
            +
                        self.spatial_norm = None
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                    if cross_attention_norm is None:
         
     | 
| 123 | 
         
            +
                        self.norm_cross = None
         
     | 
| 124 | 
         
            +
                    elif cross_attention_norm == "layer_norm":
         
     | 
| 125 | 
         
            +
                        self.norm_cross = nn.LayerNorm(cross_attention_dim)
         
     | 
| 126 | 
         
            +
                    elif cross_attention_norm == "group_norm":
         
     | 
| 127 | 
         
            +
                        if self.added_kv_proj_dim is not None:
         
     | 
| 128 | 
         
            +
                            # The given `encoder_hidden_states` are initially of shape
         
     | 
| 129 | 
         
            +
                            # (batch_size, seq_len, added_kv_proj_dim) before being projected
         
     | 
| 130 | 
         
            +
                            # to (batch_size, seq_len, cross_attention_dim). The norm is applied
         
     | 
| 131 | 
         
            +
                            # before the projection, so we need to use `added_kv_proj_dim` as
         
     | 
| 132 | 
         
            +
                            # the number of channels for the group norm.
         
     | 
| 133 | 
         
            +
                            norm_cross_num_channels = added_kv_proj_dim
         
     | 
| 134 | 
         
            +
                        else:
         
     | 
| 135 | 
         
            +
                            norm_cross_num_channels = cross_attention_dim
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                        self.norm_cross = nn.GroupNorm(
         
     | 
| 138 | 
         
            +
                            num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
         
     | 
| 139 | 
         
            +
                        )
         
     | 
| 140 | 
         
            +
                    else:
         
     | 
| 141 | 
         
            +
                        raise ValueError(
         
     | 
| 142 | 
         
            +
                            f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
         
     | 
| 143 | 
         
            +
                        )
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                    self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
                    if not self.only_cross_attention:
         
     | 
| 148 | 
         
            +
                        # only relevant for the `AddedKVProcessor` classes
         
     | 
| 149 | 
         
            +
                        self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
         
     | 
| 150 | 
         
            +
                        self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
         
     | 
| 151 | 
         
            +
                    else:
         
     | 
| 152 | 
         
            +
                        self.to_k = None
         
     | 
| 153 | 
         
            +
                        self.to_v = None
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                    if self.added_kv_proj_dim is not None:
         
     | 
| 156 | 
         
            +
                        self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
         
     | 
| 157 | 
         
            +
                        self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim)
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                    self.to_out = nn.ModuleList([])
         
     | 
| 160 | 
         
            +
                    self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
         
     | 
| 161 | 
         
            +
                    self.to_out.append(nn.Dropout(dropout))
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                    # set attention processor
         
     | 
| 164 | 
         
            +
                    # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
         
     | 
| 165 | 
         
            +
                    # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
         
     | 
| 166 | 
         
            +
                    # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
         
     | 
| 167 | 
         
            +
                    if processor is None:
         
     | 
| 168 | 
         
            +
                        processor = (
         
     | 
| 169 | 
         
            +
                            AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
         
     | 
| 170 | 
         
            +
                        )
         
     | 
| 171 | 
         
            +
                    self.set_processor(processor)
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                def set_use_memory_efficient_attention_xformers(
         
     | 
| 174 | 
         
            +
                    self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
         
     | 
| 175 | 
         
            +
                ):
         
     | 
| 176 | 
         
            +
                    is_lora = hasattr(self, "processor") and isinstance(
         
     | 
| 177 | 
         
            +
                        self.processor,
         
     | 
| 178 | 
         
            +
                        (LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor),
         
     | 
| 179 | 
         
            +
                    )
         
     | 
| 180 | 
         
            +
                    is_custom_diffusion = hasattr(self, "processor") and isinstance(
         
     | 
| 181 | 
         
            +
                        self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
         
     | 
| 182 | 
         
            +
                    )
         
     | 
| 183 | 
         
            +
                    is_added_kv_processor = hasattr(self, "processor") and isinstance(
         
     | 
| 184 | 
         
            +
                        self.processor,
         
     | 
| 185 | 
         
            +
                        (
         
     | 
| 186 | 
         
            +
                            AttnAddedKVProcessor,
         
     | 
| 187 | 
         
            +
                            AttnAddedKVProcessor2_0,
         
     | 
| 188 | 
         
            +
                            SlicedAttnAddedKVProcessor,
         
     | 
| 189 | 
         
            +
                            XFormersAttnAddedKVProcessor,
         
     | 
| 190 | 
         
            +
                            LoRAAttnAddedKVProcessor,
         
     | 
| 191 | 
         
            +
                        ),
         
     | 
| 192 | 
         
            +
                    )
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                    if use_memory_efficient_attention_xformers:
         
     | 
| 195 | 
         
            +
                        if is_added_kv_processor and (is_lora or is_custom_diffusion):
         
     | 
| 196 | 
         
            +
                            raise NotImplementedError(
         
     | 
| 197 | 
         
            +
                                f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
         
     | 
| 198 | 
         
            +
                            )
         
     | 
| 199 | 
         
            +
                        if not is_xformers_available():
         
     | 
| 200 | 
         
            +
                            raise ModuleNotFoundError(
         
     | 
| 201 | 
         
            +
                                (
         
     | 
| 202 | 
         
            +
                                    "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
         
     | 
| 203 | 
         
            +
                                    " xformers"
         
     | 
| 204 | 
         
            +
                                ),
         
     | 
| 205 | 
         
            +
                                name="xformers",
         
     | 
| 206 | 
         
            +
                            )
         
     | 
| 207 | 
         
            +
                        elif not torch.cuda.is_available():
         
     | 
| 208 | 
         
            +
                            raise ValueError(
         
     | 
| 209 | 
         
            +
                                "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
         
     | 
| 210 | 
         
            +
                                " only available for GPU "
         
     | 
| 211 | 
         
            +
                            )
         
     | 
| 212 | 
         
            +
                        else:
         
     | 
| 213 | 
         
            +
                            try:
         
     | 
| 214 | 
         
            +
                                # Make sure we can run the memory efficient attention
         
     | 
| 215 | 
         
            +
                                _ = xformers.ops.memory_efficient_attention(
         
     | 
| 216 | 
         
            +
                                    torch.randn((1, 2, 40), device="cuda"),
         
     | 
| 217 | 
         
            +
                                    torch.randn((1, 2, 40), device="cuda"),
         
     | 
| 218 | 
         
            +
                                    torch.randn((1, 2, 40), device="cuda"),
         
     | 
| 219 | 
         
            +
                                )
         
     | 
| 220 | 
         
            +
                            except Exception as e:
         
     | 
| 221 | 
         
            +
                                raise e
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
                        if is_lora:
         
     | 
| 224 | 
         
            +
                            # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
         
     | 
| 225 | 
         
            +
                            # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
         
     | 
| 226 | 
         
            +
                            processor = LoRAXFormersAttnProcessor(
         
     | 
| 227 | 
         
            +
                                hidden_size=self.processor.hidden_size,
         
     | 
| 228 | 
         
            +
                                cross_attention_dim=self.processor.cross_attention_dim,
         
     | 
| 229 | 
         
            +
                                rank=self.processor.rank,
         
     | 
| 230 | 
         
            +
                                attention_op=attention_op,
         
     | 
| 231 | 
         
            +
                            )
         
     | 
| 232 | 
         
            +
                            processor.load_state_dict(self.processor.state_dict())
         
     | 
| 233 | 
         
            +
                            processor.to(self.processor.to_q_lora.up.weight.device)
         
     | 
| 234 | 
         
            +
                        elif is_custom_diffusion:
         
     | 
| 235 | 
         
            +
                            processor = CustomDiffusionXFormersAttnProcessor(
         
     | 
| 236 | 
         
            +
                                train_kv=self.processor.train_kv,
         
     | 
| 237 | 
         
            +
                                train_q_out=self.processor.train_q_out,
         
     | 
| 238 | 
         
            +
                                hidden_size=self.processor.hidden_size,
         
     | 
| 239 | 
         
            +
                                cross_attention_dim=self.processor.cross_attention_dim,
         
     | 
| 240 | 
         
            +
                                attention_op=attention_op,
         
     | 
| 241 | 
         
            +
                            )
         
     | 
| 242 | 
         
            +
                            processor.load_state_dict(self.processor.state_dict())
         
     | 
| 243 | 
         
            +
                            if hasattr(self.processor, "to_k_custom_diffusion"):
         
     | 
| 244 | 
         
            +
                                processor.to(self.processor.to_k_custom_diffusion.weight.device)
         
     | 
| 245 | 
         
            +
                        elif is_added_kv_processor:
         
     | 
| 246 | 
         
            +
                            # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
         
     | 
| 247 | 
         
            +
                            # which uses this type of cross attention ONLY because the attention mask of format
         
     | 
| 248 | 
         
            +
                            # [0, ..., -10.000, ..., 0, ...,] is not supported
         
     | 
| 249 | 
         
            +
                            # throw warning
         
     | 
| 250 | 
         
            +
                            logger.info(
         
     | 
| 251 | 
         
            +
                                "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
         
     | 
| 252 | 
         
            +
                            )
         
     | 
| 253 | 
         
            +
                            processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
         
     | 
| 254 | 
         
            +
                        else:
         
     | 
| 255 | 
         
            +
                            processor = XFormersAttnProcessor(attention_op=attention_op)
         
     | 
| 256 | 
         
            +
                    else:
         
     | 
| 257 | 
         
            +
                        if is_lora:
         
     | 
| 258 | 
         
            +
                            attn_processor_class = (
         
     | 
| 259 | 
         
            +
                                LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
         
     | 
| 260 | 
         
            +
                            )
         
     | 
| 261 | 
         
            +
                            processor = attn_processor_class(
         
     | 
| 262 | 
         
            +
                                hidden_size=self.processor.hidden_size,
         
     | 
| 263 | 
         
            +
                                cross_attention_dim=self.processor.cross_attention_dim,
         
     | 
| 264 | 
         
            +
                                rank=self.processor.rank,
         
     | 
| 265 | 
         
            +
                            )
         
     | 
| 266 | 
         
            +
                            processor.load_state_dict(self.processor.state_dict())
         
     | 
| 267 | 
         
            +
                            processor.to(self.processor.to_q_lora.up.weight.device)
         
     | 
| 268 | 
         
            +
                        elif is_custom_diffusion:
         
     | 
| 269 | 
         
            +
                            processor = CustomDiffusionAttnProcessor(
         
     | 
| 270 | 
         
            +
                                train_kv=self.processor.train_kv,
         
     | 
| 271 | 
         
            +
                                train_q_out=self.processor.train_q_out,
         
     | 
| 272 | 
         
            +
                                hidden_size=self.processor.hidden_size,
         
     | 
| 273 | 
         
            +
                                cross_attention_dim=self.processor.cross_attention_dim,
         
     | 
| 274 | 
         
            +
                            )
         
     | 
| 275 | 
         
            +
                            processor.load_state_dict(self.processor.state_dict())
         
     | 
| 276 | 
         
            +
                            if hasattr(self.processor, "to_k_custom_diffusion"):
         
     | 
| 277 | 
         
            +
                                processor.to(self.processor.to_k_custom_diffusion.weight.device)
         
     | 
| 278 | 
         
            +
                        else:
         
     | 
| 279 | 
         
            +
                            # set attention processor
         
     | 
| 280 | 
         
            +
                            # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
         
     | 
| 281 | 
         
            +
                            # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
         
     | 
| 282 | 
         
            +
                            # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
         
     | 
| 283 | 
         
            +
                            processor = (
         
     | 
| 284 | 
         
            +
                                AttnProcessor2_0()
         
     | 
| 285 | 
         
            +
                                if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
         
     | 
| 286 | 
         
            +
                                else AttnProcessor()
         
     | 
| 287 | 
         
            +
                            )
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
                    self.set_processor(processor)
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                def set_attention_slice(self, slice_size):
         
     | 
| 292 | 
         
            +
                    if slice_size is not None and slice_size > self.sliceable_head_dim:
         
     | 
| 293 | 
         
            +
                        raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
         
     | 
| 294 | 
         
            +
             
     | 
| 295 | 
         
            +
                    if slice_size is not None and self.added_kv_proj_dim is not None:
         
     | 
| 296 | 
         
            +
                        processor = SlicedAttnAddedKVProcessor(slice_size)
         
     | 
| 297 | 
         
            +
                    elif slice_size is not None:
         
     | 
| 298 | 
         
            +
                        processor = SlicedAttnProcessor(slice_size)
         
     | 
| 299 | 
         
            +
                    elif self.added_kv_proj_dim is not None:
         
     | 
| 300 | 
         
            +
                        processor = AttnAddedKVProcessor()
         
     | 
| 301 | 
         
            +
                    else:
         
     | 
| 302 | 
         
            +
                        # set attention processor
         
     | 
| 303 | 
         
            +
                        # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
         
     | 
| 304 | 
         
            +
                        # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
         
     | 
| 305 | 
         
            +
                        # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
         
     | 
| 306 | 
         
            +
                        processor = (
         
     | 
| 307 | 
         
            +
                            AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
         
     | 
| 308 | 
         
            +
                        )
         
     | 
| 309 | 
         
            +
             
     | 
| 310 | 
         
            +
                    self.set_processor(processor)
         
     | 
| 311 | 
         
            +
             
     | 
| 312 | 
         
            +
                def set_processor(self, processor: "AttnProcessor"):
         
     | 
| 313 | 
         
            +
                    # if current processor is in `self._modules` and if passed `processor` is not, we need to
         
     | 
| 314 | 
         
            +
                    # pop `processor` from `self._modules`
         
     | 
| 315 | 
         
            +
                    if (
         
     | 
| 316 | 
         
            +
                        hasattr(self, "processor")
         
     | 
| 317 | 
         
            +
                        and isinstance(self.processor, torch.nn.Module)
         
     | 
| 318 | 
         
            +
                        and not isinstance(processor, torch.nn.Module)
         
     | 
| 319 | 
         
            +
                    ):
         
     | 
| 320 | 
         
            +
                        logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
         
     | 
| 321 | 
         
            +
                        self._modules.pop("processor")
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
                    self.processor = processor
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
                def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
         
     | 
| 326 | 
         
            +
                    # The `Attention` class can call different attention processors / attention functions
         
     | 
| 327 | 
         
            +
                    # here we simply pass along all tensors to the selected processor class
         
     | 
| 328 | 
         
            +
                    # For standard processors that are defined here, `**cross_attention_kwargs` is empty
         
     | 
| 329 | 
         
            +
                    return self.processor(
         
     | 
| 330 | 
         
            +
                        self,
         
     | 
| 331 | 
         
            +
                        hidden_states,
         
     | 
| 332 | 
         
            +
                        encoder_hidden_states=encoder_hidden_states,
         
     | 
| 333 | 
         
            +
                        attention_mask=attention_mask,
         
     | 
| 334 | 
         
            +
                        **cross_attention_kwargs,
         
     | 
| 335 | 
         
            +
                    )
         
     | 
| 336 | 
         
            +
             
     | 
| 337 | 
         
            +
                def batch_to_head_dim(self, tensor):
         
     | 
| 338 | 
         
            +
                    head_size = self.heads
         
     | 
| 339 | 
         
            +
                    batch_size, seq_len, dim = tensor.shape
         
     | 
| 340 | 
         
            +
                    tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
         
     | 
| 341 | 
         
            +
                    tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
         
     | 
| 342 | 
         
            +
                    return tensor
         
     | 
| 343 | 
         
            +
             
     | 
| 344 | 
         
            +
                def head_to_batch_dim(self, tensor, out_dim=3):
         
     | 
| 345 | 
         
            +
                    head_size = self.heads
         
     | 
| 346 | 
         
            +
                    batch_size, seq_len, dim = tensor.shape
         
     | 
| 347 | 
         
            +
                    tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
         
     | 
| 348 | 
         
            +
                    tensor = tensor.permute(0, 2, 1, 3)
         
     | 
| 349 | 
         
            +
             
     | 
| 350 | 
         
            +
                    if out_dim == 3:
         
     | 
| 351 | 
         
            +
                        tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
         
     | 
| 352 | 
         
            +
             
     | 
| 353 | 
         
            +
                    return tensor
         
     | 
| 354 | 
         
            +
             
     | 
| 355 | 
         
            +
                def get_attention_scores(self, query, key, attention_mask=None):
         
     | 
| 356 | 
         
            +
                    dtype = query.dtype
         
     | 
| 357 | 
         
            +
                    if self.upcast_attention:
         
     | 
| 358 | 
         
            +
                        query = query.float()
         
     | 
| 359 | 
         
            +
                        key = key.float()
         
     | 
| 360 | 
         
            +
             
     | 
| 361 | 
         
            +
                    if attention_mask is None:
         
     | 
| 362 | 
         
            +
                        baddbmm_input = torch.empty(
         
     | 
| 363 | 
         
            +
                            query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
         
     | 
| 364 | 
         
            +
                        )
         
     | 
| 365 | 
         
            +
                        beta = 0
         
     | 
| 366 | 
         
            +
                    else:
         
     | 
| 367 | 
         
            +
                        baddbmm_input = attention_mask
         
     | 
| 368 | 
         
            +
                        beta = 1
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
                    attention_scores = torch.baddbmm(
         
     | 
| 371 | 
         
            +
                        baddbmm_input,
         
     | 
| 372 | 
         
            +
                        query,
         
     | 
| 373 | 
         
            +
                        key.transpose(-1, -2),
         
     | 
| 374 | 
         
            +
                        beta=beta,
         
     | 
| 375 | 
         
            +
                        alpha=self.scale,
         
     | 
| 376 | 
         
            +
                    )
         
     | 
| 377 | 
         
            +
                    del baddbmm_input
         
     | 
| 378 | 
         
            +
             
     | 
| 379 | 
         
            +
                    if self.upcast_softmax:
         
     | 
| 380 | 
         
            +
                        attention_scores = attention_scores.float()
         
     | 
| 381 | 
         
            +
             
     | 
| 382 | 
         
            +
                    attention_probs = attention_scores.softmax(dim=-1)
         
     | 
| 383 | 
         
            +
                    del attention_scores
         
     | 
| 384 | 
         
            +
             
     | 
| 385 | 
         
            +
                    attention_probs = attention_probs.to(dtype)
         
     | 
| 386 | 
         
            +
             
     | 
| 387 | 
         
            +
                    return attention_probs
         
     | 
| 388 | 
         
            +
             
     | 
| 389 | 
         
            +
                def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
         
     | 
| 390 | 
         
            +
                    if batch_size is None:
         
     | 
| 391 | 
         
            +
                        deprecate(
         
     | 
| 392 | 
         
            +
                            "batch_size=None",
         
     | 
| 393 | 
         
            +
                            "0.0.15",
         
     | 
| 394 | 
         
            +
                            (
         
     | 
| 395 | 
         
            +
                                "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
         
     | 
| 396 | 
         
            +
                                " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
         
     | 
| 397 | 
         
            +
                                " `prepare_attention_mask` when preparing the attention_mask."
         
     | 
| 398 | 
         
            +
                            ),
         
     | 
| 399 | 
         
            +
                        )
         
     | 
| 400 | 
         
            +
                        batch_size = 1
         
     | 
| 401 | 
         
            +
             
     | 
| 402 | 
         
            +
                    head_size = self.heads
         
     | 
| 403 | 
         
            +
                    if attention_mask is None:
         
     | 
| 404 | 
         
            +
                        return attention_mask
         
     | 
| 405 | 
         
            +
             
     | 
| 406 | 
         
            +
                    current_length: int = attention_mask.shape[-1]
         
     | 
| 407 | 
         
            +
                    if current_length != target_length:
         
     | 
| 408 | 
         
            +
                        if attention_mask.device.type == "mps":
         
     | 
| 409 | 
         
            +
                            # HACK: MPS: Does not support padding by greater than dimension of input tensor.
         
     | 
| 410 | 
         
            +
                            # Instead, we can manually construct the padding tensor.
         
     | 
| 411 | 
         
            +
                            padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
         
     | 
| 412 | 
         
            +
                            padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
         
     | 
| 413 | 
         
            +
                            attention_mask = torch.cat([attention_mask, padding], dim=2)
         
     | 
| 414 | 
         
            +
                        else:
         
     | 
| 415 | 
         
            +
                            # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
         
     | 
| 416 | 
         
            +
                            #       we want to instead pad by (0, remaining_length), where remaining_length is:
         
     | 
| 417 | 
         
            +
                            #       remaining_length: int = target_length - current_length
         
     | 
| 418 | 
         
            +
                            # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
         
     | 
| 419 | 
         
            +
                            attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
         
     | 
| 420 | 
         
            +
             
     | 
| 421 | 
         
            +
                    if out_dim == 3:
         
     | 
| 422 | 
         
            +
                        if attention_mask.shape[0] < batch_size * head_size:
         
     | 
| 423 | 
         
            +
                            attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
         
     | 
| 424 | 
         
            +
                    elif out_dim == 4:
         
     | 
| 425 | 
         
            +
                        attention_mask = attention_mask.unsqueeze(1)
         
     | 
| 426 | 
         
            +
                        attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
         
     | 
| 427 | 
         
            +
             
     | 
| 428 | 
         
            +
                    return attention_mask
         
     | 
| 429 | 
         
            +
             
     | 
| 430 | 
         
            +
                def norm_encoder_hidden_states(self, encoder_hidden_states):
         
     | 
| 431 | 
         
            +
                    assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
         
     | 
| 432 | 
         
            +
             
     | 
| 433 | 
         
            +
                    if isinstance(self.norm_cross, nn.LayerNorm):
         
     | 
| 434 | 
         
            +
                        encoder_hidden_states = self.norm_cross(encoder_hidden_states)
         
     | 
| 435 | 
         
            +
                    elif isinstance(self.norm_cross, nn.GroupNorm):
         
     | 
| 436 | 
         
            +
                        # Group norm norms along the channels dimension and expects
         
     | 
| 437 | 
         
            +
                        # input to be in the shape of (N, C, *). In this case, we want
         
     | 
| 438 | 
         
            +
                        # to norm along the hidden dimension, so we need to move
         
     | 
| 439 | 
         
            +
                        # (batch_size, sequence_length, hidden_size) ->
         
     | 
| 440 | 
         
            +
                        # (batch_size, hidden_size, sequence_length)
         
     | 
| 441 | 
         
            +
                        encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
         
     | 
| 442 | 
         
            +
                        encoder_hidden_states = self.norm_cross(encoder_hidden_states)
         
     | 
| 443 | 
         
            +
                        encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
         
     | 
| 444 | 
         
            +
                    else:
         
     | 
| 445 | 
         
            +
                        assert False
         
     | 
| 446 | 
         
            +
             
     | 
| 447 | 
         
            +
                    return encoder_hidden_states
         
     | 
| 448 | 
         
            +
             
     | 
| 449 | 
         
            +
             
     | 
| 450 | 
         
            +
            class AttnProcessor:
         
     | 
| 451 | 
         
            +
                r"""
         
     | 
| 452 | 
         
            +
                Default processor for performing attention-related computations.
         
     | 
| 453 | 
         
            +
                """
         
     | 
| 454 | 
         
            +
             
     | 
| 455 | 
         
            +
                def __call__(
         
     | 
| 456 | 
         
            +
                    self,
         
     | 
| 457 | 
         
            +
                    attn: Attention,
         
     | 
| 458 | 
         
            +
                    hidden_states,
         
     | 
| 459 | 
         
            +
                    encoder_hidden_states=None,
         
     | 
| 460 | 
         
            +
                    attention_mask=None,
         
     | 
| 461 | 
         
            +
                    temb=None,
         
     | 
| 462 | 
         
            +
                ):
         
     | 
| 463 | 
         
            +
                    residual = hidden_states
         
     | 
| 464 | 
         
            +
             
     | 
| 465 | 
         
            +
                    if attn.spatial_norm is not None:
         
     | 
| 466 | 
         
            +
                        hidden_states = attn.spatial_norm(hidden_states, temb)
         
     | 
| 467 | 
         
            +
             
     | 
| 468 | 
         
            +
                    input_ndim = hidden_states.ndim
         
     | 
| 469 | 
         
            +
             
     | 
| 470 | 
         
            +
                    if input_ndim == 4:
         
     | 
| 471 | 
         
            +
                        batch_size, channel, height, width = hidden_states.shape
         
     | 
| 472 | 
         
            +
                        hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
         
     | 
| 473 | 
         
            +
             
     | 
| 474 | 
         
            +
                    batch_size, sequence_length, _ = (
         
     | 
| 475 | 
         
            +
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
         
     | 
| 476 | 
         
            +
                    )
         
     | 
| 477 | 
         
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         
     | 
| 478 | 
         
            +
             
     | 
| 479 | 
         
            +
                    if attn.group_norm is not None:
         
     | 
| 480 | 
         
            +
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         
     | 
| 481 | 
         
            +
             
     | 
| 482 | 
         
            +
                    query = attn.to_q(hidden_states)
         
     | 
| 483 | 
         
            +
             
     | 
| 484 | 
         
            +
                    if encoder_hidden_states is None:
         
     | 
| 485 | 
         
            +
                        encoder_hidden_states = hidden_states
         
     | 
| 486 | 
         
            +
                    elif attn.norm_cross:
         
     | 
| 487 | 
         
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         
     | 
| 488 | 
         
            +
             
     | 
| 489 | 
         
            +
                    key = attn.to_k(encoder_hidden_states)
         
     | 
| 490 | 
         
            +
                    value = attn.to_v(encoder_hidden_states)
         
     | 
| 491 | 
         
            +
             
     | 
| 492 | 
         
            +
                    query = attn.head_to_batch_dim(query)
         
     | 
| 493 | 
         
            +
                    key = attn.head_to_batch_dim(key)
         
     | 
| 494 | 
         
            +
                    value = attn.head_to_batch_dim(value)
         
     | 
| 495 | 
         
            +
             
     | 
| 496 | 
         
            +
                    attention_probs = attn.get_attention_scores(query, key, attention_mask)
         
     | 
| 497 | 
         
            +
                    hidden_states = torch.bmm(attention_probs, value)
         
     | 
| 498 | 
         
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         
     | 
| 499 | 
         
            +
             
     | 
| 500 | 
         
            +
                    # linear proj
         
     | 
| 501 | 
         
            +
                    hidden_states = attn.to_out[0](hidden_states)
         
     | 
| 502 | 
         
            +
                    # dropout
         
     | 
| 503 | 
         
            +
                    hidden_states = attn.to_out[1](hidden_states)
         
     | 
| 504 | 
         
            +
             
     | 
| 505 | 
         
            +
                    if input_ndim == 4:
         
     | 
| 506 | 
         
            +
                        hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
         
     | 
| 507 | 
         
            +
             
     | 
| 508 | 
         
            +
                    if attn.residual_connection:
         
     | 
| 509 | 
         
            +
                        hidden_states = hidden_states + residual
         
     | 
| 510 | 
         
            +
             
     | 
| 511 | 
         
            +
                    hidden_states = hidden_states / attn.rescale_output_factor
         
     | 
| 512 | 
         
            +
             
     | 
| 513 | 
         
            +
                    return hidden_states
         
     | 
| 514 | 
         
            +
             
     | 
| 515 | 
         
            +
             
     | 
| 516 | 
         
            +
            class LoRALinearLayer(nn.Module):
         
     | 
| 517 | 
         
            +
                def __init__(self, in_features, out_features, rank=4, network_alpha=None):
         
     | 
| 518 | 
         
            +
                    super().__init__()
         
     | 
| 519 | 
         
            +
             
     | 
| 520 | 
         
            +
                    if rank > min(in_features, out_features):
         
     | 
| 521 | 
         
            +
                        raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
         
     | 
| 522 | 
         
            +
             
     | 
| 523 | 
         
            +
                    self.down = nn.Linear(in_features, rank, bias=False)
         
     | 
| 524 | 
         
            +
                    self.up = nn.Linear(rank, out_features, bias=False)
         
     | 
| 525 | 
         
            +
                    # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
         
     | 
| 526 | 
         
            +
                    # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
         
     | 
| 527 | 
         
            +
                    self.network_alpha = network_alpha
         
     | 
| 528 | 
         
            +
                    self.rank = rank
         
     | 
| 529 | 
         
            +
             
     | 
| 530 | 
         
            +
                    nn.init.normal_(self.down.weight, std=1 / rank)
         
     | 
| 531 | 
         
            +
                    nn.init.zeros_(self.up.weight)
         
     | 
| 532 | 
         
            +
             
     | 
| 533 | 
         
            +
                def forward(self, hidden_states):
         
     | 
| 534 | 
         
            +
                    orig_dtype = hidden_states.dtype
         
     | 
| 535 | 
         
            +
                    dtype = self.down.weight.dtype
         
     | 
| 536 | 
         
            +
             
     | 
| 537 | 
         
            +
                    down_hidden_states = self.down(hidden_states.to(dtype))
         
     | 
| 538 | 
         
            +
                    up_hidden_states = self.up(down_hidden_states)
         
     | 
| 539 | 
         
            +
             
     | 
| 540 | 
         
            +
                    if self.network_alpha is not None:
         
     | 
| 541 | 
         
            +
                        up_hidden_states *= self.network_alpha / self.rank
         
     | 
| 542 | 
         
            +
             
     | 
| 543 | 
         
            +
                    return up_hidden_states.to(orig_dtype)
         
     | 
| 544 | 
         
            +
             
     | 
| 545 | 
         
            +
             
     | 
| 546 | 
         
            +
            class LoRAAttnProcessor(nn.Module):
         
     | 
| 547 | 
         
            +
                r"""
         
     | 
| 548 | 
         
            +
                Processor for implementing the LoRA attention mechanism.
         
     | 
| 549 | 
         
            +
             
     | 
| 550 | 
         
            +
                Args:
         
     | 
| 551 | 
         
            +
                    hidden_size (`int`, *optional*):
         
     | 
| 552 | 
         
            +
                        The hidden size of the attention layer.
         
     | 
| 553 | 
         
            +
                    cross_attention_dim (`int`, *optional*):
         
     | 
| 554 | 
         
            +
                        The number of channels in the `encoder_hidden_states`.
         
     | 
| 555 | 
         
            +
                    rank (`int`, defaults to 4):
         
     | 
| 556 | 
         
            +
                        The dimension of the LoRA update matrices.
         
     | 
| 557 | 
         
            +
                    network_alpha (`int`, *optional*):
         
     | 
| 558 | 
         
            +
                        Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
         
     | 
| 559 | 
         
            +
                """
         
     | 
| 560 | 
         
            +
             
     | 
| 561 | 
         
            +
                def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
         
     | 
| 562 | 
         
            +
                    super().__init__()
         
     | 
| 563 | 
         
            +
             
     | 
| 564 | 
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 565 | 
         
            +
                    self.cross_attention_dim = cross_attention_dim
         
     | 
| 566 | 
         
            +
                    self.rank = rank
         
     | 
| 567 | 
         
            +
             
     | 
| 568 | 
         
            +
                    self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
         
     | 
| 569 | 
         
            +
                    self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
         
     | 
| 570 | 
         
            +
                    self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
         
     | 
| 571 | 
         
            +
                    self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
         
     | 
| 572 | 
         
            +
             
     | 
| 573 | 
         
            +
                def __call__(
         
     | 
| 574 | 
         
            +
                    self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
         
     | 
| 575 | 
         
            +
                ):
         
     | 
| 576 | 
         
            +
                    residual = hidden_states
         
     | 
| 577 | 
         
            +
             
     | 
| 578 | 
         
            +
                    if attn.spatial_norm is not None:
         
     | 
| 579 | 
         
            +
                        hidden_states = attn.spatial_norm(hidden_states, temb)
         
     | 
| 580 | 
         
            +
             
     | 
| 581 | 
         
            +
                    input_ndim = hidden_states.ndim
         
     | 
| 582 | 
         
            +
             
     | 
| 583 | 
         
            +
                    if input_ndim == 4:
         
     | 
| 584 | 
         
            +
                        batch_size, channel, height, width = hidden_states.shape
         
     | 
| 585 | 
         
            +
                        hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
         
     | 
| 586 | 
         
            +
             
     | 
| 587 | 
         
            +
                    batch_size, sequence_length, _ = (
         
     | 
| 588 | 
         
            +
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
         
     | 
| 589 | 
         
            +
                    )
         
     | 
| 590 | 
         
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         
     | 
| 591 | 
         
            +
             
     | 
| 592 | 
         
            +
                    if attn.group_norm is not None:
         
     | 
| 593 | 
         
            +
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         
     | 
| 594 | 
         
            +
             
     | 
| 595 | 
         
            +
                    query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
         
     | 
| 596 | 
         
            +
                    query = attn.head_to_batch_dim(query)
         
     | 
| 597 | 
         
            +
             
     | 
| 598 | 
         
            +
                    if encoder_hidden_states is None:
         
     | 
| 599 | 
         
            +
                        encoder_hidden_states = hidden_states
         
     | 
| 600 | 
         
            +
                    elif attn.norm_cross:
         
     | 
| 601 | 
         
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         
     | 
| 602 | 
         
            +
             
     | 
| 603 | 
         
            +
                    key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
         
     | 
| 604 | 
         
            +
                    value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
         
     | 
| 605 | 
         
            +
             
     | 
| 606 | 
         
            +
                    key = attn.head_to_batch_dim(key)
         
     | 
| 607 | 
         
            +
                    value = attn.head_to_batch_dim(value)
         
     | 
| 608 | 
         
            +
             
     | 
| 609 | 
         
            +
                    attention_probs = attn.get_attention_scores(query, key, attention_mask)
         
     | 
| 610 | 
         
            +
                    hidden_states = torch.bmm(attention_probs, value)
         
     | 
| 611 | 
         
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         
     | 
| 612 | 
         
            +
             
     | 
| 613 | 
         
            +
                    # linear proj
         
     | 
| 614 | 
         
            +
                    hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
         
     | 
| 615 | 
         
            +
                    # dropout
         
     | 
| 616 | 
         
            +
                    hidden_states = attn.to_out[1](hidden_states)
         
     | 
| 617 | 
         
            +
             
     | 
| 618 | 
         
            +
                    if input_ndim == 4:
         
     | 
| 619 | 
         
            +
                        hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
         
     | 
| 620 | 
         
            +
             
     | 
| 621 | 
         
            +
                    if attn.residual_connection:
         
     | 
| 622 | 
         
            +
                        hidden_states = hidden_states + residual
         
     | 
| 623 | 
         
            +
             
     | 
| 624 | 
         
            +
                    hidden_states = hidden_states / attn.rescale_output_factor
         
     | 
| 625 | 
         
            +
             
     | 
| 626 | 
         
            +
                    return hidden_states
         
     | 
| 627 | 
         
            +
             
     | 
| 628 | 
         
            +
             
     | 
| 629 | 
         
            +
            class CustomDiffusionAttnProcessor(nn.Module):
         
     | 
| 630 | 
         
            +
                r"""
         
     | 
| 631 | 
         
            +
                Processor for implementing attention for the Custom Diffusion method.
         
     | 
| 632 | 
         
            +
             
     | 
| 633 | 
         
            +
                Args:
         
     | 
| 634 | 
         
            +
                    train_kv (`bool`, defaults to `True`):
         
     | 
| 635 | 
         
            +
                        Whether to newly train the key and value matrices corresponding to the text features.
         
     | 
| 636 | 
         
            +
                    train_q_out (`bool`, defaults to `True`):
         
     | 
| 637 | 
         
            +
                        Whether to newly train query matrices corresponding to the latent image features.
         
     | 
| 638 | 
         
            +
                    hidden_size (`int`, *optional*, defaults to `None`):
         
     | 
| 639 | 
         
            +
                        The hidden size of the attention layer.
         
     | 
| 640 | 
         
            +
                    cross_attention_dim (`int`, *optional*, defaults to `None`):
         
     | 
| 641 | 
         
            +
                        The number of channels in the `encoder_hidden_states`.
         
     | 
| 642 | 
         
            +
                    out_bias (`bool`, defaults to `True`):
         
     | 
| 643 | 
         
            +
                        Whether to include the bias parameter in `train_q_out`.
         
     | 
| 644 | 
         
            +
                    dropout (`float`, *optional*, defaults to 0.0):
         
     | 
| 645 | 
         
            +
                        The dropout probability to use.
         
     | 
| 646 | 
         
            +
                """
         
     | 
| 647 | 
         
            +
             
     | 
| 648 | 
         
            +
                def __init__(
         
     | 
| 649 | 
         
            +
                    self,
         
     | 
| 650 | 
         
            +
                    train_kv=True,
         
     | 
| 651 | 
         
            +
                    train_q_out=True,
         
     | 
| 652 | 
         
            +
                    hidden_size=None,
         
     | 
| 653 | 
         
            +
                    cross_attention_dim=None,
         
     | 
| 654 | 
         
            +
                    out_bias=True,
         
     | 
| 655 | 
         
            +
                    dropout=0.0,
         
     | 
| 656 | 
         
            +
                ):
         
     | 
| 657 | 
         
            +
                    super().__init__()
         
     | 
| 658 | 
         
            +
                    self.train_kv = train_kv
         
     | 
| 659 | 
         
            +
                    self.train_q_out = train_q_out
         
     | 
| 660 | 
         
            +
             
     | 
| 661 | 
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 662 | 
         
            +
                    self.cross_attention_dim = cross_attention_dim
         
     | 
| 663 | 
         
            +
             
     | 
| 664 | 
         
            +
                    # `_custom_diffusion` id for easy serialization and loading.
         
     | 
| 665 | 
         
            +
                    if self.train_kv:
         
     | 
| 666 | 
         
            +
                        self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
         
     | 
| 667 | 
         
            +
                        self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
         
     | 
| 668 | 
         
            +
                    if self.train_q_out:
         
     | 
| 669 | 
         
            +
                        self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
         
     | 
| 670 | 
         
            +
                        self.to_out_custom_diffusion = nn.ModuleList([])
         
     | 
| 671 | 
         
            +
                        self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
         
     | 
| 672 | 
         
            +
                        self.to_out_custom_diffusion.append(nn.Dropout(dropout))
         
     | 
| 673 | 
         
            +
             
     | 
| 674 | 
         
            +
                def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
         
     | 
| 675 | 
         
            +
                    batch_size, sequence_length, _ = hidden_states.shape
         
     | 
| 676 | 
         
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         
     | 
| 677 | 
         
            +
                    if self.train_q_out:
         
     | 
| 678 | 
         
            +
                        query = self.to_q_custom_diffusion(hidden_states)
         
     | 
| 679 | 
         
            +
                    else:
         
     | 
| 680 | 
         
            +
                        query = attn.to_q(hidden_states)
         
     | 
| 681 | 
         
            +
             
     | 
| 682 | 
         
            +
                    if encoder_hidden_states is None:
         
     | 
| 683 | 
         
            +
                        crossattn = False
         
     | 
| 684 | 
         
            +
                        encoder_hidden_states = hidden_states
         
     | 
| 685 | 
         
            +
                    else:
         
     | 
| 686 | 
         
            +
                        crossattn = True
         
     | 
| 687 | 
         
            +
                        if attn.norm_cross:
         
     | 
| 688 | 
         
            +
                            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         
     | 
| 689 | 
         
            +
             
     | 
| 690 | 
         
            +
                    if self.train_kv:
         
     | 
| 691 | 
         
            +
                        key = self.to_k_custom_diffusion(encoder_hidden_states)
         
     | 
| 692 | 
         
            +
                        value = self.to_v_custom_diffusion(encoder_hidden_states)
         
     | 
| 693 | 
         
            +
                    else:
         
     | 
| 694 | 
         
            +
                        key = attn.to_k(encoder_hidden_states)
         
     | 
| 695 | 
         
            +
                        value = attn.to_v(encoder_hidden_states)
         
     | 
| 696 | 
         
            +
             
     | 
| 697 | 
         
            +
                    if crossattn:
         
     | 
| 698 | 
         
            +
                        detach = torch.ones_like(key)
         
     | 
| 699 | 
         
            +
                        detach[:, :1, :] = detach[:, :1, :] * 0.0
         
     | 
| 700 | 
         
            +
                        key = detach * key + (1 - detach) * key.detach()
         
     | 
| 701 | 
         
            +
                        value = detach * value + (1 - detach) * value.detach()
         
     | 
| 702 | 
         
            +
             
     | 
| 703 | 
         
            +
                    query = attn.head_to_batch_dim(query)
         
     | 
| 704 | 
         
            +
                    key = attn.head_to_batch_dim(key)
         
     | 
| 705 | 
         
            +
                    value = attn.head_to_batch_dim(value)
         
     | 
| 706 | 
         
            +
             
     | 
| 707 | 
         
            +
                    attention_probs = attn.get_attention_scores(query, key, attention_mask)
         
     | 
| 708 | 
         
            +
                    hidden_states = torch.bmm(attention_probs, value)
         
     | 
| 709 | 
         
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         
     | 
| 710 | 
         
            +
             
     | 
| 711 | 
         
            +
                    if self.train_q_out:
         
     | 
| 712 | 
         
            +
                        # linear proj
         
     | 
| 713 | 
         
            +
                        hidden_states = self.to_out_custom_diffusion[0](hidden_states)
         
     | 
| 714 | 
         
            +
                        # dropout
         
     | 
| 715 | 
         
            +
                        hidden_states = self.to_out_custom_diffusion[1](hidden_states)
         
     | 
| 716 | 
         
            +
                    else:
         
     | 
| 717 | 
         
            +
                        # linear proj
         
     | 
| 718 | 
         
            +
                        hidden_states = attn.to_out[0](hidden_states)
         
     | 
| 719 | 
         
            +
                        # dropout
         
     | 
| 720 | 
         
            +
                        hidden_states = attn.to_out[1](hidden_states)
         
     | 
| 721 | 
         
            +
             
     | 
| 722 | 
         
            +
                    return hidden_states
         
     | 
| 723 | 
         
            +
             
     | 
| 724 | 
         
            +
             
     | 
| 725 | 
         
            +
            class AttnAddedKVProcessor:
         
     | 
| 726 | 
         
            +
                r"""
         
     | 
| 727 | 
         
            +
                Processor for performing attention-related computations with extra learnable key and value matrices for the text
         
     | 
| 728 | 
         
            +
                encoder.
         
     | 
| 729 | 
         
            +
                """
         
     | 
| 730 | 
         
            +
             
     | 
| 731 | 
         
            +
                def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
         
     | 
| 732 | 
         
            +
                    residual = hidden_states
         
     | 
| 733 | 
         
            +
                    hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
         
     | 
| 734 | 
         
            +
                    batch_size, sequence_length, _ = hidden_states.shape
         
     | 
| 735 | 
         
            +
             
     | 
| 736 | 
         
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         
     | 
| 737 | 
         
            +
             
     | 
| 738 | 
         
            +
                    if encoder_hidden_states is None:
         
     | 
| 739 | 
         
            +
                        encoder_hidden_states = hidden_states
         
     | 
| 740 | 
         
            +
                    elif attn.norm_cross:
         
     | 
| 741 | 
         
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         
     | 
| 742 | 
         
            +
             
     | 
| 743 | 
         
            +
                    hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         
     | 
| 744 | 
         
            +
             
     | 
| 745 | 
         
            +
                    query = attn.to_q(hidden_states)
         
     | 
| 746 | 
         
            +
                    query = attn.head_to_batch_dim(query)
         
     | 
| 747 | 
         
            +
             
     | 
| 748 | 
         
            +
                    encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
         
     | 
| 749 | 
         
            +
                    encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
         
     | 
| 750 | 
         
            +
                    encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
         
     | 
| 751 | 
         
            +
                    encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
         
     | 
| 752 | 
         
            +
             
     | 
| 753 | 
         
            +
                    if not attn.only_cross_attention:
         
     | 
| 754 | 
         
            +
                        key = attn.to_k(hidden_states)
         
     | 
| 755 | 
         
            +
                        value = attn.to_v(hidden_states)
         
     | 
| 756 | 
         
            +
                        key = attn.head_to_batch_dim(key)
         
     | 
| 757 | 
         
            +
                        value = attn.head_to_batch_dim(value)
         
     | 
| 758 | 
         
            +
                        key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
         
     | 
| 759 | 
         
            +
                        value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
         
     | 
| 760 | 
         
            +
                    else:
         
     | 
| 761 | 
         
            +
                        key = encoder_hidden_states_key_proj
         
     | 
| 762 | 
         
            +
                        value = encoder_hidden_states_value_proj
         
     | 
| 763 | 
         
            +
             
     | 
| 764 | 
         
            +
                    attention_probs = attn.get_attention_scores(query, key, attention_mask)
         
     | 
| 765 | 
         
            +
                    hidden_states = torch.bmm(attention_probs, value)
         
     | 
| 766 | 
         
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         
     | 
| 767 | 
         
            +
             
     | 
| 768 | 
         
            +
                    # linear proj
         
     | 
| 769 | 
         
            +
                    hidden_states = attn.to_out[0](hidden_states)
         
     | 
| 770 | 
         
            +
                    # dropout
         
     | 
| 771 | 
         
            +
                    hidden_states = attn.to_out[1](hidden_states)
         
     | 
| 772 | 
         
            +
             
     | 
| 773 | 
         
            +
                    hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
         
     | 
| 774 | 
         
            +
                    hidden_states = hidden_states + residual
         
     | 
| 775 | 
         
            +
             
     | 
| 776 | 
         
            +
                    return hidden_states
         
     | 
| 777 | 
         
            +
             
     | 
| 778 | 
         
            +
             
     | 
| 779 | 
         
            +
            class AttnAddedKVProcessor2_0:
         
     | 
| 780 | 
         
            +
                r"""
         
     | 
| 781 | 
         
            +
                Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
         
     | 
| 782 | 
         
            +
                learnable key and value matrices for the text encoder.
         
     | 
| 783 | 
         
            +
                """
         
     | 
| 784 | 
         
            +
             
     | 
| 785 | 
         
            +
                def __init__(self):
         
     | 
| 786 | 
         
            +
                    if not hasattr(F, "scaled_dot_product_attention"):
         
     | 
| 787 | 
         
            +
                        raise ImportError(
         
     | 
| 788 | 
         
            +
                            "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
         
     | 
| 789 | 
         
            +
                        )
         
     | 
| 790 | 
         
            +
             
     | 
| 791 | 
         
            +
                def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
         
     | 
| 792 | 
         
            +
                    residual = hidden_states
         
     | 
| 793 | 
         
            +
                    hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
         
     | 
| 794 | 
         
            +
                    batch_size, sequence_length, _ = hidden_states.shape
         
     | 
| 795 | 
         
            +
             
     | 
| 796 | 
         
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
         
     | 
| 797 | 
         
            +
             
     | 
| 798 | 
         
            +
                    if encoder_hidden_states is None:
         
     | 
| 799 | 
         
            +
                        encoder_hidden_states = hidden_states
         
     | 
| 800 | 
         
            +
                    elif attn.norm_cross:
         
     | 
| 801 | 
         
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         
     | 
| 802 | 
         
            +
             
     | 
| 803 | 
         
            +
                    hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         
     | 
| 804 | 
         
            +
             
     | 
| 805 | 
         
            +
                    query = attn.to_q(hidden_states)
         
     | 
| 806 | 
         
            +
                    query = attn.head_to_batch_dim(query, out_dim=4)
         
     | 
| 807 | 
         
            +
             
     | 
| 808 | 
         
            +
                    encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
         
     | 
| 809 | 
         
            +
                    encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
         
     | 
| 810 | 
         
            +
                    encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
         
     | 
| 811 | 
         
            +
                    encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
         
     | 
| 812 | 
         
            +
             
     | 
| 813 | 
         
            +
                    if not attn.only_cross_attention:
         
     | 
| 814 | 
         
            +
                        key = attn.to_k(hidden_states)
         
     | 
| 815 | 
         
            +
                        value = attn.to_v(hidden_states)
         
     | 
| 816 | 
         
            +
                        key = attn.head_to_batch_dim(key, out_dim=4)
         
     | 
| 817 | 
         
            +
                        value = attn.head_to_batch_dim(value, out_dim=4)
         
     | 
| 818 | 
         
            +
                        key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
         
     | 
| 819 | 
         
            +
                        value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
         
     | 
| 820 | 
         
            +
                    else:
         
     | 
| 821 | 
         
            +
                        key = encoder_hidden_states_key_proj
         
     | 
| 822 | 
         
            +
                        value = encoder_hidden_states_value_proj
         
     | 
| 823 | 
         
            +
             
     | 
| 824 | 
         
            +
                    # the output of sdp = (batch, num_heads, seq_len, head_dim)
         
     | 
| 825 | 
         
            +
                    # TODO: add support for attn.scale when we move to Torch 2.1
         
     | 
| 826 | 
         
            +
                    hidden_states = F.scaled_dot_product_attention(
         
     | 
| 827 | 
         
            +
                        query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
         
     | 
| 828 | 
         
            +
                    )
         
     | 
| 829 | 
         
            +
                    hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
         
     | 
| 830 | 
         
            +
             
     | 
| 831 | 
         
            +
                    # linear proj
         
     | 
| 832 | 
         
            +
                    hidden_states = attn.to_out[0](hidden_states)
         
     | 
| 833 | 
         
            +
                    # dropout
         
     | 
| 834 | 
         
            +
                    hidden_states = attn.to_out[1](hidden_states)
         
     | 
| 835 | 
         
            +
             
     | 
| 836 | 
         
            +
                    hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
         
     | 
| 837 | 
         
            +
                    hidden_states = hidden_states + residual
         
     | 
| 838 | 
         
            +
             
     | 
| 839 | 
         
            +
                    return hidden_states
         
     | 
| 840 | 
         
            +
             
     | 
| 841 | 
         
            +
             
     | 
| 842 | 
         
            +
            class LoRAAttnAddedKVProcessor(nn.Module):
         
     | 
| 843 | 
         
            +
                r"""
         
     | 
| 844 | 
         
            +
                Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
         
     | 
| 845 | 
         
            +
                encoder.
         
     | 
| 846 | 
         
            +
             
     | 
| 847 | 
         
            +
                Args:
         
     | 
| 848 | 
         
            +
                    hidden_size (`int`, *optional*):
         
     | 
| 849 | 
         
            +
                        The hidden size of the attention layer.
         
     | 
| 850 | 
         
            +
                    cross_attention_dim (`int`, *optional*, defaults to `None`):
         
     | 
| 851 | 
         
            +
                        The number of channels in the `encoder_hidden_states`.
         
     | 
| 852 | 
         
            +
                    rank (`int`, defaults to 4):
         
     | 
| 853 | 
         
            +
                        The dimension of the LoRA update matrices.
         
     | 
| 854 | 
         
            +
             
     | 
| 855 | 
         
            +
                """
         
     | 
| 856 | 
         
            +
             
     | 
| 857 | 
         
            +
                def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
         
     | 
| 858 | 
         
            +
                    super().__init__()
         
     | 
| 859 | 
         
            +
             
     | 
| 860 | 
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 861 | 
         
            +
                    self.cross_attention_dim = cross_attention_dim
         
     | 
| 862 | 
         
            +
                    self.rank = rank
         
     | 
| 863 | 
         
            +
             
     | 
| 864 | 
         
            +
                    self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
         
     | 
| 865 | 
         
            +
                    self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
         
     | 
| 866 | 
         
            +
                    self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
         
     | 
| 867 | 
         
            +
                    self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
         
     | 
| 868 | 
         
            +
                    self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
         
     | 
| 869 | 
         
            +
                    self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
         
     | 
| 870 | 
         
            +
             
     | 
| 871 | 
         
            +
                def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
         
     | 
| 872 | 
         
            +
                    residual = hidden_states
         
     | 
| 873 | 
         
            +
                    hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
         
     | 
| 874 | 
         
            +
                    batch_size, sequence_length, _ = hidden_states.shape
         
     | 
| 875 | 
         
            +
             
     | 
| 876 | 
         
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         
     | 
| 877 | 
         
            +
             
     | 
| 878 | 
         
            +
                    if encoder_hidden_states is None:
         
     | 
| 879 | 
         
            +
                        encoder_hidden_states = hidden_states
         
     | 
| 880 | 
         
            +
                    elif attn.norm_cross:
         
     | 
| 881 | 
         
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         
     | 
| 882 | 
         
            +
             
     | 
| 883 | 
         
            +
                    hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         
     | 
| 884 | 
         
            +
             
     | 
| 885 | 
         
            +
                    query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
         
     | 
| 886 | 
         
            +
                    query = attn.head_to_batch_dim(query)
         
     | 
| 887 | 
         
            +
             
     | 
| 888 | 
         
            +
                    encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + scale * self.add_k_proj_lora(
         
     | 
| 889 | 
         
            +
                        encoder_hidden_states
         
     | 
| 890 | 
         
            +
                    )
         
     | 
| 891 | 
         
            +
                    encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + scale * self.add_v_proj_lora(
         
     | 
| 892 | 
         
            +
                        encoder_hidden_states
         
     | 
| 893 | 
         
            +
                    )
         
     | 
| 894 | 
         
            +
                    encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
         
     | 
| 895 | 
         
            +
                    encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
         
     | 
| 896 | 
         
            +
             
     | 
| 897 | 
         
            +
                    if not attn.only_cross_attention:
         
     | 
| 898 | 
         
            +
                        key = attn.to_k(hidden_states) + scale * self.to_k_lora(hidden_states)
         
     | 
| 899 | 
         
            +
                        value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states)
         
     | 
| 900 | 
         
            +
                        key = attn.head_to_batch_dim(key)
         
     | 
| 901 | 
         
            +
                        value = attn.head_to_batch_dim(value)
         
     | 
| 902 | 
         
            +
                        key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
         
     | 
| 903 | 
         
            +
                        value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
         
     | 
| 904 | 
         
            +
                    else:
         
     | 
| 905 | 
         
            +
                        key = encoder_hidden_states_key_proj
         
     | 
| 906 | 
         
            +
                        value = encoder_hidden_states_value_proj
         
     | 
| 907 | 
         
            +
             
     | 
| 908 | 
         
            +
                    attention_probs = attn.get_attention_scores(query, key, attention_mask)
         
     | 
| 909 | 
         
            +
                    hidden_states = torch.bmm(attention_probs, value)
         
     | 
| 910 | 
         
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         
     | 
| 911 | 
         
            +
             
     | 
| 912 | 
         
            +
                    # linear proj
         
     | 
| 913 | 
         
            +
                    hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
         
     | 
| 914 | 
         
            +
                    # dropout
         
     | 
| 915 | 
         
            +
                    hidden_states = attn.to_out[1](hidden_states)
         
     | 
| 916 | 
         
            +
             
     | 
| 917 | 
         
            +
                    hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
         
     | 
| 918 | 
         
            +
                    hidden_states = hidden_states + residual
         
     | 
| 919 | 
         
            +
             
     | 
| 920 | 
         
            +
                    return hidden_states
         
     | 
| 921 | 
         
            +
             
     | 
| 922 | 
         
            +
             
     | 
| 923 | 
         
            +
            class XFormersAttnAddedKVProcessor:
         
     | 
| 924 | 
         
            +
                r"""
         
     | 
| 925 | 
         
            +
                Processor for implementing memory efficient attention using xFormers.
         
     | 
| 926 | 
         
            +
             
     | 
| 927 | 
         
            +
                Args:
         
     | 
| 928 | 
         
            +
                    attention_op (`Callable`, *optional*, defaults to `None`):
         
     | 
| 929 | 
         
            +
                        The base
         
     | 
| 930 | 
         
            +
                        [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
         
     | 
| 931 | 
         
            +
                        use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
         
     | 
| 932 | 
         
            +
                        operator.
         
     | 
| 933 | 
         
            +
                """
         
     | 
| 934 | 
         
            +
             
     | 
| 935 | 
         
            +
                def __init__(self, attention_op: Optional[Callable] = None):
         
     | 
| 936 | 
         
            +
                    self.attention_op = attention_op
         
     | 
| 937 | 
         
            +
             
     | 
| 938 | 
         
            +
                def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
         
     | 
| 939 | 
         
            +
                    residual = hidden_states
         
     | 
| 940 | 
         
            +
                    hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
         
     | 
| 941 | 
         
            +
                    batch_size, sequence_length, _ = hidden_states.shape
         
     | 
| 942 | 
         
            +
             
     | 
| 943 | 
         
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         
     | 
| 944 | 
         
            +
             
     | 
| 945 | 
         
            +
                    if encoder_hidden_states is None:
         
     | 
| 946 | 
         
            +
                        encoder_hidden_states = hidden_states
         
     | 
| 947 | 
         
            +
                    elif attn.norm_cross:
         
     | 
| 948 | 
         
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         
     | 
| 949 | 
         
            +
             
     | 
| 950 | 
         
            +
                    hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         
     | 
| 951 | 
         
            +
             
     | 
| 952 | 
         
            +
                    query = attn.to_q(hidden_states)
         
     | 
| 953 | 
         
            +
                    query = attn.head_to_batch_dim(query)
         
     | 
| 954 | 
         
            +
             
     | 
| 955 | 
         
            +
                    encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
         
     | 
| 956 | 
         
            +
                    encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
         
     | 
| 957 | 
         
            +
                    encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
         
     | 
| 958 | 
         
            +
                    encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
         
     | 
| 959 | 
         
            +
             
     | 
| 960 | 
         
            +
                    if not attn.only_cross_attention:
         
     | 
| 961 | 
         
            +
                        key = attn.to_k(hidden_states)
         
     | 
| 962 | 
         
            +
                        value = attn.to_v(hidden_states)
         
     | 
| 963 | 
         
            +
                        key = attn.head_to_batch_dim(key)
         
     | 
| 964 | 
         
            +
                        value = attn.head_to_batch_dim(value)
         
     | 
| 965 | 
         
            +
                        key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
         
     | 
| 966 | 
         
            +
                        value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
         
     | 
| 967 | 
         
            +
                    else:
         
     | 
| 968 | 
         
            +
                        key = encoder_hidden_states_key_proj
         
     | 
| 969 | 
         
            +
                        value = encoder_hidden_states_value_proj
         
     | 
| 970 | 
         
            +
             
     | 
| 971 | 
         
            +
                    hidden_states = xformers.ops.memory_efficient_attention(
         
     | 
| 972 | 
         
            +
                        query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
         
     | 
| 973 | 
         
            +
                    )
         
     | 
| 974 | 
         
            +
                    hidden_states = hidden_states.to(query.dtype)
         
     | 
| 975 | 
         
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         
     | 
| 976 | 
         
            +
             
     | 
| 977 | 
         
            +
                    # linear proj
         
     | 
| 978 | 
         
            +
                    hidden_states = attn.to_out[0](hidden_states)
         
     | 
| 979 | 
         
            +
                    # dropout
         
     | 
| 980 | 
         
            +
                    hidden_states = attn.to_out[1](hidden_states)
         
     | 
| 981 | 
         
            +
             
     | 
| 982 | 
         
            +
                    hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
         
     | 
| 983 | 
         
            +
                    hidden_states = hidden_states + residual
         
     | 
| 984 | 
         
            +
             
     | 
| 985 | 
         
            +
                    return hidden_states
         
     | 
| 986 | 
         
            +
             
     | 
| 987 | 
         
            +
             
     | 
| 988 | 
         
            +
            class XFormersAttnProcessor:
         
     | 
| 989 | 
         
            +
                r"""
         
     | 
| 990 | 
         
            +
                Processor for implementing memory efficient attention using xFormers.
         
     | 
| 991 | 
         
            +
             
     | 
| 992 | 
         
            +
                Args:
         
     | 
| 993 | 
         
            +
                    attention_op (`Callable`, *optional*, defaults to `None`):
         
     | 
| 994 | 
         
            +
                        The base
         
     | 
| 995 | 
         
            +
                        [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
         
     | 
| 996 | 
         
            +
                        use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
         
     | 
| 997 | 
         
            +
                        operator.
         
     | 
| 998 | 
         
            +
                """
         
     | 
| 999 | 
         
            +
             
     | 
| 1000 | 
         
            +
                def __init__(self, attention_op: Optional[Callable] = None):
         
     | 
| 1001 | 
         
            +
                    self.attention_op = attention_op
         
     | 
| 1002 | 
         
            +
             
     | 
| 1003 | 
         
            +
                def __call__(
         
     | 
| 1004 | 
         
            +
                    self,
         
     | 
| 1005 | 
         
            +
                    attn: Attention,
         
     | 
| 1006 | 
         
            +
                    hidden_states: torch.FloatTensor,
         
     | 
| 1007 | 
         
            +
                    encoder_hidden_states: Optional[torch.FloatTensor] = None,
         
     | 
| 1008 | 
         
            +
                    attention_mask: Optional[torch.FloatTensor] = None,
         
     | 
| 1009 | 
         
            +
                    temb: Optional[torch.FloatTensor] = None,
         
     | 
| 1010 | 
         
            +
                    posemb: Optional = None,
         
     | 
| 1011 | 
         
            +
                ):
         
     | 
| 1012 | 
         
            +
                    residual = hidden_states
         
     | 
| 1013 | 
         
            +
             
     | 
| 1014 | 
         
            +
                    if attn.spatial_norm is not None:
         
     | 
| 1015 | 
         
            +
                        hidden_states = attn.spatial_norm(hidden_states, temb)
         
     | 
| 1016 | 
         
            +
             
     | 
| 1017 | 
         
            +
                    input_ndim = hidden_states.ndim
         
     | 
| 1018 | 
         
            +
             
     | 
| 1019 | 
         
            +
                    if input_ndim == 4:
         
     | 
| 1020 | 
         
            +
                        batch_size, channel, height, width = hidden_states.shape
         
     | 
| 1021 | 
         
            +
                        hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
         
     | 
| 1022 | 
         
            +
             
     | 
| 1023 | 
         
            +
                    if posemb is not None:
         
     | 
| 1024 | 
         
            +
                        # turn 2d attention into multiview attention
         
     | 
| 1025 | 
         
            +
                        self_attn = encoder_hidden_states is None  # check if self attn or cross attn
         
     | 
| 1026 | 
         
            +
                        [p_out, p_out_inv], [p_in, p_in_inv] = posemb
         
     | 
| 1027 | 
         
            +
                        t_out, t_in = p_out.shape[1], p_in.shape[1]  # t size
         
     | 
| 1028 | 
         
            +
                        hidden_states = einops.rearrange(hidden_states, '(b t_out) l d -> b (t_out l) d', t_out=t_out)
         
     | 
| 1029 | 
         
            +
             
     | 
| 1030 | 
         
            +
                    batch_size, key_tokens, _ = (
         
     | 
| 1031 | 
         
            +
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
         
     | 
| 1032 | 
         
            +
                    )
         
     | 
| 1033 | 
         
            +
             
     | 
| 1034 | 
         
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
         
     | 
| 1035 | 
         
            +
                    if attention_mask is not None:
         
     | 
| 1036 | 
         
            +
                        # expand our mask's singleton query_tokens dimension:
         
     | 
| 1037 | 
         
            +
                        #   [batch*heads,            1, key_tokens] ->
         
     | 
| 1038 | 
         
            +
                        #   [batch*heads, query_tokens, key_tokens]
         
     | 
| 1039 | 
         
            +
                        # so that it can be added as a bias onto the attention scores that xformers computes:
         
     | 
| 1040 | 
         
            +
                        #   [batch*heads, query_tokens, key_tokens]
         
     | 
| 1041 | 
         
            +
                        # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
         
     | 
| 1042 | 
         
            +
                        _, query_tokens, _ = hidden_states.shape
         
     | 
| 1043 | 
         
            +
                        attention_mask = attention_mask.expand(-1, query_tokens, -1)
         
     | 
| 1044 | 
         
            +
             
     | 
| 1045 | 
         
            +
                    if attn.group_norm is not None:
         
     | 
| 1046 | 
         
            +
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         
     | 
| 1047 | 
         
            +
             
     | 
| 1048 | 
         
            +
                    query = attn.to_q(hidden_states)
         
     | 
| 1049 | 
         
            +
                    if encoder_hidden_states is None:
         
     | 
| 1050 | 
         
            +
                        encoder_hidden_states = hidden_states
         
     | 
| 1051 | 
         
            +
                    elif attn.norm_cross:
         
     | 
| 1052 | 
         
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         
     | 
| 1053 | 
         
            +
             
     | 
| 1054 | 
         
            +
                    key = attn.to_k(encoder_hidden_states)
         
     | 
| 1055 | 
         
            +
                    value = attn.to_v(encoder_hidden_states)
         
     | 
| 1056 | 
         
            +
             
     | 
| 1057 | 
         
            +
             
     | 
| 1058 | 
         
            +
                    # apply 6DoF, todo now only for xformer processor
         
     | 
| 1059 | 
         
            +
                    if posemb is not None:
         
     | 
| 1060 | 
         
            +
                        p_out_inv = einops.repeat(p_out_inv, 'b t_out f g -> b (t_out l) f g', l=query.shape[1] // t_out)  # query shape
         
     | 
| 1061 | 
         
            +
                        if self_attn:
         
     | 
| 1062 | 
         
            +
                            p_in = einops.repeat(p_out, 'b t_out f g -> b (t_out l) f g', l=query.shape[1] // t_out)  # query shape
         
     | 
| 1063 | 
         
            +
                        else:
         
     | 
| 1064 | 
         
            +
                            p_in = einops.repeat(p_in, 'b t_in f g -> b (t_in l) f g', l=key.shape[1] // t_in)  # key shape
         
     | 
| 1065 | 
         
            +
                        query = cape_embed(query, p_out_inv)  # query f_q @ (p_out)^(-T) .permute(0, 1, 3, 2)
         
     | 
| 1066 | 
         
            +
                        key = cape_embed(key, p_in)  # key f_k @ p_in
         
     | 
| 1067 | 
         
            +
             
     | 
| 1068 | 
         
            +
             
     | 
| 1069 | 
         
            +
                    query = attn.head_to_batch_dim(query).contiguous()
         
     | 
| 1070 | 
         
            +
                    key = attn.head_to_batch_dim(key).contiguous()
         
     | 
| 1071 | 
         
            +
                    value = attn.head_to_batch_dim(value).contiguous()
         
     | 
| 1072 | 
         
            +
             
     | 
| 1073 | 
         
            +
                    # self-ttn (bm) l c  x  (bm) l c -> (bm) l c
         
     | 
| 1074 | 
         
            +
                    # cross-ttn (bm) l c  x  b (nl) c -> (bm) l c
         
     | 
| 1075 | 
         
            +
                    # reuse 2d attention for multiview attention
         
     | 
| 1076 | 
         
            +
                    # self-ttn b (ml) c  x  b (ml) c -> b (ml) c
         
     | 
| 1077 | 
         
            +
                    # cross-ttn  b (ml) c  x  b (nl) c -> b (ml) c
         
     | 
| 1078 | 
         
            +
                    hidden_states = xformers.ops.memory_efficient_attention(    # query: (bm) l c -> b (ml) c;  key: b (nl) c
         
     | 
| 1079 | 
         
            +
                        query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
         
     | 
| 1080 | 
         
            +
                    )
         
     | 
| 1081 | 
         
            +
                    hidden_states = hidden_states.to(query.dtype)
         
     | 
| 1082 | 
         
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         
     | 
| 1083 | 
         
            +
             
     | 
| 1084 | 
         
            +
                    # linear proj
         
     | 
| 1085 | 
         
            +
                    hidden_states = attn.to_out[0](hidden_states)
         
     | 
| 1086 | 
         
            +
                    # dropout
         
     | 
| 1087 | 
         
            +
                    hidden_states = attn.to_out[1](hidden_states)
         
     | 
| 1088 | 
         
            +
             
     | 
| 1089 | 
         
            +
                    if posemb is not None:
         
     | 
| 1090 | 
         
            +
                        # reshape back
         
     | 
| 1091 | 
         
            +
                        hidden_states = einops.rearrange(hidden_states, 'b (t_out l) d -> (b t_out) l d', t_out=t_out)
         
     | 
| 1092 | 
         
            +
             
     | 
| 1093 | 
         
            +
                    if input_ndim == 4:
         
     | 
| 1094 | 
         
            +
                        hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
         
     | 
| 1095 | 
         
            +
             
     | 
| 1096 | 
         
            +
                    if attn.residual_connection:
         
     | 
| 1097 | 
         
            +
                        hidden_states = hidden_states + residual
         
     | 
| 1098 | 
         
            +
             
     | 
| 1099 | 
         
            +
                    hidden_states = hidden_states / attn.rescale_output_factor
         
     | 
| 1100 | 
         
            +
             
     | 
| 1101 | 
         
            +
             
     | 
| 1102 | 
         
            +
                    return hidden_states
         
     | 
| 1103 | 
         
            +
             
     | 
| 1104 | 
         
            +
             
     | 
| 1105 | 
         
            +
            class AttnProcessor2_0:
         
     | 
| 1106 | 
         
            +
                r"""
         
     | 
| 1107 | 
         
            +
                Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
         
     | 
| 1108 | 
         
            +
                """
         
     | 
| 1109 | 
         
            +
             
     | 
| 1110 | 
         
            +
                def __init__(self):
         
     | 
| 1111 | 
         
            +
                    if not hasattr(F, "scaled_dot_product_attention"):
         
     | 
| 1112 | 
         
            +
                        raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
         
     | 
| 1113 | 
         
            +
             
     | 
| 1114 | 
         
            +
                def __call__(
         
     | 
| 1115 | 
         
            +
                    self,
         
     | 
| 1116 | 
         
            +
                    attn: Attention,
         
     | 
| 1117 | 
         
            +
                    hidden_states,
         
     | 
| 1118 | 
         
            +
                    encoder_hidden_states=None,
         
     | 
| 1119 | 
         
            +
                    attention_mask=None,
         
     | 
| 1120 | 
         
            +
                    temb=None,
         
     | 
| 1121 | 
         
            +
                ):
         
     | 
| 1122 | 
         
            +
                    residual = hidden_states
         
     | 
| 1123 | 
         
            +
             
     | 
| 1124 | 
         
            +
                    if attn.spatial_norm is not None:
         
     | 
| 1125 | 
         
            +
                        hidden_states = attn.spatial_norm(hidden_states, temb)
         
     | 
| 1126 | 
         
            +
             
     | 
| 1127 | 
         
            +
                    input_ndim = hidden_states.ndim
         
     | 
| 1128 | 
         
            +
             
     | 
| 1129 | 
         
            +
                    if input_ndim == 4:
         
     | 
| 1130 | 
         
            +
                        batch_size, channel, height, width = hidden_states.shape
         
     | 
| 1131 | 
         
            +
                        hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
         
     | 
| 1132 | 
         
            +
             
     | 
| 1133 | 
         
            +
                    batch_size, sequence_length, _ = (
         
     | 
| 1134 | 
         
            +
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
         
     | 
| 1135 | 
         
            +
                    )
         
     | 
| 1136 | 
         
            +
                    inner_dim = hidden_states.shape[-1]
         
     | 
| 1137 | 
         
            +
             
     | 
| 1138 | 
         
            +
                    if attention_mask is not None:
         
     | 
| 1139 | 
         
            +
                        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         
     | 
| 1140 | 
         
            +
                        # scaled_dot_product_attention expects attention_mask shape to be
         
     | 
| 1141 | 
         
            +
                        # (batch, heads, source_length, target_length)
         
     | 
| 1142 | 
         
            +
                        attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
         
     | 
| 1143 | 
         
            +
             
     | 
| 1144 | 
         
            +
                    if attn.group_norm is not None:
         
     | 
| 1145 | 
         
            +
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         
     | 
| 1146 | 
         
            +
             
     | 
| 1147 | 
         
            +
                    query = attn.to_q(hidden_states)
         
     | 
| 1148 | 
         
            +
             
     | 
| 1149 | 
         
            +
                    if encoder_hidden_states is None:
         
     | 
| 1150 | 
         
            +
                        encoder_hidden_states = hidden_states
         
     | 
| 1151 | 
         
            +
                    elif attn.norm_cross:
         
     | 
| 1152 | 
         
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         
     | 
| 1153 | 
         
            +
             
     | 
| 1154 | 
         
            +
                    key = attn.to_k(encoder_hidden_states)
         
     | 
| 1155 | 
         
            +
                    value = attn.to_v(encoder_hidden_states)
         
     | 
| 1156 | 
         
            +
             
     | 
| 1157 | 
         
            +
                    head_dim = inner_dim // attn.heads
         
     | 
| 1158 | 
         
            +
             
     | 
| 1159 | 
         
            +
                    query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         
     | 
| 1160 | 
         
            +
             
     | 
| 1161 | 
         
            +
                    key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         
     | 
| 1162 | 
         
            +
                    value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         
     | 
| 1163 | 
         
            +
             
     | 
| 1164 | 
         
            +
                    # the output of sdp = (batch, num_heads, seq_len, head_dim)
         
     | 
| 1165 | 
         
            +
                    # TODO: add support for attn.scale when we move to Torch 2.1
         
     | 
| 1166 | 
         
            +
                    hidden_states = F.scaled_dot_product_attention(
         
     | 
| 1167 | 
         
            +
                        query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
         
     | 
| 1168 | 
         
            +
                    )
         
     | 
| 1169 | 
         
            +
             
     | 
| 1170 | 
         
            +
                    hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
         
     | 
| 1171 | 
         
            +
                    hidden_states = hidden_states.to(query.dtype)
         
     | 
| 1172 | 
         
            +
             
     | 
| 1173 | 
         
            +
                    # linear proj
         
     | 
| 1174 | 
         
            +
                    hidden_states = attn.to_out[0](hidden_states)
         
     | 
| 1175 | 
         
            +
                    # dropout
         
     | 
| 1176 | 
         
            +
                    hidden_states = attn.to_out[1](hidden_states)
         
     | 
| 1177 | 
         
            +
             
     | 
| 1178 | 
         
            +
                    if input_ndim == 4:
         
     | 
| 1179 | 
         
            +
                        hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
         
     | 
| 1180 | 
         
            +
             
     | 
| 1181 | 
         
            +
                    if attn.residual_connection:
         
     | 
| 1182 | 
         
            +
                        hidden_states = hidden_states + residual
         
     | 
| 1183 | 
         
            +
             
     | 
| 1184 | 
         
            +
                    hidden_states = hidden_states / attn.rescale_output_factor
         
     | 
| 1185 | 
         
            +
             
     | 
| 1186 | 
         
            +
                    return hidden_states
         
     | 
| 1187 | 
         
            +
             
     | 
| 1188 | 
         
            +
             
     | 
| 1189 | 
         
            +
            class LoRAXFormersAttnProcessor(nn.Module):
         
     | 
| 1190 | 
         
            +
                r"""
         
     | 
| 1191 | 
         
            +
                Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
         
     | 
| 1192 | 
         
            +
             
     | 
| 1193 | 
         
            +
                Args:
         
     | 
| 1194 | 
         
            +
                    hidden_size (`int`, *optional*):
         
     | 
| 1195 | 
         
            +
                        The hidden size of the attention layer.
         
     | 
| 1196 | 
         
            +
                    cross_attention_dim (`int`, *optional*):
         
     | 
| 1197 | 
         
            +
                        The number of channels in the `encoder_hidden_states`.
         
     | 
| 1198 | 
         
            +
                    rank (`int`, defaults to 4):
         
     | 
| 1199 | 
         
            +
                        The dimension of the LoRA update matrices.
         
     | 
| 1200 | 
         
            +
                    attention_op (`Callable`, *optional*, defaults to `None`):
         
     | 
| 1201 | 
         
            +
                        The base
         
     | 
| 1202 | 
         
            +
                        [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
         
     | 
| 1203 | 
         
            +
                        use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
         
     | 
| 1204 | 
         
            +
                        operator.
         
     | 
| 1205 | 
         
            +
                    network_alpha (`int`, *optional*):
         
     | 
| 1206 | 
         
            +
                        Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
         
     | 
| 1207 | 
         
            +
             
     | 
| 1208 | 
         
            +
                """
         
     | 
| 1209 | 
         
            +
             
     | 
| 1210 | 
         
            +
                def __init__(
         
     | 
| 1211 | 
         
            +
                    self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None
         
     | 
| 1212 | 
         
            +
                ):
         
     | 
| 1213 | 
         
            +
                    super().__init__()
         
     | 
| 1214 | 
         
            +
             
     | 
| 1215 | 
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 1216 | 
         
            +
                    self.cross_attention_dim = cross_attention_dim
         
     | 
| 1217 | 
         
            +
                    self.rank = rank
         
     | 
| 1218 | 
         
            +
                    self.attention_op = attention_op
         
     | 
| 1219 | 
         
            +
             
     | 
| 1220 | 
         
            +
                    self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
         
     | 
| 1221 | 
         
            +
                    self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
         
     | 
| 1222 | 
         
            +
                    self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
         
     | 
| 1223 | 
         
            +
                    self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
         
     | 
| 1224 | 
         
            +
             
     | 
| 1225 | 
         
            +
                def __call__(
         
     | 
| 1226 | 
         
            +
                    self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
         
     | 
| 1227 | 
         
            +
                ):
         
     | 
| 1228 | 
         
            +
                    residual = hidden_states
         
     | 
| 1229 | 
         
            +
             
     | 
| 1230 | 
         
            +
                    if attn.spatial_norm is not None:
         
     | 
| 1231 | 
         
            +
                        hidden_states = attn.spatial_norm(hidden_states, temb)
         
     | 
| 1232 | 
         
            +
             
     | 
| 1233 | 
         
            +
                    input_ndim = hidden_states.ndim
         
     | 
| 1234 | 
         
            +
             
     | 
| 1235 | 
         
            +
                    if input_ndim == 4:
         
     | 
| 1236 | 
         
            +
                        batch_size, channel, height, width = hidden_states.shape
         
     | 
| 1237 | 
         
            +
                        hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
         
     | 
| 1238 | 
         
            +
             
     | 
| 1239 | 
         
            +
                    batch_size, sequence_length, _ = (
         
     | 
| 1240 | 
         
            +
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
         
     | 
| 1241 | 
         
            +
                    )
         
     | 
| 1242 | 
         
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         
     | 
| 1243 | 
         
            +
             
     | 
| 1244 | 
         
            +
                    if attn.group_norm is not None:
         
     | 
| 1245 | 
         
            +
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         
     | 
| 1246 | 
         
            +
             
     | 
| 1247 | 
         
            +
                    query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
         
     | 
| 1248 | 
         
            +
                    query = attn.head_to_batch_dim(query).contiguous()
         
     | 
| 1249 | 
         
            +
             
     | 
| 1250 | 
         
            +
                    if encoder_hidden_states is None:
         
     | 
| 1251 | 
         
            +
                        encoder_hidden_states = hidden_states
         
     | 
| 1252 | 
         
            +
                    elif attn.norm_cross:
         
     | 
| 1253 | 
         
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         
     | 
| 1254 | 
         
            +
             
     | 
| 1255 | 
         
            +
                    key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
         
     | 
| 1256 | 
         
            +
                    value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
         
     | 
| 1257 | 
         
            +
             
     | 
| 1258 | 
         
            +
                    key = attn.head_to_batch_dim(key).contiguous()
         
     | 
| 1259 | 
         
            +
                    value = attn.head_to_batch_dim(value).contiguous()
         
     | 
| 1260 | 
         
            +
             
     | 
| 1261 | 
         
            +
                    hidden_states = xformers.ops.memory_efficient_attention(
         
     | 
| 1262 | 
         
            +
                        query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
         
     | 
| 1263 | 
         
            +
                    )
         
     | 
| 1264 | 
         
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         
     | 
| 1265 | 
         
            +
             
     | 
| 1266 | 
         
            +
                    # linear proj
         
     | 
| 1267 | 
         
            +
                    hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
         
     | 
| 1268 | 
         
            +
                    # dropout
         
     | 
| 1269 | 
         
            +
                    hidden_states = attn.to_out[1](hidden_states)
         
     | 
| 1270 | 
         
            +
             
     | 
| 1271 | 
         
            +
                    if input_ndim == 4:
         
     | 
| 1272 | 
         
            +
                        hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
         
     | 
| 1273 | 
         
            +
             
     | 
| 1274 | 
         
            +
                    if attn.residual_connection:
         
     | 
| 1275 | 
         
            +
                        hidden_states = hidden_states + residual
         
     | 
| 1276 | 
         
            +
             
     | 
| 1277 | 
         
            +
                    hidden_states = hidden_states / attn.rescale_output_factor
         
     | 
| 1278 | 
         
            +
             
     | 
| 1279 | 
         
            +
                    return hidden_states
         
     | 
| 1280 | 
         
            +
             
     | 
| 1281 | 
         
            +
             
     | 
| 1282 | 
         
            +
            class LoRAAttnProcessor2_0(nn.Module):
         
     | 
| 1283 | 
         
            +
                r"""
         
     | 
| 1284 | 
         
            +
                Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
         
     | 
| 1285 | 
         
            +
                attention.
         
     | 
| 1286 | 
         
            +
             
     | 
| 1287 | 
         
            +
                Args:
         
     | 
| 1288 | 
         
            +
                    hidden_size (`int`):
         
     | 
| 1289 | 
         
            +
                        The hidden size of the attention layer.
         
     | 
| 1290 | 
         
            +
                    cross_attention_dim (`int`, *optional*):
         
     | 
| 1291 | 
         
            +
                        The number of channels in the `encoder_hidden_states`.
         
     | 
| 1292 | 
         
            +
                    rank (`int`, defaults to 4):
         
     | 
| 1293 | 
         
            +
                        The dimension of the LoRA update matrices.
         
     | 
| 1294 | 
         
            +
                    network_alpha (`int`, *optional*):
         
     | 
| 1295 | 
         
            +
                        Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
         
     | 
| 1296 | 
         
            +
                """
         
     | 
| 1297 | 
         
            +
             
     | 
| 1298 | 
         
            +
                def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
         
     | 
| 1299 | 
         
            +
                    super().__init__()
         
     | 
| 1300 | 
         
            +
                    if not hasattr(F, "scaled_dot_product_attention"):
         
     | 
| 1301 | 
         
            +
                        raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
         
     | 
| 1302 | 
         
            +
             
     | 
| 1303 | 
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 1304 | 
         
            +
                    self.cross_attention_dim = cross_attention_dim
         
     | 
| 1305 | 
         
            +
                    self.rank = rank
         
     | 
| 1306 | 
         
            +
             
     | 
| 1307 | 
         
            +
                    self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
         
     | 
| 1308 | 
         
            +
                    self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
         
     | 
| 1309 | 
         
            +
                    self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
         
     | 
| 1310 | 
         
            +
                    self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
         
     | 
| 1311 | 
         
            +
             
     | 
| 1312 | 
         
            +
                def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
         
     | 
| 1313 | 
         
            +
                    residual = hidden_states
         
     | 
| 1314 | 
         
            +
             
     | 
| 1315 | 
         
            +
                    input_ndim = hidden_states.ndim
         
     | 
| 1316 | 
         
            +
             
     | 
| 1317 | 
         
            +
                    if input_ndim == 4:
         
     | 
| 1318 | 
         
            +
                        batch_size, channel, height, width = hidden_states.shape
         
     | 
| 1319 | 
         
            +
                        hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
         
     | 
| 1320 | 
         
            +
             
     | 
| 1321 | 
         
            +
                    batch_size, sequence_length, _ = (
         
     | 
| 1322 | 
         
            +
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
         
     | 
| 1323 | 
         
            +
                    )
         
     | 
| 1324 | 
         
            +
                    inner_dim = hidden_states.shape[-1]
         
     | 
| 1325 | 
         
            +
             
     | 
| 1326 | 
         
            +
                    if attention_mask is not None:
         
     | 
| 1327 | 
         
            +
                        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         
     | 
| 1328 | 
         
            +
                        # scaled_dot_product_attention expects attention_mask shape to be
         
     | 
| 1329 | 
         
            +
                        # (batch, heads, source_length, target_length)
         
     | 
| 1330 | 
         
            +
                        attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
         
     | 
| 1331 | 
         
            +
             
     | 
| 1332 | 
         
            +
                    if attn.group_norm is not None:
         
     | 
| 1333 | 
         
            +
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         
     | 
| 1334 | 
         
            +
             
     | 
| 1335 | 
         
            +
                    query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
         
     | 
| 1336 | 
         
            +
             
     | 
| 1337 | 
         
            +
                    if encoder_hidden_states is None:
         
     | 
| 1338 | 
         
            +
                        encoder_hidden_states = hidden_states
         
     | 
| 1339 | 
         
            +
                    elif attn.norm_cross:
         
     | 
| 1340 | 
         
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         
     | 
| 1341 | 
         
            +
             
     | 
| 1342 | 
         
            +
                    key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
         
     | 
| 1343 | 
         
            +
                    value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
         
     | 
| 1344 | 
         
            +
             
     | 
| 1345 | 
         
            +
                    head_dim = inner_dim // attn.heads
         
     | 
| 1346 | 
         
            +
                    query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         
     | 
| 1347 | 
         
            +
                    key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         
     | 
| 1348 | 
         
            +
                    value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         
     | 
| 1349 | 
         
            +
             
     | 
| 1350 | 
         
            +
                    # TODO: add support for attn.scale when we move to Torch 2.1
         
     | 
| 1351 | 
         
            +
                    hidden_states = F.scaled_dot_product_attention(
         
     | 
| 1352 | 
         
            +
                        query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
         
     | 
| 1353 | 
         
            +
                    )
         
     | 
| 1354 | 
         
            +
                    hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
         
     | 
| 1355 | 
         
            +
                    hidden_states = hidden_states.to(query.dtype)
         
     | 
| 1356 | 
         
            +
             
     | 
| 1357 | 
         
            +
                    # linear proj
         
     | 
| 1358 | 
         
            +
                    hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
         
     | 
| 1359 | 
         
            +
                    # dropout
         
     | 
| 1360 | 
         
            +
                    hidden_states = attn.to_out[1](hidden_states)
         
     | 
| 1361 | 
         
            +
             
     | 
| 1362 | 
         
            +
                    if input_ndim == 4:
         
     | 
| 1363 | 
         
            +
                        hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
         
     | 
| 1364 | 
         
            +
             
     | 
| 1365 | 
         
            +
                    if attn.residual_connection:
         
     | 
| 1366 | 
         
            +
                        hidden_states = hidden_states + residual
         
     | 
| 1367 | 
         
            +
             
     | 
| 1368 | 
         
            +
                    hidden_states = hidden_states / attn.rescale_output_factor
         
     | 
| 1369 | 
         
            +
             
     | 
| 1370 | 
         
            +
                    return hidden_states
         
     | 
| 1371 | 
         
            +
             
     | 
| 1372 | 
         
            +
             
     | 
| 1373 | 
         
            +
            class CustomDiffusionXFormersAttnProcessor(nn.Module):
         
     | 
| 1374 | 
         
            +
                r"""
         
     | 
| 1375 | 
         
            +
                Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
         
     | 
| 1376 | 
         
            +
             
     | 
| 1377 | 
         
            +
                Args:
         
     | 
| 1378 | 
         
            +
                train_kv (`bool`, defaults to `True`):
         
     | 
| 1379 | 
         
            +
                    Whether to newly train the key and value matrices corresponding to the text features.
         
     | 
| 1380 | 
         
            +
                train_q_out (`bool`, defaults to `True`):
         
     | 
| 1381 | 
         
            +
                    Whether to newly train query matrices corresponding to the latent image features.
         
     | 
| 1382 | 
         
            +
                hidden_size (`int`, *optional*, defaults to `None`):
         
     | 
| 1383 | 
         
            +
                    The hidden size of the attention layer.
         
     | 
| 1384 | 
         
            +
                cross_attention_dim (`int`, *optional*, defaults to `None`):
         
     | 
| 1385 | 
         
            +
                    The number of channels in the `encoder_hidden_states`.
         
     | 
| 1386 | 
         
            +
                out_bias (`bool`, defaults to `True`):
         
     | 
| 1387 | 
         
            +
                    Whether to include the bias parameter in `train_q_out`.
         
     | 
| 1388 | 
         
            +
                dropout (`float`, *optional*, defaults to 0.0):
         
     | 
| 1389 | 
         
            +
                    The dropout probability to use.
         
     | 
| 1390 | 
         
            +
                attention_op (`Callable`, *optional*, defaults to `None`):
         
     | 
| 1391 | 
         
            +
                    The base
         
     | 
| 1392 | 
         
            +
                    [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
         
     | 
| 1393 | 
         
            +
                    as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
         
     | 
| 1394 | 
         
            +
                """
         
     | 
| 1395 | 
         
            +
             
     | 
| 1396 | 
         
            +
                def __init__(
         
     | 
| 1397 | 
         
            +
                    self,
         
     | 
| 1398 | 
         
            +
                    train_kv=True,
         
     | 
| 1399 | 
         
            +
                    train_q_out=False,
         
     | 
| 1400 | 
         
            +
                    hidden_size=None,
         
     | 
| 1401 | 
         
            +
                    cross_attention_dim=None,
         
     | 
| 1402 | 
         
            +
                    out_bias=True,
         
     | 
| 1403 | 
         
            +
                    dropout=0.0,
         
     | 
| 1404 | 
         
            +
                    attention_op: Optional[Callable] = None,
         
     | 
| 1405 | 
         
            +
                ):
         
     | 
| 1406 | 
         
            +
                    super().__init__()
         
     | 
| 1407 | 
         
            +
                    self.train_kv = train_kv
         
     | 
| 1408 | 
         
            +
                    self.train_q_out = train_q_out
         
     | 
| 1409 | 
         
            +
             
     | 
| 1410 | 
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 1411 | 
         
            +
                    self.cross_attention_dim = cross_attention_dim
         
     | 
| 1412 | 
         
            +
                    self.attention_op = attention_op
         
     | 
| 1413 | 
         
            +
             
     | 
| 1414 | 
         
            +
                    # `_custom_diffusion` id for easy serialization and loading.
         
     | 
| 1415 | 
         
            +
                    if self.train_kv:
         
     | 
| 1416 | 
         
            +
                        self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
         
     | 
| 1417 | 
         
            +
                        self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
         
     | 
| 1418 | 
         
            +
                    if self.train_q_out:
         
     | 
| 1419 | 
         
            +
                        self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
         
     | 
| 1420 | 
         
            +
                        self.to_out_custom_diffusion = nn.ModuleList([])
         
     | 
| 1421 | 
         
            +
                        self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
         
     | 
| 1422 | 
         
            +
                        self.to_out_custom_diffusion.append(nn.Dropout(dropout))
         
     | 
| 1423 | 
         
            +
             
     | 
| 1424 | 
         
            +
                def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
         
     | 
| 1425 | 
         
            +
                    batch_size, sequence_length, _ = (
         
     | 
| 1426 | 
         
            +
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
         
     | 
| 1427 | 
         
            +
                    )
         
     | 
| 1428 | 
         
            +
             
     | 
| 1429 | 
         
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         
     | 
| 1430 | 
         
            +
             
     | 
| 1431 | 
         
            +
                    if self.train_q_out:
         
     | 
| 1432 | 
         
            +
                        query = self.to_q_custom_diffusion(hidden_states)
         
     | 
| 1433 | 
         
            +
                    else:
         
     | 
| 1434 | 
         
            +
                        query = attn.to_q(hidden_states)
         
     | 
| 1435 | 
         
            +
             
     | 
| 1436 | 
         
            +
                    if encoder_hidden_states is None:
         
     | 
| 1437 | 
         
            +
                        crossattn = False
         
     | 
| 1438 | 
         
            +
                        encoder_hidden_states = hidden_states
         
     | 
| 1439 | 
         
            +
                    else:
         
     | 
| 1440 | 
         
            +
                        crossattn = True
         
     | 
| 1441 | 
         
            +
                        if attn.norm_cross:
         
     | 
| 1442 | 
         
            +
                            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         
     | 
| 1443 | 
         
            +
             
     | 
| 1444 | 
         
            +
                    if self.train_kv:
         
     | 
| 1445 | 
         
            +
                        key = self.to_k_custom_diffusion(encoder_hidden_states)
         
     | 
| 1446 | 
         
            +
                        value = self.to_v_custom_diffusion(encoder_hidden_states)
         
     | 
| 1447 | 
         
            +
                    else:
         
     | 
| 1448 | 
         
            +
                        key = attn.to_k(encoder_hidden_states)
         
     | 
| 1449 | 
         
            +
                        value = attn.to_v(encoder_hidden_states)
         
     | 
| 1450 | 
         
            +
             
     | 
| 1451 | 
         
            +
                    if crossattn:
         
     | 
| 1452 | 
         
            +
                        detach = torch.ones_like(key)
         
     | 
| 1453 | 
         
            +
                        detach[:, :1, :] = detach[:, :1, :] * 0.0
         
     | 
| 1454 | 
         
            +
                        key = detach * key + (1 - detach) * key.detach()
         
     | 
| 1455 | 
         
            +
                        value = detach * value + (1 - detach) * value.detach()
         
     | 
| 1456 | 
         
            +
             
     | 
| 1457 | 
         
            +
                    query = attn.head_to_batch_dim(query).contiguous()
         
     | 
| 1458 | 
         
            +
                    key = attn.head_to_batch_dim(key).contiguous()
         
     | 
| 1459 | 
         
            +
                    value = attn.head_to_batch_dim(value).contiguous()
         
     | 
| 1460 | 
         
            +
             
     | 
| 1461 | 
         
            +
                    hidden_states = xformers.ops.memory_efficient_attention(
         
     | 
| 1462 | 
         
            +
                        query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
         
     | 
| 1463 | 
         
            +
                    )
         
     | 
| 1464 | 
         
            +
                    hidden_states = hidden_states.to(query.dtype)
         
     | 
| 1465 | 
         
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         
     | 
| 1466 | 
         
            +
             
     | 
| 1467 | 
         
            +
                    if self.train_q_out:
         
     | 
| 1468 | 
         
            +
                        # linear proj
         
     | 
| 1469 | 
         
            +
                        hidden_states = self.to_out_custom_diffusion[0](hidden_states)
         
     | 
| 1470 | 
         
            +
                        # dropout
         
     | 
| 1471 | 
         
            +
                        hidden_states = self.to_out_custom_diffusion[1](hidden_states)
         
     | 
| 1472 | 
         
            +
                    else:
         
     | 
| 1473 | 
         
            +
                        # linear proj
         
     | 
| 1474 | 
         
            +
                        hidden_states = attn.to_out[0](hidden_states)
         
     | 
| 1475 | 
         
            +
                        # dropout
         
     | 
| 1476 | 
         
            +
                        hidden_states = attn.to_out[1](hidden_states)
         
     | 
| 1477 | 
         
            +
                    return hidden_states
         
     | 
| 1478 | 
         
            +
             
     | 
| 1479 | 
         
            +
             
     | 
| 1480 | 
         
            +
            class SlicedAttnProcessor:
         
     | 
| 1481 | 
         
            +
                r"""
         
     | 
| 1482 | 
         
            +
                Processor for implementing sliced attention.
         
     | 
| 1483 | 
         
            +
             
     | 
| 1484 | 
         
            +
                Args:
         
     | 
| 1485 | 
         
            +
                    slice_size (`int`, *optional*):
         
     | 
| 1486 | 
         
            +
                        The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
         
     | 
| 1487 | 
         
            +
                        `attention_head_dim` must be a multiple of the `slice_size`.
         
     | 
| 1488 | 
         
            +
                """
         
     | 
| 1489 | 
         
            +
             
     | 
| 1490 | 
         
            +
                def __init__(self, slice_size):
         
     | 
| 1491 | 
         
            +
                    self.slice_size = slice_size
         
     | 
| 1492 | 
         
            +
             
     | 
| 1493 | 
         
            +
                def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
         
     | 
| 1494 | 
         
            +
                    residual = hidden_states
         
     | 
| 1495 | 
         
            +
             
     | 
| 1496 | 
         
            +
                    input_ndim = hidden_states.ndim
         
     | 
| 1497 | 
         
            +
             
     | 
| 1498 | 
         
            +
                    if input_ndim == 4:
         
     | 
| 1499 | 
         
            +
                        batch_size, channel, height, width = hidden_states.shape
         
     | 
| 1500 | 
         
            +
                        hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
         
     | 
| 1501 | 
         
            +
             
     | 
| 1502 | 
         
            +
                    batch_size, sequence_length, _ = (
         
     | 
| 1503 | 
         
            +
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
         
     | 
| 1504 | 
         
            +
                    )
         
     | 
| 1505 | 
         
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         
     | 
| 1506 | 
         
            +
             
     | 
| 1507 | 
         
            +
                    if attn.group_norm is not None:
         
     | 
| 1508 | 
         
            +
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         
     | 
| 1509 | 
         
            +
             
     | 
| 1510 | 
         
            +
                    query = attn.to_q(hidden_states)
         
     | 
| 1511 | 
         
            +
                    dim = query.shape[-1]
         
     | 
| 1512 | 
         
            +
                    query = attn.head_to_batch_dim(query)
         
     | 
| 1513 | 
         
            +
             
     | 
| 1514 | 
         
            +
                    if encoder_hidden_states is None:
         
     | 
| 1515 | 
         
            +
                        encoder_hidden_states = hidden_states
         
     | 
| 1516 | 
         
            +
                    elif attn.norm_cross:
         
     | 
| 1517 | 
         
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         
     | 
| 1518 | 
         
            +
             
     | 
| 1519 | 
         
            +
                    key = attn.to_k(encoder_hidden_states)
         
     | 
| 1520 | 
         
            +
                    value = attn.to_v(encoder_hidden_states)
         
     | 
| 1521 | 
         
            +
                    key = attn.head_to_batch_dim(key)
         
     | 
| 1522 | 
         
            +
                    value = attn.head_to_batch_dim(value)
         
     | 
| 1523 | 
         
            +
             
     | 
| 1524 | 
         
            +
                    batch_size_attention, query_tokens, _ = query.shape
         
     | 
| 1525 | 
         
            +
                    hidden_states = torch.zeros(
         
     | 
| 1526 | 
         
            +
                        (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
         
     | 
| 1527 | 
         
            +
                    )
         
     | 
| 1528 | 
         
            +
             
     | 
| 1529 | 
         
            +
                    for i in range(batch_size_attention // self.slice_size):
         
     | 
| 1530 | 
         
            +
                        start_idx = i * self.slice_size
         
     | 
| 1531 | 
         
            +
                        end_idx = (i + 1) * self.slice_size
         
     | 
| 1532 | 
         
            +
             
     | 
| 1533 | 
         
            +
                        query_slice = query[start_idx:end_idx]
         
     | 
| 1534 | 
         
            +
                        key_slice = key[start_idx:end_idx]
         
     | 
| 1535 | 
         
            +
                        attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
         
     | 
| 1536 | 
         
            +
             
     | 
| 1537 | 
         
            +
                        attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
         
     | 
| 1538 | 
         
            +
             
     | 
| 1539 | 
         
            +
                        attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
         
     | 
| 1540 | 
         
            +
             
     | 
| 1541 | 
         
            +
                        hidden_states[start_idx:end_idx] = attn_slice
         
     | 
| 1542 | 
         
            +
             
     | 
| 1543 | 
         
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         
     | 
| 1544 | 
         
            +
             
     | 
| 1545 | 
         
            +
                    # linear proj
         
     | 
| 1546 | 
         
            +
                    hidden_states = attn.to_out[0](hidden_states)
         
     | 
| 1547 | 
         
            +
                    # dropout
         
     | 
| 1548 | 
         
            +
                    hidden_states = attn.to_out[1](hidden_states)
         
     | 
| 1549 | 
         
            +
             
     | 
| 1550 | 
         
            +
                    if input_ndim == 4:
         
     | 
| 1551 | 
         
            +
                        hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
         
     | 
| 1552 | 
         
            +
             
     | 
| 1553 | 
         
            +
                    if attn.residual_connection:
         
     | 
| 1554 | 
         
            +
                        hidden_states = hidden_states + residual
         
     | 
| 1555 | 
         
            +
             
     | 
| 1556 | 
         
            +
                    hidden_states = hidden_states / attn.rescale_output_factor
         
     | 
| 1557 | 
         
            +
             
     | 
| 1558 | 
         
            +
                    return hidden_states
         
     | 
| 1559 | 
         
            +
             
     | 
| 1560 | 
         
            +
             
     | 
| 1561 | 
         
            +
            class SlicedAttnAddedKVProcessor:
         
     | 
| 1562 | 
         
            +
                r"""
         
     | 
| 1563 | 
         
            +
                Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
         
     | 
| 1564 | 
         
            +
             
     | 
| 1565 | 
         
            +
                Args:
         
     | 
| 1566 | 
         
            +
                    slice_size (`int`, *optional*):
         
     | 
| 1567 | 
         
            +
                        The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
         
     | 
| 1568 | 
         
            +
                        `attention_head_dim` must be a multiple of the `slice_size`.
         
     | 
| 1569 | 
         
            +
                """
         
     | 
| 1570 | 
         
            +
             
     | 
| 1571 | 
         
            +
                def __init__(self, slice_size):
         
     | 
| 1572 | 
         
            +
                    self.slice_size = slice_size
         
     | 
| 1573 | 
         
            +
             
     | 
| 1574 | 
         
            +
                def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
         
     | 
| 1575 | 
         
            +
                    residual = hidden_states
         
     | 
| 1576 | 
         
            +
             
     | 
| 1577 | 
         
            +
                    if attn.spatial_norm is not None:
         
     | 
| 1578 | 
         
            +
                        hidden_states = attn.spatial_norm(hidden_states, temb)
         
     | 
| 1579 | 
         
            +
             
     | 
| 1580 | 
         
            +
                    hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
         
     | 
| 1581 | 
         
            +
             
     | 
| 1582 | 
         
            +
                    batch_size, sequence_length, _ = hidden_states.shape
         
     | 
| 1583 | 
         
            +
             
     | 
| 1584 | 
         
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         
     | 
| 1585 | 
         
            +
             
     | 
| 1586 | 
         
            +
                    if encoder_hidden_states is None:
         
     | 
| 1587 | 
         
            +
                        encoder_hidden_states = hidden_states
         
     | 
| 1588 | 
         
            +
                    elif attn.norm_cross:
         
     | 
| 1589 | 
         
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         
     | 
| 1590 | 
         
            +
             
     | 
| 1591 | 
         
            +
                    hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         
     | 
| 1592 | 
         
            +
             
     | 
| 1593 | 
         
            +
                    query = attn.to_q(hidden_states)
         
     | 
| 1594 | 
         
            +
                    dim = query.shape[-1]
         
     | 
| 1595 | 
         
            +
                    query = attn.head_to_batch_dim(query)
         
     | 
| 1596 | 
         
            +
             
     | 
| 1597 | 
         
            +
                    encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
         
     | 
| 1598 | 
         
            +
                    encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
         
     | 
| 1599 | 
         
            +
             
     | 
| 1600 | 
         
            +
                    encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
         
     | 
| 1601 | 
         
            +
                    encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
         
     | 
| 1602 | 
         
            +
             
     | 
| 1603 | 
         
            +
                    if not attn.only_cross_attention:
         
     | 
| 1604 | 
         
            +
                        key = attn.to_k(hidden_states)
         
     | 
| 1605 | 
         
            +
                        value = attn.to_v(hidden_states)
         
     | 
| 1606 | 
         
            +
                        key = attn.head_to_batch_dim(key)
         
     | 
| 1607 | 
         
            +
                        value = attn.head_to_batch_dim(value)
         
     | 
| 1608 | 
         
            +
                        key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
         
     | 
| 1609 | 
         
            +
                        value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
         
     | 
| 1610 | 
         
            +
                    else:
         
     | 
| 1611 | 
         
            +
                        key = encoder_hidden_states_key_proj
         
     | 
| 1612 | 
         
            +
                        value = encoder_hidden_states_value_proj
         
     | 
| 1613 | 
         
            +
             
     | 
| 1614 | 
         
            +
                    batch_size_attention, query_tokens, _ = query.shape
         
     | 
| 1615 | 
         
            +
                    hidden_states = torch.zeros(
         
     | 
| 1616 | 
         
            +
                        (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
         
     | 
| 1617 | 
         
            +
                    )
         
     | 
| 1618 | 
         
            +
             
     | 
| 1619 | 
         
            +
                    for i in range(batch_size_attention // self.slice_size):
         
     | 
| 1620 | 
         
            +
                        start_idx = i * self.slice_size
         
     | 
| 1621 | 
         
            +
                        end_idx = (i + 1) * self.slice_size
         
     | 
| 1622 | 
         
            +
             
     | 
| 1623 | 
         
            +
                        query_slice = query[start_idx:end_idx]
         
     | 
| 1624 | 
         
            +
                        key_slice = key[start_idx:end_idx]
         
     | 
| 1625 | 
         
            +
                        attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
         
     | 
| 1626 | 
         
            +
             
     | 
| 1627 | 
         
            +
                        attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
         
     | 
| 1628 | 
         
            +
             
     | 
| 1629 | 
         
            +
                        attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
         
     | 
| 1630 | 
         
            +
             
     | 
| 1631 | 
         
            +
                        hidden_states[start_idx:end_idx] = attn_slice
         
     | 
| 1632 | 
         
            +
             
     | 
| 1633 | 
         
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         
     | 
| 1634 | 
         
            +
             
     | 
| 1635 | 
         
            +
                    # linear proj
         
     | 
| 1636 | 
         
            +
                    hidden_states = attn.to_out[0](hidden_states)
         
     | 
| 1637 | 
         
            +
                    # dropout
         
     | 
| 1638 | 
         
            +
                    hidden_states = attn.to_out[1](hidden_states)
         
     | 
| 1639 | 
         
            +
             
     | 
| 1640 | 
         
            +
                    hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
         
     | 
| 1641 | 
         
            +
                    hidden_states = hidden_states + residual
         
     | 
| 1642 | 
         
            +
             
     | 
| 1643 | 
         
            +
                    return hidden_states
         
     | 
| 1644 | 
         
            +
             
     | 
| 1645 | 
         
            +
             
     | 
| 1646 | 
         
            +
            AttentionProcessor = Union[
         
     | 
| 1647 | 
         
            +
                AttnProcessor,
         
     | 
| 1648 | 
         
            +
                AttnProcessor2_0,
         
     | 
| 1649 | 
         
            +
                XFormersAttnProcessor,
         
     | 
| 1650 | 
         
            +
                SlicedAttnProcessor,
         
     | 
| 1651 | 
         
            +
                AttnAddedKVProcessor,
         
     | 
| 1652 | 
         
            +
                SlicedAttnAddedKVProcessor,
         
     | 
| 1653 | 
         
            +
                AttnAddedKVProcessor2_0,
         
     | 
| 1654 | 
         
            +
                XFormersAttnAddedKVProcessor,
         
     | 
| 1655 | 
         
            +
                LoRAAttnProcessor,
         
     | 
| 1656 | 
         
            +
                LoRAXFormersAttnProcessor,
         
     | 
| 1657 | 
         
            +
                LoRAAttnProcessor2_0,
         
     | 
| 1658 | 
         
            +
                LoRAAttnAddedKVProcessor,
         
     | 
| 1659 | 
         
            +
                CustomDiffusionAttnProcessor,
         
     | 
| 1660 | 
         
            +
                CustomDiffusionXFormersAttnProcessor,
         
     | 
| 1661 | 
         
            +
            ]
         
     | 
| 1662 | 
         
            +
             
     | 
| 1663 | 
         
            +
             
     | 
| 1664 | 
         
            +
            class SpatialNorm(nn.Module):
         
     | 
| 1665 | 
         
            +
                """
         
     | 
| 1666 | 
         
            +
                Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002
         
     | 
| 1667 | 
         
            +
                """
         
     | 
| 1668 | 
         
            +
             
     | 
| 1669 | 
         
            +
                def __init__(
         
     | 
| 1670 | 
         
            +
                    self,
         
     | 
| 1671 | 
         
            +
                    f_channels,
         
     | 
| 1672 | 
         
            +
                    zq_channels,
         
     | 
| 1673 | 
         
            +
                ):
         
     | 
| 1674 | 
         
            +
                    super().__init__()
         
     | 
| 1675 | 
         
            +
                    self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
         
     | 
| 1676 | 
         
            +
                    self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
         
     | 
| 1677 | 
         
            +
                    self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
         
     | 
| 1678 | 
         
            +
             
     | 
| 1679 | 
         
            +
                def forward(self, f, zq):
         
     | 
| 1680 | 
         
            +
                    f_size = f.shape[-2:]
         
     | 
| 1681 | 
         
            +
                    zq = F.interpolate(zq, size=f_size, mode="nearest")
         
     | 
| 1682 | 
         
            +
                    norm_f = self.norm_layer(f)
         
     | 
| 1683 | 
         
            +
                    new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
         
     | 
| 1684 | 
         
            +
                    return new_f
         
     | 
    	
        6DoF/diffusers/models/autoencoder_kl.py
    ADDED
    
    | 
         @@ -0,0 +1,411 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 15 | 
         
            +
            from typing import Dict, Optional, Tuple, Union
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            import torch
         
     | 
| 18 | 
         
            +
            import torch.nn as nn
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            from ..configuration_utils import ConfigMixin, register_to_config
         
     | 
| 21 | 
         
            +
            from ..utils import BaseOutput, apply_forward_hook
         
     | 
| 22 | 
         
            +
            from .attention_processor import AttentionProcessor, AttnProcessor
         
     | 
| 23 | 
         
            +
            from .modeling_utils import ModelMixin
         
     | 
| 24 | 
         
            +
            from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            @dataclass
         
     | 
| 28 | 
         
            +
            class AutoencoderKLOutput(BaseOutput):
         
     | 
| 29 | 
         
            +
                """
         
     | 
| 30 | 
         
            +
                Output of AutoencoderKL encoding method.
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                Args:
         
     | 
| 33 | 
         
            +
                    latent_dist (`DiagonalGaussianDistribution`):
         
     | 
| 34 | 
         
            +
                        Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
         
     | 
| 35 | 
         
            +
                        `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
         
     | 
| 36 | 
         
            +
                """
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                latent_dist: "DiagonalGaussianDistribution"
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            class AutoencoderKL(ModelMixin, ConfigMixin):
         
     | 
| 42 | 
         
            +
                r"""
         
     | 
| 43 | 
         
            +
                A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
         
     | 
| 46 | 
         
            +
                for all models (such as downloading or saving).
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                Parameters:
         
     | 
| 49 | 
         
            +
                    in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
         
     | 
| 50 | 
         
            +
                    out_channels (int,  *optional*, defaults to 3): Number of channels in the output.
         
     | 
| 51 | 
         
            +
                    down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
         
     | 
| 52 | 
         
            +
                        Tuple of downsample block types.
         
     | 
| 53 | 
         
            +
                    up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
         
     | 
| 54 | 
         
            +
                        Tuple of upsample block types.
         
     | 
| 55 | 
         
            +
                    block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
         
     | 
| 56 | 
         
            +
                        Tuple of block output channels.
         
     | 
| 57 | 
         
            +
                    act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
         
     | 
| 58 | 
         
            +
                    latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
         
     | 
| 59 | 
         
            +
                    sample_size (`int`, *optional*, defaults to `32`): Sample input size.
         
     | 
| 60 | 
         
            +
                    scaling_factor (`float`, *optional*, defaults to 0.18215):
         
     | 
| 61 | 
         
            +
                        The component-wise standard deviation of the trained latent space computed using the first batch of the
         
     | 
| 62 | 
         
            +
                        training set. This is used to scale the latent space to have unit variance when training the diffusion
         
     | 
| 63 | 
         
            +
                        model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
         
     | 
| 64 | 
         
            +
                        diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
         
     | 
| 65 | 
         
            +
                        / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
         
     | 
| 66 | 
         
            +
                        Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
         
     | 
| 67 | 
         
            +
                """
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                _supports_gradient_checkpointing = True
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                @register_to_config
         
     | 
| 72 | 
         
            +
                def __init__(
         
     | 
| 73 | 
         
            +
                    self,
         
     | 
| 74 | 
         
            +
                    in_channels: int = 3,
         
     | 
| 75 | 
         
            +
                    out_channels: int = 3,
         
     | 
| 76 | 
         
            +
                    down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
         
     | 
| 77 | 
         
            +
                    up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
         
     | 
| 78 | 
         
            +
                    block_out_channels: Tuple[int] = (64,),
         
     | 
| 79 | 
         
            +
                    layers_per_block: int = 1,
         
     | 
| 80 | 
         
            +
                    act_fn: str = "silu",
         
     | 
| 81 | 
         
            +
                    latent_channels: int = 4,
         
     | 
| 82 | 
         
            +
                    norm_num_groups: int = 32,
         
     | 
| 83 | 
         
            +
                    sample_size: int = 32,
         
     | 
| 84 | 
         
            +
                    scaling_factor: float = 0.18215,
         
     | 
| 85 | 
         
            +
                ):
         
     | 
| 86 | 
         
            +
                    super().__init__()
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    # pass init params to Encoder
         
     | 
| 89 | 
         
            +
                    self.encoder = Encoder(
         
     | 
| 90 | 
         
            +
                        in_channels=in_channels,
         
     | 
| 91 | 
         
            +
                        out_channels=latent_channels,
         
     | 
| 92 | 
         
            +
                        down_block_types=down_block_types,
         
     | 
| 93 | 
         
            +
                        block_out_channels=block_out_channels,
         
     | 
| 94 | 
         
            +
                        layers_per_block=layers_per_block,
         
     | 
| 95 | 
         
            +
                        act_fn=act_fn,
         
     | 
| 96 | 
         
            +
                        norm_num_groups=norm_num_groups,
         
     | 
| 97 | 
         
            +
                        double_z=True,
         
     | 
| 98 | 
         
            +
                    )
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                    # pass init params to Decoder
         
     | 
| 101 | 
         
            +
                    self.decoder = Decoder(
         
     | 
| 102 | 
         
            +
                        in_channels=latent_channels,
         
     | 
| 103 | 
         
            +
                        out_channels=out_channels,
         
     | 
| 104 | 
         
            +
                        up_block_types=up_block_types,
         
     | 
| 105 | 
         
            +
                        block_out_channels=block_out_channels,
         
     | 
| 106 | 
         
            +
                        layers_per_block=layers_per_block,
         
     | 
| 107 | 
         
            +
                        norm_num_groups=norm_num_groups,
         
     | 
| 108 | 
         
            +
                        act_fn=act_fn,
         
     | 
| 109 | 
         
            +
                    )
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                    self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
         
     | 
| 112 | 
         
            +
                    self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                    self.use_slicing = False
         
     | 
| 115 | 
         
            +
                    self.use_tiling = False
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                    # only relevant if vae tiling is enabled
         
     | 
| 118 | 
         
            +
                    self.tile_sample_min_size = self.config.sample_size
         
     | 
| 119 | 
         
            +
                    sample_size = (
         
     | 
| 120 | 
         
            +
                        self.config.sample_size[0]
         
     | 
| 121 | 
         
            +
                        if isinstance(self.config.sample_size, (list, tuple))
         
     | 
| 122 | 
         
            +
                        else self.config.sample_size
         
     | 
| 123 | 
         
            +
                    )
         
     | 
| 124 | 
         
            +
                    self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
         
     | 
| 125 | 
         
            +
                    self.tile_overlap_factor = 0.25
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                def _set_gradient_checkpointing(self, module, value=False):
         
     | 
| 128 | 
         
            +
                    if isinstance(module, (Encoder, Decoder)):
         
     | 
| 129 | 
         
            +
                        module.gradient_checkpointing = value
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                def enable_tiling(self, use_tiling: bool = True):
         
     | 
| 132 | 
         
            +
                    r"""
         
     | 
| 133 | 
         
            +
                    Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
         
     | 
| 134 | 
         
            +
                    compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
         
     | 
| 135 | 
         
            +
                    processing larger images.
         
     | 
| 136 | 
         
            +
                    """
         
     | 
| 137 | 
         
            +
                    self.use_tiling = use_tiling
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                def disable_tiling(self):
         
     | 
| 140 | 
         
            +
                    r"""
         
     | 
| 141 | 
         
            +
                    Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
         
     | 
| 142 | 
         
            +
                    decoding in one step.
         
     | 
| 143 | 
         
            +
                    """
         
     | 
| 144 | 
         
            +
                    self.enable_tiling(False)
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                def enable_slicing(self):
         
     | 
| 147 | 
         
            +
                    r"""
         
     | 
| 148 | 
         
            +
                    Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
         
     | 
| 149 | 
         
            +
                    compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
         
     | 
| 150 | 
         
            +
                    """
         
     | 
| 151 | 
         
            +
                    self.use_slicing = True
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                def disable_slicing(self):
         
     | 
| 154 | 
         
            +
                    r"""
         
     | 
| 155 | 
         
            +
                    Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
         
     | 
| 156 | 
         
            +
                    decoding in one step.
         
     | 
| 157 | 
         
            +
                    """
         
     | 
| 158 | 
         
            +
                    self.use_slicing = False
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                @property
         
     | 
| 161 | 
         
            +
                # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
         
     | 
| 162 | 
         
            +
                def attn_processors(self) -> Dict[str, AttentionProcessor]:
         
     | 
| 163 | 
         
            +
                    r"""
         
     | 
| 164 | 
         
            +
                    Returns:
         
     | 
| 165 | 
         
            +
                        `dict` of attention processors: A dictionary containing all attention processors used in the model with
         
     | 
| 166 | 
         
            +
                        indexed by its weight name.
         
     | 
| 167 | 
         
            +
                    """
         
     | 
| 168 | 
         
            +
                    # set recursively
         
     | 
| 169 | 
         
            +
                    processors = {}
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                    def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
         
     | 
| 172 | 
         
            +
                        if hasattr(module, "set_processor"):
         
     | 
| 173 | 
         
            +
                            processors[f"{name}.processor"] = module.processor
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                        for sub_name, child in module.named_children():
         
     | 
| 176 | 
         
            +
                            fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                        return processors
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                    for name, module in self.named_children():
         
     | 
| 181 | 
         
            +
                        fn_recursive_add_processors(name, module, processors)
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                    return processors
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
         
     | 
| 186 | 
         
            +
                def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
         
     | 
| 187 | 
         
            +
                    r"""
         
     | 
| 188 | 
         
            +
                    Sets the attention processor to use to compute attention.
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
                    Parameters:
         
     | 
| 191 | 
         
            +
                        processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
         
     | 
| 192 | 
         
            +
                            The instantiated processor class or a dictionary of processor classes that will be set as the processor
         
     | 
| 193 | 
         
            +
                            for **all** `Attention` layers.
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                            If `processor` is a dict, the key needs to define the path to the corresponding cross attention
         
     | 
| 196 | 
         
            +
                            processor. This is strongly recommended when setting trainable attention processors.
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                    """
         
     | 
| 199 | 
         
            +
                    count = len(self.attn_processors.keys())
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                    if isinstance(processor, dict) and len(processor) != count:
         
     | 
| 202 | 
         
            +
                        raise ValueError(
         
     | 
| 203 | 
         
            +
                            f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
         
     | 
| 204 | 
         
            +
                            f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
         
     | 
| 205 | 
         
            +
                        )
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                    def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
         
     | 
| 208 | 
         
            +
                        if hasattr(module, "set_processor"):
         
     | 
| 209 | 
         
            +
                            if not isinstance(processor, dict):
         
     | 
| 210 | 
         
            +
                                module.set_processor(processor)
         
     | 
| 211 | 
         
            +
                            else:
         
     | 
| 212 | 
         
            +
                                module.set_processor(processor.pop(f"{name}.processor"))
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                        for sub_name, child in module.named_children():
         
     | 
| 215 | 
         
            +
                            fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
                    for name, module in self.named_children():
         
     | 
| 218 | 
         
            +
                        fn_recursive_attn_processor(name, module, processor)
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
                # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
         
     | 
| 221 | 
         
            +
                def set_default_attn_processor(self):
         
     | 
| 222 | 
         
            +
                    """
         
     | 
| 223 | 
         
            +
                    Disables custom attention processors and sets the default attention implementation.
         
     | 
| 224 | 
         
            +
                    """
         
     | 
| 225 | 
         
            +
                    self.set_attn_processor(AttnProcessor())
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
                @apply_forward_hook
         
     | 
| 228 | 
         
            +
                def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
         
     | 
| 229 | 
         
            +
                    if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
         
     | 
| 230 | 
         
            +
                        return self.tiled_encode(x, return_dict=return_dict)
         
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
                    if self.use_slicing and x.shape[0] > 1:
         
     | 
| 233 | 
         
            +
                        encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
         
     | 
| 234 | 
         
            +
                        h = torch.cat(encoded_slices)
         
     | 
| 235 | 
         
            +
                    else:
         
     | 
| 236 | 
         
            +
                        h = self.encoder(x)
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                    moments = self.quant_conv(h)
         
     | 
| 239 | 
         
            +
                    posterior = DiagonalGaussianDistribution(moments)
         
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
                    if not return_dict:
         
     | 
| 242 | 
         
            +
                        return (posterior,)
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                    return AutoencoderKLOutput(latent_dist=posterior)
         
     | 
| 245 | 
         
            +
             
     | 
| 246 | 
         
            +
                def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
         
     | 
| 247 | 
         
            +
                    if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
         
     | 
| 248 | 
         
            +
                        return self.tiled_decode(z, return_dict=return_dict)
         
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
                    z = self.post_quant_conv(z)
         
     | 
| 251 | 
         
            +
                    dec = self.decoder(z)
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
                    if not return_dict:
         
     | 
| 254 | 
         
            +
                        return (dec,)
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                    return DecoderOutput(sample=dec)
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
                @apply_forward_hook
         
     | 
| 259 | 
         
            +
                def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
         
     | 
| 260 | 
         
            +
                    if self.use_slicing and z.shape[0] > 1:
         
     | 
| 261 | 
         
            +
                        decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
         
     | 
| 262 | 
         
            +
                        decoded = torch.cat(decoded_slices)
         
     | 
| 263 | 
         
            +
                    else:
         
     | 
| 264 | 
         
            +
                        decoded = self._decode(z).sample
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                    if not return_dict:
         
     | 
| 267 | 
         
            +
                        return (decoded,)
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                    return DecoderOutput(sample=decoded)
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
                def blend_v(self, a, b, blend_extent):
         
     | 
| 272 | 
         
            +
                    blend_extent = min(a.shape[2], b.shape[2], blend_extent)
         
     | 
| 273 | 
         
            +
                    for y in range(blend_extent):
         
     | 
| 274 | 
         
            +
                        b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
         
     | 
| 275 | 
         
            +
                    return b
         
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
                def blend_h(self, a, b, blend_extent):
         
     | 
| 278 | 
         
            +
                    blend_extent = min(a.shape[3], b.shape[3], blend_extent)
         
     | 
| 279 | 
         
            +
                    for x in range(blend_extent):
         
     | 
| 280 | 
         
            +
                        b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
         
     | 
| 281 | 
         
            +
                    return b
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
                def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
         
     | 
| 284 | 
         
            +
                    r"""Encode a batch of images using a tiled encoder.
         
     | 
| 285 | 
         
            +
             
     | 
| 286 | 
         
            +
                    When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
         
     | 
| 287 | 
         
            +
                    steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
         
     | 
| 288 | 
         
            +
                    different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
         
     | 
| 289 | 
         
            +
                    tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
         
     | 
| 290 | 
         
            +
                    output, but they should be much less noticeable.
         
     | 
| 291 | 
         
            +
             
     | 
| 292 | 
         
            +
                    Args:
         
     | 
| 293 | 
         
            +
                        x (`torch.FloatTensor`): Input batch of images.
         
     | 
| 294 | 
         
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         
     | 
| 295 | 
         
            +
                            Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
                    Returns:
         
     | 
| 298 | 
         
            +
                        [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
         
     | 
| 299 | 
         
            +
                            If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
         
     | 
| 300 | 
         
            +
                            `tuple` is returned.
         
     | 
| 301 | 
         
            +
                    """
         
     | 
| 302 | 
         
            +
                    overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
         
     | 
| 303 | 
         
            +
                    blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
         
     | 
| 304 | 
         
            +
                    row_limit = self.tile_latent_min_size - blend_extent
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
                    # Split the image into 512x512 tiles and encode them separately.
         
     | 
| 307 | 
         
            +
                    rows = []
         
     | 
| 308 | 
         
            +
                    for i in range(0, x.shape[2], overlap_size):
         
     | 
| 309 | 
         
            +
                        row = []
         
     | 
| 310 | 
         
            +
                        for j in range(0, x.shape[3], overlap_size):
         
     | 
| 311 | 
         
            +
                            tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
         
     | 
| 312 | 
         
            +
                            tile = self.encoder(tile)
         
     | 
| 313 | 
         
            +
                            tile = self.quant_conv(tile)
         
     | 
| 314 | 
         
            +
                            row.append(tile)
         
     | 
| 315 | 
         
            +
                        rows.append(row)
         
     | 
| 316 | 
         
            +
                    result_rows = []
         
     | 
| 317 | 
         
            +
                    for i, row in enumerate(rows):
         
     | 
| 318 | 
         
            +
                        result_row = []
         
     | 
| 319 | 
         
            +
                        for j, tile in enumerate(row):
         
     | 
| 320 | 
         
            +
                            # blend the above tile and the left tile
         
     | 
| 321 | 
         
            +
                            # to the current tile and add the current tile to the result row
         
     | 
| 322 | 
         
            +
                            if i > 0:
         
     | 
| 323 | 
         
            +
                                tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
         
     | 
| 324 | 
         
            +
                            if j > 0:
         
     | 
| 325 | 
         
            +
                                tile = self.blend_h(row[j - 1], tile, blend_extent)
         
     | 
| 326 | 
         
            +
                            result_row.append(tile[:, :, :row_limit, :row_limit])
         
     | 
| 327 | 
         
            +
                        result_rows.append(torch.cat(result_row, dim=3))
         
     | 
| 328 | 
         
            +
             
     | 
| 329 | 
         
            +
                    moments = torch.cat(result_rows, dim=2)
         
     | 
| 330 | 
         
            +
                    posterior = DiagonalGaussianDistribution(moments)
         
     | 
| 331 | 
         
            +
             
     | 
| 332 | 
         
            +
                    if not return_dict:
         
     | 
| 333 | 
         
            +
                        return (posterior,)
         
     | 
| 334 | 
         
            +
             
     | 
| 335 | 
         
            +
                    return AutoencoderKLOutput(latent_dist=posterior)
         
     | 
| 336 | 
         
            +
             
     | 
| 337 | 
         
            +
                def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
         
     | 
| 338 | 
         
            +
                    r"""
         
     | 
| 339 | 
         
            +
                    Decode a batch of images using a tiled decoder.
         
     | 
| 340 | 
         
            +
             
     | 
| 341 | 
         
            +
                    Args:
         
     | 
| 342 | 
         
            +
                        z (`torch.FloatTensor`): Input batch of latent vectors.
         
     | 
| 343 | 
         
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         
     | 
| 344 | 
         
            +
                            Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
         
     | 
| 345 | 
         
            +
             
     | 
| 346 | 
         
            +
                    Returns:
         
     | 
| 347 | 
         
            +
                        [`~models.vae.DecoderOutput`] or `tuple`:
         
     | 
| 348 | 
         
            +
                            If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
         
     | 
| 349 | 
         
            +
                            returned.
         
     | 
| 350 | 
         
            +
                    """
         
     | 
| 351 | 
         
            +
                    overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
         
     | 
| 352 | 
         
            +
                    blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
         
     | 
| 353 | 
         
            +
                    row_limit = self.tile_sample_min_size - blend_extent
         
     | 
| 354 | 
         
            +
             
     | 
| 355 | 
         
            +
                    # Split z into overlapping 64x64 tiles and decode them separately.
         
     | 
| 356 | 
         
            +
                    # The tiles have an overlap to avoid seams between tiles.
         
     | 
| 357 | 
         
            +
                    rows = []
         
     | 
| 358 | 
         
            +
                    for i in range(0, z.shape[2], overlap_size):
         
     | 
| 359 | 
         
            +
                        row = []
         
     | 
| 360 | 
         
            +
                        for j in range(0, z.shape[3], overlap_size):
         
     | 
| 361 | 
         
            +
                            tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
         
     | 
| 362 | 
         
            +
                            tile = self.post_quant_conv(tile)
         
     | 
| 363 | 
         
            +
                            decoded = self.decoder(tile)
         
     | 
| 364 | 
         
            +
                            row.append(decoded)
         
     | 
| 365 | 
         
            +
                        rows.append(row)
         
     | 
| 366 | 
         
            +
                    result_rows = []
         
     | 
| 367 | 
         
            +
                    for i, row in enumerate(rows):
         
     | 
| 368 | 
         
            +
                        result_row = []
         
     | 
| 369 | 
         
            +
                        for j, tile in enumerate(row):
         
     | 
| 370 | 
         
            +
                            # blend the above tile and the left tile
         
     | 
| 371 | 
         
            +
                            # to the current tile and add the current tile to the result row
         
     | 
| 372 | 
         
            +
                            if i > 0:
         
     | 
| 373 | 
         
            +
                                tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
         
     | 
| 374 | 
         
            +
                            if j > 0:
         
     | 
| 375 | 
         
            +
                                tile = self.blend_h(row[j - 1], tile, blend_extent)
         
     | 
| 376 | 
         
            +
                            result_row.append(tile[:, :, :row_limit, :row_limit])
         
     | 
| 377 | 
         
            +
                        result_rows.append(torch.cat(result_row, dim=3))
         
     | 
| 378 | 
         
            +
             
     | 
| 379 | 
         
            +
                    dec = torch.cat(result_rows, dim=2)
         
     | 
| 380 | 
         
            +
                    if not return_dict:
         
     | 
| 381 | 
         
            +
                        return (dec,)
         
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
                    return DecoderOutput(sample=dec)
         
     | 
| 384 | 
         
            +
             
     | 
| 385 | 
         
            +
                def forward(
         
     | 
| 386 | 
         
            +
                    self,
         
     | 
| 387 | 
         
            +
                    sample: torch.FloatTensor,
         
     | 
| 388 | 
         
            +
                    sample_posterior: bool = False,
         
     | 
| 389 | 
         
            +
                    return_dict: bool = True,
         
     | 
| 390 | 
         
            +
                    generator: Optional[torch.Generator] = None,
         
     | 
| 391 | 
         
            +
                ) -> Union[DecoderOutput, torch.FloatTensor]:
         
     | 
| 392 | 
         
            +
                    r"""
         
     | 
| 393 | 
         
            +
                    Args:
         
     | 
| 394 | 
         
            +
                        sample (`torch.FloatTensor`): Input sample.
         
     | 
| 395 | 
         
            +
                        sample_posterior (`bool`, *optional*, defaults to `False`):
         
     | 
| 396 | 
         
            +
                            Whether to sample from the posterior.
         
     | 
| 397 | 
         
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         
     | 
| 398 | 
         
            +
                            Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
         
     | 
| 399 | 
         
            +
                    """
         
     | 
| 400 | 
         
            +
                    x = sample
         
     | 
| 401 | 
         
            +
                    posterior = self.encode(x).latent_dist
         
     | 
| 402 | 
         
            +
                    if sample_posterior:
         
     | 
| 403 | 
         
            +
                        z = posterior.sample(generator=generator)
         
     | 
| 404 | 
         
            +
                    else:
         
     | 
| 405 | 
         
            +
                        z = posterior.mode()
         
     | 
| 406 | 
         
            +
                    dec = self.decode(z).sample
         
     | 
| 407 | 
         
            +
             
     | 
| 408 | 
         
            +
                    if not return_dict:
         
     | 
| 409 | 
         
            +
                        return (dec,)
         
     | 
| 410 | 
         
            +
             
     | 
| 411 | 
         
            +
                    return DecoderOutput(sample=dec)
         
     | 
    	
        6DoF/diffusers/models/controlnet.py
    ADDED
    
    | 
         @@ -0,0 +1,705 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 15 | 
         
            +
            from typing import Any, Dict, List, Optional, Tuple, Union
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            import torch
         
     | 
| 18 | 
         
            +
            from torch import nn
         
     | 
| 19 | 
         
            +
            from torch.nn import functional as F
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            from ..configuration_utils import ConfigMixin, register_to_config
         
     | 
| 22 | 
         
            +
            from ..utils import BaseOutput, logging
         
     | 
| 23 | 
         
            +
            from .attention_processor import AttentionProcessor, AttnProcessor
         
     | 
| 24 | 
         
            +
            from .embeddings import TimestepEmbedding, Timesteps
         
     | 
| 25 | 
         
            +
            from .modeling_utils import ModelMixin
         
     | 
| 26 | 
         
            +
            from .unet_2d_blocks import (
         
     | 
| 27 | 
         
            +
                CrossAttnDownBlock2D,
         
     | 
| 28 | 
         
            +
                DownBlock2D,
         
     | 
| 29 | 
         
            +
                UNetMidBlock2DCrossAttn,
         
     | 
| 30 | 
         
            +
                get_down_block,
         
     | 
| 31 | 
         
            +
            )
         
     | 
| 32 | 
         
            +
            from .unet_2d_condition import UNet2DConditionModel
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            @dataclass
         
     | 
| 39 | 
         
            +
            class ControlNetOutput(BaseOutput):
         
     | 
| 40 | 
         
            +
                """
         
     | 
| 41 | 
         
            +
                The output of [`ControlNetModel`].
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                Args:
         
     | 
| 44 | 
         
            +
                    down_block_res_samples (`tuple[torch.Tensor]`):
         
     | 
| 45 | 
         
            +
                        A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
         
     | 
| 46 | 
         
            +
                        be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
         
     | 
| 47 | 
         
            +
                        used to condition the original UNet's downsampling activations.
         
     | 
| 48 | 
         
            +
                    mid_down_block_re_sample (`torch.Tensor`):
         
     | 
| 49 | 
         
            +
                        The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
         
     | 
| 50 | 
         
            +
                        `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
         
     | 
| 51 | 
         
            +
                        Output can be used to condition the original UNet's middle block activation.
         
     | 
| 52 | 
         
            +
                """
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                down_block_res_samples: Tuple[torch.Tensor]
         
     | 
| 55 | 
         
            +
                mid_block_res_sample: torch.Tensor
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            class ControlNetConditioningEmbedding(nn.Module):
         
     | 
| 59 | 
         
            +
                """
         
     | 
| 60 | 
         
            +
                Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
         
     | 
| 61 | 
         
            +
                [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
         
     | 
| 62 | 
         
            +
                training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
         
     | 
| 63 | 
         
            +
                convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
         
     | 
| 64 | 
         
            +
                (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
         
     | 
| 65 | 
         
            +
                model) to encode image-space conditions ... into feature maps ..."
         
     | 
| 66 | 
         
            +
                """
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                def __init__(
         
     | 
| 69 | 
         
            +
                    self,
         
     | 
| 70 | 
         
            +
                    conditioning_embedding_channels: int,
         
     | 
| 71 | 
         
            +
                    conditioning_channels: int = 3,
         
     | 
| 72 | 
         
            +
                    block_out_channels: Tuple[int] = (16, 32, 96, 256),
         
     | 
| 73 | 
         
            +
                ):
         
     | 
| 74 | 
         
            +
                    super().__init__()
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                    self.blocks = nn.ModuleList([])
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    for i in range(len(block_out_channels) - 1):
         
     | 
| 81 | 
         
            +
                        channel_in = block_out_channels[i]
         
     | 
| 82 | 
         
            +
                        channel_out = block_out_channels[i + 1]
         
     | 
| 83 | 
         
            +
                        self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
         
     | 
| 84 | 
         
            +
                        self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                    self.conv_out = zero_module(
         
     | 
| 87 | 
         
            +
                        nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
         
     | 
| 88 | 
         
            +
                    )
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                def forward(self, conditioning):
         
     | 
| 91 | 
         
            +
                    embedding = self.conv_in(conditioning)
         
     | 
| 92 | 
         
            +
                    embedding = F.silu(embedding)
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    for block in self.blocks:
         
     | 
| 95 | 
         
            +
                        embedding = block(embedding)
         
     | 
| 96 | 
         
            +
                        embedding = F.silu(embedding)
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                    embedding = self.conv_out(embedding)
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                    return embedding
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
            class ControlNetModel(ModelMixin, ConfigMixin):
         
     | 
| 104 | 
         
            +
                """
         
     | 
| 105 | 
         
            +
                A ControlNet model.
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                Args:
         
     | 
| 108 | 
         
            +
                    in_channels (`int`, defaults to 4):
         
     | 
| 109 | 
         
            +
                        The number of channels in the input sample.
         
     | 
| 110 | 
         
            +
                    flip_sin_to_cos (`bool`, defaults to `True`):
         
     | 
| 111 | 
         
            +
                        Whether to flip the sin to cos in the time embedding.
         
     | 
| 112 | 
         
            +
                    freq_shift (`int`, defaults to 0):
         
     | 
| 113 | 
         
            +
                        The frequency shift to apply to the time embedding.
         
     | 
| 114 | 
         
            +
                    down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
         
     | 
| 115 | 
         
            +
                        The tuple of downsample blocks to use.
         
     | 
| 116 | 
         
            +
                    only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
         
     | 
| 117 | 
         
            +
                    block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
         
     | 
| 118 | 
         
            +
                        The tuple of output channels for each block.
         
     | 
| 119 | 
         
            +
                    layers_per_block (`int`, defaults to 2):
         
     | 
| 120 | 
         
            +
                        The number of layers per block.
         
     | 
| 121 | 
         
            +
                    downsample_padding (`int`, defaults to 1):
         
     | 
| 122 | 
         
            +
                        The padding to use for the downsampling convolution.
         
     | 
| 123 | 
         
            +
                    mid_block_scale_factor (`float`, defaults to 1):
         
     | 
| 124 | 
         
            +
                        The scale factor to use for the mid block.
         
     | 
| 125 | 
         
            +
                    act_fn (`str`, defaults to "silu"):
         
     | 
| 126 | 
         
            +
                        The activation function to use.
         
     | 
| 127 | 
         
            +
                    norm_num_groups (`int`, *optional*, defaults to 32):
         
     | 
| 128 | 
         
            +
                        The number of groups to use for the normalization. If None, normalization and activation layers is skipped
         
     | 
| 129 | 
         
            +
                        in post-processing.
         
     | 
| 130 | 
         
            +
                    norm_eps (`float`, defaults to 1e-5):
         
     | 
| 131 | 
         
            +
                        The epsilon to use for the normalization.
         
     | 
| 132 | 
         
            +
                    cross_attention_dim (`int`, defaults to 1280):
         
     | 
| 133 | 
         
            +
                        The dimension of the cross attention features.
         
     | 
| 134 | 
         
            +
                    attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
         
     | 
| 135 | 
         
            +
                        The dimension of the attention heads.
         
     | 
| 136 | 
         
            +
                    use_linear_projection (`bool`, defaults to `False`):
         
     | 
| 137 | 
         
            +
                    class_embed_type (`str`, *optional*, defaults to `None`):
         
     | 
| 138 | 
         
            +
                        The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
         
     | 
| 139 | 
         
            +
                        `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
         
     | 
| 140 | 
         
            +
                    num_class_embeds (`int`, *optional*, defaults to 0):
         
     | 
| 141 | 
         
            +
                        Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
         
     | 
| 142 | 
         
            +
                        class conditioning with `class_embed_type` equal to `None`.
         
     | 
| 143 | 
         
            +
                    upcast_attention (`bool`, defaults to `False`):
         
     | 
| 144 | 
         
            +
                    resnet_time_scale_shift (`str`, defaults to `"default"`):
         
     | 
| 145 | 
         
            +
                        Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
         
     | 
| 146 | 
         
            +
                    projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
         
     | 
| 147 | 
         
            +
                        The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
         
     | 
| 148 | 
         
            +
                        `class_embed_type="projection"`.
         
     | 
| 149 | 
         
            +
                    controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
         
     | 
| 150 | 
         
            +
                        The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
         
     | 
| 151 | 
         
            +
                    conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
         
     | 
| 152 | 
         
            +
                        The tuple of output channel for each block in the `conditioning_embedding` layer.
         
     | 
| 153 | 
         
            +
                    global_pool_conditions (`bool`, defaults to `False`):
         
     | 
| 154 | 
         
            +
                """
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                _supports_gradient_checkpointing = True
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                @register_to_config
         
     | 
| 159 | 
         
            +
                def __init__(
         
     | 
| 160 | 
         
            +
                    self,
         
     | 
| 161 | 
         
            +
                    in_channels: int = 4,
         
     | 
| 162 | 
         
            +
                    conditioning_channels: int = 3,
         
     | 
| 163 | 
         
            +
                    flip_sin_to_cos: bool = True,
         
     | 
| 164 | 
         
            +
                    freq_shift: int = 0,
         
     | 
| 165 | 
         
            +
                    down_block_types: Tuple[str] = (
         
     | 
| 166 | 
         
            +
                        "CrossAttnDownBlock2D",
         
     | 
| 167 | 
         
            +
                        "CrossAttnDownBlock2D",
         
     | 
| 168 | 
         
            +
                        "CrossAttnDownBlock2D",
         
     | 
| 169 | 
         
            +
                        "DownBlock2D",
         
     | 
| 170 | 
         
            +
                    ),
         
     | 
| 171 | 
         
            +
                    only_cross_attention: Union[bool, Tuple[bool]] = False,
         
     | 
| 172 | 
         
            +
                    block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
         
     | 
| 173 | 
         
            +
                    layers_per_block: int = 2,
         
     | 
| 174 | 
         
            +
                    downsample_padding: int = 1,
         
     | 
| 175 | 
         
            +
                    mid_block_scale_factor: float = 1,
         
     | 
| 176 | 
         
            +
                    act_fn: str = "silu",
         
     | 
| 177 | 
         
            +
                    norm_num_groups: Optional[int] = 32,
         
     | 
| 178 | 
         
            +
                    norm_eps: float = 1e-5,
         
     | 
| 179 | 
         
            +
                    cross_attention_dim: int = 1280,
         
     | 
| 180 | 
         
            +
                    attention_head_dim: Union[int, Tuple[int]] = 8,
         
     | 
| 181 | 
         
            +
                    num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
         
     | 
| 182 | 
         
            +
                    use_linear_projection: bool = False,
         
     | 
| 183 | 
         
            +
                    class_embed_type: Optional[str] = None,
         
     | 
| 184 | 
         
            +
                    num_class_embeds: Optional[int] = None,
         
     | 
| 185 | 
         
            +
                    upcast_attention: bool = False,
         
     | 
| 186 | 
         
            +
                    resnet_time_scale_shift: str = "default",
         
     | 
| 187 | 
         
            +
                    projection_class_embeddings_input_dim: Optional[int] = None,
         
     | 
| 188 | 
         
            +
                    controlnet_conditioning_channel_order: str = "rgb",
         
     | 
| 189 | 
         
            +
                    conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
         
     | 
| 190 | 
         
            +
                    global_pool_conditions: bool = False,
         
     | 
| 191 | 
         
            +
                ):
         
     | 
| 192 | 
         
            +
                    super().__init__()
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                    # If `num_attention_heads` is not defined (which is the case for most models)
         
     | 
| 195 | 
         
            +
                    # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
         
     | 
| 196 | 
         
            +
                    # The reason for this behavior is to correct for incorrectly named variables that were introduced
         
     | 
| 197 | 
         
            +
                    # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
         
     | 
| 198 | 
         
            +
                    # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
         
     | 
| 199 | 
         
            +
                    # which is why we correct for the naming here.
         
     | 
| 200 | 
         
            +
                    num_attention_heads = num_attention_heads or attention_head_dim
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                    # Check inputs
         
     | 
| 203 | 
         
            +
                    if len(block_out_channels) != len(down_block_types):
         
     | 
| 204 | 
         
            +
                        raise ValueError(
         
     | 
| 205 | 
         
            +
                            f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
         
     | 
| 206 | 
         
            +
                        )
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                    if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
         
     | 
| 209 | 
         
            +
                        raise ValueError(
         
     | 
| 210 | 
         
            +
                            f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
         
     | 
| 211 | 
         
            +
                        )
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
                    if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
         
     | 
| 214 | 
         
            +
                        raise ValueError(
         
     | 
| 215 | 
         
            +
                            f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
         
     | 
| 216 | 
         
            +
                        )
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
                    # input
         
     | 
| 219 | 
         
            +
                    conv_in_kernel = 3
         
     | 
| 220 | 
         
            +
                    conv_in_padding = (conv_in_kernel - 1) // 2
         
     | 
| 221 | 
         
            +
                    self.conv_in = nn.Conv2d(
         
     | 
| 222 | 
         
            +
                        in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
         
     | 
| 223 | 
         
            +
                    )
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                    # time
         
     | 
| 226 | 
         
            +
                    time_embed_dim = block_out_channels[0] * 4
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                    self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
         
     | 
| 229 | 
         
            +
                    timestep_input_dim = block_out_channels[0]
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                    self.time_embedding = TimestepEmbedding(
         
     | 
| 232 | 
         
            +
                        timestep_input_dim,
         
     | 
| 233 | 
         
            +
                        time_embed_dim,
         
     | 
| 234 | 
         
            +
                        act_fn=act_fn,
         
     | 
| 235 | 
         
            +
                    )
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                    # class embedding
         
     | 
| 238 | 
         
            +
                    if class_embed_type is None and num_class_embeds is not None:
         
     | 
| 239 | 
         
            +
                        self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
         
     | 
| 240 | 
         
            +
                    elif class_embed_type == "timestep":
         
     | 
| 241 | 
         
            +
                        self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
         
     | 
| 242 | 
         
            +
                    elif class_embed_type == "identity":
         
     | 
| 243 | 
         
            +
                        self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
         
     | 
| 244 | 
         
            +
                    elif class_embed_type == "projection":
         
     | 
| 245 | 
         
            +
                        if projection_class_embeddings_input_dim is None:
         
     | 
| 246 | 
         
            +
                            raise ValueError(
         
     | 
| 247 | 
         
            +
                                "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
         
     | 
| 248 | 
         
            +
                            )
         
     | 
| 249 | 
         
            +
                        # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
         
     | 
| 250 | 
         
            +
                        # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
         
     | 
| 251 | 
         
            +
                        # 2. it projects from an arbitrary input dimension.
         
     | 
| 252 | 
         
            +
                        #
         
     | 
| 253 | 
         
            +
                        # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
         
     | 
| 254 | 
         
            +
                        # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
         
     | 
| 255 | 
         
            +
                        # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
         
     | 
| 256 | 
         
            +
                        self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
         
     | 
| 257 | 
         
            +
                    else:
         
     | 
| 258 | 
         
            +
                        self.class_embedding = None
         
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
                    # control net conditioning embedding
         
     | 
| 261 | 
         
            +
                    self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
         
     | 
| 262 | 
         
            +
                        conditioning_embedding_channels=block_out_channels[0],
         
     | 
| 263 | 
         
            +
                        block_out_channels=conditioning_embedding_out_channels,
         
     | 
| 264 | 
         
            +
                        conditioning_channels=conditioning_channels,
         
     | 
| 265 | 
         
            +
                    )
         
     | 
| 266 | 
         
            +
             
     | 
| 267 | 
         
            +
                    self.down_blocks = nn.ModuleList([])
         
     | 
| 268 | 
         
            +
                    self.controlnet_down_blocks = nn.ModuleList([])
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
                    if isinstance(only_cross_attention, bool):
         
     | 
| 271 | 
         
            +
                        only_cross_attention = [only_cross_attention] * len(down_block_types)
         
     | 
| 272 | 
         
            +
             
     | 
| 273 | 
         
            +
                    if isinstance(attention_head_dim, int):
         
     | 
| 274 | 
         
            +
                        attention_head_dim = (attention_head_dim,) * len(down_block_types)
         
     | 
| 275 | 
         
            +
             
     | 
| 276 | 
         
            +
                    if isinstance(num_attention_heads, int):
         
     | 
| 277 | 
         
            +
                        num_attention_heads = (num_attention_heads,) * len(down_block_types)
         
     | 
| 278 | 
         
            +
             
     | 
| 279 | 
         
            +
                    # down
         
     | 
| 280 | 
         
            +
                    output_channel = block_out_channels[0]
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
                    controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
         
     | 
| 283 | 
         
            +
                    controlnet_block = zero_module(controlnet_block)
         
     | 
| 284 | 
         
            +
                    self.controlnet_down_blocks.append(controlnet_block)
         
     | 
| 285 | 
         
            +
             
     | 
| 286 | 
         
            +
                    for i, down_block_type in enumerate(down_block_types):
         
     | 
| 287 | 
         
            +
                        input_channel = output_channel
         
     | 
| 288 | 
         
            +
                        output_channel = block_out_channels[i]
         
     | 
| 289 | 
         
            +
                        is_final_block = i == len(block_out_channels) - 1
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                        down_block = get_down_block(
         
     | 
| 292 | 
         
            +
                            down_block_type,
         
     | 
| 293 | 
         
            +
                            num_layers=layers_per_block,
         
     | 
| 294 | 
         
            +
                            in_channels=input_channel,
         
     | 
| 295 | 
         
            +
                            out_channels=output_channel,
         
     | 
| 296 | 
         
            +
                            temb_channels=time_embed_dim,
         
     | 
| 297 | 
         
            +
                            add_downsample=not is_final_block,
         
     | 
| 298 | 
         
            +
                            resnet_eps=norm_eps,
         
     | 
| 299 | 
         
            +
                            resnet_act_fn=act_fn,
         
     | 
| 300 | 
         
            +
                            resnet_groups=norm_num_groups,
         
     | 
| 301 | 
         
            +
                            cross_attention_dim=cross_attention_dim,
         
     | 
| 302 | 
         
            +
                            num_attention_heads=num_attention_heads[i],
         
     | 
| 303 | 
         
            +
                            attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
         
     | 
| 304 | 
         
            +
                            downsample_padding=downsample_padding,
         
     | 
| 305 | 
         
            +
                            use_linear_projection=use_linear_projection,
         
     | 
| 306 | 
         
            +
                            only_cross_attention=only_cross_attention[i],
         
     | 
| 307 | 
         
            +
                            upcast_attention=upcast_attention,
         
     | 
| 308 | 
         
            +
                            resnet_time_scale_shift=resnet_time_scale_shift,
         
     | 
| 309 | 
         
            +
                        )
         
     | 
| 310 | 
         
            +
                        self.down_blocks.append(down_block)
         
     | 
| 311 | 
         
            +
             
     | 
| 312 | 
         
            +
                        for _ in range(layers_per_block):
         
     | 
| 313 | 
         
            +
                            controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
         
     | 
| 314 | 
         
            +
                            controlnet_block = zero_module(controlnet_block)
         
     | 
| 315 | 
         
            +
                            self.controlnet_down_blocks.append(controlnet_block)
         
     | 
| 316 | 
         
            +
             
     | 
| 317 | 
         
            +
                        if not is_final_block:
         
     | 
| 318 | 
         
            +
                            controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
         
     | 
| 319 | 
         
            +
                            controlnet_block = zero_module(controlnet_block)
         
     | 
| 320 | 
         
            +
                            self.controlnet_down_blocks.append(controlnet_block)
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
                    # mid
         
     | 
| 323 | 
         
            +
                    mid_block_channel = block_out_channels[-1]
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
                    controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
         
     | 
| 326 | 
         
            +
                    controlnet_block = zero_module(controlnet_block)
         
     | 
| 327 | 
         
            +
                    self.controlnet_mid_block = controlnet_block
         
     | 
| 328 | 
         
            +
             
     | 
| 329 | 
         
            +
                    self.mid_block = UNetMidBlock2DCrossAttn(
         
     | 
| 330 | 
         
            +
                        in_channels=mid_block_channel,
         
     | 
| 331 | 
         
            +
                        temb_channels=time_embed_dim,
         
     | 
| 332 | 
         
            +
                        resnet_eps=norm_eps,
         
     | 
| 333 | 
         
            +
                        resnet_act_fn=act_fn,
         
     | 
| 334 | 
         
            +
                        output_scale_factor=mid_block_scale_factor,
         
     | 
| 335 | 
         
            +
                        resnet_time_scale_shift=resnet_time_scale_shift,
         
     | 
| 336 | 
         
            +
                        cross_attention_dim=cross_attention_dim,
         
     | 
| 337 | 
         
            +
                        num_attention_heads=num_attention_heads[-1],
         
     | 
| 338 | 
         
            +
                        resnet_groups=norm_num_groups,
         
     | 
| 339 | 
         
            +
                        use_linear_projection=use_linear_projection,
         
     | 
| 340 | 
         
            +
                        upcast_attention=upcast_attention,
         
     | 
| 341 | 
         
            +
                    )
         
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
                @classmethod
         
     | 
| 344 | 
         
            +
                def from_unet(
         
     | 
| 345 | 
         
            +
                    cls,
         
     | 
| 346 | 
         
            +
                    unet: UNet2DConditionModel,
         
     | 
| 347 | 
         
            +
                    controlnet_conditioning_channel_order: str = "rgb",
         
     | 
| 348 | 
         
            +
                    conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
         
     | 
| 349 | 
         
            +
                    load_weights_from_unet: bool = True,
         
     | 
| 350 | 
         
            +
                ):
         
     | 
| 351 | 
         
            +
                    r"""
         
     | 
| 352 | 
         
            +
                    Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
         
     | 
| 353 | 
         
            +
             
     | 
| 354 | 
         
            +
                    Parameters:
         
     | 
| 355 | 
         
            +
                        unet (`UNet2DConditionModel`):
         
     | 
| 356 | 
         
            +
                            The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
         
     | 
| 357 | 
         
            +
                            where applicable.
         
     | 
| 358 | 
         
            +
                    """
         
     | 
| 359 | 
         
            +
                    controlnet = cls(
         
     | 
| 360 | 
         
            +
                        in_channels=unet.config.in_channels,
         
     | 
| 361 | 
         
            +
                        flip_sin_to_cos=unet.config.flip_sin_to_cos,
         
     | 
| 362 | 
         
            +
                        freq_shift=unet.config.freq_shift,
         
     | 
| 363 | 
         
            +
                        down_block_types=unet.config.down_block_types,
         
     | 
| 364 | 
         
            +
                        only_cross_attention=unet.config.only_cross_attention,
         
     | 
| 365 | 
         
            +
                        block_out_channels=unet.config.block_out_channels,
         
     | 
| 366 | 
         
            +
                        layers_per_block=unet.config.layers_per_block,
         
     | 
| 367 | 
         
            +
                        downsample_padding=unet.config.downsample_padding,
         
     | 
| 368 | 
         
            +
                        mid_block_scale_factor=unet.config.mid_block_scale_factor,
         
     | 
| 369 | 
         
            +
                        act_fn=unet.config.act_fn,
         
     | 
| 370 | 
         
            +
                        norm_num_groups=unet.config.norm_num_groups,
         
     | 
| 371 | 
         
            +
                        norm_eps=unet.config.norm_eps,
         
     | 
| 372 | 
         
            +
                        cross_attention_dim=unet.config.cross_attention_dim,
         
     | 
| 373 | 
         
            +
                        attention_head_dim=unet.config.attention_head_dim,
         
     | 
| 374 | 
         
            +
                        num_attention_heads=unet.config.num_attention_heads,
         
     | 
| 375 | 
         
            +
                        use_linear_projection=unet.config.use_linear_projection,
         
     | 
| 376 | 
         
            +
                        class_embed_type=unet.config.class_embed_type,
         
     | 
| 377 | 
         
            +
                        num_class_embeds=unet.config.num_class_embeds,
         
     | 
| 378 | 
         
            +
                        upcast_attention=unet.config.upcast_attention,
         
     | 
| 379 | 
         
            +
                        resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
         
     | 
| 380 | 
         
            +
                        projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
         
     | 
| 381 | 
         
            +
                        controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
         
     | 
| 382 | 
         
            +
                        conditioning_embedding_out_channels=conditioning_embedding_out_channels,
         
     | 
| 383 | 
         
            +
                    )
         
     | 
| 384 | 
         
            +
             
     | 
| 385 | 
         
            +
                    if load_weights_from_unet:
         
     | 
| 386 | 
         
            +
                        controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
         
     | 
| 387 | 
         
            +
                        controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
         
     | 
| 388 | 
         
            +
                        controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
         
     | 
| 389 | 
         
            +
             
     | 
| 390 | 
         
            +
                        if controlnet.class_embedding:
         
     | 
| 391 | 
         
            +
                            controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
         
     | 
| 392 | 
         
            +
             
     | 
| 393 | 
         
            +
                        controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
         
     | 
| 394 | 
         
            +
                        controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
         
     | 
| 395 | 
         
            +
             
     | 
| 396 | 
         
            +
                    return controlnet
         
     | 
| 397 | 
         
            +
             
     | 
| 398 | 
         
            +
                @property
         
     | 
| 399 | 
         
            +
                # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
         
     | 
| 400 | 
         
            +
                def attn_processors(self) -> Dict[str, AttentionProcessor]:
         
     | 
| 401 | 
         
            +
                    r"""
         
     | 
| 402 | 
         
            +
                    Returns:
         
     | 
| 403 | 
         
            +
                        `dict` of attention processors: A dictionary containing all attention processors used in the model with
         
     | 
| 404 | 
         
            +
                        indexed by its weight name.
         
     | 
| 405 | 
         
            +
                    """
         
     | 
| 406 | 
         
            +
                    # set recursively
         
     | 
| 407 | 
         
            +
                    processors = {}
         
     | 
| 408 | 
         
            +
             
     | 
| 409 | 
         
            +
                    def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
         
     | 
| 410 | 
         
            +
                        if hasattr(module, "set_processor"):
         
     | 
| 411 | 
         
            +
                            processors[f"{name}.processor"] = module.processor
         
     | 
| 412 | 
         
            +
             
     | 
| 413 | 
         
            +
                        for sub_name, child in module.named_children():
         
     | 
| 414 | 
         
            +
                            fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
         
     | 
| 415 | 
         
            +
             
     | 
| 416 | 
         
            +
                        return processors
         
     | 
| 417 | 
         
            +
             
     | 
| 418 | 
         
            +
                    for name, module in self.named_children():
         
     | 
| 419 | 
         
            +
                        fn_recursive_add_processors(name, module, processors)
         
     | 
| 420 | 
         
            +
             
     | 
| 421 | 
         
            +
                    return processors
         
     | 
| 422 | 
         
            +
             
     | 
| 423 | 
         
            +
                # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
         
     | 
| 424 | 
         
            +
                def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
         
     | 
| 425 | 
         
            +
                    r"""
         
     | 
| 426 | 
         
            +
                    Sets the attention processor to use to compute attention.
         
     | 
| 427 | 
         
            +
             
     | 
| 428 | 
         
            +
                    Parameters:
         
     | 
| 429 | 
         
            +
                        processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
         
     | 
| 430 | 
         
            +
                            The instantiated processor class or a dictionary of processor classes that will be set as the processor
         
     | 
| 431 | 
         
            +
                            for **all** `Attention` layers.
         
     | 
| 432 | 
         
            +
             
     | 
| 433 | 
         
            +
                            If `processor` is a dict, the key needs to define the path to the corresponding cross attention
         
     | 
| 434 | 
         
            +
                            processor. This is strongly recommended when setting trainable attention processors.
         
     | 
| 435 | 
         
            +
             
     | 
| 436 | 
         
            +
                    """
         
     | 
| 437 | 
         
            +
                    count = len(self.attn_processors.keys())
         
     | 
| 438 | 
         
            +
             
     | 
| 439 | 
         
            +
                    if isinstance(processor, dict) and len(processor) != count:
         
     | 
| 440 | 
         
            +
                        raise ValueError(
         
     | 
| 441 | 
         
            +
                            f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
         
     | 
| 442 | 
         
            +
                            f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
         
     | 
| 443 | 
         
            +
                        )
         
     | 
| 444 | 
         
            +
             
     | 
| 445 | 
         
            +
                    def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
         
     | 
| 446 | 
         
            +
                        if hasattr(module, "set_processor"):
         
     | 
| 447 | 
         
            +
                            if not isinstance(processor, dict):
         
     | 
| 448 | 
         
            +
                                module.set_processor(processor)
         
     | 
| 449 | 
         
            +
                            else:
         
     | 
| 450 | 
         
            +
                                module.set_processor(processor.pop(f"{name}.processor"))
         
     | 
| 451 | 
         
            +
             
     | 
| 452 | 
         
            +
                        for sub_name, child in module.named_children():
         
     | 
| 453 | 
         
            +
                            fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
         
     | 
| 454 | 
         
            +
             
     | 
| 455 | 
         
            +
                    for name, module in self.named_children():
         
     | 
| 456 | 
         
            +
                        fn_recursive_attn_processor(name, module, processor)
         
     | 
| 457 | 
         
            +
             
     | 
| 458 | 
         
            +
                # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
         
     | 
| 459 | 
         
            +
                def set_default_attn_processor(self):
         
     | 
| 460 | 
         
            +
                    """
         
     | 
| 461 | 
         
            +
                    Disables custom attention processors and sets the default attention implementation.
         
     | 
| 462 | 
         
            +
                    """
         
     | 
| 463 | 
         
            +
                    self.set_attn_processor(AttnProcessor())
         
     | 
| 464 | 
         
            +
             
     | 
| 465 | 
         
            +
                # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
         
     | 
| 466 | 
         
            +
                def set_attention_slice(self, slice_size):
         
     | 
| 467 | 
         
            +
                    r"""
         
     | 
| 468 | 
         
            +
                    Enable sliced attention computation.
         
     | 
| 469 | 
         
            +
             
     | 
| 470 | 
         
            +
                    When this option is enabled, the attention module splits the input tensor in slices to compute attention in
         
     | 
| 471 | 
         
            +
                    several steps. This is useful for saving some memory in exchange for a small decrease in speed.
         
     | 
| 472 | 
         
            +
             
     | 
| 473 | 
         
            +
                    Args:
         
     | 
| 474 | 
         
            +
                        slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
         
     | 
| 475 | 
         
            +
                            When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
         
     | 
| 476 | 
         
            +
                            `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
         
     | 
| 477 | 
         
            +
                            provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
         
     | 
| 478 | 
         
            +
                            must be a multiple of `slice_size`.
         
     | 
| 479 | 
         
            +
                    """
         
     | 
| 480 | 
         
            +
                    sliceable_head_dims = []
         
     | 
| 481 | 
         
            +
             
     | 
| 482 | 
         
            +
                    def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
         
     | 
| 483 | 
         
            +
                        if hasattr(module, "set_attention_slice"):
         
     | 
| 484 | 
         
            +
                            sliceable_head_dims.append(module.sliceable_head_dim)
         
     | 
| 485 | 
         
            +
             
     | 
| 486 | 
         
            +
                        for child in module.children():
         
     | 
| 487 | 
         
            +
                            fn_recursive_retrieve_sliceable_dims(child)
         
     | 
| 488 | 
         
            +
             
     | 
| 489 | 
         
            +
                    # retrieve number of attention layers
         
     | 
| 490 | 
         
            +
                    for module in self.children():
         
     | 
| 491 | 
         
            +
                        fn_recursive_retrieve_sliceable_dims(module)
         
     | 
| 492 | 
         
            +
             
     | 
| 493 | 
         
            +
                    num_sliceable_layers = len(sliceable_head_dims)
         
     | 
| 494 | 
         
            +
             
     | 
| 495 | 
         
            +
                    if slice_size == "auto":
         
     | 
| 496 | 
         
            +
                        # half the attention head size is usually a good trade-off between
         
     | 
| 497 | 
         
            +
                        # speed and memory
         
     | 
| 498 | 
         
            +
                        slice_size = [dim // 2 for dim in sliceable_head_dims]
         
     | 
| 499 | 
         
            +
                    elif slice_size == "max":
         
     | 
| 500 | 
         
            +
                        # make smallest slice possible
         
     | 
| 501 | 
         
            +
                        slice_size = num_sliceable_layers * [1]
         
     | 
| 502 | 
         
            +
             
     | 
| 503 | 
         
            +
                    slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
         
     | 
| 504 | 
         
            +
             
     | 
| 505 | 
         
            +
                    if len(slice_size) != len(sliceable_head_dims):
         
     | 
| 506 | 
         
            +
                        raise ValueError(
         
     | 
| 507 | 
         
            +
                            f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
         
     | 
| 508 | 
         
            +
                            f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
         
     | 
| 509 | 
         
            +
                        )
         
     | 
| 510 | 
         
            +
             
     | 
| 511 | 
         
            +
                    for i in range(len(slice_size)):
         
     | 
| 512 | 
         
            +
                        size = slice_size[i]
         
     | 
| 513 | 
         
            +
                        dim = sliceable_head_dims[i]
         
     | 
| 514 | 
         
            +
                        if size is not None and size > dim:
         
     | 
| 515 | 
         
            +
                            raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
         
     | 
| 516 | 
         
            +
             
     | 
| 517 | 
         
            +
                    # Recursively walk through all the children.
         
     | 
| 518 | 
         
            +
                    # Any children which exposes the set_attention_slice method
         
     | 
| 519 | 
         
            +
                    # gets the message
         
     | 
| 520 | 
         
            +
                    def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
         
     | 
| 521 | 
         
            +
                        if hasattr(module, "set_attention_slice"):
         
     | 
| 522 | 
         
            +
                            module.set_attention_slice(slice_size.pop())
         
     | 
| 523 | 
         
            +
             
     | 
| 524 | 
         
            +
                        for child in module.children():
         
     | 
| 525 | 
         
            +
                            fn_recursive_set_attention_slice(child, slice_size)
         
     | 
| 526 | 
         
            +
             
     | 
| 527 | 
         
            +
                    reversed_slice_size = list(reversed(slice_size))
         
     | 
| 528 | 
         
            +
                    for module in self.children():
         
     | 
| 529 | 
         
            +
                        fn_recursive_set_attention_slice(module, reversed_slice_size)
         
     | 
| 530 | 
         
            +
             
     | 
| 531 | 
         
            +
                def _set_gradient_checkpointing(self, module, value=False):
         
     | 
| 532 | 
         
            +
                    if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
         
     | 
| 533 | 
         
            +
                        module.gradient_checkpointing = value
         
     | 
| 534 | 
         
            +
             
     | 
| 535 | 
         
            +
                def forward(
         
     | 
| 536 | 
         
            +
                    self,
         
     | 
| 537 | 
         
            +
                    sample: torch.FloatTensor,
         
     | 
| 538 | 
         
            +
                    timestep: Union[torch.Tensor, float, int],
         
     | 
| 539 | 
         
            +
                    encoder_hidden_states: torch.Tensor,
         
     | 
| 540 | 
         
            +
                    controlnet_cond: torch.FloatTensor,
         
     | 
| 541 | 
         
            +
                    conditioning_scale: float = 1.0,
         
     | 
| 542 | 
         
            +
                    class_labels: Optional[torch.Tensor] = None,
         
     | 
| 543 | 
         
            +
                    timestep_cond: Optional[torch.Tensor] = None,
         
     | 
| 544 | 
         
            +
                    attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 545 | 
         
            +
                    cross_attention_kwargs: Optional[Dict[str, Any]] = None,
         
     | 
| 546 | 
         
            +
                    guess_mode: bool = False,
         
     | 
| 547 | 
         
            +
                    return_dict: bool = True,
         
     | 
| 548 | 
         
            +
                ) -> Union[ControlNetOutput, Tuple]:
         
     | 
| 549 | 
         
            +
                    """
         
     | 
| 550 | 
         
            +
                    The [`ControlNetModel`] forward method.
         
     | 
| 551 | 
         
            +
             
     | 
| 552 | 
         
            +
                    Args:
         
     | 
| 553 | 
         
            +
                        sample (`torch.FloatTensor`):
         
     | 
| 554 | 
         
            +
                            The noisy input tensor.
         
     | 
| 555 | 
         
            +
                        timestep (`Union[torch.Tensor, float, int]`):
         
     | 
| 556 | 
         
            +
                            The number of timesteps to denoise an input.
         
     | 
| 557 | 
         
            +
                        encoder_hidden_states (`torch.Tensor`):
         
     | 
| 558 | 
         
            +
                            The encoder hidden states.
         
     | 
| 559 | 
         
            +
                        controlnet_cond (`torch.FloatTensor`):
         
     | 
| 560 | 
         
            +
                            The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
         
     | 
| 561 | 
         
            +
                        conditioning_scale (`float`, defaults to `1.0`):
         
     | 
| 562 | 
         
            +
                            The scale factor for ControlNet outputs.
         
     | 
| 563 | 
         
            +
                        class_labels (`torch.Tensor`, *optional*, defaults to `None`):
         
     | 
| 564 | 
         
            +
                            Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
         
     | 
| 565 | 
         
            +
                        timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
         
     | 
| 566 | 
         
            +
                        attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
         
     | 
| 567 | 
         
            +
                        cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`):
         
     | 
| 568 | 
         
            +
                            A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
         
     | 
| 569 | 
         
            +
                        guess_mode (`bool`, defaults to `False`):
         
     | 
| 570 | 
         
            +
                            In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
         
     | 
| 571 | 
         
            +
                            you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
         
     | 
| 572 | 
         
            +
                        return_dict (`bool`, defaults to `True`):
         
     | 
| 573 | 
         
            +
                            Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
         
     | 
| 574 | 
         
            +
             
     | 
| 575 | 
         
            +
                    Returns:
         
     | 
| 576 | 
         
            +
                        [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
         
     | 
| 577 | 
         
            +
                            If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
         
     | 
| 578 | 
         
            +
                            returned where the first element is the sample tensor.
         
     | 
| 579 | 
         
            +
                    """
         
     | 
| 580 | 
         
            +
                    # check channel order
         
     | 
| 581 | 
         
            +
                    channel_order = self.config.controlnet_conditioning_channel_order
         
     | 
| 582 | 
         
            +
             
     | 
| 583 | 
         
            +
                    if channel_order == "rgb":
         
     | 
| 584 | 
         
            +
                        # in rgb order by default
         
     | 
| 585 | 
         
            +
                        ...
         
     | 
| 586 | 
         
            +
                    elif channel_order == "bgr":
         
     | 
| 587 | 
         
            +
                        controlnet_cond = torch.flip(controlnet_cond, dims=[1])
         
     | 
| 588 | 
         
            +
                    else:
         
     | 
| 589 | 
         
            +
                        raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
         
     | 
| 590 | 
         
            +
             
     | 
| 591 | 
         
            +
                    # prepare attention_mask
         
     | 
| 592 | 
         
            +
                    if attention_mask is not None:
         
     | 
| 593 | 
         
            +
                        attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
         
     | 
| 594 | 
         
            +
                        attention_mask = attention_mask.unsqueeze(1)
         
     | 
| 595 | 
         
            +
             
     | 
| 596 | 
         
            +
                    # 1. time
         
     | 
| 597 | 
         
            +
                    timesteps = timestep
         
     | 
| 598 | 
         
            +
                    if not torch.is_tensor(timesteps):
         
     | 
| 599 | 
         
            +
                        # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
         
     | 
| 600 | 
         
            +
                        # This would be a good case for the `match` statement (Python 3.10+)
         
     | 
| 601 | 
         
            +
                        is_mps = sample.device.type == "mps"
         
     | 
| 602 | 
         
            +
                        if isinstance(timestep, float):
         
     | 
| 603 | 
         
            +
                            dtype = torch.float32 if is_mps else torch.float64
         
     | 
| 604 | 
         
            +
                        else:
         
     | 
| 605 | 
         
            +
                            dtype = torch.int32 if is_mps else torch.int64
         
     | 
| 606 | 
         
            +
                        timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
         
     | 
| 607 | 
         
            +
                    elif len(timesteps.shape) == 0:
         
     | 
| 608 | 
         
            +
                        timesteps = timesteps[None].to(sample.device)
         
     | 
| 609 | 
         
            +
             
     | 
| 610 | 
         
            +
                    # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
         
     | 
| 611 | 
         
            +
                    timesteps = timesteps.expand(sample.shape[0])
         
     | 
| 612 | 
         
            +
             
     | 
| 613 | 
         
            +
                    t_emb = self.time_proj(timesteps)
         
     | 
| 614 | 
         
            +
             
     | 
| 615 | 
         
            +
                    # timesteps does not contain any weights and will always return f32 tensors
         
     | 
| 616 | 
         
            +
                    # but time_embedding might actually be running in fp16. so we need to cast here.
         
     | 
| 617 | 
         
            +
                    # there might be better ways to encapsulate this.
         
     | 
| 618 | 
         
            +
                    t_emb = t_emb.to(dtype=sample.dtype)
         
     | 
| 619 | 
         
            +
             
     | 
| 620 | 
         
            +
                    emb = self.time_embedding(t_emb, timestep_cond)
         
     | 
| 621 | 
         
            +
             
     | 
| 622 | 
         
            +
                    if self.class_embedding is not None:
         
     | 
| 623 | 
         
            +
                        if class_labels is None:
         
     | 
| 624 | 
         
            +
                            raise ValueError("class_labels should be provided when num_class_embeds > 0")
         
     | 
| 625 | 
         
            +
             
     | 
| 626 | 
         
            +
                        if self.config.class_embed_type == "timestep":
         
     | 
| 627 | 
         
            +
                            class_labels = self.time_proj(class_labels)
         
     | 
| 628 | 
         
            +
             
     | 
| 629 | 
         
            +
                        class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
         
     | 
| 630 | 
         
            +
                        emb = emb + class_emb
         
     | 
| 631 | 
         
            +
             
     | 
| 632 | 
         
            +
                    # 2. pre-process
         
     | 
| 633 | 
         
            +
                    sample = self.conv_in(sample)
         
     | 
| 634 | 
         
            +
             
     | 
| 635 | 
         
            +
                    controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
         
     | 
| 636 | 
         
            +
             
     | 
| 637 | 
         
            +
                    sample = sample + controlnet_cond
         
     | 
| 638 | 
         
            +
             
     | 
| 639 | 
         
            +
                    # 3. down
         
     | 
| 640 | 
         
            +
                    down_block_res_samples = (sample,)
         
     | 
| 641 | 
         
            +
                    for downsample_block in self.down_blocks:
         
     | 
| 642 | 
         
            +
                        if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
         
     | 
| 643 | 
         
            +
                            sample, res_samples = downsample_block(
         
     | 
| 644 | 
         
            +
                                hidden_states=sample,
         
     | 
| 645 | 
         
            +
                                temb=emb,
         
     | 
| 646 | 
         
            +
                                encoder_hidden_states=encoder_hidden_states,
         
     | 
| 647 | 
         
            +
                                attention_mask=attention_mask,
         
     | 
| 648 | 
         
            +
                                cross_attention_kwargs=cross_attention_kwargs,
         
     | 
| 649 | 
         
            +
                            )
         
     | 
| 650 | 
         
            +
                        else:
         
     | 
| 651 | 
         
            +
                            sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
         
     | 
| 652 | 
         
            +
             
     | 
| 653 | 
         
            +
                        down_block_res_samples += res_samples
         
     | 
| 654 | 
         
            +
             
     | 
| 655 | 
         
            +
                    # 4. mid
         
     | 
| 656 | 
         
            +
                    if self.mid_block is not None:
         
     | 
| 657 | 
         
            +
                        sample = self.mid_block(
         
     | 
| 658 | 
         
            +
                            sample,
         
     | 
| 659 | 
         
            +
                            emb,
         
     | 
| 660 | 
         
            +
                            encoder_hidden_states=encoder_hidden_states,
         
     | 
| 661 | 
         
            +
                            attention_mask=attention_mask,
         
     | 
| 662 | 
         
            +
                            cross_attention_kwargs=cross_attention_kwargs,
         
     | 
| 663 | 
         
            +
                        )
         
     | 
| 664 | 
         
            +
             
     | 
| 665 | 
         
            +
                    # 5. Control net blocks
         
     | 
| 666 | 
         
            +
             
     | 
| 667 | 
         
            +
                    controlnet_down_block_res_samples = ()
         
     | 
| 668 | 
         
            +
             
     | 
| 669 | 
         
            +
                    for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
         
     | 
| 670 | 
         
            +
                        down_block_res_sample = controlnet_block(down_block_res_sample)
         
     | 
| 671 | 
         
            +
                        controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
         
     | 
| 672 | 
         
            +
             
     | 
| 673 | 
         
            +
                    down_block_res_samples = controlnet_down_block_res_samples
         
     | 
| 674 | 
         
            +
             
     | 
| 675 | 
         
            +
                    mid_block_res_sample = self.controlnet_mid_block(sample)
         
     | 
| 676 | 
         
            +
             
     | 
| 677 | 
         
            +
                    # 6. scaling
         
     | 
| 678 | 
         
            +
                    if guess_mode and not self.config.global_pool_conditions:
         
     | 
| 679 | 
         
            +
                        scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device)  # 0.1 to 1.0
         
     | 
| 680 | 
         
            +
             
     | 
| 681 | 
         
            +
                        scales = scales * conditioning_scale
         
     | 
| 682 | 
         
            +
                        down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
         
     | 
| 683 | 
         
            +
                        mid_block_res_sample = mid_block_res_sample * scales[-1]  # last one
         
     | 
| 684 | 
         
            +
                    else:
         
     | 
| 685 | 
         
            +
                        down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
         
     | 
| 686 | 
         
            +
                        mid_block_res_sample = mid_block_res_sample * conditioning_scale
         
     | 
| 687 | 
         
            +
             
     | 
| 688 | 
         
            +
                    if self.config.global_pool_conditions:
         
     | 
| 689 | 
         
            +
                        down_block_res_samples = [
         
     | 
| 690 | 
         
            +
                            torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
         
     | 
| 691 | 
         
            +
                        ]
         
     | 
| 692 | 
         
            +
                        mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
         
     | 
| 693 | 
         
            +
             
     | 
| 694 | 
         
            +
                    if not return_dict:
         
     | 
| 695 | 
         
            +
                        return (down_block_res_samples, mid_block_res_sample)
         
     | 
| 696 | 
         
            +
             
     | 
| 697 | 
         
            +
                    return ControlNetOutput(
         
     | 
| 698 | 
         
            +
                        down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
         
     | 
| 699 | 
         
            +
                    )
         
     | 
| 700 | 
         
            +
             
     | 
| 701 | 
         
            +
             
     | 
| 702 | 
         
            +
            def zero_module(module):
         
     | 
| 703 | 
         
            +
                for p in module.parameters():
         
     | 
| 704 | 
         
            +
                    nn.init.zeros_(p)
         
     | 
| 705 | 
         
            +
                return module
         
     | 
    	
        6DoF/diffusers/models/controlnet_flax.py
    ADDED
    
    | 
         @@ -0,0 +1,394 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            from typing import Optional, Tuple, Union
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            import flax
         
     | 
| 17 | 
         
            +
            import flax.linen as nn
         
     | 
| 18 | 
         
            +
            import jax
         
     | 
| 19 | 
         
            +
            import jax.numpy as jnp
         
     | 
| 20 | 
         
            +
            from flax.core.frozen_dict import FrozenDict
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            from ..configuration_utils import ConfigMixin, flax_register_to_config
         
     | 
| 23 | 
         
            +
            from ..utils import BaseOutput
         
     | 
| 24 | 
         
            +
            from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
         
     | 
| 25 | 
         
            +
            from .modeling_flax_utils import FlaxModelMixin
         
     | 
| 26 | 
         
            +
            from .unet_2d_blocks_flax import (
         
     | 
| 27 | 
         
            +
                FlaxCrossAttnDownBlock2D,
         
     | 
| 28 | 
         
            +
                FlaxDownBlock2D,
         
     | 
| 29 | 
         
            +
                FlaxUNetMidBlock2DCrossAttn,
         
     | 
| 30 | 
         
            +
            )
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            @flax.struct.dataclass
         
     | 
| 34 | 
         
            +
            class FlaxControlNetOutput(BaseOutput):
         
     | 
| 35 | 
         
            +
                """
         
     | 
| 36 | 
         
            +
                The output of [`FlaxControlNetModel`].
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                Args:
         
     | 
| 39 | 
         
            +
                    down_block_res_samples (`jnp.ndarray`):
         
     | 
| 40 | 
         
            +
                    mid_block_res_sample (`jnp.ndarray`):
         
     | 
| 41 | 
         
            +
                """
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                down_block_res_samples: jnp.ndarray
         
     | 
| 44 | 
         
            +
                mid_block_res_sample: jnp.ndarray
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            class FlaxControlNetConditioningEmbedding(nn.Module):
         
     | 
| 48 | 
         
            +
                conditioning_embedding_channels: int
         
     | 
| 49 | 
         
            +
                block_out_channels: Tuple[int] = (16, 32, 96, 256)
         
     | 
| 50 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                def setup(self):
         
     | 
| 53 | 
         
            +
                    self.conv_in = nn.Conv(
         
     | 
| 54 | 
         
            +
                        self.block_out_channels[0],
         
     | 
| 55 | 
         
            +
                        kernel_size=(3, 3),
         
     | 
| 56 | 
         
            +
                        padding=((1, 1), (1, 1)),
         
     | 
| 57 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 58 | 
         
            +
                    )
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                    blocks = []
         
     | 
| 61 | 
         
            +
                    for i in range(len(self.block_out_channels) - 1):
         
     | 
| 62 | 
         
            +
                        channel_in = self.block_out_channels[i]
         
     | 
| 63 | 
         
            +
                        channel_out = self.block_out_channels[i + 1]
         
     | 
| 64 | 
         
            +
                        conv1 = nn.Conv(
         
     | 
| 65 | 
         
            +
                            channel_in,
         
     | 
| 66 | 
         
            +
                            kernel_size=(3, 3),
         
     | 
| 67 | 
         
            +
                            padding=((1, 1), (1, 1)),
         
     | 
| 68 | 
         
            +
                            dtype=self.dtype,
         
     | 
| 69 | 
         
            +
                        )
         
     | 
| 70 | 
         
            +
                        blocks.append(conv1)
         
     | 
| 71 | 
         
            +
                        conv2 = nn.Conv(
         
     | 
| 72 | 
         
            +
                            channel_out,
         
     | 
| 73 | 
         
            +
                            kernel_size=(3, 3),
         
     | 
| 74 | 
         
            +
                            strides=(2, 2),
         
     | 
| 75 | 
         
            +
                            padding=((1, 1), (1, 1)),
         
     | 
| 76 | 
         
            +
                            dtype=self.dtype,
         
     | 
| 77 | 
         
            +
                        )
         
     | 
| 78 | 
         
            +
                        blocks.append(conv2)
         
     | 
| 79 | 
         
            +
                    self.blocks = blocks
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    self.conv_out = nn.Conv(
         
     | 
| 82 | 
         
            +
                        self.conditioning_embedding_channels,
         
     | 
| 83 | 
         
            +
                        kernel_size=(3, 3),
         
     | 
| 84 | 
         
            +
                        padding=((1, 1), (1, 1)),
         
     | 
| 85 | 
         
            +
                        kernel_init=nn.initializers.zeros_init(),
         
     | 
| 86 | 
         
            +
                        bias_init=nn.initializers.zeros_init(),
         
     | 
| 87 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 88 | 
         
            +
                    )
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                def __call__(self, conditioning):
         
     | 
| 91 | 
         
            +
                    embedding = self.conv_in(conditioning)
         
     | 
| 92 | 
         
            +
                    embedding = nn.silu(embedding)
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    for block in self.blocks:
         
     | 
| 95 | 
         
            +
                        embedding = block(embedding)
         
     | 
| 96 | 
         
            +
                        embedding = nn.silu(embedding)
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                    embedding = self.conv_out(embedding)
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                    return embedding
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
            @flax_register_to_config
         
     | 
| 104 | 
         
            +
            class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
         
     | 
| 105 | 
         
            +
                r"""
         
     | 
| 106 | 
         
            +
                A ControlNet model.
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it’s generic methods
         
     | 
| 109 | 
         
            +
                implemented for all models (such as downloading or saving).
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                This model is also a Flax Linen [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
         
     | 
| 112 | 
         
            +
                subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
         
     | 
| 113 | 
         
            +
                general usage and behavior.
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                Inherent JAX features such as the following are supported:
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
         
     | 
| 118 | 
         
            +
                - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
         
     | 
| 119 | 
         
            +
                - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
         
     | 
| 120 | 
         
            +
                - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                Parameters:
         
     | 
| 123 | 
         
            +
                    sample_size (`int`, *optional*):
         
     | 
| 124 | 
         
            +
                        The size of the input sample.
         
     | 
| 125 | 
         
            +
                    in_channels (`int`, *optional*, defaults to 4):
         
     | 
| 126 | 
         
            +
                        The number of channels in the input sample.
         
     | 
| 127 | 
         
            +
                    down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
         
     | 
| 128 | 
         
            +
                        The tuple of downsample blocks to use.
         
     | 
| 129 | 
         
            +
                    block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
         
     | 
| 130 | 
         
            +
                        The tuple of output channels for each block.
         
     | 
| 131 | 
         
            +
                    layers_per_block (`int`, *optional*, defaults to 2):
         
     | 
| 132 | 
         
            +
                        The number of layers per block.
         
     | 
| 133 | 
         
            +
                    attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
         
     | 
| 134 | 
         
            +
                        The dimension of the attention heads.
         
     | 
| 135 | 
         
            +
                    num_attention_heads (`int` or `Tuple[int]`, *optional*):
         
     | 
| 136 | 
         
            +
                        The number of attention heads.
         
     | 
| 137 | 
         
            +
                    cross_attention_dim (`int`, *optional*, defaults to 768):
         
     | 
| 138 | 
         
            +
                        The dimension of the cross attention features.
         
     | 
| 139 | 
         
            +
                    dropout (`float`, *optional*, defaults to 0):
         
     | 
| 140 | 
         
            +
                        Dropout probability for down, up and bottleneck blocks.
         
     | 
| 141 | 
         
            +
                    flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
         
     | 
| 142 | 
         
            +
                        Whether to flip the sin to cos in the time embedding.
         
     | 
| 143 | 
         
            +
                    freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
         
     | 
| 144 | 
         
            +
                    controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`):
         
     | 
| 145 | 
         
            +
                        The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
         
     | 
| 146 | 
         
            +
                    conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`):
         
     | 
| 147 | 
         
            +
                        The tuple of output channel for each block in the `conditioning_embedding` layer.
         
     | 
| 148 | 
         
            +
                """
         
     | 
| 149 | 
         
            +
                sample_size: int = 32
         
     | 
| 150 | 
         
            +
                in_channels: int = 4
         
     | 
| 151 | 
         
            +
                down_block_types: Tuple[str] = (
         
     | 
| 152 | 
         
            +
                    "CrossAttnDownBlock2D",
         
     | 
| 153 | 
         
            +
                    "CrossAttnDownBlock2D",
         
     | 
| 154 | 
         
            +
                    "CrossAttnDownBlock2D",
         
     | 
| 155 | 
         
            +
                    "DownBlock2D",
         
     | 
| 156 | 
         
            +
                )
         
     | 
| 157 | 
         
            +
                only_cross_attention: Union[bool, Tuple[bool]] = False
         
     | 
| 158 | 
         
            +
                block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
         
     | 
| 159 | 
         
            +
                layers_per_block: int = 2
         
     | 
| 160 | 
         
            +
                attention_head_dim: Union[int, Tuple[int]] = 8
         
     | 
| 161 | 
         
            +
                num_attention_heads: Optional[Union[int, Tuple[int]]] = None
         
     | 
| 162 | 
         
            +
                cross_attention_dim: int = 1280
         
     | 
| 163 | 
         
            +
                dropout: float = 0.0
         
     | 
| 164 | 
         
            +
                use_linear_projection: bool = False
         
     | 
| 165 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 166 | 
         
            +
                flip_sin_to_cos: bool = True
         
     | 
| 167 | 
         
            +
                freq_shift: int = 0
         
     | 
| 168 | 
         
            +
                controlnet_conditioning_channel_order: str = "rgb"
         
     | 
| 169 | 
         
            +
                conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256)
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
         
     | 
| 172 | 
         
            +
                    # init input tensors
         
     | 
| 173 | 
         
            +
                    sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
         
     | 
| 174 | 
         
            +
                    sample = jnp.zeros(sample_shape, dtype=jnp.float32)
         
     | 
| 175 | 
         
            +
                    timesteps = jnp.ones((1,), dtype=jnp.int32)
         
     | 
| 176 | 
         
            +
                    encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
         
     | 
| 177 | 
         
            +
                    controlnet_cond_shape = (1, 3, self.sample_size * 8, self.sample_size * 8)
         
     | 
| 178 | 
         
            +
                    controlnet_cond = jnp.zeros(controlnet_cond_shape, dtype=jnp.float32)
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                    params_rng, dropout_rng = jax.random.split(rng)
         
     | 
| 181 | 
         
            +
                    rngs = {"params": params_rng, "dropout": dropout_rng}
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                    return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                def setup(self):
         
     | 
| 186 | 
         
            +
                    block_out_channels = self.block_out_channels
         
     | 
| 187 | 
         
            +
                    time_embed_dim = block_out_channels[0] * 4
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                    # If `num_attention_heads` is not defined (which is the case for most models)
         
     | 
| 190 | 
         
            +
                    # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
         
     | 
| 191 | 
         
            +
                    # The reason for this behavior is to correct for incorrectly named variables that were introduced
         
     | 
| 192 | 
         
            +
                    # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
         
     | 
| 193 | 
         
            +
                    # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
         
     | 
| 194 | 
         
            +
                    # which is why we correct for the naming here.
         
     | 
| 195 | 
         
            +
                    num_attention_heads = self.num_attention_heads or self.attention_head_dim
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                    # input
         
     | 
| 198 | 
         
            +
                    self.conv_in = nn.Conv(
         
     | 
| 199 | 
         
            +
                        block_out_channels[0],
         
     | 
| 200 | 
         
            +
                        kernel_size=(3, 3),
         
     | 
| 201 | 
         
            +
                        strides=(1, 1),
         
     | 
| 202 | 
         
            +
                        padding=((1, 1), (1, 1)),
         
     | 
| 203 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 204 | 
         
            +
                    )
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                    # time
         
     | 
| 207 | 
         
            +
                    self.time_proj = FlaxTimesteps(
         
     | 
| 208 | 
         
            +
                        block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift
         
     | 
| 209 | 
         
            +
                    )
         
     | 
| 210 | 
         
            +
                    self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                    self.controlnet_cond_embedding = FlaxControlNetConditioningEmbedding(
         
     | 
| 213 | 
         
            +
                        conditioning_embedding_channels=block_out_channels[0],
         
     | 
| 214 | 
         
            +
                        block_out_channels=self.conditioning_embedding_out_channels,
         
     | 
| 215 | 
         
            +
                    )
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
                    only_cross_attention = self.only_cross_attention
         
     | 
| 218 | 
         
            +
                    if isinstance(only_cross_attention, bool):
         
     | 
| 219 | 
         
            +
                        only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                    if isinstance(num_attention_heads, int):
         
     | 
| 222 | 
         
            +
                        num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
                    # down
         
     | 
| 225 | 
         
            +
                    down_blocks = []
         
     | 
| 226 | 
         
            +
                    controlnet_down_blocks = []
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                    output_channel = block_out_channels[0]
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                    controlnet_block = nn.Conv(
         
     | 
| 231 | 
         
            +
                        output_channel,
         
     | 
| 232 | 
         
            +
                        kernel_size=(1, 1),
         
     | 
| 233 | 
         
            +
                        padding="VALID",
         
     | 
| 234 | 
         
            +
                        kernel_init=nn.initializers.zeros_init(),
         
     | 
| 235 | 
         
            +
                        bias_init=nn.initializers.zeros_init(),
         
     | 
| 236 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 237 | 
         
            +
                    )
         
     | 
| 238 | 
         
            +
                    controlnet_down_blocks.append(controlnet_block)
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                    for i, down_block_type in enumerate(self.down_block_types):
         
     | 
| 241 | 
         
            +
                        input_channel = output_channel
         
     | 
| 242 | 
         
            +
                        output_channel = block_out_channels[i]
         
     | 
| 243 | 
         
            +
                        is_final_block = i == len(block_out_channels) - 1
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                        if down_block_type == "CrossAttnDownBlock2D":
         
     | 
| 246 | 
         
            +
                            down_block = FlaxCrossAttnDownBlock2D(
         
     | 
| 247 | 
         
            +
                                in_channels=input_channel,
         
     | 
| 248 | 
         
            +
                                out_channels=output_channel,
         
     | 
| 249 | 
         
            +
                                dropout=self.dropout,
         
     | 
| 250 | 
         
            +
                                num_layers=self.layers_per_block,
         
     | 
| 251 | 
         
            +
                                num_attention_heads=num_attention_heads[i],
         
     | 
| 252 | 
         
            +
                                add_downsample=not is_final_block,
         
     | 
| 253 | 
         
            +
                                use_linear_projection=self.use_linear_projection,
         
     | 
| 254 | 
         
            +
                                only_cross_attention=only_cross_attention[i],
         
     | 
| 255 | 
         
            +
                                dtype=self.dtype,
         
     | 
| 256 | 
         
            +
                            )
         
     | 
| 257 | 
         
            +
                        else:
         
     | 
| 258 | 
         
            +
                            down_block = FlaxDownBlock2D(
         
     | 
| 259 | 
         
            +
                                in_channels=input_channel,
         
     | 
| 260 | 
         
            +
                                out_channels=output_channel,
         
     | 
| 261 | 
         
            +
                                dropout=self.dropout,
         
     | 
| 262 | 
         
            +
                                num_layers=self.layers_per_block,
         
     | 
| 263 | 
         
            +
                                add_downsample=not is_final_block,
         
     | 
| 264 | 
         
            +
                                dtype=self.dtype,
         
     | 
| 265 | 
         
            +
                            )
         
     | 
| 266 | 
         
            +
             
     | 
| 267 | 
         
            +
                        down_blocks.append(down_block)
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                        for _ in range(self.layers_per_block):
         
     | 
| 270 | 
         
            +
                            controlnet_block = nn.Conv(
         
     | 
| 271 | 
         
            +
                                output_channel,
         
     | 
| 272 | 
         
            +
                                kernel_size=(1, 1),
         
     | 
| 273 | 
         
            +
                                padding="VALID",
         
     | 
| 274 | 
         
            +
                                kernel_init=nn.initializers.zeros_init(),
         
     | 
| 275 | 
         
            +
                                bias_init=nn.initializers.zeros_init(),
         
     | 
| 276 | 
         
            +
                                dtype=self.dtype,
         
     | 
| 277 | 
         
            +
                            )
         
     | 
| 278 | 
         
            +
                            controlnet_down_blocks.append(controlnet_block)
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                        if not is_final_block:
         
     | 
| 281 | 
         
            +
                            controlnet_block = nn.Conv(
         
     | 
| 282 | 
         
            +
                                output_channel,
         
     | 
| 283 | 
         
            +
                                kernel_size=(1, 1),
         
     | 
| 284 | 
         
            +
                                padding="VALID",
         
     | 
| 285 | 
         
            +
                                kernel_init=nn.initializers.zeros_init(),
         
     | 
| 286 | 
         
            +
                                bias_init=nn.initializers.zeros_init(),
         
     | 
| 287 | 
         
            +
                                dtype=self.dtype,
         
     | 
| 288 | 
         
            +
                            )
         
     | 
| 289 | 
         
            +
                            controlnet_down_blocks.append(controlnet_block)
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                    self.down_blocks = down_blocks
         
     | 
| 292 | 
         
            +
                    self.controlnet_down_blocks = controlnet_down_blocks
         
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
                    # mid
         
     | 
| 295 | 
         
            +
                    mid_block_channel = block_out_channels[-1]
         
     | 
| 296 | 
         
            +
                    self.mid_block = FlaxUNetMidBlock2DCrossAttn(
         
     | 
| 297 | 
         
            +
                        in_channels=mid_block_channel,
         
     | 
| 298 | 
         
            +
                        dropout=self.dropout,
         
     | 
| 299 | 
         
            +
                        num_attention_heads=num_attention_heads[-1],
         
     | 
| 300 | 
         
            +
                        use_linear_projection=self.use_linear_projection,
         
     | 
| 301 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 302 | 
         
            +
                    )
         
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
            +
                    self.controlnet_mid_block = nn.Conv(
         
     | 
| 305 | 
         
            +
                        mid_block_channel,
         
     | 
| 306 | 
         
            +
                        kernel_size=(1, 1),
         
     | 
| 307 | 
         
            +
                        padding="VALID",
         
     | 
| 308 | 
         
            +
                        kernel_init=nn.initializers.zeros_init(),
         
     | 
| 309 | 
         
            +
                        bias_init=nn.initializers.zeros_init(),
         
     | 
| 310 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 311 | 
         
            +
                    )
         
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
                def __call__(
         
     | 
| 314 | 
         
            +
                    self,
         
     | 
| 315 | 
         
            +
                    sample,
         
     | 
| 316 | 
         
            +
                    timesteps,
         
     | 
| 317 | 
         
            +
                    encoder_hidden_states,
         
     | 
| 318 | 
         
            +
                    controlnet_cond,
         
     | 
| 319 | 
         
            +
                    conditioning_scale: float = 1.0,
         
     | 
| 320 | 
         
            +
                    return_dict: bool = True,
         
     | 
| 321 | 
         
            +
                    train: bool = False,
         
     | 
| 322 | 
         
            +
                ) -> Union[FlaxControlNetOutput, Tuple]:
         
     | 
| 323 | 
         
            +
                    r"""
         
     | 
| 324 | 
         
            +
                    Args:
         
     | 
| 325 | 
         
            +
                        sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
         
     | 
| 326 | 
         
            +
                        timestep (`jnp.ndarray` or `float` or `int`): timesteps
         
     | 
| 327 | 
         
            +
                        encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
         
     | 
| 328 | 
         
            +
                        controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
         
     | 
| 329 | 
         
            +
                        conditioning_scale: (`float`) the scale factor for controlnet outputs
         
     | 
| 330 | 
         
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         
     | 
| 331 | 
         
            +
                            Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
         
     | 
| 332 | 
         
            +
                            plain tuple.
         
     | 
| 333 | 
         
            +
                        train (`bool`, *optional*, defaults to `False`):
         
     | 
| 334 | 
         
            +
                            Use deterministic functions and disable dropout when not training.
         
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
                    Returns:
         
     | 
| 337 | 
         
            +
                        [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
         
     | 
| 338 | 
         
            +
                        [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
         
     | 
| 339 | 
         
            +
                        When returning a tuple, the first element is the sample tensor.
         
     | 
| 340 | 
         
            +
                    """
         
     | 
| 341 | 
         
            +
                    channel_order = self.controlnet_conditioning_channel_order
         
     | 
| 342 | 
         
            +
                    if channel_order == "bgr":
         
     | 
| 343 | 
         
            +
                        controlnet_cond = jnp.flip(controlnet_cond, axis=1)
         
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
                    # 1. time
         
     | 
| 346 | 
         
            +
                    if not isinstance(timesteps, jnp.ndarray):
         
     | 
| 347 | 
         
            +
                        timesteps = jnp.array([timesteps], dtype=jnp.int32)
         
     | 
| 348 | 
         
            +
                    elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
         
     | 
| 349 | 
         
            +
                        timesteps = timesteps.astype(dtype=jnp.float32)
         
     | 
| 350 | 
         
            +
                        timesteps = jnp.expand_dims(timesteps, 0)
         
     | 
| 351 | 
         
            +
             
     | 
| 352 | 
         
            +
                    t_emb = self.time_proj(timesteps)
         
     | 
| 353 | 
         
            +
                    t_emb = self.time_embedding(t_emb)
         
     | 
| 354 | 
         
            +
             
     | 
| 355 | 
         
            +
                    # 2. pre-process
         
     | 
| 356 | 
         
            +
                    sample = jnp.transpose(sample, (0, 2, 3, 1))
         
     | 
| 357 | 
         
            +
                    sample = self.conv_in(sample)
         
     | 
| 358 | 
         
            +
             
     | 
| 359 | 
         
            +
                    controlnet_cond = jnp.transpose(controlnet_cond, (0, 2, 3, 1))
         
     | 
| 360 | 
         
            +
                    controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
         
     | 
| 361 | 
         
            +
                    sample += controlnet_cond
         
     | 
| 362 | 
         
            +
             
     | 
| 363 | 
         
            +
                    # 3. down
         
     | 
| 364 | 
         
            +
                    down_block_res_samples = (sample,)
         
     | 
| 365 | 
         
            +
                    for down_block in self.down_blocks:
         
     | 
| 366 | 
         
            +
                        if isinstance(down_block, FlaxCrossAttnDownBlock2D):
         
     | 
| 367 | 
         
            +
                            sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
         
     | 
| 368 | 
         
            +
                        else:
         
     | 
| 369 | 
         
            +
                            sample, res_samples = down_block(sample, t_emb, deterministic=not train)
         
     | 
| 370 | 
         
            +
                        down_block_res_samples += res_samples
         
     | 
| 371 | 
         
            +
             
     | 
| 372 | 
         
            +
                    # 4. mid
         
     | 
| 373 | 
         
            +
                    sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
         
     | 
| 374 | 
         
            +
             
     | 
| 375 | 
         
            +
                    # 5. contronet blocks
         
     | 
| 376 | 
         
            +
                    controlnet_down_block_res_samples = ()
         
     | 
| 377 | 
         
            +
                    for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
         
     | 
| 378 | 
         
            +
                        down_block_res_sample = controlnet_block(down_block_res_sample)
         
     | 
| 379 | 
         
            +
                        controlnet_down_block_res_samples += (down_block_res_sample,)
         
     | 
| 380 | 
         
            +
             
     | 
| 381 | 
         
            +
                    down_block_res_samples = controlnet_down_block_res_samples
         
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
                    mid_block_res_sample = self.controlnet_mid_block(sample)
         
     | 
| 384 | 
         
            +
             
     | 
| 385 | 
         
            +
                    # 6. scaling
         
     | 
| 386 | 
         
            +
                    down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
         
     | 
| 387 | 
         
            +
                    mid_block_res_sample *= conditioning_scale
         
     | 
| 388 | 
         
            +
             
     | 
| 389 | 
         
            +
                    if not return_dict:
         
     | 
| 390 | 
         
            +
                        return (down_block_res_samples, mid_block_res_sample)
         
     | 
| 391 | 
         
            +
             
     | 
| 392 | 
         
            +
                    return FlaxControlNetOutput(
         
     | 
| 393 | 
         
            +
                        down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
         
     | 
| 394 | 
         
            +
                    )
         
     | 
    	
        6DoF/diffusers/models/cross_attention.py
    ADDED
    
    | 
         @@ -0,0 +1,94 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            from ..utils import deprecate
         
     | 
| 15 | 
         
            +
            from .attention_processor import (  # noqa: F401
         
     | 
| 16 | 
         
            +
                Attention,
         
     | 
| 17 | 
         
            +
                AttentionProcessor,
         
     | 
| 18 | 
         
            +
                AttnAddedKVProcessor,
         
     | 
| 19 | 
         
            +
                AttnProcessor2_0,
         
     | 
| 20 | 
         
            +
                LoRAAttnProcessor,
         
     | 
| 21 | 
         
            +
                LoRALinearLayer,
         
     | 
| 22 | 
         
            +
                LoRAXFormersAttnProcessor,
         
     | 
| 23 | 
         
            +
                SlicedAttnAddedKVProcessor,
         
     | 
| 24 | 
         
            +
                SlicedAttnProcessor,
         
     | 
| 25 | 
         
            +
                XFormersAttnProcessor,
         
     | 
| 26 | 
         
            +
            )
         
     | 
| 27 | 
         
            +
            from .attention_processor import AttnProcessor as AttnProcessorRename  # noqa: F401
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            deprecate(
         
     | 
| 31 | 
         
            +
                "cross_attention",
         
     | 
| 32 | 
         
            +
                "0.20.0",
         
     | 
| 33 | 
         
            +
                "Importing from cross_attention is deprecated. Please import from diffusers.models.attention_processor instead.",
         
     | 
| 34 | 
         
            +
                standard_warn=False,
         
     | 
| 35 | 
         
            +
            )
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            AttnProcessor = AttentionProcessor
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            class CrossAttention(Attention):
         
     | 
| 42 | 
         
            +
                def __init__(self, *args, **kwargs):
         
     | 
| 43 | 
         
            +
                    deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
         
     | 
| 44 | 
         
            +
                    deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
         
     | 
| 45 | 
         
            +
                    super().__init__(*args, **kwargs)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            class CrossAttnProcessor(AttnProcessorRename):
         
     | 
| 49 | 
         
            +
                def __init__(self, *args, **kwargs):
         
     | 
| 50 | 
         
            +
                    deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
         
     | 
| 51 | 
         
            +
                    deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
         
     | 
| 52 | 
         
            +
                    super().__init__(*args, **kwargs)
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
            class LoRACrossAttnProcessor(LoRAAttnProcessor):
         
     | 
| 56 | 
         
            +
                def __init__(self, *args, **kwargs):
         
     | 
| 57 | 
         
            +
                    deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
         
     | 
| 58 | 
         
            +
                    deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
         
     | 
| 59 | 
         
            +
                    super().__init__(*args, **kwargs)
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
            class CrossAttnAddedKVProcessor(AttnAddedKVProcessor):
         
     | 
| 63 | 
         
            +
                def __init__(self, *args, **kwargs):
         
     | 
| 64 | 
         
            +
                    deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
         
     | 
| 65 | 
         
            +
                    deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
         
     | 
| 66 | 
         
            +
                    super().__init__(*args, **kwargs)
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
            class XFormersCrossAttnProcessor(XFormersAttnProcessor):
         
     | 
| 70 | 
         
            +
                def __init__(self, *args, **kwargs):
         
     | 
| 71 | 
         
            +
                    deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
         
     | 
| 72 | 
         
            +
                    deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
         
     | 
| 73 | 
         
            +
                    super().__init__(*args, **kwargs)
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
            class LoRAXFormersCrossAttnProcessor(LoRAXFormersAttnProcessor):
         
     | 
| 77 | 
         
            +
                def __init__(self, *args, **kwargs):
         
     | 
| 78 | 
         
            +
                    deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
         
     | 
| 79 | 
         
            +
                    deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
         
     | 
| 80 | 
         
            +
                    super().__init__(*args, **kwargs)
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
            class SlicedCrossAttnProcessor(SlicedAttnProcessor):
         
     | 
| 84 | 
         
            +
                def __init__(self, *args, **kwargs):
         
     | 
| 85 | 
         
            +
                    deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
         
     | 
| 86 | 
         
            +
                    deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
         
     | 
| 87 | 
         
            +
                    super().__init__(*args, **kwargs)
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            class SlicedCrossAttnAddedKVProcessor(SlicedAttnAddedKVProcessor):
         
     | 
| 91 | 
         
            +
                def __init__(self, *args, **kwargs):
         
     | 
| 92 | 
         
            +
                    deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
         
     | 
| 93 | 
         
            +
                    deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
         
     | 
| 94 | 
         
            +
                    super().__init__(*args, **kwargs)
         
     | 
    	
        6DoF/diffusers/models/dual_transformer_2d.py
    ADDED
    
    | 
         @@ -0,0 +1,151 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            from typing import Optional
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            from torch import nn
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            class DualTransformer2DModel(nn.Module):
         
     | 
| 22 | 
         
            +
                """
         
     | 
| 23 | 
         
            +
                Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                Parameters:
         
     | 
| 26 | 
         
            +
                    num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
         
     | 
| 27 | 
         
            +
                    attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
         
     | 
| 28 | 
         
            +
                    in_channels (`int`, *optional*):
         
     | 
| 29 | 
         
            +
                        Pass if the input is continuous. The number of channels in the input and output.
         
     | 
| 30 | 
         
            +
                    num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
         
     | 
| 31 | 
         
            +
                    dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
         
     | 
| 32 | 
         
            +
                    cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
         
     | 
| 33 | 
         
            +
                    sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
         
     | 
| 34 | 
         
            +
                        Note that this is fixed at training time as it is used for learning a number of position embeddings. See
         
     | 
| 35 | 
         
            +
                        `ImagePositionalEmbeddings`.
         
     | 
| 36 | 
         
            +
                    num_vector_embeds (`int`, *optional*):
         
     | 
| 37 | 
         
            +
                        Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
         
     | 
| 38 | 
         
            +
                        Includes the class for the masked latent pixel.
         
     | 
| 39 | 
         
            +
                    activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
         
     | 
| 40 | 
         
            +
                    num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
         
     | 
| 41 | 
         
            +
                        The number of diffusion steps used during training. Note that this is fixed at training time as it is used
         
     | 
| 42 | 
         
            +
                        to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
         
     | 
| 43 | 
         
            +
                        up to but not more than steps than `num_embeds_ada_norm`.
         
     | 
| 44 | 
         
            +
                    attention_bias (`bool`, *optional*):
         
     | 
| 45 | 
         
            +
                        Configure if the TransformerBlocks' attention should contain a bias parameter.
         
     | 
| 46 | 
         
            +
                """
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                def __init__(
         
     | 
| 49 | 
         
            +
                    self,
         
     | 
| 50 | 
         
            +
                    num_attention_heads: int = 16,
         
     | 
| 51 | 
         
            +
                    attention_head_dim: int = 88,
         
     | 
| 52 | 
         
            +
                    in_channels: Optional[int] = None,
         
     | 
| 53 | 
         
            +
                    num_layers: int = 1,
         
     | 
| 54 | 
         
            +
                    dropout: float = 0.0,
         
     | 
| 55 | 
         
            +
                    norm_num_groups: int = 32,
         
     | 
| 56 | 
         
            +
                    cross_attention_dim: Optional[int] = None,
         
     | 
| 57 | 
         
            +
                    attention_bias: bool = False,
         
     | 
| 58 | 
         
            +
                    sample_size: Optional[int] = None,
         
     | 
| 59 | 
         
            +
                    num_vector_embeds: Optional[int] = None,
         
     | 
| 60 | 
         
            +
                    activation_fn: str = "geglu",
         
     | 
| 61 | 
         
            +
                    num_embeds_ada_norm: Optional[int] = None,
         
     | 
| 62 | 
         
            +
                ):
         
     | 
| 63 | 
         
            +
                    super().__init__()
         
     | 
| 64 | 
         
            +
                    self.transformers = nn.ModuleList(
         
     | 
| 65 | 
         
            +
                        [
         
     | 
| 66 | 
         
            +
                            Transformer2DModel(
         
     | 
| 67 | 
         
            +
                                num_attention_heads=num_attention_heads,
         
     | 
| 68 | 
         
            +
                                attention_head_dim=attention_head_dim,
         
     | 
| 69 | 
         
            +
                                in_channels=in_channels,
         
     | 
| 70 | 
         
            +
                                num_layers=num_layers,
         
     | 
| 71 | 
         
            +
                                dropout=dropout,
         
     | 
| 72 | 
         
            +
                                norm_num_groups=norm_num_groups,
         
     | 
| 73 | 
         
            +
                                cross_attention_dim=cross_attention_dim,
         
     | 
| 74 | 
         
            +
                                attention_bias=attention_bias,
         
     | 
| 75 | 
         
            +
                                sample_size=sample_size,
         
     | 
| 76 | 
         
            +
                                num_vector_embeds=num_vector_embeds,
         
     | 
| 77 | 
         
            +
                                activation_fn=activation_fn,
         
     | 
| 78 | 
         
            +
                                num_embeds_ada_norm=num_embeds_ada_norm,
         
     | 
| 79 | 
         
            +
                            )
         
     | 
| 80 | 
         
            +
                            for _ in range(2)
         
     | 
| 81 | 
         
            +
                        ]
         
     | 
| 82 | 
         
            +
                    )
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                    # Variables that can be set by a pipeline:
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                    # The ratio of transformer1 to transformer2's output states to be combined during inference
         
     | 
| 87 | 
         
            +
                    self.mix_ratio = 0.5
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    # The shape of `encoder_hidden_states` is expected to be
         
     | 
| 90 | 
         
            +
                    # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
         
     | 
| 91 | 
         
            +
                    self.condition_lengths = [77, 257]
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                    # Which transformer to use to encode which condition.
         
     | 
| 94 | 
         
            +
                    # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
         
     | 
| 95 | 
         
            +
                    self.transformer_index_for_condition = [1, 0]
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                def forward(
         
     | 
| 98 | 
         
            +
                    self,
         
     | 
| 99 | 
         
            +
                    hidden_states,
         
     | 
| 100 | 
         
            +
                    encoder_hidden_states,
         
     | 
| 101 | 
         
            +
                    timestep=None,
         
     | 
| 102 | 
         
            +
                    attention_mask=None,
         
     | 
| 103 | 
         
            +
                    cross_attention_kwargs=None,
         
     | 
| 104 | 
         
            +
                    return_dict: bool = True,
         
     | 
| 105 | 
         
            +
                ):
         
     | 
| 106 | 
         
            +
                    """
         
     | 
| 107 | 
         
            +
                    Args:
         
     | 
| 108 | 
         
            +
                        hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
         
     | 
| 109 | 
         
            +
                            When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
         
     | 
| 110 | 
         
            +
                            hidden_states
         
     | 
| 111 | 
         
            +
                        encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
         
     | 
| 112 | 
         
            +
                            Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
         
     | 
| 113 | 
         
            +
                            self-attention.
         
     | 
| 114 | 
         
            +
                        timestep ( `torch.long`, *optional*):
         
     | 
| 115 | 
         
            +
                            Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
         
     | 
| 116 | 
         
            +
                        attention_mask (`torch.FloatTensor`, *optional*):
         
     | 
| 117 | 
         
            +
                            Optional attention mask to be applied in Attention
         
     | 
| 118 | 
         
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         
     | 
| 119 | 
         
            +
                            Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                    Returns:
         
     | 
| 122 | 
         
            +
                        [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
         
     | 
| 123 | 
         
            +
                        [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
         
     | 
| 124 | 
         
            +
                        returning a tuple, the first element is the sample tensor.
         
     | 
| 125 | 
         
            +
                    """
         
     | 
| 126 | 
         
            +
                    input_states = hidden_states
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                    encoded_states = []
         
     | 
| 129 | 
         
            +
                    tokens_start = 0
         
     | 
| 130 | 
         
            +
                    # attention_mask is not used yet
         
     | 
| 131 | 
         
            +
                    for i in range(2):
         
     | 
| 132 | 
         
            +
                        # for each of the two transformers, pass the corresponding condition tokens
         
     | 
| 133 | 
         
            +
                        condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
         
     | 
| 134 | 
         
            +
                        transformer_index = self.transformer_index_for_condition[i]
         
     | 
| 135 | 
         
            +
                        encoded_state = self.transformers[transformer_index](
         
     | 
| 136 | 
         
            +
                            input_states,
         
     | 
| 137 | 
         
            +
                            encoder_hidden_states=condition_state,
         
     | 
| 138 | 
         
            +
                            timestep=timestep,
         
     | 
| 139 | 
         
            +
                            cross_attention_kwargs=cross_attention_kwargs,
         
     | 
| 140 | 
         
            +
                            return_dict=False,
         
     | 
| 141 | 
         
            +
                        )[0]
         
     | 
| 142 | 
         
            +
                        encoded_states.append(encoded_state - input_states)
         
     | 
| 143 | 
         
            +
                        tokens_start += self.condition_lengths[i]
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                    output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
         
     | 
| 146 | 
         
            +
                    output_states = output_states + input_states
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                    if not return_dict:
         
     | 
| 149 | 
         
            +
                        return (output_states,)
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                    return Transformer2DModelOutput(sample=output_states)
         
     | 
    	
        6DoF/diffusers/models/embeddings.py
    ADDED
    
    | 
         @@ -0,0 +1,546 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            import math
         
     | 
| 15 | 
         
            +
            from typing import Optional
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            import numpy as np
         
     | 
| 18 | 
         
            +
            import torch
         
     | 
| 19 | 
         
            +
            from torch import nn
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            from .activations import get_activation
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            def get_timestep_embedding(
         
     | 
| 25 | 
         
            +
                timesteps: torch.Tensor,
         
     | 
| 26 | 
         
            +
                embedding_dim: int,
         
     | 
| 27 | 
         
            +
                flip_sin_to_cos: bool = False,
         
     | 
| 28 | 
         
            +
                downscale_freq_shift: float = 1,
         
     | 
| 29 | 
         
            +
                scale: float = 1,
         
     | 
| 30 | 
         
            +
                max_period: int = 10000,
         
     | 
| 31 | 
         
            +
            ):
         
     | 
| 32 | 
         
            +
                """
         
     | 
| 33 | 
         
            +
                This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                :param timesteps: a 1-D Tensor of N indices, one per batch element.
         
     | 
| 36 | 
         
            +
                                  These may be fractional.
         
     | 
| 37 | 
         
            +
                :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
         
     | 
| 38 | 
         
            +
                embeddings. :return: an [N x dim] Tensor of positional embeddings.
         
     | 
| 39 | 
         
            +
                """
         
     | 
| 40 | 
         
            +
                assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                half_dim = embedding_dim // 2
         
     | 
| 43 | 
         
            +
                exponent = -math.log(max_period) * torch.arange(
         
     | 
| 44 | 
         
            +
                    start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
         
     | 
| 45 | 
         
            +
                )
         
     | 
| 46 | 
         
            +
                exponent = exponent / (half_dim - downscale_freq_shift)
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                emb = torch.exp(exponent)
         
     | 
| 49 | 
         
            +
                emb = timesteps[:, None].float() * emb[None, :]
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                # scale embeddings
         
     | 
| 52 | 
         
            +
                emb = scale * emb
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                # concat sine and cosine embeddings
         
     | 
| 55 | 
         
            +
                emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                # flip sine and cosine embeddings
         
     | 
| 58 | 
         
            +
                if flip_sin_to_cos:
         
     | 
| 59 | 
         
            +
                    emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                # zero pad
         
     | 
| 62 | 
         
            +
                if embedding_dim % 2 == 1:
         
     | 
| 63 | 
         
            +
                    emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
         
     | 
| 64 | 
         
            +
                return emb
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
         
     | 
| 68 | 
         
            +
                """
         
     | 
| 69 | 
         
            +
                grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
         
     | 
| 70 | 
         
            +
                [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
         
     | 
| 71 | 
         
            +
                """
         
     | 
| 72 | 
         
            +
                grid_h = np.arange(grid_size, dtype=np.float32)
         
     | 
| 73 | 
         
            +
                grid_w = np.arange(grid_size, dtype=np.float32)
         
     | 
| 74 | 
         
            +
                grid = np.meshgrid(grid_w, grid_h)  # here w goes first
         
     | 
| 75 | 
         
            +
                grid = np.stack(grid, axis=0)
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                grid = grid.reshape([2, 1, grid_size, grid_size])
         
     | 
| 78 | 
         
            +
                pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
         
     | 
| 79 | 
         
            +
                if cls_token and extra_tokens > 0:
         
     | 
| 80 | 
         
            +
                    pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
         
     | 
| 81 | 
         
            +
                return pos_embed
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
            def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
         
     | 
| 85 | 
         
            +
                if embed_dim % 2 != 0:
         
     | 
| 86 | 
         
            +
                    raise ValueError("embed_dim must be divisible by 2")
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                # use half of dimensions to encode grid_h
         
     | 
| 89 | 
         
            +
                emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
         
     | 
| 90 | 
         
            +
                emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
         
     | 
| 93 | 
         
            +
                return emb
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
            def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
         
     | 
| 97 | 
         
            +
                """
         
     | 
| 98 | 
         
            +
                embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
         
     | 
| 99 | 
         
            +
                """
         
     | 
| 100 | 
         
            +
                if embed_dim % 2 != 0:
         
     | 
| 101 | 
         
            +
                    raise ValueError("embed_dim must be divisible by 2")
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                omega = np.arange(embed_dim // 2, dtype=np.float64)
         
     | 
| 104 | 
         
            +
                omega /= embed_dim / 2.0
         
     | 
| 105 | 
         
            +
                omega = 1.0 / 10000**omega  # (D/2,)
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                pos = pos.reshape(-1)  # (M,)
         
     | 
| 108 | 
         
            +
                out = np.einsum("m,d->md", pos, omega)  # (M, D/2), outer product
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                emb_sin = np.sin(out)  # (M, D/2)
         
     | 
| 111 | 
         
            +
                emb_cos = np.cos(out)  # (M, D/2)
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
         
     | 
| 114 | 
         
            +
                return emb
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
            class PatchEmbed(nn.Module):
         
     | 
| 118 | 
         
            +
                """2D Image to Patch Embedding"""
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                def __init__(
         
     | 
| 121 | 
         
            +
                    self,
         
     | 
| 122 | 
         
            +
                    height=224,
         
     | 
| 123 | 
         
            +
                    width=224,
         
     | 
| 124 | 
         
            +
                    patch_size=16,
         
     | 
| 125 | 
         
            +
                    in_channels=3,
         
     | 
| 126 | 
         
            +
                    embed_dim=768,
         
     | 
| 127 | 
         
            +
                    layer_norm=False,
         
     | 
| 128 | 
         
            +
                    flatten=True,
         
     | 
| 129 | 
         
            +
                    bias=True,
         
     | 
| 130 | 
         
            +
                ):
         
     | 
| 131 | 
         
            +
                    super().__init__()
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    num_patches = (height // patch_size) * (width // patch_size)
         
     | 
| 134 | 
         
            +
                    self.flatten = flatten
         
     | 
| 135 | 
         
            +
                    self.layer_norm = layer_norm
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                    self.proj = nn.Conv2d(
         
     | 
| 138 | 
         
            +
                        in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
         
     | 
| 139 | 
         
            +
                    )
         
     | 
| 140 | 
         
            +
                    if layer_norm:
         
     | 
| 141 | 
         
            +
                        self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
         
     | 
| 142 | 
         
            +
                    else:
         
     | 
| 143 | 
         
            +
                        self.norm = None
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                    pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
         
     | 
| 146 | 
         
            +
                    self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                def forward(self, latent):
         
     | 
| 149 | 
         
            +
                    latent = self.proj(latent)
         
     | 
| 150 | 
         
            +
                    if self.flatten:
         
     | 
| 151 | 
         
            +
                        latent = latent.flatten(2).transpose(1, 2)  # BCHW -> BNC
         
     | 
| 152 | 
         
            +
                    if self.layer_norm:
         
     | 
| 153 | 
         
            +
                        latent = self.norm(latent)
         
     | 
| 154 | 
         
            +
                    return latent + self.pos_embed
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
            class TimestepEmbedding(nn.Module):
         
     | 
| 158 | 
         
            +
                def __init__(
         
     | 
| 159 | 
         
            +
                    self,
         
     | 
| 160 | 
         
            +
                    in_channels: int,
         
     | 
| 161 | 
         
            +
                    time_embed_dim: int,
         
     | 
| 162 | 
         
            +
                    act_fn: str = "silu",
         
     | 
| 163 | 
         
            +
                    out_dim: int = None,
         
     | 
| 164 | 
         
            +
                    post_act_fn: Optional[str] = None,
         
     | 
| 165 | 
         
            +
                    cond_proj_dim=None,
         
     | 
| 166 | 
         
            +
                ):
         
     | 
| 167 | 
         
            +
                    super().__init__()
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                    self.linear_1 = nn.Linear(in_channels, time_embed_dim)
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                    if cond_proj_dim is not None:
         
     | 
| 172 | 
         
            +
                        self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
         
     | 
| 173 | 
         
            +
                    else:
         
     | 
| 174 | 
         
            +
                        self.cond_proj = None
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                    self.act = get_activation(act_fn)
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                    if out_dim is not None:
         
     | 
| 179 | 
         
            +
                        time_embed_dim_out = out_dim
         
     | 
| 180 | 
         
            +
                    else:
         
     | 
| 181 | 
         
            +
                        time_embed_dim_out = time_embed_dim
         
     | 
| 182 | 
         
            +
                    self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                    if post_act_fn is None:
         
     | 
| 185 | 
         
            +
                        self.post_act = None
         
     | 
| 186 | 
         
            +
                    else:
         
     | 
| 187 | 
         
            +
                        self.post_act = get_activation(post_act_fn)
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                def forward(self, sample, condition=None):
         
     | 
| 190 | 
         
            +
                    if condition is not None:
         
     | 
| 191 | 
         
            +
                        sample = sample + self.cond_proj(condition)
         
     | 
| 192 | 
         
            +
                    sample = self.linear_1(sample)
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                    if self.act is not None:
         
     | 
| 195 | 
         
            +
                        sample = self.act(sample)
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                    sample = self.linear_2(sample)
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                    if self.post_act is not None:
         
     | 
| 200 | 
         
            +
                        sample = self.post_act(sample)
         
     | 
| 201 | 
         
            +
                    return sample
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
            class Timesteps(nn.Module):
         
     | 
| 205 | 
         
            +
                def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
         
     | 
| 206 | 
         
            +
                    super().__init__()
         
     | 
| 207 | 
         
            +
                    self.num_channels = num_channels
         
     | 
| 208 | 
         
            +
                    self.flip_sin_to_cos = flip_sin_to_cos
         
     | 
| 209 | 
         
            +
                    self.downscale_freq_shift = downscale_freq_shift
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
                def forward(self, timesteps):
         
     | 
| 212 | 
         
            +
                    t_emb = get_timestep_embedding(
         
     | 
| 213 | 
         
            +
                        timesteps,
         
     | 
| 214 | 
         
            +
                        self.num_channels,
         
     | 
| 215 | 
         
            +
                        flip_sin_to_cos=self.flip_sin_to_cos,
         
     | 
| 216 | 
         
            +
                        downscale_freq_shift=self.downscale_freq_shift,
         
     | 
| 217 | 
         
            +
                    )
         
     | 
| 218 | 
         
            +
                    return t_emb
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
            class GaussianFourierProjection(nn.Module):
         
     | 
| 222 | 
         
            +
                """Gaussian Fourier embeddings for noise levels."""
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
                def __init__(
         
     | 
| 225 | 
         
            +
                    self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
         
     | 
| 226 | 
         
            +
                ):
         
     | 
| 227 | 
         
            +
                    super().__init__()
         
     | 
| 228 | 
         
            +
                    self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
         
     | 
| 229 | 
         
            +
                    self.log = log
         
     | 
| 230 | 
         
            +
                    self.flip_sin_to_cos = flip_sin_to_cos
         
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
                    if set_W_to_weight:
         
     | 
| 233 | 
         
            +
                        # to delete later
         
     | 
| 234 | 
         
            +
                        self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
                        self.weight = self.W
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                def forward(self, x):
         
     | 
| 239 | 
         
            +
                    if self.log:
         
     | 
| 240 | 
         
            +
                        x = torch.log(x)
         
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
                    x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                    if self.flip_sin_to_cos:
         
     | 
| 245 | 
         
            +
                        out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
         
     | 
| 246 | 
         
            +
                    else:
         
     | 
| 247 | 
         
            +
                        out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
         
     | 
| 248 | 
         
            +
                    return out
         
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
            class ImagePositionalEmbeddings(nn.Module):
         
     | 
| 252 | 
         
            +
                """
         
     | 
| 253 | 
         
            +
                Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
         
     | 
| 254 | 
         
            +
                height and width of the latent space.
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
                For VQ-diffusion:
         
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
                Output vector embeddings are used as input for the transformer.
         
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
                Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                Args:
         
     | 
| 265 | 
         
            +
                    num_embed (`int`):
         
     | 
| 266 | 
         
            +
                        Number of embeddings for the latent pixels embeddings.
         
     | 
| 267 | 
         
            +
                    height (`int`):
         
     | 
| 268 | 
         
            +
                        Height of the latent image i.e. the number of height embeddings.
         
     | 
| 269 | 
         
            +
                    width (`int`):
         
     | 
| 270 | 
         
            +
                        Width of the latent image i.e. the number of width embeddings.
         
     | 
| 271 | 
         
            +
                    embed_dim (`int`):
         
     | 
| 272 | 
         
            +
                        Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
         
     | 
| 273 | 
         
            +
                """
         
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
                def __init__(
         
     | 
| 276 | 
         
            +
                    self,
         
     | 
| 277 | 
         
            +
                    num_embed: int,
         
     | 
| 278 | 
         
            +
                    height: int,
         
     | 
| 279 | 
         
            +
                    width: int,
         
     | 
| 280 | 
         
            +
                    embed_dim: int,
         
     | 
| 281 | 
         
            +
                ):
         
     | 
| 282 | 
         
            +
                    super().__init__()
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
                    self.height = height
         
     | 
| 285 | 
         
            +
                    self.width = width
         
     | 
| 286 | 
         
            +
                    self.num_embed = num_embed
         
     | 
| 287 | 
         
            +
                    self.embed_dim = embed_dim
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
                    self.emb = nn.Embedding(self.num_embed, embed_dim)
         
     | 
| 290 | 
         
            +
                    self.height_emb = nn.Embedding(self.height, embed_dim)
         
     | 
| 291 | 
         
            +
                    self.width_emb = nn.Embedding(self.width, embed_dim)
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
                def forward(self, index):
         
     | 
| 294 | 
         
            +
                    emb = self.emb(index)
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
                    height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
                    # 1 x H x D -> 1 x H x 1 x D
         
     | 
| 299 | 
         
            +
                    height_emb = height_emb.unsqueeze(2)
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                    width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
         
     | 
| 302 | 
         
            +
             
     | 
| 303 | 
         
            +
                    # 1 x W x D -> 1 x 1 x W x D
         
     | 
| 304 | 
         
            +
                    width_emb = width_emb.unsqueeze(1)
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
                    pos_emb = height_emb + width_emb
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
                    # 1 x H x W x D -> 1 x L xD
         
     | 
| 309 | 
         
            +
                    pos_emb = pos_emb.view(1, self.height * self.width, -1)
         
     | 
| 310 | 
         
            +
             
     | 
| 311 | 
         
            +
                    emb = emb + pos_emb[:, : emb.shape[1], :]
         
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
                    return emb
         
     | 
| 314 | 
         
            +
             
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
            class LabelEmbedding(nn.Module):
         
     | 
| 317 | 
         
            +
                """
         
     | 
| 318 | 
         
            +
                Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
         
     | 
| 319 | 
         
            +
             
     | 
| 320 | 
         
            +
                Args:
         
     | 
| 321 | 
         
            +
                    num_classes (`int`): The number of classes.
         
     | 
| 322 | 
         
            +
                    hidden_size (`int`): The size of the vector embeddings.
         
     | 
| 323 | 
         
            +
                    dropout_prob (`float`): The probability of dropping a label.
         
     | 
| 324 | 
         
            +
                """
         
     | 
| 325 | 
         
            +
             
     | 
| 326 | 
         
            +
                def __init__(self, num_classes, hidden_size, dropout_prob):
         
     | 
| 327 | 
         
            +
                    super().__init__()
         
     | 
| 328 | 
         
            +
                    use_cfg_embedding = dropout_prob > 0
         
     | 
| 329 | 
         
            +
                    self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
         
     | 
| 330 | 
         
            +
                    self.num_classes = num_classes
         
     | 
| 331 | 
         
            +
                    self.dropout_prob = dropout_prob
         
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
            +
                def token_drop(self, labels, force_drop_ids=None):
         
     | 
| 334 | 
         
            +
                    """
         
     | 
| 335 | 
         
            +
                    Drops labels to enable classifier-free guidance.
         
     | 
| 336 | 
         
            +
                    """
         
     | 
| 337 | 
         
            +
                    if force_drop_ids is None:
         
     | 
| 338 | 
         
            +
                        drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
         
     | 
| 339 | 
         
            +
                    else:
         
     | 
| 340 | 
         
            +
                        drop_ids = torch.tensor(force_drop_ids == 1)
         
     | 
| 341 | 
         
            +
                    labels = torch.where(drop_ids, self.num_classes, labels)
         
     | 
| 342 | 
         
            +
                    return labels
         
     | 
| 343 | 
         
            +
             
     | 
| 344 | 
         
            +
                def forward(self, labels: torch.LongTensor, force_drop_ids=None):
         
     | 
| 345 | 
         
            +
                    use_dropout = self.dropout_prob > 0
         
     | 
| 346 | 
         
            +
                    if (self.training and use_dropout) or (force_drop_ids is not None):
         
     | 
| 347 | 
         
            +
                        labels = self.token_drop(labels, force_drop_ids)
         
     | 
| 348 | 
         
            +
                    embeddings = self.embedding_table(labels)
         
     | 
| 349 | 
         
            +
                    return embeddings
         
     | 
| 350 | 
         
            +
             
     | 
| 351 | 
         
            +
             
     | 
| 352 | 
         
            +
            class TextImageProjection(nn.Module):
         
     | 
| 353 | 
         
            +
                def __init__(
         
     | 
| 354 | 
         
            +
                    self,
         
     | 
| 355 | 
         
            +
                    text_embed_dim: int = 1024,
         
     | 
| 356 | 
         
            +
                    image_embed_dim: int = 768,
         
     | 
| 357 | 
         
            +
                    cross_attention_dim: int = 768,
         
     | 
| 358 | 
         
            +
                    num_image_text_embeds: int = 10,
         
     | 
| 359 | 
         
            +
                ):
         
     | 
| 360 | 
         
            +
                    super().__init__()
         
     | 
| 361 | 
         
            +
             
     | 
| 362 | 
         
            +
                    self.num_image_text_embeds = num_image_text_embeds
         
     | 
| 363 | 
         
            +
                    self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
         
     | 
| 364 | 
         
            +
                    self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
         
     | 
| 365 | 
         
            +
             
     | 
| 366 | 
         
            +
                def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
         
     | 
| 367 | 
         
            +
                    batch_size = text_embeds.shape[0]
         
     | 
| 368 | 
         
            +
             
     | 
| 369 | 
         
            +
                    # image
         
     | 
| 370 | 
         
            +
                    image_text_embeds = self.image_embeds(image_embeds)
         
     | 
| 371 | 
         
            +
                    image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
         
     | 
| 372 | 
         
            +
             
     | 
| 373 | 
         
            +
                    # text
         
     | 
| 374 | 
         
            +
                    text_embeds = self.text_proj(text_embeds)
         
     | 
| 375 | 
         
            +
             
     | 
| 376 | 
         
            +
                    return torch.cat([image_text_embeds, text_embeds], dim=1)
         
     | 
| 377 | 
         
            +
             
     | 
| 378 | 
         
            +
             
     | 
| 379 | 
         
            +
            class ImageProjection(nn.Module):
         
     | 
| 380 | 
         
            +
                def __init__(
         
     | 
| 381 | 
         
            +
                    self,
         
     | 
| 382 | 
         
            +
                    image_embed_dim: int = 768,
         
     | 
| 383 | 
         
            +
                    cross_attention_dim: int = 768,
         
     | 
| 384 | 
         
            +
                    num_image_text_embeds: int = 32,
         
     | 
| 385 | 
         
            +
                ):
         
     | 
| 386 | 
         
            +
                    super().__init__()
         
     | 
| 387 | 
         
            +
             
     | 
| 388 | 
         
            +
                    self.num_image_text_embeds = num_image_text_embeds
         
     | 
| 389 | 
         
            +
                    self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
         
     | 
| 390 | 
         
            +
                    self.norm = nn.LayerNorm(cross_attention_dim)
         
     | 
| 391 | 
         
            +
             
     | 
| 392 | 
         
            +
                def forward(self, image_embeds: torch.FloatTensor):
         
     | 
| 393 | 
         
            +
                    batch_size = image_embeds.shape[0]
         
     | 
| 394 | 
         
            +
             
     | 
| 395 | 
         
            +
                    # image
         
     | 
| 396 | 
         
            +
                    image_embeds = self.image_embeds(image_embeds)
         
     | 
| 397 | 
         
            +
                    image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
         
     | 
| 398 | 
         
            +
                    image_embeds = self.norm(image_embeds)
         
     | 
| 399 | 
         
            +
                    return image_embeds
         
     | 
| 400 | 
         
            +
             
     | 
| 401 | 
         
            +
             
     | 
| 402 | 
         
            +
            class CombinedTimestepLabelEmbeddings(nn.Module):
         
     | 
| 403 | 
         
            +
                def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
         
     | 
| 404 | 
         
            +
                    super().__init__()
         
     | 
| 405 | 
         
            +
             
     | 
| 406 | 
         
            +
                    self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
         
     | 
| 407 | 
         
            +
                    self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
         
     | 
| 408 | 
         
            +
                    self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
         
     | 
| 409 | 
         
            +
             
     | 
| 410 | 
         
            +
                def forward(self, timestep, class_labels, hidden_dtype=None):
         
     | 
| 411 | 
         
            +
                    timesteps_proj = self.time_proj(timestep)
         
     | 
| 412 | 
         
            +
                    timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype))  # (N, D)
         
     | 
| 413 | 
         
            +
             
     | 
| 414 | 
         
            +
                    class_labels = self.class_embedder(class_labels)  # (N, D)
         
     | 
| 415 | 
         
            +
             
     | 
| 416 | 
         
            +
                    conditioning = timesteps_emb + class_labels  # (N, D)
         
     | 
| 417 | 
         
            +
             
     | 
| 418 | 
         
            +
                    return conditioning
         
     | 
| 419 | 
         
            +
             
     | 
| 420 | 
         
            +
             
     | 
| 421 | 
         
            +
            class TextTimeEmbedding(nn.Module):
         
     | 
| 422 | 
         
            +
                def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
         
     | 
| 423 | 
         
            +
                    super().__init__()
         
     | 
| 424 | 
         
            +
                    self.norm1 = nn.LayerNorm(encoder_dim)
         
     | 
| 425 | 
         
            +
                    self.pool = AttentionPooling(num_heads, encoder_dim)
         
     | 
| 426 | 
         
            +
                    self.proj = nn.Linear(encoder_dim, time_embed_dim)
         
     | 
| 427 | 
         
            +
                    self.norm2 = nn.LayerNorm(time_embed_dim)
         
     | 
| 428 | 
         
            +
             
     | 
| 429 | 
         
            +
                def forward(self, hidden_states):
         
     | 
| 430 | 
         
            +
                    hidden_states = self.norm1(hidden_states)
         
     | 
| 431 | 
         
            +
                    hidden_states = self.pool(hidden_states)
         
     | 
| 432 | 
         
            +
                    hidden_states = self.proj(hidden_states)
         
     | 
| 433 | 
         
            +
                    hidden_states = self.norm2(hidden_states)
         
     | 
| 434 | 
         
            +
                    return hidden_states
         
     | 
| 435 | 
         
            +
             
     | 
| 436 | 
         
            +
             
     | 
| 437 | 
         
            +
            class TextImageTimeEmbedding(nn.Module):
         
     | 
| 438 | 
         
            +
                def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536):
         
     | 
| 439 | 
         
            +
                    super().__init__()
         
     | 
| 440 | 
         
            +
                    self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
         
     | 
| 441 | 
         
            +
                    self.text_norm = nn.LayerNorm(time_embed_dim)
         
     | 
| 442 | 
         
            +
                    self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
         
     | 
| 443 | 
         
            +
             
     | 
| 444 | 
         
            +
                def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
         
     | 
| 445 | 
         
            +
                    # text
         
     | 
| 446 | 
         
            +
                    time_text_embeds = self.text_proj(text_embeds)
         
     | 
| 447 | 
         
            +
                    time_text_embeds = self.text_norm(time_text_embeds)
         
     | 
| 448 | 
         
            +
             
     | 
| 449 | 
         
            +
                    # image
         
     | 
| 450 | 
         
            +
                    time_image_embeds = self.image_proj(image_embeds)
         
     | 
| 451 | 
         
            +
             
     | 
| 452 | 
         
            +
                    return time_image_embeds + time_text_embeds
         
     | 
| 453 | 
         
            +
             
     | 
| 454 | 
         
            +
             
     | 
| 455 | 
         
            +
            class ImageTimeEmbedding(nn.Module):
         
     | 
| 456 | 
         
            +
                def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
         
     | 
| 457 | 
         
            +
                    super().__init__()
         
     | 
| 458 | 
         
            +
                    self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
         
     | 
| 459 | 
         
            +
                    self.image_norm = nn.LayerNorm(time_embed_dim)
         
     | 
| 460 | 
         
            +
             
     | 
| 461 | 
         
            +
                def forward(self, image_embeds: torch.FloatTensor):
         
     | 
| 462 | 
         
            +
                    # image
         
     | 
| 463 | 
         
            +
                    time_image_embeds = self.image_proj(image_embeds)
         
     | 
| 464 | 
         
            +
                    time_image_embeds = self.image_norm(time_image_embeds)
         
     | 
| 465 | 
         
            +
                    return time_image_embeds
         
     | 
| 466 | 
         
            +
             
     | 
| 467 | 
         
            +
             
     | 
| 468 | 
         
            +
            class ImageHintTimeEmbedding(nn.Module):
         
     | 
| 469 | 
         
            +
                def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
         
     | 
| 470 | 
         
            +
                    super().__init__()
         
     | 
| 471 | 
         
            +
                    self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
         
     | 
| 472 | 
         
            +
                    self.image_norm = nn.LayerNorm(time_embed_dim)
         
     | 
| 473 | 
         
            +
                    self.input_hint_block = nn.Sequential(
         
     | 
| 474 | 
         
            +
                        nn.Conv2d(3, 16, 3, padding=1),
         
     | 
| 475 | 
         
            +
                        nn.SiLU(),
         
     | 
| 476 | 
         
            +
                        nn.Conv2d(16, 16, 3, padding=1),
         
     | 
| 477 | 
         
            +
                        nn.SiLU(),
         
     | 
| 478 | 
         
            +
                        nn.Conv2d(16, 32, 3, padding=1, stride=2),
         
     | 
| 479 | 
         
            +
                        nn.SiLU(),
         
     | 
| 480 | 
         
            +
                        nn.Conv2d(32, 32, 3, padding=1),
         
     | 
| 481 | 
         
            +
                        nn.SiLU(),
         
     | 
| 482 | 
         
            +
                        nn.Conv2d(32, 96, 3, padding=1, stride=2),
         
     | 
| 483 | 
         
            +
                        nn.SiLU(),
         
     | 
| 484 | 
         
            +
                        nn.Conv2d(96, 96, 3, padding=1),
         
     | 
| 485 | 
         
            +
                        nn.SiLU(),
         
     | 
| 486 | 
         
            +
                        nn.Conv2d(96, 256, 3, padding=1, stride=2),
         
     | 
| 487 | 
         
            +
                        nn.SiLU(),
         
     | 
| 488 | 
         
            +
                        nn.Conv2d(256, 4, 3, padding=1),
         
     | 
| 489 | 
         
            +
                    )
         
     | 
| 490 | 
         
            +
             
     | 
| 491 | 
         
            +
                def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor):
         
     | 
| 492 | 
         
            +
                    # image
         
     | 
| 493 | 
         
            +
                    time_image_embeds = self.image_proj(image_embeds)
         
     | 
| 494 | 
         
            +
                    time_image_embeds = self.image_norm(time_image_embeds)
         
     | 
| 495 | 
         
            +
                    hint = self.input_hint_block(hint)
         
     | 
| 496 | 
         
            +
                    return time_image_embeds, hint
         
     | 
| 497 | 
         
            +
             
     | 
| 498 | 
         
            +
             
     | 
| 499 | 
         
            +
            class AttentionPooling(nn.Module):
         
     | 
| 500 | 
         
            +
                # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
         
     | 
| 501 | 
         
            +
             
     | 
| 502 | 
         
            +
                def __init__(self, num_heads, embed_dim, dtype=None):
         
     | 
| 503 | 
         
            +
                    super().__init__()
         
     | 
| 504 | 
         
            +
                    self.dtype = dtype
         
     | 
| 505 | 
         
            +
                    self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
         
     | 
| 506 | 
         
            +
                    self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
         
     | 
| 507 | 
         
            +
                    self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
         
     | 
| 508 | 
         
            +
                    self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
         
     | 
| 509 | 
         
            +
                    self.num_heads = num_heads
         
     | 
| 510 | 
         
            +
                    self.dim_per_head = embed_dim // self.num_heads
         
     | 
| 511 | 
         
            +
             
     | 
| 512 | 
         
            +
                def forward(self, x):
         
     | 
| 513 | 
         
            +
                    bs, length, width = x.size()
         
     | 
| 514 | 
         
            +
             
     | 
| 515 | 
         
            +
                    def shape(x):
         
     | 
| 516 | 
         
            +
                        # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
         
     | 
| 517 | 
         
            +
                        x = x.view(bs, -1, self.num_heads, self.dim_per_head)
         
     | 
| 518 | 
         
            +
                        # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
         
     | 
| 519 | 
         
            +
                        x = x.transpose(1, 2)
         
     | 
| 520 | 
         
            +
                        # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
         
     | 
| 521 | 
         
            +
                        x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
         
     | 
| 522 | 
         
            +
                        # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
         
     | 
| 523 | 
         
            +
                        x = x.transpose(1, 2)
         
     | 
| 524 | 
         
            +
                        return x
         
     | 
| 525 | 
         
            +
             
     | 
| 526 | 
         
            +
                    class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
         
     | 
| 527 | 
         
            +
                    x = torch.cat([class_token, x], dim=1)  # (bs, length+1, width)
         
     | 
| 528 | 
         
            +
             
     | 
| 529 | 
         
            +
                    # (bs*n_heads, class_token_length, dim_per_head)
         
     | 
| 530 | 
         
            +
                    q = shape(self.q_proj(class_token))
         
     | 
| 531 | 
         
            +
                    # (bs*n_heads, length+class_token_length, dim_per_head)
         
     | 
| 532 | 
         
            +
                    k = shape(self.k_proj(x))
         
     | 
| 533 | 
         
            +
                    v = shape(self.v_proj(x))
         
     | 
| 534 | 
         
            +
             
     | 
| 535 | 
         
            +
                    # (bs*n_heads, class_token_length, length+class_token_length):
         
     | 
| 536 | 
         
            +
                    scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
         
     | 
| 537 | 
         
            +
                    weight = torch.einsum("bct,bcs->bts", q * scale, k * scale)  # More stable with f16 than dividing afterwards
         
     | 
| 538 | 
         
            +
                    weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
         
     | 
| 539 | 
         
            +
             
     | 
| 540 | 
         
            +
                    # (bs*n_heads, dim_per_head, class_token_length)
         
     | 
| 541 | 
         
            +
                    a = torch.einsum("bts,bcs->bct", weight, v)
         
     | 
| 542 | 
         
            +
             
     | 
| 543 | 
         
            +
                    # (bs, length+1, width)
         
     | 
| 544 | 
         
            +
                    a = a.reshape(bs, -1, 1).transpose(1, 2)
         
     | 
| 545 | 
         
            +
             
     | 
| 546 | 
         
            +
                    return a[:, 0, :]  # cls_token
         
     | 
    	
        6DoF/diffusers/models/embeddings_flax.py
    ADDED
    
    | 
         @@ -0,0 +1,95 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            import math
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            import flax.linen as nn
         
     | 
| 17 | 
         
            +
            import jax.numpy as jnp
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            def get_sinusoidal_embeddings(
         
     | 
| 21 | 
         
            +
                timesteps: jnp.ndarray,
         
     | 
| 22 | 
         
            +
                embedding_dim: int,
         
     | 
| 23 | 
         
            +
                freq_shift: float = 1,
         
     | 
| 24 | 
         
            +
                min_timescale: float = 1,
         
     | 
| 25 | 
         
            +
                max_timescale: float = 1.0e4,
         
     | 
| 26 | 
         
            +
                flip_sin_to_cos: bool = False,
         
     | 
| 27 | 
         
            +
                scale: float = 1.0,
         
     | 
| 28 | 
         
            +
            ) -> jnp.ndarray:
         
     | 
| 29 | 
         
            +
                """Returns the positional encoding (same as Tensor2Tensor).
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                Args:
         
     | 
| 32 | 
         
            +
                    timesteps: a 1-D Tensor of N indices, one per batch element.
         
     | 
| 33 | 
         
            +
                    These may be fractional.
         
     | 
| 34 | 
         
            +
                    embedding_dim: The number of output channels.
         
     | 
| 35 | 
         
            +
                    min_timescale: The smallest time unit (should probably be 0.0).
         
     | 
| 36 | 
         
            +
                    max_timescale: The largest time unit.
         
     | 
| 37 | 
         
            +
                Returns:
         
     | 
| 38 | 
         
            +
                    a Tensor of timing signals [N, num_channels]
         
     | 
| 39 | 
         
            +
                """
         
     | 
| 40 | 
         
            +
                assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
         
     | 
| 41 | 
         
            +
                assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
         
     | 
| 42 | 
         
            +
                num_timescales = float(embedding_dim // 2)
         
     | 
| 43 | 
         
            +
                log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
         
     | 
| 44 | 
         
            +
                inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
         
     | 
| 45 | 
         
            +
                emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                # scale embeddings
         
     | 
| 48 | 
         
            +
                scaled_time = scale * emb
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                if flip_sin_to_cos:
         
     | 
| 51 | 
         
            +
                    signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)
         
     | 
| 52 | 
         
            +
                else:
         
     | 
| 53 | 
         
            +
                    signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)
         
     | 
| 54 | 
         
            +
                signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
         
     | 
| 55 | 
         
            +
                return signal
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            class FlaxTimestepEmbedding(nn.Module):
         
     | 
| 59 | 
         
            +
                r"""
         
     | 
| 60 | 
         
            +
                Time step Embedding Module. Learns embeddings for input time steps.
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                Args:
         
     | 
| 63 | 
         
            +
                    time_embed_dim (`int`, *optional*, defaults to `32`):
         
     | 
| 64 | 
         
            +
                            Time step embedding dimension
         
     | 
| 65 | 
         
            +
                    dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
         
     | 
| 66 | 
         
            +
                            Parameters `dtype`
         
     | 
| 67 | 
         
            +
                """
         
     | 
| 68 | 
         
            +
                time_embed_dim: int = 32
         
     | 
| 69 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                @nn.compact
         
     | 
| 72 | 
         
            +
                def __call__(self, temb):
         
     | 
| 73 | 
         
            +
                    temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb)
         
     | 
| 74 | 
         
            +
                    temb = nn.silu(temb)
         
     | 
| 75 | 
         
            +
                    temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb)
         
     | 
| 76 | 
         
            +
                    return temb
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            class FlaxTimesteps(nn.Module):
         
     | 
| 80 | 
         
            +
                r"""
         
     | 
| 81 | 
         
            +
                Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                Args:
         
     | 
| 84 | 
         
            +
                    dim (`int`, *optional*, defaults to `32`):
         
     | 
| 85 | 
         
            +
                            Time step embedding dimension
         
     | 
| 86 | 
         
            +
                """
         
     | 
| 87 | 
         
            +
                dim: int = 32
         
     | 
| 88 | 
         
            +
                flip_sin_to_cos: bool = False
         
     | 
| 89 | 
         
            +
                freq_shift: float = 1
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                @nn.compact
         
     | 
| 92 | 
         
            +
                def __call__(self, timesteps):
         
     | 
| 93 | 
         
            +
                    return get_sinusoidal_embeddings(
         
     | 
| 94 | 
         
            +
                        timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift
         
     | 
| 95 | 
         
            +
                    )
         
     | 
    	
        6DoF/diffusers/models/modeling_flax_pytorch_utils.py
    ADDED
    
    | 
         @@ -0,0 +1,118 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # coding=utf-8
         
     | 
| 2 | 
         
            +
            # Copyright 2023 The HuggingFace Inc. team.
         
     | 
| 3 | 
         
            +
            #
         
     | 
| 4 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 5 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 6 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 7 | 
         
            +
            #
         
     | 
| 8 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 9 | 
         
            +
            #
         
     | 
| 10 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 11 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 12 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 13 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 14 | 
         
            +
            # limitations under the License.
         
     | 
| 15 | 
         
            +
            """ PyTorch - Flax general utilities."""
         
     | 
| 16 | 
         
            +
            import re
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            import jax.numpy as jnp
         
     | 
| 19 | 
         
            +
            from flax.traverse_util import flatten_dict, unflatten_dict
         
     | 
| 20 | 
         
            +
            from jax.random import PRNGKey
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            from ..utils import logging
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            logger = logging.get_logger(__name__)
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            def rename_key(key):
         
     | 
| 29 | 
         
            +
                regex = r"\w+[.]\d+"
         
     | 
| 30 | 
         
            +
                pats = re.findall(regex, key)
         
     | 
| 31 | 
         
            +
                for pat in pats:
         
     | 
| 32 | 
         
            +
                    key = key.replace(pat, "_".join(pat.split(".")))
         
     | 
| 33 | 
         
            +
                return key
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            #####################
         
     | 
| 37 | 
         
            +
            # PyTorch => Flax #
         
     | 
| 38 | 
         
            +
            #####################
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            # Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
         
     | 
| 42 | 
         
            +
            # and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
         
     | 
| 43 | 
         
            +
            def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
         
     | 
| 44 | 
         
            +
                """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                # conv norm or layer norm
         
     | 
| 47 | 
         
            +
                renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
         
     | 
| 48 | 
         
            +
                if (
         
     | 
| 49 | 
         
            +
                    any("norm" in str_ for str_ in pt_tuple_key)
         
     | 
| 50 | 
         
            +
                    and (pt_tuple_key[-1] == "bias")
         
     | 
| 51 | 
         
            +
                    and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
         
     | 
| 52 | 
         
            +
                    and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
         
     | 
| 53 | 
         
            +
                ):
         
     | 
| 54 | 
         
            +
                    renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
         
     | 
| 55 | 
         
            +
                    return renamed_pt_tuple_key, pt_tensor
         
     | 
| 56 | 
         
            +
                elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
         
     | 
| 57 | 
         
            +
                    renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
         
     | 
| 58 | 
         
            +
                    return renamed_pt_tuple_key, pt_tensor
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                # embedding
         
     | 
| 61 | 
         
            +
                if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
         
     | 
| 62 | 
         
            +
                    pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
         
     | 
| 63 | 
         
            +
                    return renamed_pt_tuple_key, pt_tensor
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                # conv layer
         
     | 
| 66 | 
         
            +
                renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
         
     | 
| 67 | 
         
            +
                if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
         
     | 
| 68 | 
         
            +
                    pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
         
     | 
| 69 | 
         
            +
                    return renamed_pt_tuple_key, pt_tensor
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                # linear layer
         
     | 
| 72 | 
         
            +
                renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
         
     | 
| 73 | 
         
            +
                if pt_tuple_key[-1] == "weight":
         
     | 
| 74 | 
         
            +
                    pt_tensor = pt_tensor.T
         
     | 
| 75 | 
         
            +
                    return renamed_pt_tuple_key, pt_tensor
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                # old PyTorch layer norm weight
         
     | 
| 78 | 
         
            +
                renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
         
     | 
| 79 | 
         
            +
                if pt_tuple_key[-1] == "gamma":
         
     | 
| 80 | 
         
            +
                    return renamed_pt_tuple_key, pt_tensor
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                # old PyTorch layer norm bias
         
     | 
| 83 | 
         
            +
                renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
         
     | 
| 84 | 
         
            +
                if pt_tuple_key[-1] == "beta":
         
     | 
| 85 | 
         
            +
                    return renamed_pt_tuple_key, pt_tensor
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                return pt_tuple_key, pt_tensor
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
         
     | 
| 91 | 
         
            +
                # Step 1: Convert pytorch tensor to numpy
         
     | 
| 92 | 
         
            +
                pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                # Step 2: Since the model is stateless, get random Flax params
         
     | 
| 95 | 
         
            +
                random_flax_params = flax_model.init_weights(PRNGKey(init_key))
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                random_flax_state_dict = flatten_dict(random_flax_params)
         
     | 
| 98 | 
         
            +
                flax_state_dict = {}
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                # Need to change some parameters name to match Flax names
         
     | 
| 101 | 
         
            +
                for pt_key, pt_tensor in pt_state_dict.items():
         
     | 
| 102 | 
         
            +
                    renamed_pt_key = rename_key(pt_key)
         
     | 
| 103 | 
         
            +
                    pt_tuple_key = tuple(renamed_pt_key.split("."))
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                    # Correctly rename weight parameters
         
     | 
| 106 | 
         
            +
                    flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict)
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                    if flax_key in random_flax_state_dict:
         
     | 
| 109 | 
         
            +
                        if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
         
     | 
| 110 | 
         
            +
                            raise ValueError(
         
     | 
| 111 | 
         
            +
                                f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
         
     | 
| 112 | 
         
            +
                                f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
         
     | 
| 113 | 
         
            +
                            )
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                    # also add unexpected weight so that warning is thrown
         
     | 
| 116 | 
         
            +
                    flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                return unflatten_dict(flax_state_dict)
         
     | 
    	
        6DoF/diffusers/models/modeling_flax_utils.py
    ADDED
    
    | 
         @@ -0,0 +1,534 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # coding=utf-8
         
     | 
| 2 | 
         
            +
            # Copyright 2023 The HuggingFace Inc. team.
         
     | 
| 3 | 
         
            +
            #
         
     | 
| 4 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 5 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 6 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 7 | 
         
            +
            #
         
     | 
| 8 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 9 | 
         
            +
            #
         
     | 
| 10 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 11 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 12 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 13 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 14 | 
         
            +
            # limitations under the License.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            import os
         
     | 
| 17 | 
         
            +
            from pickle import UnpicklingError
         
     | 
| 18 | 
         
            +
            from typing import Any, Dict, Union
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            import jax
         
     | 
| 21 | 
         
            +
            import jax.numpy as jnp
         
     | 
| 22 | 
         
            +
            import msgpack.exceptions
         
     | 
| 23 | 
         
            +
            from flax.core.frozen_dict import FrozenDict, unfreeze
         
     | 
| 24 | 
         
            +
            from flax.serialization import from_bytes, to_bytes
         
     | 
| 25 | 
         
            +
            from flax.traverse_util import flatten_dict, unflatten_dict
         
     | 
| 26 | 
         
            +
            from huggingface_hub import hf_hub_download
         
     | 
| 27 | 
         
            +
            from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
         
     | 
| 28 | 
         
            +
            from requests import HTTPError
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            from .. import __version__, is_torch_available
         
     | 
| 31 | 
         
            +
            from ..utils import (
         
     | 
| 32 | 
         
            +
                CONFIG_NAME,
         
     | 
| 33 | 
         
            +
                DIFFUSERS_CACHE,
         
     | 
| 34 | 
         
            +
                FLAX_WEIGHTS_NAME,
         
     | 
| 35 | 
         
            +
                HUGGINGFACE_CO_RESOLVE_ENDPOINT,
         
     | 
| 36 | 
         
            +
                WEIGHTS_NAME,
         
     | 
| 37 | 
         
            +
                logging,
         
     | 
| 38 | 
         
            +
            )
         
     | 
| 39 | 
         
            +
            from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            logger = logging.get_logger(__name__)
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            class FlaxModelMixin:
         
     | 
| 46 | 
         
            +
                r"""
         
     | 
| 47 | 
         
            +
                Base class for all Flax models.
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                [`FlaxModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
         
     | 
| 50 | 
         
            +
                saving models.
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                    - **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`].
         
     | 
| 53 | 
         
            +
                """
         
     | 
| 54 | 
         
            +
                config_name = CONFIG_NAME
         
     | 
| 55 | 
         
            +
                _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
         
     | 
| 56 | 
         
            +
                _flax_internal_args = ["name", "parent", "dtype"]
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                @classmethod
         
     | 
| 59 | 
         
            +
                def _from_config(cls, config, **kwargs):
         
     | 
| 60 | 
         
            +
                    """
         
     | 
| 61 | 
         
            +
                    All context managers that the model should be initialized under go here.
         
     | 
| 62 | 
         
            +
                    """
         
     | 
| 63 | 
         
            +
                    return cls(config, **kwargs)
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
         
     | 
| 66 | 
         
            +
                    """
         
     | 
| 67 | 
         
            +
                    Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
         
     | 
| 68 | 
         
            +
                    """
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                    # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
         
     | 
| 71 | 
         
            +
                    def conditional_cast(param):
         
     | 
| 72 | 
         
            +
                        if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
         
     | 
| 73 | 
         
            +
                            param = param.astype(dtype)
         
     | 
| 74 | 
         
            +
                        return param
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    if mask is None:
         
     | 
| 77 | 
         
            +
                        return jax.tree_map(conditional_cast, params)
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                    flat_params = flatten_dict(params)
         
     | 
| 80 | 
         
            +
                    flat_mask, _ = jax.tree_flatten(mask)
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                    for masked, key in zip(flat_mask, flat_params.keys()):
         
     | 
| 83 | 
         
            +
                        if masked:
         
     | 
| 84 | 
         
            +
                            param = flat_params[key]
         
     | 
| 85 | 
         
            +
                            flat_params[key] = conditional_cast(param)
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                    return unflatten_dict(flat_params)
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
         
     | 
| 90 | 
         
            +
                    r"""
         
     | 
| 91 | 
         
            +
                    Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
         
     | 
| 92 | 
         
            +
                    the `params` in place.
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    This method can be used on a TPU to explicitly convert the model parameters to bfloat16 precision to do full
         
     | 
| 95 | 
         
            +
                    half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                    Arguments:
         
     | 
| 98 | 
         
            +
                        params (`Union[Dict, FrozenDict]`):
         
     | 
| 99 | 
         
            +
                            A `PyTree` of model parameters.
         
     | 
| 100 | 
         
            +
                        mask (`Union[Dict, FrozenDict]`):
         
     | 
| 101 | 
         
            +
                            A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
         
     | 
| 102 | 
         
            +
                            for params you want to cast, and `False` for those you want to skip.
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                    Examples:
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                    ```python
         
     | 
| 107 | 
         
            +
                    >>> from diffusers import FlaxUNet2DConditionModel
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                    >>> # load model
         
     | 
| 110 | 
         
            +
                    >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
         
     | 
| 111 | 
         
            +
                    >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
         
     | 
| 112 | 
         
            +
                    >>> params = model.to_bf16(params)
         
     | 
| 113 | 
         
            +
                    >>> # If you don't want to cast certain parameters (for example layer norm bias and scale)
         
     | 
| 114 | 
         
            +
                    >>> # then pass the mask as follows
         
     | 
| 115 | 
         
            +
                    >>> from flax import traverse_util
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                    >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
         
     | 
| 118 | 
         
            +
                    >>> flat_params = traverse_util.flatten_dict(params)
         
     | 
| 119 | 
         
            +
                    >>> mask = {
         
     | 
| 120 | 
         
            +
                    ...     path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
         
     | 
| 121 | 
         
            +
                    ...     for path in flat_params
         
     | 
| 122 | 
         
            +
                    ... }
         
     | 
| 123 | 
         
            +
                    >>> mask = traverse_util.unflatten_dict(mask)
         
     | 
| 124 | 
         
            +
                    >>> params = model.to_bf16(params, mask)
         
     | 
| 125 | 
         
            +
                    ```"""
         
     | 
| 126 | 
         
            +
                    return self._cast_floating_to(params, jnp.bfloat16, mask)
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
         
     | 
| 129 | 
         
            +
                    r"""
         
     | 
| 130 | 
         
            +
                    Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the
         
     | 
| 131 | 
         
            +
                    model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    Arguments:
         
     | 
| 134 | 
         
            +
                        params (`Union[Dict, FrozenDict]`):
         
     | 
| 135 | 
         
            +
                            A `PyTree` of model parameters.
         
     | 
| 136 | 
         
            +
                        mask (`Union[Dict, FrozenDict]`):
         
     | 
| 137 | 
         
            +
                            A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
         
     | 
| 138 | 
         
            +
                            for params you want to cast, and `False` for those you want to skip.
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                    Examples:
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                    ```python
         
     | 
| 143 | 
         
            +
                    >>> from diffusers import FlaxUNet2DConditionModel
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                    >>> # Download model and configuration from huggingface.co
         
     | 
| 146 | 
         
            +
                    >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
         
     | 
| 147 | 
         
            +
                    >>> # By default, the model params will be in fp32, to illustrate the use of this method,
         
     | 
| 148 | 
         
            +
                    >>> # we'll first cast to fp16 and back to fp32
         
     | 
| 149 | 
         
            +
                    >>> params = model.to_f16(params)
         
     | 
| 150 | 
         
            +
                    >>> # now cast back to fp32
         
     | 
| 151 | 
         
            +
                    >>> params = model.to_fp32(params)
         
     | 
| 152 | 
         
            +
                    ```"""
         
     | 
| 153 | 
         
            +
                    return self._cast_floating_to(params, jnp.float32, mask)
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
         
     | 
| 156 | 
         
            +
                    r"""
         
     | 
| 157 | 
         
            +
                    Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
         
     | 
| 158 | 
         
            +
                    `params` in place.
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                    This method can be used on a GPU to explicitly convert the model parameters to float16 precision to do full
         
     | 
| 161 | 
         
            +
                    half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                    Arguments:
         
     | 
| 164 | 
         
            +
                        params (`Union[Dict, FrozenDict]`):
         
     | 
| 165 | 
         
            +
                            A `PyTree` of model parameters.
         
     | 
| 166 | 
         
            +
                        mask (`Union[Dict, FrozenDict]`):
         
     | 
| 167 | 
         
            +
                            A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
         
     | 
| 168 | 
         
            +
                            for params you want to cast, and `False` for those you want to skip.
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                    Examples:
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                    ```python
         
     | 
| 173 | 
         
            +
                    >>> from diffusers import FlaxUNet2DConditionModel
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                    >>> # load model
         
     | 
| 176 | 
         
            +
                    >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
         
     | 
| 177 | 
         
            +
                    >>> # By default, the model params will be in fp32, to cast these to float16
         
     | 
| 178 | 
         
            +
                    >>> params = model.to_fp16(params)
         
     | 
| 179 | 
         
            +
                    >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
         
     | 
| 180 | 
         
            +
                    >>> # then pass the mask as follows
         
     | 
| 181 | 
         
            +
                    >>> from flax import traverse_util
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                    >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
         
     | 
| 184 | 
         
            +
                    >>> flat_params = traverse_util.flatten_dict(params)
         
     | 
| 185 | 
         
            +
                    >>> mask = {
         
     | 
| 186 | 
         
            +
                    ...     path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
         
     | 
| 187 | 
         
            +
                    ...     for path in flat_params
         
     | 
| 188 | 
         
            +
                    ... }
         
     | 
| 189 | 
         
            +
                    >>> mask = traverse_util.unflatten_dict(mask)
         
     | 
| 190 | 
         
            +
                    >>> params = model.to_fp16(params, mask)
         
     | 
| 191 | 
         
            +
                    ```"""
         
     | 
| 192 | 
         
            +
                    return self._cast_floating_to(params, jnp.float16, mask)
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                def init_weights(self, rng: jax.random.KeyArray) -> Dict:
         
     | 
| 195 | 
         
            +
                    raise NotImplementedError(f"init_weights method has to be implemented for {self}")
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                @classmethod
         
     | 
| 198 | 
         
            +
                def from_pretrained(
         
     | 
| 199 | 
         
            +
                    cls,
         
     | 
| 200 | 
         
            +
                    pretrained_model_name_or_path: Union[str, os.PathLike],
         
     | 
| 201 | 
         
            +
                    dtype: jnp.dtype = jnp.float32,
         
     | 
| 202 | 
         
            +
                    *model_args,
         
     | 
| 203 | 
         
            +
                    **kwargs,
         
     | 
| 204 | 
         
            +
                ):
         
     | 
| 205 | 
         
            +
                    r"""
         
     | 
| 206 | 
         
            +
                    Instantiate a pretrained Flax model from a pretrained model configuration.
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                    Parameters:
         
     | 
| 209 | 
         
            +
                        pretrained_model_name_or_path (`str` or `os.PathLike`):
         
     | 
| 210 | 
         
            +
                            Can be either:
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                                - A string, the *model id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained model
         
     | 
| 213 | 
         
            +
                                  hosted on the Hub.
         
     | 
| 214 | 
         
            +
                                - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
         
     | 
| 215 | 
         
            +
                                  using [`~FlaxModelMixin.save_pretrained`].
         
     | 
| 216 | 
         
            +
                        dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
         
     | 
| 217 | 
         
            +
                            The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
         
     | 
| 218 | 
         
            +
                            `jax.numpy.bfloat16` (on TPUs).
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
                            This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
         
     | 
| 221 | 
         
            +
                            specified, all the computation will be performed with the given `dtype`.
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
                            <Tip>
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                            This only specifies the dtype of the *computation* and does not influence the dtype of model
         
     | 
| 226 | 
         
            +
                            parameters.
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                            If you wish to change the dtype of the model parameters, see [`~FlaxModelMixin.to_fp16`] and
         
     | 
| 229 | 
         
            +
                            [`~FlaxModelMixin.to_bf16`].
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                            </Tip>
         
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
                        model_args (sequence of positional arguments, *optional*):
         
     | 
| 234 | 
         
            +
                            All remaining positional arguments are passed to the underlying model's `__init__` method.
         
     | 
| 235 | 
         
            +
                        cache_dir (`Union[str, os.PathLike]`, *optional*):
         
     | 
| 236 | 
         
            +
                            Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
         
     | 
| 237 | 
         
            +
                            is not used.
         
     | 
| 238 | 
         
            +
                        force_download (`bool`, *optional*, defaults to `False`):
         
     | 
| 239 | 
         
            +
                            Whether or not to force the (re-)download of the model weights and configuration files, overriding the
         
     | 
| 240 | 
         
            +
                            cached versions if they exist.
         
     | 
| 241 | 
         
            +
                        resume_download (`bool`, *optional*, defaults to `False`):
         
     | 
| 242 | 
         
            +
                            Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
         
     | 
| 243 | 
         
            +
                            incompletely downloaded files are deleted.
         
     | 
| 244 | 
         
            +
                        proxies (`Dict[str, str]`, *optional*):
         
     | 
| 245 | 
         
            +
                            A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
         
     | 
| 246 | 
         
            +
                            'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
         
     | 
| 247 | 
         
            +
                        local_files_only(`bool`, *optional*, defaults to `False`):
         
     | 
| 248 | 
         
            +
                            Whether to only load local model weights and configuration files or not. If set to `True`, the model
         
     | 
| 249 | 
         
            +
                            won't be downloaded from the Hub.
         
     | 
| 250 | 
         
            +
                        revision (`str`, *optional*, defaults to `"main"`):
         
     | 
| 251 | 
         
            +
                            The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
         
     | 
| 252 | 
         
            +
                            allowed by Git.
         
     | 
| 253 | 
         
            +
                        from_pt (`bool`, *optional*, defaults to `False`):
         
     | 
| 254 | 
         
            +
                            Load the model weights from a PyTorch checkpoint save file.
         
     | 
| 255 | 
         
            +
                        kwargs (remaining dictionary of keyword arguments, *optional*):
         
     | 
| 256 | 
         
            +
                            Can be used to update the configuration object (after it is loaded) and initiate the model (for
         
     | 
| 257 | 
         
            +
                            example, `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
         
     | 
| 258 | 
         
            +
                            automatically loaded:
         
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
                                - If a configuration is provided with `config`, `kwargs` are directly passed to the underlying
         
     | 
| 261 | 
         
            +
                                  model's `__init__` method (we assume all relevant updates to the configuration have already been
         
     | 
| 262 | 
         
            +
                                  done).
         
     | 
| 263 | 
         
            +
                                - If a configuration is not provided, `kwargs` are first passed to the configuration class
         
     | 
| 264 | 
         
            +
                                  initialization function [`~ConfigMixin.from_config`]. Each key of the `kwargs` that corresponds
         
     | 
| 265 | 
         
            +
                                  to a configuration attribute is used to override said attribute with the supplied `kwargs` value.
         
     | 
| 266 | 
         
            +
                                  Remaining keys that do not correspond to any configuration attribute are passed to the underlying
         
     | 
| 267 | 
         
            +
                                  model's `__init__` function.
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                    Examples:
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
                    ```python
         
     | 
| 272 | 
         
            +
                    >>> from diffusers import FlaxUNet2DConditionModel
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
                    >>> # Download model and configuration from huggingface.co and cache.
         
     | 
| 275 | 
         
            +
                    >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
         
     | 
| 276 | 
         
            +
                    >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
         
     | 
| 277 | 
         
            +
                    >>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/")
         
     | 
| 278 | 
         
            +
                    ```
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                    If you get the error message below, you need to finetune the weights for your downstream task:
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
                    ```bash
         
     | 
| 283 | 
         
            +
                    Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
         
     | 
| 284 | 
         
            +
                    - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
         
     | 
| 285 | 
         
            +
                    You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
         
     | 
| 286 | 
         
            +
                    ```
         
     | 
| 287 | 
         
            +
                    """
         
     | 
| 288 | 
         
            +
                    config = kwargs.pop("config", None)
         
     | 
| 289 | 
         
            +
                    cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
         
     | 
| 290 | 
         
            +
                    force_download = kwargs.pop("force_download", False)
         
     | 
| 291 | 
         
            +
                    from_pt = kwargs.pop("from_pt", False)
         
     | 
| 292 | 
         
            +
                    resume_download = kwargs.pop("resume_download", False)
         
     | 
| 293 | 
         
            +
                    proxies = kwargs.pop("proxies", None)
         
     | 
| 294 | 
         
            +
                    local_files_only = kwargs.pop("local_files_only", False)
         
     | 
| 295 | 
         
            +
                    use_auth_token = kwargs.pop("use_auth_token", None)
         
     | 
| 296 | 
         
            +
                    revision = kwargs.pop("revision", None)
         
     | 
| 297 | 
         
            +
                    subfolder = kwargs.pop("subfolder", None)
         
     | 
| 298 | 
         
            +
             
     | 
| 299 | 
         
            +
                    user_agent = {
         
     | 
| 300 | 
         
            +
                        "diffusers": __version__,
         
     | 
| 301 | 
         
            +
                        "file_type": "model",
         
     | 
| 302 | 
         
            +
                        "framework": "flax",
         
     | 
| 303 | 
         
            +
                    }
         
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
                    # Load config if we don't provide a configuration
         
     | 
| 306 | 
         
            +
                    config_path = config if config is not None else pretrained_model_name_or_path
         
     | 
| 307 | 
         
            +
                    model, model_kwargs = cls.from_config(
         
     | 
| 308 | 
         
            +
                        config_path,
         
     | 
| 309 | 
         
            +
                        cache_dir=cache_dir,
         
     | 
| 310 | 
         
            +
                        return_unused_kwargs=True,
         
     | 
| 311 | 
         
            +
                        force_download=force_download,
         
     | 
| 312 | 
         
            +
                        resume_download=resume_download,
         
     | 
| 313 | 
         
            +
                        proxies=proxies,
         
     | 
| 314 | 
         
            +
                        local_files_only=local_files_only,
         
     | 
| 315 | 
         
            +
                        use_auth_token=use_auth_token,
         
     | 
| 316 | 
         
            +
                        revision=revision,
         
     | 
| 317 | 
         
            +
                        subfolder=subfolder,
         
     | 
| 318 | 
         
            +
                        # model args
         
     | 
| 319 | 
         
            +
                        dtype=dtype,
         
     | 
| 320 | 
         
            +
                        **kwargs,
         
     | 
| 321 | 
         
            +
                    )
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
                    # Load model
         
     | 
| 324 | 
         
            +
                    pretrained_path_with_subfolder = (
         
     | 
| 325 | 
         
            +
                        pretrained_model_name_or_path
         
     | 
| 326 | 
         
            +
                        if subfolder is None
         
     | 
| 327 | 
         
            +
                        else os.path.join(pretrained_model_name_or_path, subfolder)
         
     | 
| 328 | 
         
            +
                    )
         
     | 
| 329 | 
         
            +
                    if os.path.isdir(pretrained_path_with_subfolder):
         
     | 
| 330 | 
         
            +
                        if from_pt:
         
     | 
| 331 | 
         
            +
                            if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
         
     | 
| 332 | 
         
            +
                                raise EnvironmentError(
         
     | 
| 333 | 
         
            +
                                    f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} "
         
     | 
| 334 | 
         
            +
                                )
         
     | 
| 335 | 
         
            +
                            model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)
         
     | 
| 336 | 
         
            +
                        elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)):
         
     | 
| 337 | 
         
            +
                            # Load from a Flax checkpoint
         
     | 
| 338 | 
         
            +
                            model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)
         
     | 
| 339 | 
         
            +
                        # Check if pytorch weights exist instead
         
     | 
| 340 | 
         
            +
                        elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
         
     | 
| 341 | 
         
            +
                            raise EnvironmentError(
         
     | 
| 342 | 
         
            +
                                f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
         
     | 
| 343 | 
         
            +
                                " using `from_pt=True`."
         
     | 
| 344 | 
         
            +
                            )
         
     | 
| 345 | 
         
            +
                        else:
         
     | 
| 346 | 
         
            +
                            raise EnvironmentError(
         
     | 
| 347 | 
         
            +
                                f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
         
     | 
| 348 | 
         
            +
                                f"{pretrained_path_with_subfolder}."
         
     | 
| 349 | 
         
            +
                            )
         
     | 
| 350 | 
         
            +
                    else:
         
     | 
| 351 | 
         
            +
                        try:
         
     | 
| 352 | 
         
            +
                            model_file = hf_hub_download(
         
     | 
| 353 | 
         
            +
                                pretrained_model_name_or_path,
         
     | 
| 354 | 
         
            +
                                filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME,
         
     | 
| 355 | 
         
            +
                                cache_dir=cache_dir,
         
     | 
| 356 | 
         
            +
                                force_download=force_download,
         
     | 
| 357 | 
         
            +
                                proxies=proxies,
         
     | 
| 358 | 
         
            +
                                resume_download=resume_download,
         
     | 
| 359 | 
         
            +
                                local_files_only=local_files_only,
         
     | 
| 360 | 
         
            +
                                use_auth_token=use_auth_token,
         
     | 
| 361 | 
         
            +
                                user_agent=user_agent,
         
     | 
| 362 | 
         
            +
                                subfolder=subfolder,
         
     | 
| 363 | 
         
            +
                                revision=revision,
         
     | 
| 364 | 
         
            +
                            )
         
     | 
| 365 | 
         
            +
             
     | 
| 366 | 
         
            +
                        except RepositoryNotFoundError:
         
     | 
| 367 | 
         
            +
                            raise EnvironmentError(
         
     | 
| 368 | 
         
            +
                                f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
         
     | 
| 369 | 
         
            +
                                "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
         
     | 
| 370 | 
         
            +
                                "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
         
     | 
| 371 | 
         
            +
                                "login`."
         
     | 
| 372 | 
         
            +
                            )
         
     | 
| 373 | 
         
            +
                        except RevisionNotFoundError:
         
     | 
| 374 | 
         
            +
                            raise EnvironmentError(
         
     | 
| 375 | 
         
            +
                                f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
         
     | 
| 376 | 
         
            +
                                "this model name. Check the model page at "
         
     | 
| 377 | 
         
            +
                                f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
         
     | 
| 378 | 
         
            +
                            )
         
     | 
| 379 | 
         
            +
                        except EntryNotFoundError:
         
     | 
| 380 | 
         
            +
                            raise EnvironmentError(
         
     | 
| 381 | 
         
            +
                                f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}."
         
     | 
| 382 | 
         
            +
                            )
         
     | 
| 383 | 
         
            +
                        except HTTPError as err:
         
     | 
| 384 | 
         
            +
                            raise EnvironmentError(
         
     | 
| 385 | 
         
            +
                                f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
         
     | 
| 386 | 
         
            +
                                f"{err}"
         
     | 
| 387 | 
         
            +
                            )
         
     | 
| 388 | 
         
            +
                        except ValueError:
         
     | 
| 389 | 
         
            +
                            raise EnvironmentError(
         
     | 
| 390 | 
         
            +
                                f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
         
     | 
| 391 | 
         
            +
                                f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
         
     | 
| 392 | 
         
            +
                                f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your"
         
     | 
| 393 | 
         
            +
                                " internet connection or see how to run the library in offline mode at"
         
     | 
| 394 | 
         
            +
                                " 'https://huggingface.co/docs/transformers/installation#offline-mode'."
         
     | 
| 395 | 
         
            +
                            )
         
     | 
| 396 | 
         
            +
                        except EnvironmentError:
         
     | 
| 397 | 
         
            +
                            raise EnvironmentError(
         
     | 
| 398 | 
         
            +
                                f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
         
     | 
| 399 | 
         
            +
                                "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
         
     | 
| 400 | 
         
            +
                                f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
         
     | 
| 401 | 
         
            +
                                f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
         
     | 
| 402 | 
         
            +
                            )
         
     | 
| 403 | 
         
            +
             
     | 
| 404 | 
         
            +
                    if from_pt:
         
     | 
| 405 | 
         
            +
                        if is_torch_available():
         
     | 
| 406 | 
         
            +
                            from .modeling_utils import load_state_dict
         
     | 
| 407 | 
         
            +
                        else:
         
     | 
| 408 | 
         
            +
                            raise EnvironmentError(
         
     | 
| 409 | 
         
            +
                                "Can't load the model in PyTorch format because PyTorch is not installed. "
         
     | 
| 410 | 
         
            +
                                "Please, install PyTorch or use native Flax weights."
         
     | 
| 411 | 
         
            +
                            )
         
     | 
| 412 | 
         
            +
             
     | 
| 413 | 
         
            +
                        # Step 1: Get the pytorch file
         
     | 
| 414 | 
         
            +
                        pytorch_model_file = load_state_dict(model_file)
         
     | 
| 415 | 
         
            +
             
     | 
| 416 | 
         
            +
                        # Step 2: Convert the weights
         
     | 
| 417 | 
         
            +
                        state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model)
         
     | 
| 418 | 
         
            +
                    else:
         
     | 
| 419 | 
         
            +
                        try:
         
     | 
| 420 | 
         
            +
                            with open(model_file, "rb") as state_f:
         
     | 
| 421 | 
         
            +
                                state = from_bytes(cls, state_f.read())
         
     | 
| 422 | 
         
            +
                        except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
         
     | 
| 423 | 
         
            +
                            try:
         
     | 
| 424 | 
         
            +
                                with open(model_file) as f:
         
     | 
| 425 | 
         
            +
                                    if f.read().startswith("version"):
         
     | 
| 426 | 
         
            +
                                        raise OSError(
         
     | 
| 427 | 
         
            +
                                            "You seem to have cloned a repository without having git-lfs installed. Please"
         
     | 
| 428 | 
         
            +
                                            " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
         
     | 
| 429 | 
         
            +
                                            " folder you cloned."
         
     | 
| 430 | 
         
            +
                                        )
         
     | 
| 431 | 
         
            +
                                    else:
         
     | 
| 432 | 
         
            +
                                        raise ValueError from e
         
     | 
| 433 | 
         
            +
                            except (UnicodeDecodeError, ValueError):
         
     | 
| 434 | 
         
            +
                                raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
         
     | 
| 435 | 
         
            +
                        # make sure all arrays are stored as jnp.ndarray
         
     | 
| 436 | 
         
            +
                        # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
         
     | 
| 437 | 
         
            +
                        # https://github.com/google/flax/issues/1261
         
     | 
| 438 | 
         
            +
                    state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
         
     | 
| 439 | 
         
            +
             
     | 
| 440 | 
         
            +
                    # flatten dicts
         
     | 
| 441 | 
         
            +
                    state = flatten_dict(state)
         
     | 
| 442 | 
         
            +
             
     | 
| 443 | 
         
            +
                    params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0))
         
     | 
| 444 | 
         
            +
                    required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
         
     | 
| 445 | 
         
            +
             
     | 
| 446 | 
         
            +
                    shape_state = flatten_dict(unfreeze(params_shape_tree))
         
     | 
| 447 | 
         
            +
             
     | 
| 448 | 
         
            +
                    missing_keys = required_params - set(state.keys())
         
     | 
| 449 | 
         
            +
                    unexpected_keys = set(state.keys()) - required_params
         
     | 
| 450 | 
         
            +
             
     | 
| 451 | 
         
            +
                    if missing_keys:
         
     | 
| 452 | 
         
            +
                        logger.warning(
         
     | 
| 453 | 
         
            +
                            f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
         
     | 
| 454 | 
         
            +
                            "Make sure to call model.init_weights to initialize the missing weights."
         
     | 
| 455 | 
         
            +
                        )
         
     | 
| 456 | 
         
            +
                        cls._missing_keys = missing_keys
         
     | 
| 457 | 
         
            +
             
     | 
| 458 | 
         
            +
                    for key in state.keys():
         
     | 
| 459 | 
         
            +
                        if key in shape_state and state[key].shape != shape_state[key].shape:
         
     | 
| 460 | 
         
            +
                            raise ValueError(
         
     | 
| 461 | 
         
            +
                                f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
         
     | 
| 462 | 
         
            +
                                f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. "
         
     | 
| 463 | 
         
            +
                            )
         
     | 
| 464 | 
         
            +
             
     | 
| 465 | 
         
            +
                    # remove unexpected keys to not be saved again
         
     | 
| 466 | 
         
            +
                    for unexpected_key in unexpected_keys:
         
     | 
| 467 | 
         
            +
                        del state[unexpected_key]
         
     | 
| 468 | 
         
            +
             
     | 
| 469 | 
         
            +
                    if len(unexpected_keys) > 0:
         
     | 
| 470 | 
         
            +
                        logger.warning(
         
     | 
| 471 | 
         
            +
                            f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
         
     | 
| 472 | 
         
            +
                            f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
         
     | 
| 473 | 
         
            +
                            f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
         
     | 
| 474 | 
         
            +
                            " with another architecture."
         
     | 
| 475 | 
         
            +
                        )
         
     | 
| 476 | 
         
            +
                    else:
         
     | 
| 477 | 
         
            +
                        logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
         
     | 
| 478 | 
         
            +
             
     | 
| 479 | 
         
            +
                    if len(missing_keys) > 0:
         
     | 
| 480 | 
         
            +
                        logger.warning(
         
     | 
| 481 | 
         
            +
                            f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
         
     | 
| 482 | 
         
            +
                            f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
         
     | 
| 483 | 
         
            +
                            " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
         
     | 
| 484 | 
         
            +
                        )
         
     | 
| 485 | 
         
            +
                    else:
         
     | 
| 486 | 
         
            +
                        logger.info(
         
     | 
| 487 | 
         
            +
                            f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
         
     | 
| 488 | 
         
            +
                            f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
         
     | 
| 489 | 
         
            +
                            f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
         
     | 
| 490 | 
         
            +
                            " training."
         
     | 
| 491 | 
         
            +
                        )
         
     | 
| 492 | 
         
            +
             
     | 
| 493 | 
         
            +
                    return model, unflatten_dict(state)
         
     | 
| 494 | 
         
            +
             
     | 
| 495 | 
         
            +
                def save_pretrained(
         
     | 
| 496 | 
         
            +
                    self,
         
     | 
| 497 | 
         
            +
                    save_directory: Union[str, os.PathLike],
         
     | 
| 498 | 
         
            +
                    params: Union[Dict, FrozenDict],
         
     | 
| 499 | 
         
            +
                    is_main_process: bool = True,
         
     | 
| 500 | 
         
            +
                ):
         
     | 
| 501 | 
         
            +
                    """
         
     | 
| 502 | 
         
            +
                    Save a model and its configuration file to a directory so that it can be reloaded using the
         
     | 
| 503 | 
         
            +
                    [`~FlaxModelMixin.from_pretrained`] class method.
         
     | 
| 504 | 
         
            +
             
     | 
| 505 | 
         
            +
                    Arguments:
         
     | 
| 506 | 
         
            +
                        save_directory (`str` or `os.PathLike`):
         
     | 
| 507 | 
         
            +
                            Directory to save a model and its configuration file to. Will be created if it doesn't exist.
         
     | 
| 508 | 
         
            +
                        params (`Union[Dict, FrozenDict]`):
         
     | 
| 509 | 
         
            +
                            A `PyTree` of model parameters.
         
     | 
| 510 | 
         
            +
                        is_main_process (`bool`, *optional*, defaults to `True`):
         
     | 
| 511 | 
         
            +
                            Whether the process calling this is the main process or not. Useful during distributed training and you
         
     | 
| 512 | 
         
            +
                            need to call this function on all processes. In this case, set `is_main_process=True` only on the main
         
     | 
| 513 | 
         
            +
                            process to avoid race conditions.
         
     | 
| 514 | 
         
            +
                    """
         
     | 
| 515 | 
         
            +
                    if os.path.isfile(save_directory):
         
     | 
| 516 | 
         
            +
                        logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
         
     | 
| 517 | 
         
            +
                        return
         
     | 
| 518 | 
         
            +
             
     | 
| 519 | 
         
            +
                    os.makedirs(save_directory, exist_ok=True)
         
     | 
| 520 | 
         
            +
             
     | 
| 521 | 
         
            +
                    model_to_save = self
         
     | 
| 522 | 
         
            +
             
     | 
| 523 | 
         
            +
                    # Attach architecture to the config
         
     | 
| 524 | 
         
            +
                    # Save the config
         
     | 
| 525 | 
         
            +
                    if is_main_process:
         
     | 
| 526 | 
         
            +
                        model_to_save.save_config(save_directory)
         
     | 
| 527 | 
         
            +
             
     | 
| 528 | 
         
            +
                    # save model
         
     | 
| 529 | 
         
            +
                    output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
         
     | 
| 530 | 
         
            +
                    with open(output_model_file, "wb") as f:
         
     | 
| 531 | 
         
            +
                        model_bytes = to_bytes(params)
         
     | 
| 532 | 
         
            +
                        f.write(model_bytes)
         
     | 
| 533 | 
         
            +
             
     | 
| 534 | 
         
            +
                    logger.info(f"Model weights saved in {output_model_file}")
         
     | 
    	
        6DoF/diffusers/models/modeling_pytorch_flax_utils.py
    ADDED
    
    | 
         @@ -0,0 +1,161 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # coding=utf-8
         
     | 
| 2 | 
         
            +
            # Copyright 2023 The HuggingFace Inc. team.
         
     | 
| 3 | 
         
            +
            #
         
     | 
| 4 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 5 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 6 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 7 | 
         
            +
            #
         
     | 
| 8 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 9 | 
         
            +
            #
         
     | 
| 10 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 11 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 12 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 13 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 14 | 
         
            +
            # limitations under the License.
         
     | 
| 15 | 
         
            +
            """ PyTorch - Flax general utilities."""
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            from pickle import UnpicklingError
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            import jax
         
     | 
| 20 | 
         
            +
            import jax.numpy as jnp
         
     | 
| 21 | 
         
            +
            import numpy as np
         
     | 
| 22 | 
         
            +
            from flax.serialization import from_bytes
         
     | 
| 23 | 
         
            +
            from flax.traverse_util import flatten_dict
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            from ..utils import logging
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            logger = logging.get_logger(__name__)
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            #####################
         
     | 
| 32 | 
         
            +
            # Flax => PyTorch #
         
     | 
| 33 | 
         
            +
            #####################
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            # from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352
         
     | 
| 37 | 
         
            +
            def load_flax_checkpoint_in_pytorch_model(pt_model, model_file):
         
     | 
| 38 | 
         
            +
                try:
         
     | 
| 39 | 
         
            +
                    with open(model_file, "rb") as flax_state_f:
         
     | 
| 40 | 
         
            +
                        flax_state = from_bytes(None, flax_state_f.read())
         
     | 
| 41 | 
         
            +
                except UnpicklingError as e:
         
     | 
| 42 | 
         
            +
                    try:
         
     | 
| 43 | 
         
            +
                        with open(model_file) as f:
         
     | 
| 44 | 
         
            +
                            if f.read().startswith("version"):
         
     | 
| 45 | 
         
            +
                                raise OSError(
         
     | 
| 46 | 
         
            +
                                    "You seem to have cloned a repository without having git-lfs installed. Please"
         
     | 
| 47 | 
         
            +
                                    " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
         
     | 
| 48 | 
         
            +
                                    " folder you cloned."
         
     | 
| 49 | 
         
            +
                                )
         
     | 
| 50 | 
         
            +
                            else:
         
     | 
| 51 | 
         
            +
                                raise ValueError from e
         
     | 
| 52 | 
         
            +
                    except (UnicodeDecodeError, ValueError):
         
     | 
| 53 | 
         
            +
                        raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                return load_flax_weights_in_pytorch_model(pt_model, flax_state)
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            def load_flax_weights_in_pytorch_model(pt_model, flax_state):
         
     | 
| 59 | 
         
            +
                """Load flax checkpoints in a PyTorch model"""
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                try:
         
     | 
| 62 | 
         
            +
                    import torch  # noqa: F401
         
     | 
| 63 | 
         
            +
                except ImportError:
         
     | 
| 64 | 
         
            +
                    logger.error(
         
     | 
| 65 | 
         
            +
                        "Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see"
         
     | 
| 66 | 
         
            +
                        " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
         
     | 
| 67 | 
         
            +
                        " instructions."
         
     | 
| 68 | 
         
            +
                    )
         
     | 
| 69 | 
         
            +
                    raise
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                # check if we have bf16 weights
         
     | 
| 72 | 
         
            +
                is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
         
     | 
| 73 | 
         
            +
                if any(is_type_bf16):
         
     | 
| 74 | 
         
            +
                    # convert all weights to fp32 if they are bf16 since torch.from_numpy can-not handle bf16
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    # and bf16 is not fully supported in PT yet.
         
     | 
| 77 | 
         
            +
                    logger.warning(
         
     | 
| 78 | 
         
            +
                        "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
         
     | 
| 79 | 
         
            +
                        "before loading those in PyTorch model."
         
     | 
| 80 | 
         
            +
                    )
         
     | 
| 81 | 
         
            +
                    flax_state = jax.tree_util.tree_map(
         
     | 
| 82 | 
         
            +
                        lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
         
     | 
| 83 | 
         
            +
                    )
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                pt_model.base_model_prefix = ""
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                flax_state_dict = flatten_dict(flax_state, sep=".")
         
     | 
| 88 | 
         
            +
                pt_model_dict = pt_model.state_dict()
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                # keep track of unexpected & missing keys
         
     | 
| 91 | 
         
            +
                unexpected_keys = []
         
     | 
| 92 | 
         
            +
                missing_keys = set(pt_model_dict.keys())
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                for flax_key_tuple, flax_tensor in flax_state_dict.items():
         
     | 
| 95 | 
         
            +
                    flax_key_tuple_array = flax_key_tuple.split(".")
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                    if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4:
         
     | 
| 98 | 
         
            +
                        flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
         
     | 
| 99 | 
         
            +
                        flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
         
     | 
| 100 | 
         
            +
                    elif flax_key_tuple_array[-1] == "kernel":
         
     | 
| 101 | 
         
            +
                        flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
         
     | 
| 102 | 
         
            +
                        flax_tensor = flax_tensor.T
         
     | 
| 103 | 
         
            +
                    elif flax_key_tuple_array[-1] == "scale":
         
     | 
| 104 | 
         
            +
                        flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                    if "time_embedding" not in flax_key_tuple_array:
         
     | 
| 107 | 
         
            +
                        for i, flax_key_tuple_string in enumerate(flax_key_tuple_array):
         
     | 
| 108 | 
         
            +
                            flax_key_tuple_array[i] = (
         
     | 
| 109 | 
         
            +
                                flax_key_tuple_string.replace("_0", ".0")
         
     | 
| 110 | 
         
            +
                                .replace("_1", ".1")
         
     | 
| 111 | 
         
            +
                                .replace("_2", ".2")
         
     | 
| 112 | 
         
            +
                                .replace("_3", ".3")
         
     | 
| 113 | 
         
            +
                                .replace("_4", ".4")
         
     | 
| 114 | 
         
            +
                                .replace("_5", ".5")
         
     | 
| 115 | 
         
            +
                                .replace("_6", ".6")
         
     | 
| 116 | 
         
            +
                                .replace("_7", ".7")
         
     | 
| 117 | 
         
            +
                                .replace("_8", ".8")
         
     | 
| 118 | 
         
            +
                                .replace("_9", ".9")
         
     | 
| 119 | 
         
            +
                            )
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                    flax_key = ".".join(flax_key_tuple_array)
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                    if flax_key in pt_model_dict:
         
     | 
| 124 | 
         
            +
                        if flax_tensor.shape != pt_model_dict[flax_key].shape:
         
     | 
| 125 | 
         
            +
                            raise ValueError(
         
     | 
| 126 | 
         
            +
                                f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected "
         
     | 
| 127 | 
         
            +
                                f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
         
     | 
| 128 | 
         
            +
                            )
         
     | 
| 129 | 
         
            +
                        else:
         
     | 
| 130 | 
         
            +
                            # add weight to pytorch dict
         
     | 
| 131 | 
         
            +
                            flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
         
     | 
| 132 | 
         
            +
                            pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
         
     | 
| 133 | 
         
            +
                            # remove from missing keys
         
     | 
| 134 | 
         
            +
                            missing_keys.remove(flax_key)
         
     | 
| 135 | 
         
            +
                    else:
         
     | 
| 136 | 
         
            +
                        # weight is not expected by PyTorch model
         
     | 
| 137 | 
         
            +
                        unexpected_keys.append(flax_key)
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                pt_model.load_state_dict(pt_model_dict)
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                # re-transform missing_keys to list
         
     | 
| 142 | 
         
            +
                missing_keys = list(missing_keys)
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                if len(unexpected_keys) > 0:
         
     | 
| 145 | 
         
            +
                    logger.warning(
         
     | 
| 146 | 
         
            +
                        "Some weights of the Flax model were not used when initializing the PyTorch model"
         
     | 
| 147 | 
         
            +
                        f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
         
     | 
| 148 | 
         
            +
                        f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture"
         
     | 
| 149 | 
         
            +
                        " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This"
         
     | 
| 150 | 
         
            +
                        f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect"
         
     | 
| 151 | 
         
            +
                        " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
         
     | 
| 152 | 
         
            +
                        " FlaxBertForSequenceClassification model)."
         
     | 
| 153 | 
         
            +
                    )
         
     | 
| 154 | 
         
            +
                if len(missing_keys) > 0:
         
     | 
| 155 | 
         
            +
                    logger.warning(
         
     | 
| 156 | 
         
            +
                        f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
         
     | 
| 157 | 
         
            +
                        f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
         
     | 
| 158 | 
         
            +
                        " use it for predictions and inference."
         
     | 
| 159 | 
         
            +
                    )
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                return pt_model
         
     | 
    	
        6DoF/diffusers/models/modeling_utils.py
    ADDED
    
    | 
         @@ -0,0 +1,980 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # coding=utf-8
         
     | 
| 2 | 
         
            +
            # Copyright 2023 The HuggingFace Inc. team.
         
     | 
| 3 | 
         
            +
            # Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
         
     | 
| 4 | 
         
            +
            #
         
     | 
| 5 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 6 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 7 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 10 | 
         
            +
            #
         
     | 
| 11 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 12 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 13 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 14 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 15 | 
         
            +
            # limitations under the License.
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            import inspect
         
     | 
| 18 | 
         
            +
            import itertools
         
     | 
| 19 | 
         
            +
            import os
         
     | 
| 20 | 
         
            +
            import re
         
     | 
| 21 | 
         
            +
            from functools import partial
         
     | 
| 22 | 
         
            +
            from typing import Any, Callable, List, Optional, Tuple, Union
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            import torch
         
     | 
| 25 | 
         
            +
            from torch import Tensor, device, nn
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            from .. import __version__
         
     | 
| 28 | 
         
            +
            from ..utils import (
         
     | 
| 29 | 
         
            +
                CONFIG_NAME,
         
     | 
| 30 | 
         
            +
                DIFFUSERS_CACHE,
         
     | 
| 31 | 
         
            +
                FLAX_WEIGHTS_NAME,
         
     | 
| 32 | 
         
            +
                HF_HUB_OFFLINE,
         
     | 
| 33 | 
         
            +
                SAFETENSORS_WEIGHTS_NAME,
         
     | 
| 34 | 
         
            +
                WEIGHTS_NAME,
         
     | 
| 35 | 
         
            +
                _add_variant,
         
     | 
| 36 | 
         
            +
                _get_model_file,
         
     | 
| 37 | 
         
            +
                deprecate,
         
     | 
| 38 | 
         
            +
                is_accelerate_available,
         
     | 
| 39 | 
         
            +
                is_safetensors_available,
         
     | 
| 40 | 
         
            +
                is_torch_version,
         
     | 
| 41 | 
         
            +
                logging,
         
     | 
| 42 | 
         
            +
            )
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            logger = logging.get_logger(__name__)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            if is_torch_version(">=", "1.9.0"):
         
     | 
| 49 | 
         
            +
                _LOW_CPU_MEM_USAGE_DEFAULT = True
         
     | 
| 50 | 
         
            +
            else:
         
     | 
| 51 | 
         
            +
                _LOW_CPU_MEM_USAGE_DEFAULT = False
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            if is_accelerate_available():
         
     | 
| 55 | 
         
            +
                import accelerate
         
     | 
| 56 | 
         
            +
                from accelerate.utils import set_module_tensor_to_device
         
     | 
| 57 | 
         
            +
                from accelerate.utils.versions import is_torch_version
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            if is_safetensors_available():
         
     | 
| 60 | 
         
            +
                import safetensors
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            def get_parameter_device(parameter: torch.nn.Module):
         
     | 
| 64 | 
         
            +
                try:
         
     | 
| 65 | 
         
            +
                    parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
         
     | 
| 66 | 
         
            +
                    return next(parameters_and_buffers).device
         
     | 
| 67 | 
         
            +
                except StopIteration:
         
     | 
| 68 | 
         
            +
                    # For torch.nn.DataParallel compatibility in PyTorch 1.5
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                    def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
         
     | 
| 71 | 
         
            +
                        tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
         
     | 
| 72 | 
         
            +
                        return tuples
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    gen = parameter._named_members(get_members_fn=find_tensor_attributes)
         
     | 
| 75 | 
         
            +
                    first_tuple = next(gen)
         
     | 
| 76 | 
         
            +
                    return first_tuple[1].device
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            def get_parameter_dtype(parameter: torch.nn.Module):
         
     | 
| 80 | 
         
            +
                try:
         
     | 
| 81 | 
         
            +
                    params = tuple(parameter.parameters())
         
     | 
| 82 | 
         
            +
                    if len(params) > 0:
         
     | 
| 83 | 
         
            +
                        return params[0].dtype
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                    buffers = tuple(parameter.buffers())
         
     | 
| 86 | 
         
            +
                    if len(buffers) > 0:
         
     | 
| 87 | 
         
            +
                        return buffers[0].dtype
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                except StopIteration:
         
     | 
| 90 | 
         
            +
                    # For torch.nn.DataParallel compatibility in PyTorch 1.5
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                    def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
         
     | 
| 93 | 
         
            +
                        tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
         
     | 
| 94 | 
         
            +
                        return tuples
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                    gen = parameter._named_members(get_members_fn=find_tensor_attributes)
         
     | 
| 97 | 
         
            +
                    first_tuple = next(gen)
         
     | 
| 98 | 
         
            +
                    return first_tuple[1].dtype
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
            def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
         
     | 
| 102 | 
         
            +
                """
         
     | 
| 103 | 
         
            +
                Reads a checkpoint file, returning properly formatted errors if they arise.
         
     | 
| 104 | 
         
            +
                """
         
     | 
| 105 | 
         
            +
                try:
         
     | 
| 106 | 
         
            +
                    if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
         
     | 
| 107 | 
         
            +
                        return torch.load(checkpoint_file, map_location="cpu")
         
     | 
| 108 | 
         
            +
                    else:
         
     | 
| 109 | 
         
            +
                        return safetensors.torch.load_file(checkpoint_file, device="cpu")
         
     | 
| 110 | 
         
            +
                except Exception as e:
         
     | 
| 111 | 
         
            +
                    try:
         
     | 
| 112 | 
         
            +
                        with open(checkpoint_file) as f:
         
     | 
| 113 | 
         
            +
                            if f.read().startswith("version"):
         
     | 
| 114 | 
         
            +
                                raise OSError(
         
     | 
| 115 | 
         
            +
                                    "You seem to have cloned a repository without having git-lfs installed. Please install "
         
     | 
| 116 | 
         
            +
                                    "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
         
     | 
| 117 | 
         
            +
                                    "you cloned."
         
     | 
| 118 | 
         
            +
                                )
         
     | 
| 119 | 
         
            +
                            else:
         
     | 
| 120 | 
         
            +
                                raise ValueError(
         
     | 
| 121 | 
         
            +
                                    f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
         
     | 
| 122 | 
         
            +
                                    "model. Make sure you have saved the model properly."
         
     | 
| 123 | 
         
            +
                                ) from e
         
     | 
| 124 | 
         
            +
                    except (UnicodeDecodeError, ValueError):
         
     | 
| 125 | 
         
            +
                        raise OSError(
         
     | 
| 126 | 
         
            +
                            f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
         
     | 
| 127 | 
         
            +
                            f"at '{checkpoint_file}'. "
         
     | 
| 128 | 
         
            +
                            "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
         
     | 
| 129 | 
         
            +
                        )
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
            def _load_state_dict_into_model(model_to_load, state_dict):
         
     | 
| 133 | 
         
            +
                # Convert old format to new format if needed from a PyTorch state_dict
         
     | 
| 134 | 
         
            +
                # copy state_dict so _load_from_state_dict can modify it
         
     | 
| 135 | 
         
            +
                state_dict = state_dict.copy()
         
     | 
| 136 | 
         
            +
                error_msgs = []
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
         
     | 
| 139 | 
         
            +
                # so we need to apply the function recursively.
         
     | 
| 140 | 
         
            +
                def load(module: torch.nn.Module, prefix=""):
         
     | 
| 141 | 
         
            +
                    args = (state_dict, prefix, {}, True, [], [], error_msgs)
         
     | 
| 142 | 
         
            +
                    module._load_from_state_dict(*args)
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                    for name, child in module._modules.items():
         
     | 
| 145 | 
         
            +
                        if child is not None:
         
     | 
| 146 | 
         
            +
                            load(child, prefix + name + ".")
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                load(model_to_load)
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                return error_msgs
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
            class ModelMixin(torch.nn.Module):
         
     | 
| 154 | 
         
            +
                r"""
         
     | 
| 155 | 
         
            +
                Base class for all models.
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                [`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
         
     | 
| 158 | 
         
            +
                saving models.
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                    - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
         
     | 
| 161 | 
         
            +
                """
         
     | 
| 162 | 
         
            +
                config_name = CONFIG_NAME
         
     | 
| 163 | 
         
            +
                _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
         
     | 
| 164 | 
         
            +
                _supports_gradient_checkpointing = False
         
     | 
| 165 | 
         
            +
                _keys_to_ignore_on_load_unexpected = None
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                def __init__(self):
         
     | 
| 168 | 
         
            +
                    super().__init__()
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                def __getattr__(self, name: str) -> Any:
         
     | 
| 171 | 
         
            +
                    """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
         
     | 
| 172 | 
         
            +
                    config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
         
     | 
| 173 | 
         
            +
                    __getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
         
     | 
| 174 | 
         
            +
                    https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
         
     | 
| 175 | 
         
            +
                    """
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                    is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
         
     | 
| 178 | 
         
            +
                    is_attribute = name in self.__dict__
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                    if is_in_config and not is_attribute:
         
     | 
| 181 | 
         
            +
                        deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
         
     | 
| 182 | 
         
            +
                        deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
         
     | 
| 183 | 
         
            +
                        return self._internal_dict[name]
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                    # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
         
     | 
| 186 | 
         
            +
                    return super().__getattr__(name)
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                @property
         
     | 
| 189 | 
         
            +
                def is_gradient_checkpointing(self) -> bool:
         
     | 
| 190 | 
         
            +
                    """
         
     | 
| 191 | 
         
            +
                    Whether gradient checkpointing is activated for this model or not.
         
     | 
| 192 | 
         
            +
                    """
         
     | 
| 193 | 
         
            +
                    return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                def enable_gradient_checkpointing(self):
         
     | 
| 196 | 
         
            +
                    """
         
     | 
| 197 | 
         
            +
                    Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
         
     | 
| 198 | 
         
            +
                    *checkpoint activations* in other frameworks).
         
     | 
| 199 | 
         
            +
                    """
         
     | 
| 200 | 
         
            +
                    if not self._supports_gradient_checkpointing:
         
     | 
| 201 | 
         
            +
                        raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
         
     | 
| 202 | 
         
            +
                    self.apply(partial(self._set_gradient_checkpointing, value=True))
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                def disable_gradient_checkpointing(self):
         
     | 
| 205 | 
         
            +
                    """
         
     | 
| 206 | 
         
            +
                    Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
         
     | 
| 207 | 
         
            +
                    *checkpoint activations* in other frameworks).
         
     | 
| 208 | 
         
            +
                    """
         
     | 
| 209 | 
         
            +
                    if self._supports_gradient_checkpointing:
         
     | 
| 210 | 
         
            +
                        self.apply(partial(self._set_gradient_checkpointing, value=False))
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                def set_use_memory_efficient_attention_xformers(
         
     | 
| 213 | 
         
            +
                    self, valid: bool, attention_op: Optional[Callable] = None
         
     | 
| 214 | 
         
            +
                ) -> None:
         
     | 
| 215 | 
         
            +
                    # Recursively walk through all the children.
         
     | 
| 216 | 
         
            +
                    # Any children which exposes the set_use_memory_efficient_attention_xformers method
         
     | 
| 217 | 
         
            +
                    # gets the message
         
     | 
| 218 | 
         
            +
                    def fn_recursive_set_mem_eff(module: torch.nn.Module):
         
     | 
| 219 | 
         
            +
                        if hasattr(module, "set_use_memory_efficient_attention_xformers"):
         
     | 
| 220 | 
         
            +
                            module.set_use_memory_efficient_attention_xformers(valid, attention_op)
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                        for child in module.children():
         
     | 
| 223 | 
         
            +
                            fn_recursive_set_mem_eff(child)
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                    for module in self.children():
         
     | 
| 226 | 
         
            +
                        if isinstance(module, torch.nn.Module):
         
     | 
| 227 | 
         
            +
                            fn_recursive_set_mem_eff(module)
         
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
                def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
         
     | 
| 230 | 
         
            +
                    r"""
         
     | 
| 231 | 
         
            +
                    Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
         
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
                    When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
         
     | 
| 234 | 
         
            +
                    inference. Speed up during training is not guaranteed.
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
                    <Tip warning={true}>
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                    ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
         
     | 
| 239 | 
         
            +
                    precedent.
         
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
                    </Tip>
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                    Parameters:
         
     | 
| 244 | 
         
            +
                        attention_op (`Callable`, *optional*):
         
     | 
| 245 | 
         
            +
                            Override the default `None` operator for use as `op` argument to the
         
     | 
| 246 | 
         
            +
                            [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
         
     | 
| 247 | 
         
            +
                            function of xFormers.
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
                    Examples:
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
                    ```py
         
     | 
| 252 | 
         
            +
                    >>> import torch
         
     | 
| 253 | 
         
            +
                    >>> from diffusers import UNet2DConditionModel
         
     | 
| 254 | 
         
            +
                    >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                    >>> model = UNet2DConditionModel.from_pretrained(
         
     | 
| 257 | 
         
            +
                    ...     "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
         
     | 
| 258 | 
         
            +
                    ... )
         
     | 
| 259 | 
         
            +
                    >>> model = model.to("cuda")
         
     | 
| 260 | 
         
            +
                    >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
         
     | 
| 261 | 
         
            +
                    ```
         
     | 
| 262 | 
         
            +
                    """
         
     | 
| 263 | 
         
            +
                    self.set_use_memory_efficient_attention_xformers(True, attention_op)
         
     | 
| 264 | 
         
            +
             
     | 
| 265 | 
         
            +
                def disable_xformers_memory_efficient_attention(self):
         
     | 
| 266 | 
         
            +
                    r"""
         
     | 
| 267 | 
         
            +
                    Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
         
     | 
| 268 | 
         
            +
                    """
         
     | 
| 269 | 
         
            +
                    self.set_use_memory_efficient_attention_xformers(False)
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
                def save_pretrained(
         
     | 
| 272 | 
         
            +
                    self,
         
     | 
| 273 | 
         
            +
                    save_directory: Union[str, os.PathLike],
         
     | 
| 274 | 
         
            +
                    is_main_process: bool = True,
         
     | 
| 275 | 
         
            +
                    save_function: Callable = None,
         
     | 
| 276 | 
         
            +
                    safe_serialization: bool = False,
         
     | 
| 277 | 
         
            +
                    variant: Optional[str] = None,
         
     | 
| 278 | 
         
            +
                ):
         
     | 
| 279 | 
         
            +
                    """
         
     | 
| 280 | 
         
            +
                    Save a model and its configuration file to a directory so that it can be reloaded using the
         
     | 
| 281 | 
         
            +
                    [`~models.ModelMixin.from_pretrained`] class method.
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
                    Arguments:
         
     | 
| 284 | 
         
            +
                        save_directory (`str` or `os.PathLike`):
         
     | 
| 285 | 
         
            +
                            Directory to save a model and its configuration file to. Will be created if it doesn't exist.
         
     | 
| 286 | 
         
            +
                        is_main_process (`bool`, *optional*, defaults to `True`):
         
     | 
| 287 | 
         
            +
                            Whether the process calling this is the main process or not. Useful during distributed training and you
         
     | 
| 288 | 
         
            +
                            need to call this function on all processes. In this case, set `is_main_process=True` only on the main
         
     | 
| 289 | 
         
            +
                            process to avoid race conditions.
         
     | 
| 290 | 
         
            +
                        save_function (`Callable`):
         
     | 
| 291 | 
         
            +
                            The function to use to save the state dictionary. Useful during distributed training when you need to
         
     | 
| 292 | 
         
            +
                            replace `torch.save` with another method. Can be configured with the environment variable
         
     | 
| 293 | 
         
            +
                            `DIFFUSERS_SAVE_MODE`.
         
     | 
| 294 | 
         
            +
                        safe_serialization (`bool`, *optional*, defaults to `False`):
         
     | 
| 295 | 
         
            +
                            Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
         
     | 
| 296 | 
         
            +
                        variant (`str`, *optional*):
         
     | 
| 297 | 
         
            +
                            If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
         
     | 
| 298 | 
         
            +
                    """
         
     | 
| 299 | 
         
            +
                    if safe_serialization and not is_safetensors_available():
         
     | 
| 300 | 
         
            +
                        raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
         
     | 
| 301 | 
         
            +
             
     | 
| 302 | 
         
            +
                    if os.path.isfile(save_directory):
         
     | 
| 303 | 
         
            +
                        logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
         
     | 
| 304 | 
         
            +
                        return
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
                    os.makedirs(save_directory, exist_ok=True)
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
                    model_to_save = self
         
     | 
| 309 | 
         
            +
             
     | 
| 310 | 
         
            +
                    # Attach architecture to the config
         
     | 
| 311 | 
         
            +
                    # Save the config
         
     | 
| 312 | 
         
            +
                    if is_main_process:
         
     | 
| 313 | 
         
            +
                        model_to_save.save_config(save_directory)
         
     | 
| 314 | 
         
            +
             
     | 
| 315 | 
         
            +
                    # Save the model
         
     | 
| 316 | 
         
            +
                    state_dict = model_to_save.state_dict()
         
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
            +
                    weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
         
     | 
| 319 | 
         
            +
                    weights_name = _add_variant(weights_name, variant)
         
     | 
| 320 | 
         
            +
             
     | 
| 321 | 
         
            +
                    # Save the model
         
     | 
| 322 | 
         
            +
                    if safe_serialization:
         
     | 
| 323 | 
         
            +
                        safetensors.torch.save_file(
         
     | 
| 324 | 
         
            +
                            state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
         
     | 
| 325 | 
         
            +
                        )
         
     | 
| 326 | 
         
            +
                    else:
         
     | 
| 327 | 
         
            +
                        torch.save(state_dict, os.path.join(save_directory, weights_name))
         
     | 
| 328 | 
         
            +
             
     | 
| 329 | 
         
            +
                    logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
         
     | 
| 330 | 
         
            +
             
     | 
| 331 | 
         
            +
                @classmethod
         
     | 
| 332 | 
         
            +
                def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
         
     | 
| 333 | 
         
            +
                    r"""
         
     | 
| 334 | 
         
            +
                    Instantiate a pretrained PyTorch model from a pretrained model configuration.
         
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
                    The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
         
     | 
| 337 | 
         
            +
                    train the model, set it back in training mode with `model.train()`.
         
     | 
| 338 | 
         
            +
             
     | 
| 339 | 
         
            +
                    Parameters:
         
     | 
| 340 | 
         
            +
                        pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
         
     | 
| 341 | 
         
            +
                            Can be either:
         
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
                                - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
         
     | 
| 344 | 
         
            +
                                  the Hub.
         
     | 
| 345 | 
         
            +
                                - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
         
     | 
| 346 | 
         
            +
                                  with [`~ModelMixin.save_pretrained`].
         
     | 
| 347 | 
         
            +
             
     | 
| 348 | 
         
            +
                        cache_dir (`Union[str, os.PathLike]`, *optional*):
         
     | 
| 349 | 
         
            +
                            Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
         
     | 
| 350 | 
         
            +
                            is not used.
         
     | 
| 351 | 
         
            +
                        torch_dtype (`str` or `torch.dtype`, *optional*):
         
     | 
| 352 | 
         
            +
                            Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
         
     | 
| 353 | 
         
            +
                            dtype is automatically derived from the model's weights.
         
     | 
| 354 | 
         
            +
                        force_download (`bool`, *optional*, defaults to `False`):
         
     | 
| 355 | 
         
            +
                            Whether or not to force the (re-)download of the model weights and configuration files, overriding the
         
     | 
| 356 | 
         
            +
                            cached versions if they exist.
         
     | 
| 357 | 
         
            +
                        resume_download (`bool`, *optional*, defaults to `False`):
         
     | 
| 358 | 
         
            +
                            Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
         
     | 
| 359 | 
         
            +
                            incompletely downloaded files are deleted.
         
     | 
| 360 | 
         
            +
                        proxies (`Dict[str, str]`, *optional*):
         
     | 
| 361 | 
         
            +
                            A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
         
     | 
| 362 | 
         
            +
                            'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
         
     | 
| 363 | 
         
            +
                        output_loading_info (`bool`, *optional*, defaults to `False`):
         
     | 
| 364 | 
         
            +
                            Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
         
     | 
| 365 | 
         
            +
                        local_files_only(`bool`, *optional*, defaults to `False`):
         
     | 
| 366 | 
         
            +
                            Whether to only load local model weights and configuration files or not. If set to `True`, the model
         
     | 
| 367 | 
         
            +
                            won't be downloaded from the Hub.
         
     | 
| 368 | 
         
            +
                        use_auth_token (`str` or *bool*, *optional*):
         
     | 
| 369 | 
         
            +
                            The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
         
     | 
| 370 | 
         
            +
                            `diffusers-cli login` (stored in `~/.huggingface`) is used.
         
     | 
| 371 | 
         
            +
                        revision (`str`, *optional*, defaults to `"main"`):
         
     | 
| 372 | 
         
            +
                            The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
         
     | 
| 373 | 
         
            +
                            allowed by Git.
         
     | 
| 374 | 
         
            +
                        from_flax (`bool`, *optional*, defaults to `False`):
         
     | 
| 375 | 
         
            +
                            Load the model weights from a Flax checkpoint save file.
         
     | 
| 376 | 
         
            +
                        subfolder (`str`, *optional*, defaults to `""`):
         
     | 
| 377 | 
         
            +
                            The subfolder location of a model file within a larger model repository on the Hub or locally.
         
     | 
| 378 | 
         
            +
                        mirror (`str`, *optional*):
         
     | 
| 379 | 
         
            +
                            Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
         
     | 
| 380 | 
         
            +
                            guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
         
     | 
| 381 | 
         
            +
                            information.
         
     | 
| 382 | 
         
            +
                        device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
         
     | 
| 383 | 
         
            +
                            A map that specifies where each submodule should go. It doesn't need to be defined for each
         
     | 
| 384 | 
         
            +
                            parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
         
     | 
| 385 | 
         
            +
                            same device.
         
     | 
| 386 | 
         
            +
             
     | 
| 387 | 
         
            +
                            Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
         
     | 
| 388 | 
         
            +
                            more information about each option see [designing a device
         
     | 
| 389 | 
         
            +
                            map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
         
     | 
| 390 | 
         
            +
                        max_memory (`Dict`, *optional*):
         
     | 
| 391 | 
         
            +
                            A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
         
     | 
| 392 | 
         
            +
                            each GPU and the available CPU RAM if unset.
         
     | 
| 393 | 
         
            +
                        offload_folder (`str` or `os.PathLike`, *optional*):
         
     | 
| 394 | 
         
            +
                            The path to offload weights if `device_map` contains the value `"disk"`.
         
     | 
| 395 | 
         
            +
                        offload_state_dict (`bool`, *optional*):
         
     | 
| 396 | 
         
            +
                            If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
         
     | 
| 397 | 
         
            +
                            the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
         
     | 
| 398 | 
         
            +
                            when there is some disk offload.
         
     | 
| 399 | 
         
            +
                        low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
         
     | 
| 400 | 
         
            +
                            Speed up model loading only loading the pretrained weights and not initializing the weights. This also
         
     | 
| 401 | 
         
            +
                            tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
         
     | 
| 402 | 
         
            +
                            Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
         
     | 
| 403 | 
         
            +
                            argument to `True` will raise an error.
         
     | 
| 404 | 
         
            +
                        variant (`str`, *optional*):
         
     | 
| 405 | 
         
            +
                            Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
         
     | 
| 406 | 
         
            +
                            loading `from_flax`.
         
     | 
| 407 | 
         
            +
                        use_safetensors (`bool`, *optional*, defaults to `None`):
         
     | 
| 408 | 
         
            +
                            If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
         
     | 
| 409 | 
         
            +
                            `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
         
     | 
| 410 | 
         
            +
                            weights. If set to `False`, `safetensors` weights are not loaded.
         
     | 
| 411 | 
         
            +
             
     | 
| 412 | 
         
            +
                    <Tip>
         
     | 
| 413 | 
         
            +
             
     | 
| 414 | 
         
            +
                    To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
         
     | 
| 415 | 
         
            +
                    `huggingface-cli login`. You can also activate the special
         
     | 
| 416 | 
         
            +
                    ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
         
     | 
| 417 | 
         
            +
                    firewalled environment.
         
     | 
| 418 | 
         
            +
             
     | 
| 419 | 
         
            +
                    </Tip>
         
     | 
| 420 | 
         
            +
             
     | 
| 421 | 
         
            +
                    Example:
         
     | 
| 422 | 
         
            +
             
     | 
| 423 | 
         
            +
                    ```py
         
     | 
| 424 | 
         
            +
                    from diffusers import UNet2DConditionModel
         
     | 
| 425 | 
         
            +
             
     | 
| 426 | 
         
            +
                    unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
         
     | 
| 427 | 
         
            +
                    ```
         
     | 
| 428 | 
         
            +
             
     | 
| 429 | 
         
            +
                    If you get the error message below, you need to finetune the weights for your downstream task:
         
     | 
| 430 | 
         
            +
             
     | 
| 431 | 
         
            +
                    ```bash
         
     | 
| 432 | 
         
            +
                    Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
         
     | 
| 433 | 
         
            +
                    - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
         
     | 
| 434 | 
         
            +
                    You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
         
     | 
| 435 | 
         
            +
                    ```
         
     | 
| 436 | 
         
            +
                    """
         
     | 
| 437 | 
         
            +
                    cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
         
     | 
| 438 | 
         
            +
                    ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
         
     | 
| 439 | 
         
            +
                    force_download = kwargs.pop("force_download", False)
         
     | 
| 440 | 
         
            +
                    from_flax = kwargs.pop("from_flax", False)
         
     | 
| 441 | 
         
            +
                    resume_download = kwargs.pop("resume_download", False)
         
     | 
| 442 | 
         
            +
                    proxies = kwargs.pop("proxies", None)
         
     | 
| 443 | 
         
            +
                    output_loading_info = kwargs.pop("output_loading_info", False)
         
     | 
| 444 | 
         
            +
                    local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
         
     | 
| 445 | 
         
            +
                    use_auth_token = kwargs.pop("use_auth_token", None)
         
     | 
| 446 | 
         
            +
                    revision = kwargs.pop("revision", None)
         
     | 
| 447 | 
         
            +
                    torch_dtype = kwargs.pop("torch_dtype", None)
         
     | 
| 448 | 
         
            +
                    subfolder = kwargs.pop("subfolder", None)
         
     | 
| 449 | 
         
            +
                    device_map = kwargs.pop("device_map", None)
         
     | 
| 450 | 
         
            +
                    max_memory = kwargs.pop("max_memory", None)
         
     | 
| 451 | 
         
            +
                    offload_folder = kwargs.pop("offload_folder", None)
         
     | 
| 452 | 
         
            +
                    offload_state_dict = kwargs.pop("offload_state_dict", False)
         
     | 
| 453 | 
         
            +
                    low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
         
     | 
| 454 | 
         
            +
                    variant = kwargs.pop("variant", None)
         
     | 
| 455 | 
         
            +
                    use_safetensors = kwargs.pop("use_safetensors", None)
         
     | 
| 456 | 
         
            +
             
     | 
| 457 | 
         
            +
                    if use_safetensors and not is_safetensors_available():
         
     | 
| 458 | 
         
            +
                        raise ValueError(
         
     | 
| 459 | 
         
            +
                            "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
         
     | 
| 460 | 
         
            +
                        )
         
     | 
| 461 | 
         
            +
             
     | 
| 462 | 
         
            +
                    allow_pickle = False
         
     | 
| 463 | 
         
            +
                    if use_safetensors is None:
         
     | 
| 464 | 
         
            +
                        use_safetensors = is_safetensors_available()
         
     | 
| 465 | 
         
            +
                        allow_pickle = True
         
     | 
| 466 | 
         
            +
             
     | 
| 467 | 
         
            +
                    if low_cpu_mem_usage and not is_accelerate_available():
         
     | 
| 468 | 
         
            +
                        low_cpu_mem_usage = False
         
     | 
| 469 | 
         
            +
                        logger.warning(
         
     | 
| 470 | 
         
            +
                            "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
         
     | 
| 471 | 
         
            +
                            " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
         
     | 
| 472 | 
         
            +
                            " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
         
     | 
| 473 | 
         
            +
                            " install accelerate\n```\n."
         
     | 
| 474 | 
         
            +
                        )
         
     | 
| 475 | 
         
            +
             
     | 
| 476 | 
         
            +
                    if device_map is not None and not is_accelerate_available():
         
     | 
| 477 | 
         
            +
                        raise NotImplementedError(
         
     | 
| 478 | 
         
            +
                            "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
         
     | 
| 479 | 
         
            +
                            " `device_map=None`. You can install accelerate with `pip install accelerate`."
         
     | 
| 480 | 
         
            +
                        )
         
     | 
| 481 | 
         
            +
             
     | 
| 482 | 
         
            +
                    # Check if we can handle device_map and dispatching the weights
         
     | 
| 483 | 
         
            +
                    if device_map is not None and not is_torch_version(">=", "1.9.0"):
         
     | 
| 484 | 
         
            +
                        raise NotImplementedError(
         
     | 
| 485 | 
         
            +
                            "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
         
     | 
| 486 | 
         
            +
                            " `device_map=None`."
         
     | 
| 487 | 
         
            +
                        )
         
     | 
| 488 | 
         
            +
             
     | 
| 489 | 
         
            +
                    if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
         
     | 
| 490 | 
         
            +
                        raise NotImplementedError(
         
     | 
| 491 | 
         
            +
                            "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
         
     | 
| 492 | 
         
            +
                            " `low_cpu_mem_usage=False`."
         
     | 
| 493 | 
         
            +
                        )
         
     | 
| 494 | 
         
            +
             
     | 
| 495 | 
         
            +
                    if low_cpu_mem_usage is False and device_map is not None:
         
     | 
| 496 | 
         
            +
                        raise ValueError(
         
     | 
| 497 | 
         
            +
                            f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
         
     | 
| 498 | 
         
            +
                            " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
         
     | 
| 499 | 
         
            +
                        )
         
     | 
| 500 | 
         
            +
             
     | 
| 501 | 
         
            +
                    # Load config if we don't provide a configuration
         
     | 
| 502 | 
         
            +
                    config_path = pretrained_model_name_or_path
         
     | 
| 503 | 
         
            +
             
     | 
| 504 | 
         
            +
                    user_agent = {
         
     | 
| 505 | 
         
            +
                        "diffusers": __version__,
         
     | 
| 506 | 
         
            +
                        "file_type": "model",
         
     | 
| 507 | 
         
            +
                        "framework": "pytorch",
         
     | 
| 508 | 
         
            +
                    }
         
     | 
| 509 | 
         
            +
             
     | 
| 510 | 
         
            +
                    # load config
         
     | 
| 511 | 
         
            +
                    config, unused_kwargs, commit_hash = cls.load_config(
         
     | 
| 512 | 
         
            +
                        config_path,
         
     | 
| 513 | 
         
            +
                        cache_dir=cache_dir,
         
     | 
| 514 | 
         
            +
                        return_unused_kwargs=True,
         
     | 
| 515 | 
         
            +
                        return_commit_hash=True,
         
     | 
| 516 | 
         
            +
                        force_download=force_download,
         
     | 
| 517 | 
         
            +
                        resume_download=resume_download,
         
     | 
| 518 | 
         
            +
                        proxies=proxies,
         
     | 
| 519 | 
         
            +
                        local_files_only=local_files_only,
         
     | 
| 520 | 
         
            +
                        use_auth_token=use_auth_token,
         
     | 
| 521 | 
         
            +
                        revision=revision,
         
     | 
| 522 | 
         
            +
                        subfolder=subfolder,
         
     | 
| 523 | 
         
            +
                        device_map=device_map,
         
     | 
| 524 | 
         
            +
                        max_memory=max_memory,
         
     | 
| 525 | 
         
            +
                        offload_folder=offload_folder,
         
     | 
| 526 | 
         
            +
                        offload_state_dict=offload_state_dict,
         
     | 
| 527 | 
         
            +
                        user_agent=user_agent,
         
     | 
| 528 | 
         
            +
                        **kwargs,
         
     | 
| 529 | 
         
            +
                    )
         
     | 
| 530 | 
         
            +
             
     | 
| 531 | 
         
            +
                    # load model
         
     | 
| 532 | 
         
            +
                    model_file = None
         
     | 
| 533 | 
         
            +
                    if from_flax:
         
     | 
| 534 | 
         
            +
                        model_file = _get_model_file(
         
     | 
| 535 | 
         
            +
                            pretrained_model_name_or_path,
         
     | 
| 536 | 
         
            +
                            weights_name=FLAX_WEIGHTS_NAME,
         
     | 
| 537 | 
         
            +
                            cache_dir=cache_dir,
         
     | 
| 538 | 
         
            +
                            force_download=force_download,
         
     | 
| 539 | 
         
            +
                            resume_download=resume_download,
         
     | 
| 540 | 
         
            +
                            proxies=proxies,
         
     | 
| 541 | 
         
            +
                            local_files_only=local_files_only,
         
     | 
| 542 | 
         
            +
                            use_auth_token=use_auth_token,
         
     | 
| 543 | 
         
            +
                            revision=revision,
         
     | 
| 544 | 
         
            +
                            subfolder=subfolder,
         
     | 
| 545 | 
         
            +
                            user_agent=user_agent,
         
     | 
| 546 | 
         
            +
                            commit_hash=commit_hash,
         
     | 
| 547 | 
         
            +
                        )
         
     | 
| 548 | 
         
            +
                        model = cls.from_config(config, **unused_kwargs)
         
     | 
| 549 | 
         
            +
             
     | 
| 550 | 
         
            +
                        # Convert the weights
         
     | 
| 551 | 
         
            +
                        from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
         
     | 
| 552 | 
         
            +
             
     | 
| 553 | 
         
            +
                        model = load_flax_checkpoint_in_pytorch_model(model, model_file)
         
     | 
| 554 | 
         
            +
                    else:
         
     | 
| 555 | 
         
            +
                        if use_safetensors:
         
     | 
| 556 | 
         
            +
                            try:
         
     | 
| 557 | 
         
            +
                                model_file = _get_model_file(
         
     | 
| 558 | 
         
            +
                                    pretrained_model_name_or_path,
         
     | 
| 559 | 
         
            +
                                    weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
         
     | 
| 560 | 
         
            +
                                    cache_dir=cache_dir,
         
     | 
| 561 | 
         
            +
                                    force_download=force_download,
         
     | 
| 562 | 
         
            +
                                    resume_download=resume_download,
         
     | 
| 563 | 
         
            +
                                    proxies=proxies,
         
     | 
| 564 | 
         
            +
                                    local_files_only=local_files_only,
         
     | 
| 565 | 
         
            +
                                    use_auth_token=use_auth_token,
         
     | 
| 566 | 
         
            +
                                    revision=revision,
         
     | 
| 567 | 
         
            +
                                    subfolder=subfolder,
         
     | 
| 568 | 
         
            +
                                    user_agent=user_agent,
         
     | 
| 569 | 
         
            +
                                    commit_hash=commit_hash,
         
     | 
| 570 | 
         
            +
                                )
         
     | 
| 571 | 
         
            +
                            except IOError as e:
         
     | 
| 572 | 
         
            +
                                if not allow_pickle:
         
     | 
| 573 | 
         
            +
                                    raise e
         
     | 
| 574 | 
         
            +
                                pass
         
     | 
| 575 | 
         
            +
                        if model_file is None:
         
     | 
| 576 | 
         
            +
                            model_file = _get_model_file(
         
     | 
| 577 | 
         
            +
                                pretrained_model_name_or_path,
         
     | 
| 578 | 
         
            +
                                weights_name=_add_variant(WEIGHTS_NAME, variant),
         
     | 
| 579 | 
         
            +
                                cache_dir=cache_dir,
         
     | 
| 580 | 
         
            +
                                force_download=force_download,
         
     | 
| 581 | 
         
            +
                                resume_download=resume_download,
         
     | 
| 582 | 
         
            +
                                proxies=proxies,
         
     | 
| 583 | 
         
            +
                                local_files_only=local_files_only,
         
     | 
| 584 | 
         
            +
                                use_auth_token=use_auth_token,
         
     | 
| 585 | 
         
            +
                                revision=revision,
         
     | 
| 586 | 
         
            +
                                subfolder=subfolder,
         
     | 
| 587 | 
         
            +
                                user_agent=user_agent,
         
     | 
| 588 | 
         
            +
                                commit_hash=commit_hash,
         
     | 
| 589 | 
         
            +
                            )
         
     | 
| 590 | 
         
            +
             
     | 
| 591 | 
         
            +
                        if low_cpu_mem_usage:
         
     | 
| 592 | 
         
            +
                            # Instantiate model with empty weights
         
     | 
| 593 | 
         
            +
                            with accelerate.init_empty_weights():
         
     | 
| 594 | 
         
            +
                                model = cls.from_config(config, **unused_kwargs)
         
     | 
| 595 | 
         
            +
             
     | 
| 596 | 
         
            +
                            # if device_map is None, load the state dict and move the params from meta device to the cpu
         
     | 
| 597 | 
         
            +
                            if device_map is None:
         
     | 
| 598 | 
         
            +
                                param_device = "cpu"
         
     | 
| 599 | 
         
            +
                                state_dict = load_state_dict(model_file, variant=variant)
         
     | 
| 600 | 
         
            +
                                model._convert_deprecated_attention_blocks(state_dict)
         
     | 
| 601 | 
         
            +
                                # move the params from meta device to cpu
         
     | 
| 602 | 
         
            +
                                missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
         
     | 
| 603 | 
         
            +
                                if len(missing_keys) > 0:
         
     | 
| 604 | 
         
            +
                                    raise ValueError(
         
     | 
| 605 | 
         
            +
                                        f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
         
     | 
| 606 | 
         
            +
                                        f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
         
     | 
| 607 | 
         
            +
                                        " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
         
     | 
| 608 | 
         
            +
                                        " those weights or else make sure your checkpoint file is correct."
         
     | 
| 609 | 
         
            +
                                    )
         
     | 
| 610 | 
         
            +
                                unexpected_keys = []
         
     | 
| 611 | 
         
            +
             
     | 
| 612 | 
         
            +
                                empty_state_dict = model.state_dict()
         
     | 
| 613 | 
         
            +
                                for param_name, param in state_dict.items():
         
     | 
| 614 | 
         
            +
                                    accepts_dtype = "dtype" in set(
         
     | 
| 615 | 
         
            +
                                        inspect.signature(set_module_tensor_to_device).parameters.keys()
         
     | 
| 616 | 
         
            +
                                    )
         
     | 
| 617 | 
         
            +
             
     | 
| 618 | 
         
            +
                                    if param_name not in empty_state_dict:
         
     | 
| 619 | 
         
            +
                                        unexpected_keys.append(param_name)
         
     | 
| 620 | 
         
            +
                                        continue
         
     | 
| 621 | 
         
            +
             
     | 
| 622 | 
         
            +
                                    if empty_state_dict[param_name].shape != param.shape:
         
     | 
| 623 | 
         
            +
                                        raise ValueError(
         
     | 
| 624 | 
         
            +
                                            f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
         
     | 
| 625 | 
         
            +
                                        )
         
     | 
| 626 | 
         
            +
             
     | 
| 627 | 
         
            +
                                    if accepts_dtype:
         
     | 
| 628 | 
         
            +
                                        set_module_tensor_to_device(
         
     | 
| 629 | 
         
            +
                                            model, param_name, param_device, value=param, dtype=torch_dtype
         
     | 
| 630 | 
         
            +
                                        )
         
     | 
| 631 | 
         
            +
                                    else:
         
     | 
| 632 | 
         
            +
                                        set_module_tensor_to_device(model, param_name, param_device, value=param)
         
     | 
| 633 | 
         
            +
             
     | 
| 634 | 
         
            +
                                if cls._keys_to_ignore_on_load_unexpected is not None:
         
     | 
| 635 | 
         
            +
                                    for pat in cls._keys_to_ignore_on_load_unexpected:
         
     | 
| 636 | 
         
            +
                                        unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
         
     | 
| 637 | 
         
            +
             
     | 
| 638 | 
         
            +
                                if len(unexpected_keys) > 0:
         
     | 
| 639 | 
         
            +
                                    logger.warn(
         
     | 
| 640 | 
         
            +
                                        f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
         
     | 
| 641 | 
         
            +
                                    )
         
     | 
| 642 | 
         
            +
             
     | 
| 643 | 
         
            +
                            else:  # else let accelerate handle loading and dispatching.
         
     | 
| 644 | 
         
            +
                                # Load weights and dispatch according to the device_map
         
     | 
| 645 | 
         
            +
                                # by default the device_map is None and the weights are loaded on the CPU
         
     | 
| 646 | 
         
            +
                                try:
         
     | 
| 647 | 
         
            +
                                    accelerate.load_checkpoint_and_dispatch(
         
     | 
| 648 | 
         
            +
                                        model,
         
     | 
| 649 | 
         
            +
                                        model_file,
         
     | 
| 650 | 
         
            +
                                        device_map,
         
     | 
| 651 | 
         
            +
                                        max_memory=max_memory,
         
     | 
| 652 | 
         
            +
                                        offload_folder=offload_folder,
         
     | 
| 653 | 
         
            +
                                        offload_state_dict=offload_state_dict,
         
     | 
| 654 | 
         
            +
                                        dtype=torch_dtype,
         
     | 
| 655 | 
         
            +
                                    )
         
     | 
| 656 | 
         
            +
                                except AttributeError as e:
         
     | 
| 657 | 
         
            +
                                    # When using accelerate loading, we do not have the ability to load the state
         
     | 
| 658 | 
         
            +
                                    # dict and rename the weight names manually. Additionally, accelerate skips
         
     | 
| 659 | 
         
            +
                                    # torch loading conventions and directly writes into `module.{_buffers, _parameters}`
         
     | 
| 660 | 
         
            +
                                    # (which look like they should be private variables?), so we can't use the standard hooks
         
     | 
| 661 | 
         
            +
                                    # to rename parameters on load. We need to mimic the original weight names so the correct
         
     | 
| 662 | 
         
            +
                                    # attributes are available. After we have loaded the weights, we convert the deprecated
         
     | 
| 663 | 
         
            +
                                    # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
         
     | 
| 664 | 
         
            +
                                    # the weights so we don't have to do this again.
         
     | 
| 665 | 
         
            +
             
     | 
| 666 | 
         
            +
                                    if "'Attention' object has no attribute" in str(e):
         
     | 
| 667 | 
         
            +
                                        logger.warn(
         
     | 
| 668 | 
         
            +
                                            f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
         
     | 
| 669 | 
         
            +
                                            " was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
         
     | 
| 670 | 
         
            +
                                            " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
         
     | 
| 671 | 
         
            +
                                            " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
         
     | 
| 672 | 
         
            +
                                            " please also re-upload it or open a PR on the original repository."
         
     | 
| 673 | 
         
            +
                                        )
         
     | 
| 674 | 
         
            +
                                        model._temp_convert_self_to_deprecated_attention_blocks()
         
     | 
| 675 | 
         
            +
                                        accelerate.load_checkpoint_and_dispatch(
         
     | 
| 676 | 
         
            +
                                            model,
         
     | 
| 677 | 
         
            +
                                            model_file,
         
     | 
| 678 | 
         
            +
                                            device_map,
         
     | 
| 679 | 
         
            +
                                            max_memory=max_memory,
         
     | 
| 680 | 
         
            +
                                            offload_folder=offload_folder,
         
     | 
| 681 | 
         
            +
                                            offload_state_dict=offload_state_dict,
         
     | 
| 682 | 
         
            +
                                            dtype=torch_dtype,
         
     | 
| 683 | 
         
            +
                                        )
         
     | 
| 684 | 
         
            +
                                        model._undo_temp_convert_self_to_deprecated_attention_blocks()
         
     | 
| 685 | 
         
            +
                                    else:
         
     | 
| 686 | 
         
            +
                                        raise e
         
     | 
| 687 | 
         
            +
             
     | 
| 688 | 
         
            +
                            loading_info = {
         
     | 
| 689 | 
         
            +
                                "missing_keys": [],
         
     | 
| 690 | 
         
            +
                                "unexpected_keys": [],
         
     | 
| 691 | 
         
            +
                                "mismatched_keys": [],
         
     | 
| 692 | 
         
            +
                                "error_msgs": [],
         
     | 
| 693 | 
         
            +
                            }
         
     | 
| 694 | 
         
            +
                        else:
         
     | 
| 695 | 
         
            +
                            model = cls.from_config(config, **unused_kwargs)
         
     | 
| 696 | 
         
            +
             
     | 
| 697 | 
         
            +
                            state_dict = load_state_dict(model_file, variant=variant)
         
     | 
| 698 | 
         
            +
                            model._convert_deprecated_attention_blocks(state_dict)
         
     | 
| 699 | 
         
            +
             
     | 
| 700 | 
         
            +
                            model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
         
     | 
| 701 | 
         
            +
                                model,
         
     | 
| 702 | 
         
            +
                                state_dict,
         
     | 
| 703 | 
         
            +
                                model_file,
         
     | 
| 704 | 
         
            +
                                pretrained_model_name_or_path,
         
     | 
| 705 | 
         
            +
                                ignore_mismatched_sizes=ignore_mismatched_sizes,
         
     | 
| 706 | 
         
            +
                            )
         
     | 
| 707 | 
         
            +
             
     | 
| 708 | 
         
            +
                            loading_info = {
         
     | 
| 709 | 
         
            +
                                "missing_keys": missing_keys,
         
     | 
| 710 | 
         
            +
                                "unexpected_keys": unexpected_keys,
         
     | 
| 711 | 
         
            +
                                "mismatched_keys": mismatched_keys,
         
     | 
| 712 | 
         
            +
                                "error_msgs": error_msgs,
         
     | 
| 713 | 
         
            +
                            }
         
     | 
| 714 | 
         
            +
             
     | 
| 715 | 
         
            +
                    if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
         
     | 
| 716 | 
         
            +
                        raise ValueError(
         
     | 
| 717 | 
         
            +
                            f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
         
     | 
| 718 | 
         
            +
                        )
         
     | 
| 719 | 
         
            +
                    elif torch_dtype is not None:
         
     | 
| 720 | 
         
            +
                        model = model.to(torch_dtype)
         
     | 
| 721 | 
         
            +
             
     | 
| 722 | 
         
            +
                    model.register_to_config(_name_or_path=pretrained_model_name_or_path)
         
     | 
| 723 | 
         
            +
             
     | 
| 724 | 
         
            +
                    # Set model in evaluation mode to deactivate DropOut modules by default
         
     | 
| 725 | 
         
            +
                    model.eval()
         
     | 
| 726 | 
         
            +
                    if output_loading_info:
         
     | 
| 727 | 
         
            +
                        return model, loading_info
         
     | 
| 728 | 
         
            +
             
     | 
| 729 | 
         
            +
                    return model
         
     | 
| 730 | 
         
            +
             
     | 
| 731 | 
         
            +
                @classmethod
         
     | 
| 732 | 
         
            +
                def _load_pretrained_model(
         
     | 
| 733 | 
         
            +
                    cls,
         
     | 
| 734 | 
         
            +
                    model,
         
     | 
| 735 | 
         
            +
                    state_dict,
         
     | 
| 736 | 
         
            +
                    resolved_archive_file,
         
     | 
| 737 | 
         
            +
                    pretrained_model_name_or_path,
         
     | 
| 738 | 
         
            +
                    ignore_mismatched_sizes=False,
         
     | 
| 739 | 
         
            +
                ):
         
     | 
| 740 | 
         
            +
                    # Retrieve missing & unexpected_keys
         
     | 
| 741 | 
         
            +
                    model_state_dict = model.state_dict()
         
     | 
| 742 | 
         
            +
                    loaded_keys = list(state_dict.keys())
         
     | 
| 743 | 
         
            +
             
     | 
| 744 | 
         
            +
                    expected_keys = list(model_state_dict.keys())
         
     | 
| 745 | 
         
            +
             
     | 
| 746 | 
         
            +
                    original_loaded_keys = loaded_keys
         
     | 
| 747 | 
         
            +
             
     | 
| 748 | 
         
            +
                    missing_keys = list(set(expected_keys) - set(loaded_keys))
         
     | 
| 749 | 
         
            +
                    unexpected_keys = list(set(loaded_keys) - set(expected_keys))
         
     | 
| 750 | 
         
            +
             
     | 
| 751 | 
         
            +
                    # Make sure we are able to load base models as well as derived models (with heads)
         
     | 
| 752 | 
         
            +
                    model_to_load = model
         
     | 
| 753 | 
         
            +
             
     | 
| 754 | 
         
            +
                    def _find_mismatched_keys(
         
     | 
| 755 | 
         
            +
                        state_dict,
         
     | 
| 756 | 
         
            +
                        model_state_dict,
         
     | 
| 757 | 
         
            +
                        loaded_keys,
         
     | 
| 758 | 
         
            +
                        ignore_mismatched_sizes,
         
     | 
| 759 | 
         
            +
                    ):
         
     | 
| 760 | 
         
            +
                        mismatched_keys = []
         
     | 
| 761 | 
         
            +
                        if ignore_mismatched_sizes:
         
     | 
| 762 | 
         
            +
                            for checkpoint_key in loaded_keys:
         
     | 
| 763 | 
         
            +
                                model_key = checkpoint_key
         
     | 
| 764 | 
         
            +
             
     | 
| 765 | 
         
            +
                                if (
         
     | 
| 766 | 
         
            +
                                    model_key in model_state_dict
         
     | 
| 767 | 
         
            +
                                    and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
         
     | 
| 768 | 
         
            +
                                ):
         
     | 
| 769 | 
         
            +
                                    mismatched_keys.append(
         
     | 
| 770 | 
         
            +
                                        (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
         
     | 
| 771 | 
         
            +
                                    )
         
     | 
| 772 | 
         
            +
                                    del state_dict[checkpoint_key]
         
     | 
| 773 | 
         
            +
                        return mismatched_keys
         
     | 
| 774 | 
         
            +
             
     | 
| 775 | 
         
            +
                    if state_dict is not None:
         
     | 
| 776 | 
         
            +
                        # Whole checkpoint
         
     | 
| 777 | 
         
            +
                        mismatched_keys = _find_mismatched_keys(
         
     | 
| 778 | 
         
            +
                            state_dict,
         
     | 
| 779 | 
         
            +
                            model_state_dict,
         
     | 
| 780 | 
         
            +
                            original_loaded_keys,
         
     | 
| 781 | 
         
            +
                            ignore_mismatched_sizes,
         
     | 
| 782 | 
         
            +
                        )
         
     | 
| 783 | 
         
            +
                        error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
         
     | 
| 784 | 
         
            +
             
     | 
| 785 | 
         
            +
                    if len(error_msgs) > 0:
         
     | 
| 786 | 
         
            +
                        error_msg = "\n\t".join(error_msgs)
         
     | 
| 787 | 
         
            +
                        if "size mismatch" in error_msg:
         
     | 
| 788 | 
         
            +
                            error_msg += (
         
     | 
| 789 | 
         
            +
                                "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
         
     | 
| 790 | 
         
            +
                            )
         
     | 
| 791 | 
         
            +
                        raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
         
     | 
| 792 | 
         
            +
             
     | 
| 793 | 
         
            +
                    if len(unexpected_keys) > 0:
         
     | 
| 794 | 
         
            +
                        logger.warning(
         
     | 
| 795 | 
         
            +
                            f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
         
     | 
| 796 | 
         
            +
                            f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
         
     | 
| 797 | 
         
            +
                            f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
         
     | 
| 798 | 
         
            +
                            " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
         
     | 
| 799 | 
         
            +
                            " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
         
     | 
| 800 | 
         
            +
                            f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
         
     | 
| 801 | 
         
            +
                            " identical (initializing a BertForSequenceClassification model from a"
         
     | 
| 802 | 
         
            +
                            " BertForSequenceClassification model)."
         
     | 
| 803 | 
         
            +
                        )
         
     | 
| 804 | 
         
            +
                    else:
         
     | 
| 805 | 
         
            +
                        logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
         
     | 
| 806 | 
         
            +
                    if len(missing_keys) > 0:
         
     | 
| 807 | 
         
            +
                        logger.warning(
         
     | 
| 808 | 
         
            +
                            f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
         
     | 
| 809 | 
         
            +
                            f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
         
     | 
| 810 | 
         
            +
                            " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
         
     | 
| 811 | 
         
            +
                        )
         
     | 
| 812 | 
         
            +
                    elif len(mismatched_keys) == 0:
         
     | 
| 813 | 
         
            +
                        logger.info(
         
     | 
| 814 | 
         
            +
                            f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
         
     | 
| 815 | 
         
            +
                            f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
         
     | 
| 816 | 
         
            +
                            f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
         
     | 
| 817 | 
         
            +
                            " without further training."
         
     | 
| 818 | 
         
            +
                        )
         
     | 
| 819 | 
         
            +
                    if len(mismatched_keys) > 0:
         
     | 
| 820 | 
         
            +
                        mismatched_warning = "\n".join(
         
     | 
| 821 | 
         
            +
                            [
         
     | 
| 822 | 
         
            +
                                f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
         
     | 
| 823 | 
         
            +
                                for key, shape1, shape2 in mismatched_keys
         
     | 
| 824 | 
         
            +
                            ]
         
     | 
| 825 | 
         
            +
                        )
         
     | 
| 826 | 
         
            +
                        logger.warning(
         
     | 
| 827 | 
         
            +
                            f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
         
     | 
| 828 | 
         
            +
                            f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
         
     | 
| 829 | 
         
            +
                            f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
         
     | 
| 830 | 
         
            +
                            " able to use it for predictions and inference."
         
     | 
| 831 | 
         
            +
                        )
         
     | 
| 832 | 
         
            +
             
     | 
| 833 | 
         
            +
                    return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
         
     | 
| 834 | 
         
            +
             
     | 
| 835 | 
         
            +
                @property
         
     | 
| 836 | 
         
            +
                def device(self) -> device:
         
     | 
| 837 | 
         
            +
                    """
         
     | 
| 838 | 
         
            +
                    `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
         
     | 
| 839 | 
         
            +
                    device).
         
     | 
| 840 | 
         
            +
                    """
         
     | 
| 841 | 
         
            +
                    return get_parameter_device(self)
         
     | 
| 842 | 
         
            +
             
     | 
| 843 | 
         
            +
                @property
         
     | 
| 844 | 
         
            +
                def dtype(self) -> torch.dtype:
         
     | 
| 845 | 
         
            +
                    """
         
     | 
| 846 | 
         
            +
                    `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
         
     | 
| 847 | 
         
            +
                    """
         
     | 
| 848 | 
         
            +
                    return get_parameter_dtype(self)
         
     | 
| 849 | 
         
            +
             
     | 
| 850 | 
         
            +
                def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
         
     | 
| 851 | 
         
            +
                    """
         
     | 
| 852 | 
         
            +
                    Get number of (trainable or non-embedding) parameters in the module.
         
     | 
| 853 | 
         
            +
             
     | 
| 854 | 
         
            +
                    Args:
         
     | 
| 855 | 
         
            +
                        only_trainable (`bool`, *optional*, defaults to `False`):
         
     | 
| 856 | 
         
            +
                            Whether or not to return only the number of trainable parameters.
         
     | 
| 857 | 
         
            +
                        exclude_embeddings (`bool`, *optional*, defaults to `False`):
         
     | 
| 858 | 
         
            +
                            Whether or not to return only the number of non-embedding parameters.
         
     | 
| 859 | 
         
            +
             
     | 
| 860 | 
         
            +
                    Returns:
         
     | 
| 861 | 
         
            +
                        `int`: The number of parameters.
         
     | 
| 862 | 
         
            +
             
     | 
| 863 | 
         
            +
                    Example:
         
     | 
| 864 | 
         
            +
             
     | 
| 865 | 
         
            +
                    ```py
         
     | 
| 866 | 
         
            +
                    from diffusers import UNet2DConditionModel
         
     | 
| 867 | 
         
            +
             
     | 
| 868 | 
         
            +
                    model_id = "runwayml/stable-diffusion-v1-5"
         
     | 
| 869 | 
         
            +
                    unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
         
     | 
| 870 | 
         
            +
                    unet.num_parameters(only_trainable=True)
         
     | 
| 871 | 
         
            +
                    859520964
         
     | 
| 872 | 
         
            +
                    ```
         
     | 
| 873 | 
         
            +
                    """
         
     | 
| 874 | 
         
            +
             
     | 
| 875 | 
         
            +
                    if exclude_embeddings:
         
     | 
| 876 | 
         
            +
                        embedding_param_names = [
         
     | 
| 877 | 
         
            +
                            f"{name}.weight"
         
     | 
| 878 | 
         
            +
                            for name, module_type in self.named_modules()
         
     | 
| 879 | 
         
            +
                            if isinstance(module_type, torch.nn.Embedding)
         
     | 
| 880 | 
         
            +
                        ]
         
     | 
| 881 | 
         
            +
                        non_embedding_parameters = [
         
     | 
| 882 | 
         
            +
                            parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
         
     | 
| 883 | 
         
            +
                        ]
         
     | 
| 884 | 
         
            +
                        return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
         
     | 
| 885 | 
         
            +
                    else:
         
     | 
| 886 | 
         
            +
                        return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
         
     | 
| 887 | 
         
            +
             
     | 
| 888 | 
         
            +
                def _convert_deprecated_attention_blocks(self, state_dict):
         
     | 
| 889 | 
         
            +
                    deprecated_attention_block_paths = []
         
     | 
| 890 | 
         
            +
             
     | 
| 891 | 
         
            +
                    def recursive_find_attn_block(name, module):
         
     | 
| 892 | 
         
            +
                        if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
         
     | 
| 893 | 
         
            +
                            deprecated_attention_block_paths.append(name)
         
     | 
| 894 | 
         
            +
             
     | 
| 895 | 
         
            +
                        for sub_name, sub_module in module.named_children():
         
     | 
| 896 | 
         
            +
                            sub_name = sub_name if name == "" else f"{name}.{sub_name}"
         
     | 
| 897 | 
         
            +
                            recursive_find_attn_block(sub_name, sub_module)
         
     | 
| 898 | 
         
            +
             
     | 
| 899 | 
         
            +
                    recursive_find_attn_block("", self)
         
     | 
| 900 | 
         
            +
             
     | 
| 901 | 
         
            +
                    # NOTE: we have to check if the deprecated parameters are in the state dict
         
     | 
| 902 | 
         
            +
                    # because it is possible we are loading from a state dict that was already
         
     | 
| 903 | 
         
            +
                    # converted
         
     | 
| 904 | 
         
            +
             
     | 
| 905 | 
         
            +
                    for path in deprecated_attention_block_paths:
         
     | 
| 906 | 
         
            +
                        # group_norm path stays the same
         
     | 
| 907 | 
         
            +
             
     | 
| 908 | 
         
            +
                        # query -> to_q
         
     | 
| 909 | 
         
            +
                        if f"{path}.query.weight" in state_dict:
         
     | 
| 910 | 
         
            +
                            state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
         
     | 
| 911 | 
         
            +
                        if f"{path}.query.bias" in state_dict:
         
     | 
| 912 | 
         
            +
                            state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
         
     | 
| 913 | 
         
            +
             
     | 
| 914 | 
         
            +
                        # key -> to_k
         
     | 
| 915 | 
         
            +
                        if f"{path}.key.weight" in state_dict:
         
     | 
| 916 | 
         
            +
                            state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
         
     | 
| 917 | 
         
            +
                        if f"{path}.key.bias" in state_dict:
         
     | 
| 918 | 
         
            +
                            state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
         
     | 
| 919 | 
         
            +
             
     | 
| 920 | 
         
            +
                        # value -> to_v
         
     | 
| 921 | 
         
            +
                        if f"{path}.value.weight" in state_dict:
         
     | 
| 922 | 
         
            +
                            state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
         
     | 
| 923 | 
         
            +
                        if f"{path}.value.bias" in state_dict:
         
     | 
| 924 | 
         
            +
                            state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
         
     | 
| 925 | 
         
            +
             
     | 
| 926 | 
         
            +
                        # proj_attn -> to_out.0
         
     | 
| 927 | 
         
            +
                        if f"{path}.proj_attn.weight" in state_dict:
         
     | 
| 928 | 
         
            +
                            state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
         
     | 
| 929 | 
         
            +
                        if f"{path}.proj_attn.bias" in state_dict:
         
     | 
| 930 | 
         
            +
                            state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
         
     | 
| 931 | 
         
            +
             
     | 
| 932 | 
         
            +
                def _temp_convert_self_to_deprecated_attention_blocks(self):
         
     | 
| 933 | 
         
            +
                    deprecated_attention_block_modules = []
         
     | 
| 934 | 
         
            +
             
     | 
| 935 | 
         
            +
                    def recursive_find_attn_block(module):
         
     | 
| 936 | 
         
            +
                        if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
         
     | 
| 937 | 
         
            +
                            deprecated_attention_block_modules.append(module)
         
     | 
| 938 | 
         
            +
             
     | 
| 939 | 
         
            +
                        for sub_module in module.children():
         
     | 
| 940 | 
         
            +
                            recursive_find_attn_block(sub_module)
         
     | 
| 941 | 
         
            +
             
     | 
| 942 | 
         
            +
                    recursive_find_attn_block(self)
         
     | 
| 943 | 
         
            +
             
     | 
| 944 | 
         
            +
                    for module in deprecated_attention_block_modules:
         
     | 
| 945 | 
         
            +
                        module.query = module.to_q
         
     | 
| 946 | 
         
            +
                        module.key = module.to_k
         
     | 
| 947 | 
         
            +
                        module.value = module.to_v
         
     | 
| 948 | 
         
            +
                        module.proj_attn = module.to_out[0]
         
     | 
| 949 | 
         
            +
             
     | 
| 950 | 
         
            +
                        # We don't _have_ to delete the old attributes, but it's helpful to ensure
         
     | 
| 951 | 
         
            +
                        # that _all_ the weights are loaded into the new attributes and we're not
         
     | 
| 952 | 
         
            +
                        # making an incorrect assumption that this model should be converted when
         
     | 
| 953 | 
         
            +
                        # it really shouldn't be.
         
     | 
| 954 | 
         
            +
                        del module.to_q
         
     | 
| 955 | 
         
            +
                        del module.to_k
         
     | 
| 956 | 
         
            +
                        del module.to_v
         
     | 
| 957 | 
         
            +
                        del module.to_out
         
     | 
| 958 | 
         
            +
             
     | 
| 959 | 
         
            +
                def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
         
     | 
| 960 | 
         
            +
                    deprecated_attention_block_modules = []
         
     | 
| 961 | 
         
            +
             
     | 
| 962 | 
         
            +
                    def recursive_find_attn_block(module):
         
     | 
| 963 | 
         
            +
                        if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
         
     | 
| 964 | 
         
            +
                            deprecated_attention_block_modules.append(module)
         
     | 
| 965 | 
         
            +
             
     | 
| 966 | 
         
            +
                        for sub_module in module.children():
         
     | 
| 967 | 
         
            +
                            recursive_find_attn_block(sub_module)
         
     | 
| 968 | 
         
            +
             
     | 
| 969 | 
         
            +
                    recursive_find_attn_block(self)
         
     | 
| 970 | 
         
            +
             
     | 
| 971 | 
         
            +
                    for module in deprecated_attention_block_modules:
         
     | 
| 972 | 
         
            +
                        module.to_q = module.query
         
     | 
| 973 | 
         
            +
                        module.to_k = module.key
         
     | 
| 974 | 
         
            +
                        module.to_v = module.value
         
     | 
| 975 | 
         
            +
                        module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
         
     | 
| 976 | 
         
            +
             
     | 
| 977 | 
         
            +
                        del module.query
         
     | 
| 978 | 
         
            +
                        del module.key
         
     | 
| 979 | 
         
            +
                        del module.value
         
     | 
| 980 | 
         
            +
                        del module.proj_attn
         
     | 
    	
        6DoF/diffusers/models/prior_transformer.py
    ADDED
    
    | 
         @@ -0,0 +1,364 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 2 | 
         
            +
            from typing import Dict, Optional, Union
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 6 | 
         
            +
            from torch import nn
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from ..configuration_utils import ConfigMixin, register_to_config
         
     | 
| 9 | 
         
            +
            from ..utils import BaseOutput
         
     | 
| 10 | 
         
            +
            from .attention import BasicTransformerBlock
         
     | 
| 11 | 
         
            +
            from .attention_processor import AttentionProcessor, AttnProcessor
         
     | 
| 12 | 
         
            +
            from .embeddings import TimestepEmbedding, Timesteps
         
     | 
| 13 | 
         
            +
            from .modeling_utils import ModelMixin
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            @dataclass
         
     | 
| 17 | 
         
            +
            class PriorTransformerOutput(BaseOutput):
         
     | 
| 18 | 
         
            +
                """
         
     | 
| 19 | 
         
            +
                The output of [`PriorTransformer`].
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                Args:
         
     | 
| 22 | 
         
            +
                    predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
         
     | 
| 23 | 
         
            +
                        The predicted CLIP image embedding conditioned on the CLIP text embedding input.
         
     | 
| 24 | 
         
            +
                """
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                predicted_image_embedding: torch.FloatTensor
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            class PriorTransformer(ModelMixin, ConfigMixin):
         
     | 
| 30 | 
         
            +
                """
         
     | 
| 31 | 
         
            +
                A Prior Transformer model.
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                Parameters:
         
     | 
| 34 | 
         
            +
                    num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
         
     | 
| 35 | 
         
            +
                    attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
         
     | 
| 36 | 
         
            +
                    num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
         
     | 
| 37 | 
         
            +
                    embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
         
     | 
| 38 | 
         
            +
                    num_embeddings (`int`, *optional*, defaults to 77):
         
     | 
| 39 | 
         
            +
                        The number of embeddings of the model input `hidden_states`
         
     | 
| 40 | 
         
            +
                    additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
         
     | 
| 41 | 
         
            +
                        projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
         
     | 
| 42 | 
         
            +
                        additional_embeddings`.
         
     | 
| 43 | 
         
            +
                    dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
         
     | 
| 44 | 
         
            +
                    time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
         
     | 
| 45 | 
         
            +
                        The activation function to use to create timestep embeddings.
         
     | 
| 46 | 
         
            +
                    norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
         
     | 
| 47 | 
         
            +
                        passing to Transformer blocks. Set it to `None` if normalization is not needed.
         
     | 
| 48 | 
         
            +
                    embedding_proj_norm_type (`str`, *optional*, defaults to None):
         
     | 
| 49 | 
         
            +
                        The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
         
     | 
| 50 | 
         
            +
                        needed.
         
     | 
| 51 | 
         
            +
                    encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
         
     | 
| 52 | 
         
            +
                        The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
         
     | 
| 53 | 
         
            +
                        `encoder_hidden_states` is `None`.
         
     | 
| 54 | 
         
            +
                    added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
         
     | 
| 55 | 
         
            +
                        Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
         
     | 
| 56 | 
         
            +
                        product between the text embedding and image embedding as proposed in the unclip paper
         
     | 
| 57 | 
         
            +
                        https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
         
     | 
| 58 | 
         
            +
                    time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
         
     | 
| 59 | 
         
            +
                        If None, will be set to `num_attention_heads * attention_head_dim`
         
     | 
| 60 | 
         
            +
                    embedding_proj_dim (`int`, *optional*, default to None):
         
     | 
| 61 | 
         
            +
                        The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
         
     | 
| 62 | 
         
            +
                    clip_embed_dim (`int`, *optional*, default to None):
         
     | 
| 63 | 
         
            +
                        The dimension of the output. If None, will be set to `embedding_dim`.
         
     | 
| 64 | 
         
            +
                """
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                @register_to_config
         
     | 
| 67 | 
         
            +
                def __init__(
         
     | 
| 68 | 
         
            +
                    self,
         
     | 
| 69 | 
         
            +
                    num_attention_heads: int = 32,
         
     | 
| 70 | 
         
            +
                    attention_head_dim: int = 64,
         
     | 
| 71 | 
         
            +
                    num_layers: int = 20,
         
     | 
| 72 | 
         
            +
                    embedding_dim: int = 768,
         
     | 
| 73 | 
         
            +
                    num_embeddings=77,
         
     | 
| 74 | 
         
            +
                    additional_embeddings=4,
         
     | 
| 75 | 
         
            +
                    dropout: float = 0.0,
         
     | 
| 76 | 
         
            +
                    time_embed_act_fn: str = "silu",
         
     | 
| 77 | 
         
            +
                    norm_in_type: Optional[str] = None,  # layer
         
     | 
| 78 | 
         
            +
                    embedding_proj_norm_type: Optional[str] = None,  # layer
         
     | 
| 79 | 
         
            +
                    encoder_hid_proj_type: Optional[str] = "linear",  # linear
         
     | 
| 80 | 
         
            +
                    added_emb_type: Optional[str] = "prd",  # prd
         
     | 
| 81 | 
         
            +
                    time_embed_dim: Optional[int] = None,
         
     | 
| 82 | 
         
            +
                    embedding_proj_dim: Optional[int] = None,
         
     | 
| 83 | 
         
            +
                    clip_embed_dim: Optional[int] = None,
         
     | 
| 84 | 
         
            +
                ):
         
     | 
| 85 | 
         
            +
                    super().__init__()
         
     | 
| 86 | 
         
            +
                    self.num_attention_heads = num_attention_heads
         
     | 
| 87 | 
         
            +
                    self.attention_head_dim = attention_head_dim
         
     | 
| 88 | 
         
            +
                    inner_dim = num_attention_heads * attention_head_dim
         
     | 
| 89 | 
         
            +
                    self.additional_embeddings = additional_embeddings
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                    time_embed_dim = time_embed_dim or inner_dim
         
     | 
| 92 | 
         
            +
                    embedding_proj_dim = embedding_proj_dim or embedding_dim
         
     | 
| 93 | 
         
            +
                    clip_embed_dim = clip_embed_dim or embedding_dim
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                    self.time_proj = Timesteps(inner_dim, True, 0)
         
     | 
| 96 | 
         
            +
                    self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                    self.proj_in = nn.Linear(embedding_dim, inner_dim)
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                    if embedding_proj_norm_type is None:
         
     | 
| 101 | 
         
            +
                        self.embedding_proj_norm = None
         
     | 
| 102 | 
         
            +
                    elif embedding_proj_norm_type == "layer":
         
     | 
| 103 | 
         
            +
                        self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
         
     | 
| 104 | 
         
            +
                    else:
         
     | 
| 105 | 
         
            +
                        raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                    self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                    if encoder_hid_proj_type is None:
         
     | 
| 110 | 
         
            +
                        self.encoder_hidden_states_proj = None
         
     | 
| 111 | 
         
            +
                    elif encoder_hid_proj_type == "linear":
         
     | 
| 112 | 
         
            +
                        self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
         
     | 
| 113 | 
         
            +
                    else:
         
     | 
| 114 | 
         
            +
                        raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                    self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                    if added_emb_type == "prd":
         
     | 
| 119 | 
         
            +
                        self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
         
     | 
| 120 | 
         
            +
                    elif added_emb_type is None:
         
     | 
| 121 | 
         
            +
                        self.prd_embedding = None
         
     | 
| 122 | 
         
            +
                    else:
         
     | 
| 123 | 
         
            +
                        raise ValueError(
         
     | 
| 124 | 
         
            +
                            f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
         
     | 
| 125 | 
         
            +
                        )
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    self.transformer_blocks = nn.ModuleList(
         
     | 
| 128 | 
         
            +
                        [
         
     | 
| 129 | 
         
            +
                            BasicTransformerBlock(
         
     | 
| 130 | 
         
            +
                                inner_dim,
         
     | 
| 131 | 
         
            +
                                num_attention_heads,
         
     | 
| 132 | 
         
            +
                                attention_head_dim,
         
     | 
| 133 | 
         
            +
                                dropout=dropout,
         
     | 
| 134 | 
         
            +
                                activation_fn="gelu",
         
     | 
| 135 | 
         
            +
                                attention_bias=True,
         
     | 
| 136 | 
         
            +
                            )
         
     | 
| 137 | 
         
            +
                            for d in range(num_layers)
         
     | 
| 138 | 
         
            +
                        ]
         
     | 
| 139 | 
         
            +
                    )
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                    if norm_in_type == "layer":
         
     | 
| 142 | 
         
            +
                        self.norm_in = nn.LayerNorm(inner_dim)
         
     | 
| 143 | 
         
            +
                    elif norm_in_type is None:
         
     | 
| 144 | 
         
            +
                        self.norm_in = None
         
     | 
| 145 | 
         
            +
                    else:
         
     | 
| 146 | 
         
            +
                        raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                    self.norm_out = nn.LayerNorm(inner_dim)
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                    self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                    causal_attention_mask = torch.full(
         
     | 
| 153 | 
         
            +
                        [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
         
     | 
| 154 | 
         
            +
                    )
         
     | 
| 155 | 
         
            +
                    causal_attention_mask.triu_(1)
         
     | 
| 156 | 
         
            +
                    causal_attention_mask = causal_attention_mask[None, ...]
         
     | 
| 157 | 
         
            +
                    self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                    self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
         
     | 
| 160 | 
         
            +
                    self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                @property
         
     | 
| 163 | 
         
            +
                # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
         
     | 
| 164 | 
         
            +
                def attn_processors(self) -> Dict[str, AttentionProcessor]:
         
     | 
| 165 | 
         
            +
                    r"""
         
     | 
| 166 | 
         
            +
                    Returns:
         
     | 
| 167 | 
         
            +
                        `dict` of attention processors: A dictionary containing all attention processors used in the model with
         
     | 
| 168 | 
         
            +
                        indexed by its weight name.
         
     | 
| 169 | 
         
            +
                    """
         
     | 
| 170 | 
         
            +
                    # set recursively
         
     | 
| 171 | 
         
            +
                    processors = {}
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                    def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
         
     | 
| 174 | 
         
            +
                        if hasattr(module, "set_processor"):
         
     | 
| 175 | 
         
            +
                            processors[f"{name}.processor"] = module.processor
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                        for sub_name, child in module.named_children():
         
     | 
| 178 | 
         
            +
                            fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                        return processors
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                    for name, module in self.named_children():
         
     | 
| 183 | 
         
            +
                        fn_recursive_add_processors(name, module, processors)
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                    return processors
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
         
     | 
| 188 | 
         
            +
                def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
         
     | 
| 189 | 
         
            +
                    r"""
         
     | 
| 190 | 
         
            +
                    Sets the attention processor to use to compute attention.
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                    Parameters:
         
     | 
| 193 | 
         
            +
                        processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
         
     | 
| 194 | 
         
            +
                            The instantiated processor class or a dictionary of processor classes that will be set as the processor
         
     | 
| 195 | 
         
            +
                            for **all** `Attention` layers.
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                            If `processor` is a dict, the key needs to define the path to the corresponding cross attention
         
     | 
| 198 | 
         
            +
                            processor. This is strongly recommended when setting trainable attention processors.
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                    """
         
     | 
| 201 | 
         
            +
                    count = len(self.attn_processors.keys())
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
                    if isinstance(processor, dict) and len(processor) != count:
         
     | 
| 204 | 
         
            +
                        raise ValueError(
         
     | 
| 205 | 
         
            +
                            f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
         
     | 
| 206 | 
         
            +
                            f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
         
     | 
| 207 | 
         
            +
                        )
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
                    def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
         
     | 
| 210 | 
         
            +
                        if hasattr(module, "set_processor"):
         
     | 
| 211 | 
         
            +
                            if not isinstance(processor, dict):
         
     | 
| 212 | 
         
            +
                                module.set_processor(processor)
         
     | 
| 213 | 
         
            +
                            else:
         
     | 
| 214 | 
         
            +
                                module.set_processor(processor.pop(f"{name}.processor"))
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                        for sub_name, child in module.named_children():
         
     | 
| 217 | 
         
            +
                            fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                    for name, module in self.named_children():
         
     | 
| 220 | 
         
            +
                        fn_recursive_attn_processor(name, module, processor)
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
         
     | 
| 223 | 
         
            +
                def set_default_attn_processor(self):
         
     | 
| 224 | 
         
            +
                    """
         
     | 
| 225 | 
         
            +
                    Disables custom attention processors and sets the default attention implementation.
         
     | 
| 226 | 
         
            +
                    """
         
     | 
| 227 | 
         
            +
                    self.set_attn_processor(AttnProcessor())
         
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
                def forward(
         
     | 
| 230 | 
         
            +
                    self,
         
     | 
| 231 | 
         
            +
                    hidden_states,
         
     | 
| 232 | 
         
            +
                    timestep: Union[torch.Tensor, float, int],
         
     | 
| 233 | 
         
            +
                    proj_embedding: torch.FloatTensor,
         
     | 
| 234 | 
         
            +
                    encoder_hidden_states: Optional[torch.FloatTensor] = None,
         
     | 
| 235 | 
         
            +
                    attention_mask: Optional[torch.BoolTensor] = None,
         
     | 
| 236 | 
         
            +
                    return_dict: bool = True,
         
     | 
| 237 | 
         
            +
                ):
         
     | 
| 238 | 
         
            +
                    """
         
     | 
| 239 | 
         
            +
                    The [`PriorTransformer`] forward method.
         
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
                    Args:
         
     | 
| 242 | 
         
            +
                        hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
         
     | 
| 243 | 
         
            +
                            The currently predicted image embeddings.
         
     | 
| 244 | 
         
            +
                        timestep (`torch.LongTensor`):
         
     | 
| 245 | 
         
            +
                            Current denoising step.
         
     | 
| 246 | 
         
            +
                        proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
         
     | 
| 247 | 
         
            +
                            Projected embedding vector the denoising process is conditioned on.
         
     | 
| 248 | 
         
            +
                        encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
         
     | 
| 249 | 
         
            +
                            Hidden states of the text embeddings the denoising process is conditioned on.
         
     | 
| 250 | 
         
            +
                        attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
         
     | 
| 251 | 
         
            +
                            Text mask for the text embeddings.
         
     | 
| 252 | 
         
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         
     | 
| 253 | 
         
            +
                            Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
         
     | 
| 254 | 
         
            +
                            tuple.
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                    Returns:
         
     | 
| 257 | 
         
            +
                        [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
         
     | 
| 258 | 
         
            +
                            If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
         
     | 
| 259 | 
         
            +
                            tuple is returned where the first element is the sample tensor.
         
     | 
| 260 | 
         
            +
                    """
         
     | 
| 261 | 
         
            +
                    batch_size = hidden_states.shape[0]
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                    timesteps = timestep
         
     | 
| 264 | 
         
            +
                    if not torch.is_tensor(timesteps):
         
     | 
| 265 | 
         
            +
                        timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
         
     | 
| 266 | 
         
            +
                    elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
         
     | 
| 267 | 
         
            +
                        timesteps = timesteps[None].to(hidden_states.device)
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                    # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
         
     | 
| 270 | 
         
            +
                    timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
         
     | 
| 271 | 
         
            +
             
     | 
| 272 | 
         
            +
                    timesteps_projected = self.time_proj(timesteps)
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
                    # timesteps does not contain any weights and will always return f32 tensors
         
     | 
| 275 | 
         
            +
                    # but time_embedding might be fp16, so we need to cast here.
         
     | 
| 276 | 
         
            +
                    timesteps_projected = timesteps_projected.to(dtype=self.dtype)
         
     | 
| 277 | 
         
            +
                    time_embeddings = self.time_embedding(timesteps_projected)
         
     | 
| 278 | 
         
            +
             
     | 
| 279 | 
         
            +
                    if self.embedding_proj_norm is not None:
         
     | 
| 280 | 
         
            +
                        proj_embedding = self.embedding_proj_norm(proj_embedding)
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
                    proj_embeddings = self.embedding_proj(proj_embedding)
         
     | 
| 283 | 
         
            +
                    if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
         
     | 
| 284 | 
         
            +
                        encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
         
     | 
| 285 | 
         
            +
                    elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
         
     | 
| 286 | 
         
            +
                        raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
         
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
                    hidden_states = self.proj_in(hidden_states)
         
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
                    positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
         
     | 
| 291 | 
         
            +
             
     | 
| 292 | 
         
            +
                    additional_embeds = []
         
     | 
| 293 | 
         
            +
                    additional_embeddings_len = 0
         
     | 
| 294 | 
         
            +
             
     | 
| 295 | 
         
            +
                    if encoder_hidden_states is not None:
         
     | 
| 296 | 
         
            +
                        additional_embeds.append(encoder_hidden_states)
         
     | 
| 297 | 
         
            +
                        additional_embeddings_len += encoder_hidden_states.shape[1]
         
     | 
| 298 | 
         
            +
             
     | 
| 299 | 
         
            +
                    if len(proj_embeddings.shape) == 2:
         
     | 
| 300 | 
         
            +
                        proj_embeddings = proj_embeddings[:, None, :]
         
     | 
| 301 | 
         
            +
             
     | 
| 302 | 
         
            +
                    if len(hidden_states.shape) == 2:
         
     | 
| 303 | 
         
            +
                        hidden_states = hidden_states[:, None, :]
         
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
                    additional_embeds = additional_embeds + [
         
     | 
| 306 | 
         
            +
                        proj_embeddings,
         
     | 
| 307 | 
         
            +
                        time_embeddings[:, None, :],
         
     | 
| 308 | 
         
            +
                        hidden_states,
         
     | 
| 309 | 
         
            +
                    ]
         
     | 
| 310 | 
         
            +
             
     | 
| 311 | 
         
            +
                    if self.prd_embedding is not None:
         
     | 
| 312 | 
         
            +
                        prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
         
     | 
| 313 | 
         
            +
                        additional_embeds.append(prd_embedding)
         
     | 
| 314 | 
         
            +
             
     | 
| 315 | 
         
            +
                    hidden_states = torch.cat(
         
     | 
| 316 | 
         
            +
                        additional_embeds,
         
     | 
| 317 | 
         
            +
                        dim=1,
         
     | 
| 318 | 
         
            +
                    )
         
     | 
| 319 | 
         
            +
             
     | 
| 320 | 
         
            +
                    # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
         
     | 
| 321 | 
         
            +
                    additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
         
     | 
| 322 | 
         
            +
                    if positional_embeddings.shape[1] < hidden_states.shape[1]:
         
     | 
| 323 | 
         
            +
                        positional_embeddings = F.pad(
         
     | 
| 324 | 
         
            +
                            positional_embeddings,
         
     | 
| 325 | 
         
            +
                            (
         
     | 
| 326 | 
         
            +
                                0,
         
     | 
| 327 | 
         
            +
                                0,
         
     | 
| 328 | 
         
            +
                                additional_embeddings_len,
         
     | 
| 329 | 
         
            +
                                self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
         
     | 
| 330 | 
         
            +
                            ),
         
     | 
| 331 | 
         
            +
                            value=0.0,
         
     | 
| 332 | 
         
            +
                        )
         
     | 
| 333 | 
         
            +
             
     | 
| 334 | 
         
            +
                    hidden_states = hidden_states + positional_embeddings
         
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
                    if attention_mask is not None:
         
     | 
| 337 | 
         
            +
                        attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
         
     | 
| 338 | 
         
            +
                        attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
         
     | 
| 339 | 
         
            +
                        attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
         
     | 
| 340 | 
         
            +
                        attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
                    if self.norm_in is not None:
         
     | 
| 343 | 
         
            +
                        hidden_states = self.norm_in(hidden_states)
         
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
                    for block in self.transformer_blocks:
         
     | 
| 346 | 
         
            +
                        hidden_states = block(hidden_states, attention_mask=attention_mask)
         
     | 
| 347 | 
         
            +
             
     | 
| 348 | 
         
            +
                    hidden_states = self.norm_out(hidden_states)
         
     | 
| 349 | 
         
            +
             
     | 
| 350 | 
         
            +
                    if self.prd_embedding is not None:
         
     | 
| 351 | 
         
            +
                        hidden_states = hidden_states[:, -1]
         
     | 
| 352 | 
         
            +
                    else:
         
     | 
| 353 | 
         
            +
                        hidden_states = hidden_states[:, additional_embeddings_len:]
         
     | 
| 354 | 
         
            +
             
     | 
| 355 | 
         
            +
                    predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
         
     | 
| 356 | 
         
            +
             
     | 
| 357 | 
         
            +
                    if not return_dict:
         
     | 
| 358 | 
         
            +
                        return (predicted_image_embedding,)
         
     | 
| 359 | 
         
            +
             
     | 
| 360 | 
         
            +
                    return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
         
     | 
| 361 | 
         
            +
             
     | 
| 362 | 
         
            +
                def post_process_latents(self, prior_latents):
         
     | 
| 363 | 
         
            +
                    prior_latents = (prior_latents * self.clip_std) + self.clip_mean
         
     | 
| 364 | 
         
            +
                    return prior_latents
         
     | 
    	
        6DoF/diffusers/models/resnet.py
    ADDED
    
    | 
         @@ -0,0 +1,877 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            # `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
         
     | 
| 3 | 
         
            +
            #
         
     | 
| 4 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 5 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 6 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 7 | 
         
            +
            #
         
     | 
| 8 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 9 | 
         
            +
            #
         
     | 
| 10 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 11 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 12 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 13 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 14 | 
         
            +
            # limitations under the License.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            from functools import partial
         
     | 
| 17 | 
         
            +
            from typing import Optional
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            import torch
         
     | 
| 20 | 
         
            +
            import torch.nn as nn
         
     | 
| 21 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            from .activations import get_activation
         
     | 
| 24 | 
         
            +
            from .attention import AdaGroupNorm
         
     | 
| 25 | 
         
            +
            from .attention_processor import SpatialNorm
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            class Upsample1D(nn.Module):
         
     | 
| 29 | 
         
            +
                """A 1D upsampling layer with an optional convolution.
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                Parameters:
         
     | 
| 32 | 
         
            +
                    channels (`int`):
         
     | 
| 33 | 
         
            +
                        number of channels in the inputs and outputs.
         
     | 
| 34 | 
         
            +
                    use_conv (`bool`, default `False`):
         
     | 
| 35 | 
         
            +
                        option to use a convolution.
         
     | 
| 36 | 
         
            +
                    use_conv_transpose (`bool`, default `False`):
         
     | 
| 37 | 
         
            +
                        option to use a convolution transpose.
         
     | 
| 38 | 
         
            +
                    out_channels (`int`, optional):
         
     | 
| 39 | 
         
            +
                        number of output channels. Defaults to `channels`.
         
     | 
| 40 | 
         
            +
                """
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
         
     | 
| 43 | 
         
            +
                    super().__init__()
         
     | 
| 44 | 
         
            +
                    self.channels = channels
         
     | 
| 45 | 
         
            +
                    self.out_channels = out_channels or channels
         
     | 
| 46 | 
         
            +
                    self.use_conv = use_conv
         
     | 
| 47 | 
         
            +
                    self.use_conv_transpose = use_conv_transpose
         
     | 
| 48 | 
         
            +
                    self.name = name
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                    self.conv = None
         
     | 
| 51 | 
         
            +
                    if use_conv_transpose:
         
     | 
| 52 | 
         
            +
                        self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
         
     | 
| 53 | 
         
            +
                    elif use_conv:
         
     | 
| 54 | 
         
            +
                        self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                def forward(self, inputs):
         
     | 
| 57 | 
         
            +
                    assert inputs.shape[1] == self.channels
         
     | 
| 58 | 
         
            +
                    if self.use_conv_transpose:
         
     | 
| 59 | 
         
            +
                        return self.conv(inputs)
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                    if self.use_conv:
         
     | 
| 64 | 
         
            +
                        outputs = self.conv(outputs)
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    return outputs
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
            class Downsample1D(nn.Module):
         
     | 
| 70 | 
         
            +
                """A 1D downsampling layer with an optional convolution.
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                Parameters:
         
     | 
| 73 | 
         
            +
                    channels (`int`):
         
     | 
| 74 | 
         
            +
                        number of channels in the inputs and outputs.
         
     | 
| 75 | 
         
            +
                    use_conv (`bool`, default `False`):
         
     | 
| 76 | 
         
            +
                        option to use a convolution.
         
     | 
| 77 | 
         
            +
                    out_channels (`int`, optional):
         
     | 
| 78 | 
         
            +
                        number of output channels. Defaults to `channels`.
         
     | 
| 79 | 
         
            +
                    padding (`int`, default `1`):
         
     | 
| 80 | 
         
            +
                        padding for the convolution.
         
     | 
| 81 | 
         
            +
                """
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
         
     | 
| 84 | 
         
            +
                    super().__init__()
         
     | 
| 85 | 
         
            +
                    self.channels = channels
         
     | 
| 86 | 
         
            +
                    self.out_channels = out_channels or channels
         
     | 
| 87 | 
         
            +
                    self.use_conv = use_conv
         
     | 
| 88 | 
         
            +
                    self.padding = padding
         
     | 
| 89 | 
         
            +
                    stride = 2
         
     | 
| 90 | 
         
            +
                    self.name = name
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                    if use_conv:
         
     | 
| 93 | 
         
            +
                        self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
         
     | 
| 94 | 
         
            +
                    else:
         
     | 
| 95 | 
         
            +
                        assert self.channels == self.out_channels
         
     | 
| 96 | 
         
            +
                        self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                def forward(self, inputs):
         
     | 
| 99 | 
         
            +
                    assert inputs.shape[1] == self.channels
         
     | 
| 100 | 
         
            +
                    return self.conv(inputs)
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
            class Upsample2D(nn.Module):
         
     | 
| 104 | 
         
            +
                """A 2D upsampling layer with an optional convolution.
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                Parameters:
         
     | 
| 107 | 
         
            +
                    channels (`int`):
         
     | 
| 108 | 
         
            +
                        number of channels in the inputs and outputs.
         
     | 
| 109 | 
         
            +
                    use_conv (`bool`, default `False`):
         
     | 
| 110 | 
         
            +
                        option to use a convolution.
         
     | 
| 111 | 
         
            +
                    use_conv_transpose (`bool`, default `False`):
         
     | 
| 112 | 
         
            +
                        option to use a convolution transpose.
         
     | 
| 113 | 
         
            +
                    out_channels (`int`, optional):
         
     | 
| 114 | 
         
            +
                        number of output channels. Defaults to `channels`.
         
     | 
| 115 | 
         
            +
                """
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
         
     | 
| 118 | 
         
            +
                    super().__init__()
         
     | 
| 119 | 
         
            +
                    self.channels = channels
         
     | 
| 120 | 
         
            +
                    self.out_channels = out_channels or channels
         
     | 
| 121 | 
         
            +
                    self.use_conv = use_conv
         
     | 
| 122 | 
         
            +
                    self.use_conv_transpose = use_conv_transpose
         
     | 
| 123 | 
         
            +
                    self.name = name
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    conv = None
         
     | 
| 126 | 
         
            +
                    if use_conv_transpose:
         
     | 
| 127 | 
         
            +
                        conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
         
     | 
| 128 | 
         
            +
                    elif use_conv:
         
     | 
| 129 | 
         
            +
                        conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                    # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
         
     | 
| 132 | 
         
            +
                    if name == "conv":
         
     | 
| 133 | 
         
            +
                        self.conv = conv
         
     | 
| 134 | 
         
            +
                    else:
         
     | 
| 135 | 
         
            +
                        self.Conv2d_0 = conv
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                def forward(self, hidden_states, output_size=None):
         
     | 
| 138 | 
         
            +
                    assert hidden_states.shape[1] == self.channels
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                    if self.use_conv_transpose:
         
     | 
| 141 | 
         
            +
                        return self.conv(hidden_states)
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                    # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
         
     | 
| 144 | 
         
            +
                    # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
         
     | 
| 145 | 
         
            +
                    # https://github.com/pytorch/pytorch/issues/86679
         
     | 
| 146 | 
         
            +
                    dtype = hidden_states.dtype
         
     | 
| 147 | 
         
            +
                    if dtype == torch.bfloat16:
         
     | 
| 148 | 
         
            +
                        hidden_states = hidden_states.to(torch.float32)
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                    # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
         
     | 
| 151 | 
         
            +
                    if hidden_states.shape[0] >= 64:
         
     | 
| 152 | 
         
            +
                        hidden_states = hidden_states.contiguous()
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                    # if `output_size` is passed we force the interpolation output
         
     | 
| 155 | 
         
            +
                    # size and do not make use of `scale_factor=2`
         
     | 
| 156 | 
         
            +
                    if output_size is None:
         
     | 
| 157 | 
         
            +
                        hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
         
     | 
| 158 | 
         
            +
                    else:
         
     | 
| 159 | 
         
            +
                        hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                    # If the input is bfloat16, we cast back to bfloat16
         
     | 
| 162 | 
         
            +
                    if dtype == torch.bfloat16:
         
     | 
| 163 | 
         
            +
                        hidden_states = hidden_states.to(dtype)
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                    # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
         
     | 
| 166 | 
         
            +
                    if self.use_conv:
         
     | 
| 167 | 
         
            +
                        if self.name == "conv":
         
     | 
| 168 | 
         
            +
                            hidden_states = self.conv(hidden_states)
         
     | 
| 169 | 
         
            +
                        else:
         
     | 
| 170 | 
         
            +
                            hidden_states = self.Conv2d_0(hidden_states)
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                    return hidden_states
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
            class Downsample2D(nn.Module):
         
     | 
| 176 | 
         
            +
                """A 2D downsampling layer with an optional convolution.
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                Parameters:
         
     | 
| 179 | 
         
            +
                    channels (`int`):
         
     | 
| 180 | 
         
            +
                        number of channels in the inputs and outputs.
         
     | 
| 181 | 
         
            +
                    use_conv (`bool`, default `False`):
         
     | 
| 182 | 
         
            +
                        option to use a convolution.
         
     | 
| 183 | 
         
            +
                    out_channels (`int`, optional):
         
     | 
| 184 | 
         
            +
                        number of output channels. Defaults to `channels`.
         
     | 
| 185 | 
         
            +
                    padding (`int`, default `1`):
         
     | 
| 186 | 
         
            +
                        padding for the convolution.
         
     | 
| 187 | 
         
            +
                """
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
         
     | 
| 190 | 
         
            +
                    super().__init__()
         
     | 
| 191 | 
         
            +
                    self.channels = channels
         
     | 
| 192 | 
         
            +
                    self.out_channels = out_channels or channels
         
     | 
| 193 | 
         
            +
                    self.use_conv = use_conv
         
     | 
| 194 | 
         
            +
                    self.padding = padding
         
     | 
| 195 | 
         
            +
                    stride = 2
         
     | 
| 196 | 
         
            +
                    self.name = name
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                    if use_conv:
         
     | 
| 199 | 
         
            +
                        conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
         
     | 
| 200 | 
         
            +
                    else:
         
     | 
| 201 | 
         
            +
                        assert self.channels == self.out_channels
         
     | 
| 202 | 
         
            +
                        conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                    # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
         
     | 
| 205 | 
         
            +
                    if name == "conv":
         
     | 
| 206 | 
         
            +
                        self.Conv2d_0 = conv
         
     | 
| 207 | 
         
            +
                        self.conv = conv
         
     | 
| 208 | 
         
            +
                    elif name == "Conv2d_0":
         
     | 
| 209 | 
         
            +
                        self.conv = conv
         
     | 
| 210 | 
         
            +
                    else:
         
     | 
| 211 | 
         
            +
                        self.conv = conv
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
                def forward(self, hidden_states):
         
     | 
| 214 | 
         
            +
                    assert hidden_states.shape[1] == self.channels
         
     | 
| 215 | 
         
            +
                    if self.use_conv and self.padding == 0:
         
     | 
| 216 | 
         
            +
                        pad = (0, 1, 0, 1)
         
     | 
| 217 | 
         
            +
                        hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                    assert hidden_states.shape[1] == self.channels
         
     | 
| 220 | 
         
            +
                    hidden_states = self.conv(hidden_states)
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                    return hidden_states
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
            class FirUpsample2D(nn.Module):
         
     | 
| 226 | 
         
            +
                """A 2D FIR upsampling layer with an optional convolution.
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                Parameters:
         
     | 
| 229 | 
         
            +
                    channels (`int`):
         
     | 
| 230 | 
         
            +
                        number of channels in the inputs and outputs.
         
     | 
| 231 | 
         
            +
                    use_conv (`bool`, default `False`):
         
     | 
| 232 | 
         
            +
                        option to use a convolution.
         
     | 
| 233 | 
         
            +
                    out_channels (`int`, optional):
         
     | 
| 234 | 
         
            +
                        number of output channels. Defaults to `channels`.
         
     | 
| 235 | 
         
            +
                    fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
         
     | 
| 236 | 
         
            +
                        kernel for the FIR filter.
         
     | 
| 237 | 
         
            +
                """
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
                def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
         
     | 
| 240 | 
         
            +
                    super().__init__()
         
     | 
| 241 | 
         
            +
                    out_channels = out_channels if out_channels else channels
         
     | 
| 242 | 
         
            +
                    if use_conv:
         
     | 
| 243 | 
         
            +
                        self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
         
     | 
| 244 | 
         
            +
                    self.use_conv = use_conv
         
     | 
| 245 | 
         
            +
                    self.fir_kernel = fir_kernel
         
     | 
| 246 | 
         
            +
                    self.out_channels = out_channels
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
         
     | 
| 249 | 
         
            +
                    """Fused `upsample_2d()` followed by `Conv2d()`.
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
                    Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
         
     | 
| 252 | 
         
            +
                    efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
         
     | 
| 253 | 
         
            +
                    arbitrary order.
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
                    Args:
         
     | 
| 256 | 
         
            +
                        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
         
     | 
| 257 | 
         
            +
                        weight: Weight tensor of the shape `[filterH, filterW, inChannels,
         
     | 
| 258 | 
         
            +
                            outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
         
     | 
| 259 | 
         
            +
                        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
         
     | 
| 260 | 
         
            +
                            (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
         
     | 
| 261 | 
         
            +
                        factor: Integer upsampling factor (default: 2).
         
     | 
| 262 | 
         
            +
                        gain: Scaling factor for signal magnitude (default: 1.0).
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                    Returns:
         
     | 
| 265 | 
         
            +
                        output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
         
     | 
| 266 | 
         
            +
                        datatype as `hidden_states`.
         
     | 
| 267 | 
         
            +
                    """
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                    assert isinstance(factor, int) and factor >= 1
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
                    # Setup filter kernel.
         
     | 
| 272 | 
         
            +
                    if kernel is None:
         
     | 
| 273 | 
         
            +
                        kernel = [1] * factor
         
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
                    # setup kernel
         
     | 
| 276 | 
         
            +
                    kernel = torch.tensor(kernel, dtype=torch.float32)
         
     | 
| 277 | 
         
            +
                    if kernel.ndim == 1:
         
     | 
| 278 | 
         
            +
                        kernel = torch.outer(kernel, kernel)
         
     | 
| 279 | 
         
            +
                    kernel /= torch.sum(kernel)
         
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
            +
                    kernel = kernel * (gain * (factor**2))
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
                    if self.use_conv:
         
     | 
| 284 | 
         
            +
                        convH = weight.shape[2]
         
     | 
| 285 | 
         
            +
                        convW = weight.shape[3]
         
     | 
| 286 | 
         
            +
                        inC = weight.shape[1]
         
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
                        pad_value = (kernel.shape[0] - factor) - (convW - 1)
         
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
                        stride = (factor, factor)
         
     | 
| 291 | 
         
            +
                        # Determine data dimensions.
         
     | 
| 292 | 
         
            +
                        output_shape = (
         
     | 
| 293 | 
         
            +
                            (hidden_states.shape[2] - 1) * factor + convH,
         
     | 
| 294 | 
         
            +
                            (hidden_states.shape[3] - 1) * factor + convW,
         
     | 
| 295 | 
         
            +
                        )
         
     | 
| 296 | 
         
            +
                        output_padding = (
         
     | 
| 297 | 
         
            +
                            output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
         
     | 
| 298 | 
         
            +
                            output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
         
     | 
| 299 | 
         
            +
                        )
         
     | 
| 300 | 
         
            +
                        assert output_padding[0] >= 0 and output_padding[1] >= 0
         
     | 
| 301 | 
         
            +
                        num_groups = hidden_states.shape[1] // inC
         
     | 
| 302 | 
         
            +
             
     | 
| 303 | 
         
            +
                        # Transpose weights.
         
     | 
| 304 | 
         
            +
                        weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
         
     | 
| 305 | 
         
            +
                        weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
         
     | 
| 306 | 
         
            +
                        weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
                        inverse_conv = F.conv_transpose2d(
         
     | 
| 309 | 
         
            +
                            hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
         
     | 
| 310 | 
         
            +
                        )
         
     | 
| 311 | 
         
            +
             
     | 
| 312 | 
         
            +
                        output = upfirdn2d_native(
         
     | 
| 313 | 
         
            +
                            inverse_conv,
         
     | 
| 314 | 
         
            +
                            torch.tensor(kernel, device=inverse_conv.device),
         
     | 
| 315 | 
         
            +
                            pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
         
     | 
| 316 | 
         
            +
                        )
         
     | 
| 317 | 
         
            +
                    else:
         
     | 
| 318 | 
         
            +
                        pad_value = kernel.shape[0] - factor
         
     | 
| 319 | 
         
            +
                        output = upfirdn2d_native(
         
     | 
| 320 | 
         
            +
                            hidden_states,
         
     | 
| 321 | 
         
            +
                            torch.tensor(kernel, device=hidden_states.device),
         
     | 
| 322 | 
         
            +
                            up=factor,
         
     | 
| 323 | 
         
            +
                            pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
         
     | 
| 324 | 
         
            +
                        )
         
     | 
| 325 | 
         
            +
             
     | 
| 326 | 
         
            +
                    return output
         
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
                def forward(self, hidden_states):
         
     | 
| 329 | 
         
            +
                    if self.use_conv:
         
     | 
| 330 | 
         
            +
                        height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
         
     | 
| 331 | 
         
            +
                        height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
         
     | 
| 332 | 
         
            +
                    else:
         
     | 
| 333 | 
         
            +
                        height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
         
     | 
| 334 | 
         
            +
             
     | 
| 335 | 
         
            +
                    return height
         
     | 
| 336 | 
         
            +
             
     | 
| 337 | 
         
            +
             
     | 
| 338 | 
         
            +
            class FirDownsample2D(nn.Module):
         
     | 
| 339 | 
         
            +
                """A 2D FIR downsampling layer with an optional convolution.
         
     | 
| 340 | 
         
            +
             
     | 
| 341 | 
         
            +
                Parameters:
         
     | 
| 342 | 
         
            +
                    channels (`int`):
         
     | 
| 343 | 
         
            +
                        number of channels in the inputs and outputs.
         
     | 
| 344 | 
         
            +
                    use_conv (`bool`, default `False`):
         
     | 
| 345 | 
         
            +
                        option to use a convolution.
         
     | 
| 346 | 
         
            +
                    out_channels (`int`, optional):
         
     | 
| 347 | 
         
            +
                        number of output channels. Defaults to `channels`.
         
     | 
| 348 | 
         
            +
                    fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
         
     | 
| 349 | 
         
            +
                        kernel for the FIR filter.
         
     | 
| 350 | 
         
            +
                """
         
     | 
| 351 | 
         
            +
             
     | 
| 352 | 
         
            +
                def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
         
     | 
| 353 | 
         
            +
                    super().__init__()
         
     | 
| 354 | 
         
            +
                    out_channels = out_channels if out_channels else channels
         
     | 
| 355 | 
         
            +
                    if use_conv:
         
     | 
| 356 | 
         
            +
                        self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
         
     | 
| 357 | 
         
            +
                    self.fir_kernel = fir_kernel
         
     | 
| 358 | 
         
            +
                    self.use_conv = use_conv
         
     | 
| 359 | 
         
            +
                    self.out_channels = out_channels
         
     | 
| 360 | 
         
            +
             
     | 
| 361 | 
         
            +
                def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
         
     | 
| 362 | 
         
            +
                    """Fused `Conv2d()` followed by `downsample_2d()`.
         
     | 
| 363 | 
         
            +
                    Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
         
     | 
| 364 | 
         
            +
                    efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
         
     | 
| 365 | 
         
            +
                    arbitrary order.
         
     | 
| 366 | 
         
            +
             
     | 
| 367 | 
         
            +
                    Args:
         
     | 
| 368 | 
         
            +
                        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
         
     | 
| 369 | 
         
            +
                        weight:
         
     | 
| 370 | 
         
            +
                            Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
         
     | 
| 371 | 
         
            +
                            performed by `inChannels = x.shape[0] // numGroups`.
         
     | 
| 372 | 
         
            +
                        kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
         
     | 
| 373 | 
         
            +
                        factor`, which corresponds to average pooling.
         
     | 
| 374 | 
         
            +
                        factor: Integer downsampling factor (default: 2).
         
     | 
| 375 | 
         
            +
                        gain: Scaling factor for signal magnitude (default: 1.0).
         
     | 
| 376 | 
         
            +
             
     | 
| 377 | 
         
            +
                    Returns:
         
     | 
| 378 | 
         
            +
                        output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
         
     | 
| 379 | 
         
            +
                        same datatype as `x`.
         
     | 
| 380 | 
         
            +
                    """
         
     | 
| 381 | 
         
            +
             
     | 
| 382 | 
         
            +
                    assert isinstance(factor, int) and factor >= 1
         
     | 
| 383 | 
         
            +
                    if kernel is None:
         
     | 
| 384 | 
         
            +
                        kernel = [1] * factor
         
     | 
| 385 | 
         
            +
             
     | 
| 386 | 
         
            +
                    # setup kernel
         
     | 
| 387 | 
         
            +
                    kernel = torch.tensor(kernel, dtype=torch.float32)
         
     | 
| 388 | 
         
            +
                    if kernel.ndim == 1:
         
     | 
| 389 | 
         
            +
                        kernel = torch.outer(kernel, kernel)
         
     | 
| 390 | 
         
            +
                    kernel /= torch.sum(kernel)
         
     | 
| 391 | 
         
            +
             
     | 
| 392 | 
         
            +
                    kernel = kernel * gain
         
     | 
| 393 | 
         
            +
             
     | 
| 394 | 
         
            +
                    if self.use_conv:
         
     | 
| 395 | 
         
            +
                        _, _, convH, convW = weight.shape
         
     | 
| 396 | 
         
            +
                        pad_value = (kernel.shape[0] - factor) + (convW - 1)
         
     | 
| 397 | 
         
            +
                        stride_value = [factor, factor]
         
     | 
| 398 | 
         
            +
                        upfirdn_input = upfirdn2d_native(
         
     | 
| 399 | 
         
            +
                            hidden_states,
         
     | 
| 400 | 
         
            +
                            torch.tensor(kernel, device=hidden_states.device),
         
     | 
| 401 | 
         
            +
                            pad=((pad_value + 1) // 2, pad_value // 2),
         
     | 
| 402 | 
         
            +
                        )
         
     | 
| 403 | 
         
            +
                        output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
         
     | 
| 404 | 
         
            +
                    else:
         
     | 
| 405 | 
         
            +
                        pad_value = kernel.shape[0] - factor
         
     | 
| 406 | 
         
            +
                        output = upfirdn2d_native(
         
     | 
| 407 | 
         
            +
                            hidden_states,
         
     | 
| 408 | 
         
            +
                            torch.tensor(kernel, device=hidden_states.device),
         
     | 
| 409 | 
         
            +
                            down=factor,
         
     | 
| 410 | 
         
            +
                            pad=((pad_value + 1) // 2, pad_value // 2),
         
     | 
| 411 | 
         
            +
                        )
         
     | 
| 412 | 
         
            +
             
     | 
| 413 | 
         
            +
                    return output
         
     | 
| 414 | 
         
            +
             
     | 
| 415 | 
         
            +
                def forward(self, hidden_states):
         
     | 
| 416 | 
         
            +
                    if self.use_conv:
         
     | 
| 417 | 
         
            +
                        downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
         
     | 
| 418 | 
         
            +
                        hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
         
     | 
| 419 | 
         
            +
                    else:
         
     | 
| 420 | 
         
            +
                        hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
         
     | 
| 421 | 
         
            +
             
     | 
| 422 | 
         
            +
                    return hidden_states
         
     | 
| 423 | 
         
            +
             
     | 
| 424 | 
         
            +
             
     | 
| 425 | 
         
            +
            # downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
         
     | 
| 426 | 
         
            +
            class KDownsample2D(nn.Module):
         
     | 
| 427 | 
         
            +
                def __init__(self, pad_mode="reflect"):
         
     | 
| 428 | 
         
            +
                    super().__init__()
         
     | 
| 429 | 
         
            +
                    self.pad_mode = pad_mode
         
     | 
| 430 | 
         
            +
                    kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
         
     | 
| 431 | 
         
            +
                    self.pad = kernel_1d.shape[1] // 2 - 1
         
     | 
| 432 | 
         
            +
                    self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
         
     | 
| 433 | 
         
            +
             
     | 
| 434 | 
         
            +
                def forward(self, inputs):
         
     | 
| 435 | 
         
            +
                    inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
         
     | 
| 436 | 
         
            +
                    weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
         
     | 
| 437 | 
         
            +
                    indices = torch.arange(inputs.shape[1], device=inputs.device)
         
     | 
| 438 | 
         
            +
                    kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
         
     | 
| 439 | 
         
            +
                    weight[indices, indices] = kernel
         
     | 
| 440 | 
         
            +
                    return F.conv2d(inputs, weight, stride=2)
         
     | 
| 441 | 
         
            +
             
     | 
| 442 | 
         
            +
             
     | 
| 443 | 
         
            +
            class KUpsample2D(nn.Module):
         
     | 
| 444 | 
         
            +
                def __init__(self, pad_mode="reflect"):
         
     | 
| 445 | 
         
            +
                    super().__init__()
         
     | 
| 446 | 
         
            +
                    self.pad_mode = pad_mode
         
     | 
| 447 | 
         
            +
                    kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
         
     | 
| 448 | 
         
            +
                    self.pad = kernel_1d.shape[1] // 2 - 1
         
     | 
| 449 | 
         
            +
                    self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
         
     | 
| 450 | 
         
            +
             
     | 
| 451 | 
         
            +
                def forward(self, inputs):
         
     | 
| 452 | 
         
            +
                    inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
         
     | 
| 453 | 
         
            +
                    weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
         
     | 
| 454 | 
         
            +
                    indices = torch.arange(inputs.shape[1], device=inputs.device)
         
     | 
| 455 | 
         
            +
                    kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
         
     | 
| 456 | 
         
            +
                    weight[indices, indices] = kernel
         
     | 
| 457 | 
         
            +
                    return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
         
     | 
| 458 | 
         
            +
             
     | 
| 459 | 
         
            +
             
     | 
| 460 | 
         
            +
            class ResnetBlock2D(nn.Module):
         
     | 
| 461 | 
         
            +
                r"""
         
     | 
| 462 | 
         
            +
                A Resnet block.
         
     | 
| 463 | 
         
            +
             
     | 
| 464 | 
         
            +
                Parameters:
         
     | 
| 465 | 
         
            +
                    in_channels (`int`): The number of channels in the input.
         
     | 
| 466 | 
         
            +
                    out_channels (`int`, *optional*, default to be `None`):
         
     | 
| 467 | 
         
            +
                        The number of output channels for the first conv2d layer. If None, same as `in_channels`.
         
     | 
| 468 | 
         
            +
                    dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
         
     | 
| 469 | 
         
            +
                    temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
         
     | 
| 470 | 
         
            +
                    groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
         
     | 
| 471 | 
         
            +
                    groups_out (`int`, *optional*, default to None):
         
     | 
| 472 | 
         
            +
                        The number of groups to use for the second normalization layer. if set to None, same as `groups`.
         
     | 
| 473 | 
         
            +
                    eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
         
     | 
| 474 | 
         
            +
                    non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
         
     | 
| 475 | 
         
            +
                    time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
         
     | 
| 476 | 
         
            +
                        By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
         
     | 
| 477 | 
         
            +
                        "ada_group" for a stronger conditioning with scale and shift.
         
     | 
| 478 | 
         
            +
                    kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
         
     | 
| 479 | 
         
            +
                        [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
         
     | 
| 480 | 
         
            +
                    output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
         
     | 
| 481 | 
         
            +
                    use_in_shortcut (`bool`, *optional*, default to `True`):
         
     | 
| 482 | 
         
            +
                        If `True`, add a 1x1 nn.conv2d layer for skip-connection.
         
     | 
| 483 | 
         
            +
                    up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
         
     | 
| 484 | 
         
            +
                    down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
         
     | 
| 485 | 
         
            +
                    conv_shortcut_bias (`bool`, *optional*, default to `True`):  If `True`, adds a learnable bias to the
         
     | 
| 486 | 
         
            +
                        `conv_shortcut` output.
         
     | 
| 487 | 
         
            +
                    conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
         
     | 
| 488 | 
         
            +
                        If None, same as `out_channels`.
         
     | 
| 489 | 
         
            +
                """
         
     | 
| 490 | 
         
            +
             
     | 
| 491 | 
         
            +
                def __init__(
         
     | 
| 492 | 
         
            +
                    self,
         
     | 
| 493 | 
         
            +
                    *,
         
     | 
| 494 | 
         
            +
                    in_channels,
         
     | 
| 495 | 
         
            +
                    out_channels=None,
         
     | 
| 496 | 
         
            +
                    conv_shortcut=False,
         
     | 
| 497 | 
         
            +
                    dropout=0.0,
         
     | 
| 498 | 
         
            +
                    temb_channels=512,
         
     | 
| 499 | 
         
            +
                    groups=32,
         
     | 
| 500 | 
         
            +
                    groups_out=None,
         
     | 
| 501 | 
         
            +
                    pre_norm=True,
         
     | 
| 502 | 
         
            +
                    eps=1e-6,
         
     | 
| 503 | 
         
            +
                    non_linearity="swish",
         
     | 
| 504 | 
         
            +
                    skip_time_act=False,
         
     | 
| 505 | 
         
            +
                    time_embedding_norm="default",  # default, scale_shift, ada_group, spatial
         
     | 
| 506 | 
         
            +
                    kernel=None,
         
     | 
| 507 | 
         
            +
                    output_scale_factor=1.0,
         
     | 
| 508 | 
         
            +
                    use_in_shortcut=None,
         
     | 
| 509 | 
         
            +
                    up=False,
         
     | 
| 510 | 
         
            +
                    down=False,
         
     | 
| 511 | 
         
            +
                    conv_shortcut_bias: bool = True,
         
     | 
| 512 | 
         
            +
                    conv_2d_out_channels: Optional[int] = None,
         
     | 
| 513 | 
         
            +
                ):
         
     | 
| 514 | 
         
            +
                    super().__init__()
         
     | 
| 515 | 
         
            +
                    self.pre_norm = pre_norm
         
     | 
| 516 | 
         
            +
                    self.pre_norm = True
         
     | 
| 517 | 
         
            +
                    self.in_channels = in_channels
         
     | 
| 518 | 
         
            +
                    out_channels = in_channels if out_channels is None else out_channels
         
     | 
| 519 | 
         
            +
                    self.out_channels = out_channels
         
     | 
| 520 | 
         
            +
                    self.use_conv_shortcut = conv_shortcut
         
     | 
| 521 | 
         
            +
                    self.up = up
         
     | 
| 522 | 
         
            +
                    self.down = down
         
     | 
| 523 | 
         
            +
                    self.output_scale_factor = output_scale_factor
         
     | 
| 524 | 
         
            +
                    self.time_embedding_norm = time_embedding_norm
         
     | 
| 525 | 
         
            +
                    self.skip_time_act = skip_time_act
         
     | 
| 526 | 
         
            +
             
     | 
| 527 | 
         
            +
                    if groups_out is None:
         
     | 
| 528 | 
         
            +
                        groups_out = groups
         
     | 
| 529 | 
         
            +
             
     | 
| 530 | 
         
            +
                    if self.time_embedding_norm == "ada_group":
         
     | 
| 531 | 
         
            +
                        self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
         
     | 
| 532 | 
         
            +
                    elif self.time_embedding_norm == "spatial":
         
     | 
| 533 | 
         
            +
                        self.norm1 = SpatialNorm(in_channels, temb_channels)
         
     | 
| 534 | 
         
            +
                    else:
         
     | 
| 535 | 
         
            +
                        self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
         
     | 
| 536 | 
         
            +
             
     | 
| 537 | 
         
            +
                    self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
         
     | 
| 538 | 
         
            +
             
     | 
| 539 | 
         
            +
                    if temb_channels is not None:
         
     | 
| 540 | 
         
            +
                        if self.time_embedding_norm == "default":
         
     | 
| 541 | 
         
            +
                            self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
         
     | 
| 542 | 
         
            +
                        elif self.time_embedding_norm == "scale_shift":
         
     | 
| 543 | 
         
            +
                            self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
         
     | 
| 544 | 
         
            +
                        elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
         
     | 
| 545 | 
         
            +
                            self.time_emb_proj = None
         
     | 
| 546 | 
         
            +
                        else:
         
     | 
| 547 | 
         
            +
                            raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
         
     | 
| 548 | 
         
            +
                    else:
         
     | 
| 549 | 
         
            +
                        self.time_emb_proj = None
         
     | 
| 550 | 
         
            +
             
     | 
| 551 | 
         
            +
                    if self.time_embedding_norm == "ada_group":
         
     | 
| 552 | 
         
            +
                        self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
         
     | 
| 553 | 
         
            +
                    elif self.time_embedding_norm == "spatial":
         
     | 
| 554 | 
         
            +
                        self.norm2 = SpatialNorm(out_channels, temb_channels)
         
     | 
| 555 | 
         
            +
                    else:
         
     | 
| 556 | 
         
            +
                        self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
         
     | 
| 557 | 
         
            +
             
     | 
| 558 | 
         
            +
                    self.dropout = torch.nn.Dropout(dropout)
         
     | 
| 559 | 
         
            +
                    conv_2d_out_channels = conv_2d_out_channels or out_channels
         
     | 
| 560 | 
         
            +
                    self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
         
     | 
| 561 | 
         
            +
             
     | 
| 562 | 
         
            +
                    self.nonlinearity = get_activation(non_linearity)
         
     | 
| 563 | 
         
            +
             
     | 
| 564 | 
         
            +
                    self.upsample = self.downsample = None
         
     | 
| 565 | 
         
            +
                    if self.up:
         
     | 
| 566 | 
         
            +
                        if kernel == "fir":
         
     | 
| 567 | 
         
            +
                            fir_kernel = (1, 3, 3, 1)
         
     | 
| 568 | 
         
            +
                            self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
         
     | 
| 569 | 
         
            +
                        elif kernel == "sde_vp":
         
     | 
| 570 | 
         
            +
                            self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
         
     | 
| 571 | 
         
            +
                        else:
         
     | 
| 572 | 
         
            +
                            self.upsample = Upsample2D(in_channels, use_conv=False)
         
     | 
| 573 | 
         
            +
                    elif self.down:
         
     | 
| 574 | 
         
            +
                        if kernel == "fir":
         
     | 
| 575 | 
         
            +
                            fir_kernel = (1, 3, 3, 1)
         
     | 
| 576 | 
         
            +
                            self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
         
     | 
| 577 | 
         
            +
                        elif kernel == "sde_vp":
         
     | 
| 578 | 
         
            +
                            self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
         
     | 
| 579 | 
         
            +
                        else:
         
     | 
| 580 | 
         
            +
                            self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
         
     | 
| 581 | 
         
            +
             
     | 
| 582 | 
         
            +
                    self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
         
     | 
| 583 | 
         
            +
             
     | 
| 584 | 
         
            +
                    self.conv_shortcut = None
         
     | 
| 585 | 
         
            +
                    if self.use_in_shortcut:
         
     | 
| 586 | 
         
            +
                        self.conv_shortcut = torch.nn.Conv2d(
         
     | 
| 587 | 
         
            +
                            in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
         
     | 
| 588 | 
         
            +
                        )
         
     | 
| 589 | 
         
            +
             
     | 
| 590 | 
         
            +
                def forward(self, input_tensor, temb):
         
     | 
| 591 | 
         
            +
                    hidden_states = input_tensor
         
     | 
| 592 | 
         
            +
             
     | 
| 593 | 
         
            +
                    if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
         
     | 
| 594 | 
         
            +
                        hidden_states = self.norm1(hidden_states, temb)
         
     | 
| 595 | 
         
            +
                    else:
         
     | 
| 596 | 
         
            +
                        hidden_states = self.norm1(hidden_states)
         
     | 
| 597 | 
         
            +
             
     | 
| 598 | 
         
            +
                    hidden_states = self.nonlinearity(hidden_states)
         
     | 
| 599 | 
         
            +
             
     | 
| 600 | 
         
            +
                    if self.upsample is not None:
         
     | 
| 601 | 
         
            +
                        # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
         
     | 
| 602 | 
         
            +
                        if hidden_states.shape[0] >= 64:
         
     | 
| 603 | 
         
            +
                            input_tensor = input_tensor.contiguous()
         
     | 
| 604 | 
         
            +
                            hidden_states = hidden_states.contiguous()
         
     | 
| 605 | 
         
            +
                        input_tensor = self.upsample(input_tensor)
         
     | 
| 606 | 
         
            +
                        hidden_states = self.upsample(hidden_states)
         
     | 
| 607 | 
         
            +
                    elif self.downsample is not None:
         
     | 
| 608 | 
         
            +
                        input_tensor = self.downsample(input_tensor)
         
     | 
| 609 | 
         
            +
                        hidden_states = self.downsample(hidden_states)
         
     | 
| 610 | 
         
            +
             
     | 
| 611 | 
         
            +
                    hidden_states = self.conv1(hidden_states)
         
     | 
| 612 | 
         
            +
             
     | 
| 613 | 
         
            +
                    if self.time_emb_proj is not None:
         
     | 
| 614 | 
         
            +
                        if not self.skip_time_act:
         
     | 
| 615 | 
         
            +
                            temb = self.nonlinearity(temb)
         
     | 
| 616 | 
         
            +
                        temb = self.time_emb_proj(temb)[:, :, None, None]
         
     | 
| 617 | 
         
            +
             
     | 
| 618 | 
         
            +
                    if temb is not None and self.time_embedding_norm == "default":
         
     | 
| 619 | 
         
            +
                        hidden_states = hidden_states + temb
         
     | 
| 620 | 
         
            +
             
     | 
| 621 | 
         
            +
                    if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
         
     | 
| 622 | 
         
            +
                        hidden_states = self.norm2(hidden_states, temb)
         
     | 
| 623 | 
         
            +
                    else:
         
     | 
| 624 | 
         
            +
                        hidden_states = self.norm2(hidden_states)
         
     | 
| 625 | 
         
            +
             
     | 
| 626 | 
         
            +
                    if temb is not None and self.time_embedding_norm == "scale_shift":
         
     | 
| 627 | 
         
            +
                        scale, shift = torch.chunk(temb, 2, dim=1)
         
     | 
| 628 | 
         
            +
                        hidden_states = hidden_states * (1 + scale) + shift
         
     | 
| 629 | 
         
            +
             
     | 
| 630 | 
         
            +
                    hidden_states = self.nonlinearity(hidden_states)
         
     | 
| 631 | 
         
            +
             
     | 
| 632 | 
         
            +
                    hidden_states = self.dropout(hidden_states)
         
     | 
| 633 | 
         
            +
                    hidden_states = self.conv2(hidden_states)
         
     | 
| 634 | 
         
            +
             
     | 
| 635 | 
         
            +
                    if self.conv_shortcut is not None:
         
     | 
| 636 | 
         
            +
                        input_tensor = self.conv_shortcut(input_tensor)
         
     | 
| 637 | 
         
            +
             
     | 
| 638 | 
         
            +
                    output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
         
     | 
| 639 | 
         
            +
             
     | 
| 640 | 
         
            +
                    return output_tensor
         
     | 
| 641 | 
         
            +
             
     | 
| 642 | 
         
            +
             
     | 
| 643 | 
         
            +
            # unet_rl.py
         
     | 
| 644 | 
         
            +
            def rearrange_dims(tensor):
         
     | 
| 645 | 
         
            +
                if len(tensor.shape) == 2:
         
     | 
| 646 | 
         
            +
                    return tensor[:, :, None]
         
     | 
| 647 | 
         
            +
                if len(tensor.shape) == 3:
         
     | 
| 648 | 
         
            +
                    return tensor[:, :, None, :]
         
     | 
| 649 | 
         
            +
                elif len(tensor.shape) == 4:
         
     | 
| 650 | 
         
            +
                    return tensor[:, :, 0, :]
         
     | 
| 651 | 
         
            +
                else:
         
     | 
| 652 | 
         
            +
                    raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
         
     | 
| 653 | 
         
            +
             
     | 
| 654 | 
         
            +
             
     | 
| 655 | 
         
            +
            class Conv1dBlock(nn.Module):
         
     | 
| 656 | 
         
            +
                """
         
     | 
| 657 | 
         
            +
                Conv1d --> GroupNorm --> Mish
         
     | 
| 658 | 
         
            +
                """
         
     | 
| 659 | 
         
            +
             
     | 
| 660 | 
         
            +
                def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
         
     | 
| 661 | 
         
            +
                    super().__init__()
         
     | 
| 662 | 
         
            +
             
     | 
| 663 | 
         
            +
                    self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
         
     | 
| 664 | 
         
            +
                    self.group_norm = nn.GroupNorm(n_groups, out_channels)
         
     | 
| 665 | 
         
            +
                    self.mish = nn.Mish()
         
     | 
| 666 | 
         
            +
             
     | 
| 667 | 
         
            +
                def forward(self, inputs):
         
     | 
| 668 | 
         
            +
                    intermediate_repr = self.conv1d(inputs)
         
     | 
| 669 | 
         
            +
                    intermediate_repr = rearrange_dims(intermediate_repr)
         
     | 
| 670 | 
         
            +
                    intermediate_repr = self.group_norm(intermediate_repr)
         
     | 
| 671 | 
         
            +
                    intermediate_repr = rearrange_dims(intermediate_repr)
         
     | 
| 672 | 
         
            +
                    output = self.mish(intermediate_repr)
         
     | 
| 673 | 
         
            +
                    return output
         
     | 
| 674 | 
         
            +
             
     | 
| 675 | 
         
            +
             
     | 
| 676 | 
         
            +
            # unet_rl.py
         
     | 
| 677 | 
         
            +
            class ResidualTemporalBlock1D(nn.Module):
         
     | 
| 678 | 
         
            +
                def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
         
     | 
| 679 | 
         
            +
                    super().__init__()
         
     | 
| 680 | 
         
            +
                    self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
         
     | 
| 681 | 
         
            +
                    self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
         
     | 
| 682 | 
         
            +
             
     | 
| 683 | 
         
            +
                    self.time_emb_act = nn.Mish()
         
     | 
| 684 | 
         
            +
                    self.time_emb = nn.Linear(embed_dim, out_channels)
         
     | 
| 685 | 
         
            +
             
     | 
| 686 | 
         
            +
                    self.residual_conv = (
         
     | 
| 687 | 
         
            +
                        nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
         
     | 
| 688 | 
         
            +
                    )
         
     | 
| 689 | 
         
            +
             
     | 
| 690 | 
         
            +
                def forward(self, inputs, t):
         
     | 
| 691 | 
         
            +
                    """
         
     | 
| 692 | 
         
            +
                    Args:
         
     | 
| 693 | 
         
            +
                        inputs : [ batch_size x inp_channels x horizon ]
         
     | 
| 694 | 
         
            +
                        t : [ batch_size x embed_dim ]
         
     | 
| 695 | 
         
            +
             
     | 
| 696 | 
         
            +
                    returns:
         
     | 
| 697 | 
         
            +
                        out : [ batch_size x out_channels x horizon ]
         
     | 
| 698 | 
         
            +
                    """
         
     | 
| 699 | 
         
            +
                    t = self.time_emb_act(t)
         
     | 
| 700 | 
         
            +
                    t = self.time_emb(t)
         
     | 
| 701 | 
         
            +
                    out = self.conv_in(inputs) + rearrange_dims(t)
         
     | 
| 702 | 
         
            +
                    out = self.conv_out(out)
         
     | 
| 703 | 
         
            +
                    return out + self.residual_conv(inputs)
         
     | 
| 704 | 
         
            +
             
     | 
| 705 | 
         
            +
             
     | 
| 706 | 
         
            +
            def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
         
     | 
| 707 | 
         
            +
                r"""Upsample2D a batch of 2D images with the given filter.
         
     | 
| 708 | 
         
            +
                Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
         
     | 
| 709 | 
         
            +
                filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
         
     | 
| 710 | 
         
            +
                `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
         
     | 
| 711 | 
         
            +
                a: multiple of the upsampling factor.
         
     | 
| 712 | 
         
            +
             
     | 
| 713 | 
         
            +
                Args:
         
     | 
| 714 | 
         
            +
                    hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
         
     | 
| 715 | 
         
            +
                    kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
         
     | 
| 716 | 
         
            +
                      (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
         
     | 
| 717 | 
         
            +
                    factor: Integer upsampling factor (default: 2).
         
     | 
| 718 | 
         
            +
                    gain: Scaling factor for signal magnitude (default: 1.0).
         
     | 
| 719 | 
         
            +
             
     | 
| 720 | 
         
            +
                Returns:
         
     | 
| 721 | 
         
            +
                    output: Tensor of the shape `[N, C, H * factor, W * factor]`
         
     | 
| 722 | 
         
            +
                """
         
     | 
| 723 | 
         
            +
                assert isinstance(factor, int) and factor >= 1
         
     | 
| 724 | 
         
            +
                if kernel is None:
         
     | 
| 725 | 
         
            +
                    kernel = [1] * factor
         
     | 
| 726 | 
         
            +
             
     | 
| 727 | 
         
            +
                kernel = torch.tensor(kernel, dtype=torch.float32)
         
     | 
| 728 | 
         
            +
                if kernel.ndim == 1:
         
     | 
| 729 | 
         
            +
                    kernel = torch.outer(kernel, kernel)
         
     | 
| 730 | 
         
            +
                kernel /= torch.sum(kernel)
         
     | 
| 731 | 
         
            +
             
     | 
| 732 | 
         
            +
                kernel = kernel * (gain * (factor**2))
         
     | 
| 733 | 
         
            +
                pad_value = kernel.shape[0] - factor
         
     | 
| 734 | 
         
            +
                output = upfirdn2d_native(
         
     | 
| 735 | 
         
            +
                    hidden_states,
         
     | 
| 736 | 
         
            +
                    kernel.to(device=hidden_states.device),
         
     | 
| 737 | 
         
            +
                    up=factor,
         
     | 
| 738 | 
         
            +
                    pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
         
     | 
| 739 | 
         
            +
                )
         
     | 
| 740 | 
         
            +
                return output
         
     | 
| 741 | 
         
            +
             
     | 
| 742 | 
         
            +
             
     | 
| 743 | 
         
            +
            def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
         
     | 
| 744 | 
         
            +
                r"""Downsample2D a batch of 2D images with the given filter.
         
     | 
| 745 | 
         
            +
                Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
         
     | 
| 746 | 
         
            +
                given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
         
     | 
| 747 | 
         
            +
                specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
         
     | 
| 748 | 
         
            +
                shape is a multiple of the downsampling factor.
         
     | 
| 749 | 
         
            +
             
     | 
| 750 | 
         
            +
                Args:
         
     | 
| 751 | 
         
            +
                    hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
         
     | 
| 752 | 
         
            +
                    kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
         
     | 
| 753 | 
         
            +
                      (separable). The default is `[1] * factor`, which corresponds to average pooling.
         
     | 
| 754 | 
         
            +
                    factor: Integer downsampling factor (default: 2).
         
     | 
| 755 | 
         
            +
                    gain: Scaling factor for signal magnitude (default: 1.0).
         
     | 
| 756 | 
         
            +
             
     | 
| 757 | 
         
            +
                Returns:
         
     | 
| 758 | 
         
            +
                    output: Tensor of the shape `[N, C, H // factor, W // factor]`
         
     | 
| 759 | 
         
            +
                """
         
     | 
| 760 | 
         
            +
             
     | 
| 761 | 
         
            +
                assert isinstance(factor, int) and factor >= 1
         
     | 
| 762 | 
         
            +
                if kernel is None:
         
     | 
| 763 | 
         
            +
                    kernel = [1] * factor
         
     | 
| 764 | 
         
            +
             
     | 
| 765 | 
         
            +
                kernel = torch.tensor(kernel, dtype=torch.float32)
         
     | 
| 766 | 
         
            +
                if kernel.ndim == 1:
         
     | 
| 767 | 
         
            +
                    kernel = torch.outer(kernel, kernel)
         
     | 
| 768 | 
         
            +
                kernel /= torch.sum(kernel)
         
     | 
| 769 | 
         
            +
             
     | 
| 770 | 
         
            +
                kernel = kernel * gain
         
     | 
| 771 | 
         
            +
                pad_value = kernel.shape[0] - factor
         
     | 
| 772 | 
         
            +
                output = upfirdn2d_native(
         
     | 
| 773 | 
         
            +
                    hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
         
     | 
| 774 | 
         
            +
                )
         
     | 
| 775 | 
         
            +
                return output
         
     | 
| 776 | 
         
            +
             
     | 
| 777 | 
         
            +
             
     | 
| 778 | 
         
            +
            def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
         
     | 
| 779 | 
         
            +
                up_x = up_y = up
         
     | 
| 780 | 
         
            +
                down_x = down_y = down
         
     | 
| 781 | 
         
            +
                pad_x0 = pad_y0 = pad[0]
         
     | 
| 782 | 
         
            +
                pad_x1 = pad_y1 = pad[1]
         
     | 
| 783 | 
         
            +
             
     | 
| 784 | 
         
            +
                _, channel, in_h, in_w = tensor.shape
         
     | 
| 785 | 
         
            +
                tensor = tensor.reshape(-1, in_h, in_w, 1)
         
     | 
| 786 | 
         
            +
             
     | 
| 787 | 
         
            +
                _, in_h, in_w, minor = tensor.shape
         
     | 
| 788 | 
         
            +
                kernel_h, kernel_w = kernel.shape
         
     | 
| 789 | 
         
            +
             
     | 
| 790 | 
         
            +
                out = tensor.view(-1, in_h, 1, in_w, 1, minor)
         
     | 
| 791 | 
         
            +
                out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
         
     | 
| 792 | 
         
            +
                out = out.view(-1, in_h * up_y, in_w * up_x, minor)
         
     | 
| 793 | 
         
            +
             
     | 
| 794 | 
         
            +
                out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
         
     | 
| 795 | 
         
            +
                out = out.to(tensor.device)  # Move back to mps if necessary
         
     | 
| 796 | 
         
            +
                out = out[
         
     | 
| 797 | 
         
            +
                    :,
         
     | 
| 798 | 
         
            +
                    max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
         
     | 
| 799 | 
         
            +
                    max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
         
     | 
| 800 | 
         
            +
                    :,
         
     | 
| 801 | 
         
            +
                ]
         
     | 
| 802 | 
         
            +
             
     | 
| 803 | 
         
            +
                out = out.permute(0, 3, 1, 2)
         
     | 
| 804 | 
         
            +
                out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
         
     | 
| 805 | 
         
            +
                w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
         
     | 
| 806 | 
         
            +
                out = F.conv2d(out, w)
         
     | 
| 807 | 
         
            +
                out = out.reshape(
         
     | 
| 808 | 
         
            +
                    -1,
         
     | 
| 809 | 
         
            +
                    minor,
         
     | 
| 810 | 
         
            +
                    in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
         
     | 
| 811 | 
         
            +
                    in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
         
     | 
| 812 | 
         
            +
                )
         
     | 
| 813 | 
         
            +
                out = out.permute(0, 2, 3, 1)
         
     | 
| 814 | 
         
            +
                out = out[:, ::down_y, ::down_x, :]
         
     | 
| 815 | 
         
            +
             
     | 
| 816 | 
         
            +
                out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
         
     | 
| 817 | 
         
            +
                out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
         
     | 
| 818 | 
         
            +
             
     | 
| 819 | 
         
            +
                return out.view(-1, channel, out_h, out_w)
         
     | 
| 820 | 
         
            +
             
     | 
| 821 | 
         
            +
             
     | 
| 822 | 
         
            +
            class TemporalConvLayer(nn.Module):
         
     | 
| 823 | 
         
            +
                """
         
     | 
| 824 | 
         
            +
                Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
         
     | 
| 825 | 
         
            +
                https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
         
     | 
| 826 | 
         
            +
                """
         
     | 
| 827 | 
         
            +
             
     | 
| 828 | 
         
            +
                def __init__(self, in_dim, out_dim=None, dropout=0.0):
         
     | 
| 829 | 
         
            +
                    super().__init__()
         
     | 
| 830 | 
         
            +
                    out_dim = out_dim or in_dim
         
     | 
| 831 | 
         
            +
                    self.in_dim = in_dim
         
     | 
| 832 | 
         
            +
                    self.out_dim = out_dim
         
     | 
| 833 | 
         
            +
             
     | 
| 834 | 
         
            +
                    # conv layers
         
     | 
| 835 | 
         
            +
                    self.conv1 = nn.Sequential(
         
     | 
| 836 | 
         
            +
                        nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
         
     | 
| 837 | 
         
            +
                    )
         
     | 
| 838 | 
         
            +
                    self.conv2 = nn.Sequential(
         
     | 
| 839 | 
         
            +
                        nn.GroupNorm(32, out_dim),
         
     | 
| 840 | 
         
            +
                        nn.SiLU(),
         
     | 
| 841 | 
         
            +
                        nn.Dropout(dropout),
         
     | 
| 842 | 
         
            +
                        nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
         
     | 
| 843 | 
         
            +
                    )
         
     | 
| 844 | 
         
            +
                    self.conv3 = nn.Sequential(
         
     | 
| 845 | 
         
            +
                        nn.GroupNorm(32, out_dim),
         
     | 
| 846 | 
         
            +
                        nn.SiLU(),
         
     | 
| 847 | 
         
            +
                        nn.Dropout(dropout),
         
     | 
| 848 | 
         
            +
                        nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
         
     | 
| 849 | 
         
            +
                    )
         
     | 
| 850 | 
         
            +
                    self.conv4 = nn.Sequential(
         
     | 
| 851 | 
         
            +
                        nn.GroupNorm(32, out_dim),
         
     | 
| 852 | 
         
            +
                        nn.SiLU(),
         
     | 
| 853 | 
         
            +
                        nn.Dropout(dropout),
         
     | 
| 854 | 
         
            +
                        nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
         
     | 
| 855 | 
         
            +
                    )
         
     | 
| 856 | 
         
            +
             
     | 
| 857 | 
         
            +
                    # zero out the last layer params,so the conv block is identity
         
     | 
| 858 | 
         
            +
                    nn.init.zeros_(self.conv4[-1].weight)
         
     | 
| 859 | 
         
            +
                    nn.init.zeros_(self.conv4[-1].bias)
         
     | 
| 860 | 
         
            +
             
     | 
| 861 | 
         
            +
                def forward(self, hidden_states, num_frames=1):
         
     | 
| 862 | 
         
            +
                    hidden_states = (
         
     | 
| 863 | 
         
            +
                        hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
         
     | 
| 864 | 
         
            +
                    )
         
     | 
| 865 | 
         
            +
             
     | 
| 866 | 
         
            +
                    identity = hidden_states
         
     | 
| 867 | 
         
            +
                    hidden_states = self.conv1(hidden_states)
         
     | 
| 868 | 
         
            +
                    hidden_states = self.conv2(hidden_states)
         
     | 
| 869 | 
         
            +
                    hidden_states = self.conv3(hidden_states)
         
     | 
| 870 | 
         
            +
                    hidden_states = self.conv4(hidden_states)
         
     | 
| 871 | 
         
            +
             
     | 
| 872 | 
         
            +
                    hidden_states = identity + hidden_states
         
     | 
| 873 | 
         
            +
             
     | 
| 874 | 
         
            +
                    hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
         
     | 
| 875 | 
         
            +
                        (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
         
     | 
| 876 | 
         
            +
                    )
         
     | 
| 877 | 
         
            +
                    return hidden_states
         
     | 
    	
        6DoF/diffusers/models/resnet_flax.py
    ADDED
    
    | 
         @@ -0,0 +1,124 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            import flax.linen as nn
         
     | 
| 15 | 
         
            +
            import jax
         
     | 
| 16 | 
         
            +
            import jax.numpy as jnp
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            class FlaxUpsample2D(nn.Module):
         
     | 
| 20 | 
         
            +
                out_channels: int
         
     | 
| 21 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                def setup(self):
         
     | 
| 24 | 
         
            +
                    self.conv = nn.Conv(
         
     | 
| 25 | 
         
            +
                        self.out_channels,
         
     | 
| 26 | 
         
            +
                        kernel_size=(3, 3),
         
     | 
| 27 | 
         
            +
                        strides=(1, 1),
         
     | 
| 28 | 
         
            +
                        padding=((1, 1), (1, 1)),
         
     | 
| 29 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 30 | 
         
            +
                    )
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                def __call__(self, hidden_states):
         
     | 
| 33 | 
         
            +
                    batch, height, width, channels = hidden_states.shape
         
     | 
| 34 | 
         
            +
                    hidden_states = jax.image.resize(
         
     | 
| 35 | 
         
            +
                        hidden_states,
         
     | 
| 36 | 
         
            +
                        shape=(batch, height * 2, width * 2, channels),
         
     | 
| 37 | 
         
            +
                        method="nearest",
         
     | 
| 38 | 
         
            +
                    )
         
     | 
| 39 | 
         
            +
                    hidden_states = self.conv(hidden_states)
         
     | 
| 40 | 
         
            +
                    return hidden_states
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            class FlaxDownsample2D(nn.Module):
         
     | 
| 44 | 
         
            +
                out_channels: int
         
     | 
| 45 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                def setup(self):
         
     | 
| 48 | 
         
            +
                    self.conv = nn.Conv(
         
     | 
| 49 | 
         
            +
                        self.out_channels,
         
     | 
| 50 | 
         
            +
                        kernel_size=(3, 3),
         
     | 
| 51 | 
         
            +
                        strides=(2, 2),
         
     | 
| 52 | 
         
            +
                        padding=((1, 1), (1, 1)),  # padding="VALID",
         
     | 
| 53 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 54 | 
         
            +
                    )
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                def __call__(self, hidden_states):
         
     | 
| 57 | 
         
            +
                    # pad = ((0, 0), (0, 1), (0, 1), (0, 0))  # pad height and width dim
         
     | 
| 58 | 
         
            +
                    # hidden_states = jnp.pad(hidden_states, pad_width=pad)
         
     | 
| 59 | 
         
            +
                    hidden_states = self.conv(hidden_states)
         
     | 
| 60 | 
         
            +
                    return hidden_states
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            class FlaxResnetBlock2D(nn.Module):
         
     | 
| 64 | 
         
            +
                in_channels: int
         
     | 
| 65 | 
         
            +
                out_channels: int = None
         
     | 
| 66 | 
         
            +
                dropout_prob: float = 0.0
         
     | 
| 67 | 
         
            +
                use_nin_shortcut: bool = None
         
     | 
| 68 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                def setup(self):
         
     | 
| 71 | 
         
            +
                    out_channels = self.in_channels if self.out_channels is None else self.out_channels
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                    self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
         
     | 
| 74 | 
         
            +
                    self.conv1 = nn.Conv(
         
     | 
| 75 | 
         
            +
                        out_channels,
         
     | 
| 76 | 
         
            +
                        kernel_size=(3, 3),
         
     | 
| 77 | 
         
            +
                        strides=(1, 1),
         
     | 
| 78 | 
         
            +
                        padding=((1, 1), (1, 1)),
         
     | 
| 79 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 80 | 
         
            +
                    )
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                    self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype)
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                    self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
         
     | 
| 85 | 
         
            +
                    self.dropout = nn.Dropout(self.dropout_prob)
         
     | 
| 86 | 
         
            +
                    self.conv2 = nn.Conv(
         
     | 
| 87 | 
         
            +
                        out_channels,
         
     | 
| 88 | 
         
            +
                        kernel_size=(3, 3),
         
     | 
| 89 | 
         
            +
                        strides=(1, 1),
         
     | 
| 90 | 
         
            +
                        padding=((1, 1), (1, 1)),
         
     | 
| 91 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 92 | 
         
            +
                    )
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                    self.conv_shortcut = None
         
     | 
| 97 | 
         
            +
                    if use_nin_shortcut:
         
     | 
| 98 | 
         
            +
                        self.conv_shortcut = nn.Conv(
         
     | 
| 99 | 
         
            +
                            out_channels,
         
     | 
| 100 | 
         
            +
                            kernel_size=(1, 1),
         
     | 
| 101 | 
         
            +
                            strides=(1, 1),
         
     | 
| 102 | 
         
            +
                            padding="VALID",
         
     | 
| 103 | 
         
            +
                            dtype=self.dtype,
         
     | 
| 104 | 
         
            +
                        )
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                def __call__(self, hidden_states, temb, deterministic=True):
         
     | 
| 107 | 
         
            +
                    residual = hidden_states
         
     | 
| 108 | 
         
            +
                    hidden_states = self.norm1(hidden_states)
         
     | 
| 109 | 
         
            +
                    hidden_states = nn.swish(hidden_states)
         
     | 
| 110 | 
         
            +
                    hidden_states = self.conv1(hidden_states)
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    temb = self.time_emb_proj(nn.swish(temb))
         
     | 
| 113 | 
         
            +
                    temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
         
     | 
| 114 | 
         
            +
                    hidden_states = hidden_states + temb
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                    hidden_states = self.norm2(hidden_states)
         
     | 
| 117 | 
         
            +
                    hidden_states = nn.swish(hidden_states)
         
     | 
| 118 | 
         
            +
                    hidden_states = self.dropout(hidden_states, deterministic)
         
     | 
| 119 | 
         
            +
                    hidden_states = self.conv2(hidden_states)
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                    if self.conv_shortcut is not None:
         
     | 
| 122 | 
         
            +
                        residual = self.conv_shortcut(residual)
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                    return hidden_states + residual
         
     | 
    	
        6DoF/diffusers/models/t5_film_transformer.py
    ADDED
    
    | 
         @@ -0,0 +1,321 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            import math
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            import torch
         
     | 
| 17 | 
         
            +
            from torch import nn
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            from ..configuration_utils import ConfigMixin, register_to_config
         
     | 
| 20 | 
         
            +
            from .attention_processor import Attention
         
     | 
| 21 | 
         
            +
            from .embeddings import get_timestep_embedding
         
     | 
| 22 | 
         
            +
            from .modeling_utils import ModelMixin
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            class T5FilmDecoder(ModelMixin, ConfigMixin):
         
     | 
| 26 | 
         
            +
                @register_to_config
         
     | 
| 27 | 
         
            +
                def __init__(
         
     | 
| 28 | 
         
            +
                    self,
         
     | 
| 29 | 
         
            +
                    input_dims: int = 128,
         
     | 
| 30 | 
         
            +
                    targets_length: int = 256,
         
     | 
| 31 | 
         
            +
                    max_decoder_noise_time: float = 2000.0,
         
     | 
| 32 | 
         
            +
                    d_model: int = 768,
         
     | 
| 33 | 
         
            +
                    num_layers: int = 12,
         
     | 
| 34 | 
         
            +
                    num_heads: int = 12,
         
     | 
| 35 | 
         
            +
                    d_kv: int = 64,
         
     | 
| 36 | 
         
            +
                    d_ff: int = 2048,
         
     | 
| 37 | 
         
            +
                    dropout_rate: float = 0.1,
         
     | 
| 38 | 
         
            +
                ):
         
     | 
| 39 | 
         
            +
                    super().__init__()
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                    self.conditioning_emb = nn.Sequential(
         
     | 
| 42 | 
         
            +
                        nn.Linear(d_model, d_model * 4, bias=False),
         
     | 
| 43 | 
         
            +
                        nn.SiLU(),
         
     | 
| 44 | 
         
            +
                        nn.Linear(d_model * 4, d_model * 4, bias=False),
         
     | 
| 45 | 
         
            +
                        nn.SiLU(),
         
     | 
| 46 | 
         
            +
                    )
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                    self.position_encoding = nn.Embedding(targets_length, d_model)
         
     | 
| 49 | 
         
            +
                    self.position_encoding.weight.requires_grad = False
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                    self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False)
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                    self.dropout = nn.Dropout(p=dropout_rate)
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                    self.decoders = nn.ModuleList()
         
     | 
| 56 | 
         
            +
                    for lyr_num in range(num_layers):
         
     | 
| 57 | 
         
            +
                        # FiLM conditional T5 decoder
         
     | 
| 58 | 
         
            +
                        lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate)
         
     | 
| 59 | 
         
            +
                        self.decoders.append(lyr)
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    self.decoder_norm = T5LayerNorm(d_model)
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                    self.post_dropout = nn.Dropout(p=dropout_rate)
         
     | 
| 64 | 
         
            +
                    self.spec_out = nn.Linear(d_model, input_dims, bias=False)
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                def encoder_decoder_mask(self, query_input, key_input):
         
     | 
| 67 | 
         
            +
                    mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
         
     | 
| 68 | 
         
            +
                    return mask.unsqueeze(-3)
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time):
         
     | 
| 71 | 
         
            +
                    batch, _, _ = decoder_input_tokens.shape
         
     | 
| 72 | 
         
            +
                    assert decoder_noise_time.shape == (batch,)
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    # decoder_noise_time is in [0, 1), so rescale to expected timing range.
         
     | 
| 75 | 
         
            +
                    time_steps = get_timestep_embedding(
         
     | 
| 76 | 
         
            +
                        decoder_noise_time * self.config.max_decoder_noise_time,
         
     | 
| 77 | 
         
            +
                        embedding_dim=self.config.d_model,
         
     | 
| 78 | 
         
            +
                        max_period=self.config.max_decoder_noise_time,
         
     | 
| 79 | 
         
            +
                    ).to(dtype=self.dtype)
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1)
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                    assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4)
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                    seq_length = decoder_input_tokens.shape[1]
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                    # If we want to use relative positions for audio context, we can just offset
         
     | 
| 88 | 
         
            +
                    # this sequence by the length of encodings_and_masks.
         
     | 
| 89 | 
         
            +
                    decoder_positions = torch.broadcast_to(
         
     | 
| 90 | 
         
            +
                        torch.arange(seq_length, device=decoder_input_tokens.device),
         
     | 
| 91 | 
         
            +
                        (batch, seq_length),
         
     | 
| 92 | 
         
            +
                    )
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    position_encodings = self.position_encoding(decoder_positions)
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                    inputs = self.continuous_inputs_projection(decoder_input_tokens)
         
     | 
| 97 | 
         
            +
                    inputs += position_encodings
         
     | 
| 98 | 
         
            +
                    y = self.dropout(inputs)
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                    # decoder: No padding present.
         
     | 
| 101 | 
         
            +
                    decoder_mask = torch.ones(
         
     | 
| 102 | 
         
            +
                        decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype
         
     | 
| 103 | 
         
            +
                    )
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                    # Translate encoding masks to encoder-decoder masks.
         
     | 
| 106 | 
         
            +
                    encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks]
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                    # cross attend style: concat encodings
         
     | 
| 109 | 
         
            +
                    encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1)
         
     | 
| 110 | 
         
            +
                    encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1)
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    for lyr in self.decoders:
         
     | 
| 113 | 
         
            +
                        y = lyr(
         
     | 
| 114 | 
         
            +
                            y,
         
     | 
| 115 | 
         
            +
                            conditioning_emb=conditioning_emb,
         
     | 
| 116 | 
         
            +
                            encoder_hidden_states=encoded,
         
     | 
| 117 | 
         
            +
                            encoder_attention_mask=encoder_decoder_mask,
         
     | 
| 118 | 
         
            +
                        )[0]
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                    y = self.decoder_norm(y)
         
     | 
| 121 | 
         
            +
                    y = self.post_dropout(y)
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                    spec_out = self.spec_out(y)
         
     | 
| 124 | 
         
            +
                    return spec_out
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
            class DecoderLayer(nn.Module):
         
     | 
| 128 | 
         
            +
                def __init__(self, d_model, d_kv, num_heads, d_ff, dropout_rate, layer_norm_epsilon=1e-6):
         
     | 
| 129 | 
         
            +
                    super().__init__()
         
     | 
| 130 | 
         
            +
                    self.layer = nn.ModuleList()
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    # cond self attention: layer 0
         
     | 
| 133 | 
         
            +
                    self.layer.append(
         
     | 
| 134 | 
         
            +
                        T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate)
         
     | 
| 135 | 
         
            +
                    )
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                    # cross attention: layer 1
         
     | 
| 138 | 
         
            +
                    self.layer.append(
         
     | 
| 139 | 
         
            +
                        T5LayerCrossAttention(
         
     | 
| 140 | 
         
            +
                            d_model=d_model,
         
     | 
| 141 | 
         
            +
                            d_kv=d_kv,
         
     | 
| 142 | 
         
            +
                            num_heads=num_heads,
         
     | 
| 143 | 
         
            +
                            dropout_rate=dropout_rate,
         
     | 
| 144 | 
         
            +
                            layer_norm_epsilon=layer_norm_epsilon,
         
     | 
| 145 | 
         
            +
                        )
         
     | 
| 146 | 
         
            +
                    )
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                    # Film Cond MLP + dropout: last layer
         
     | 
| 149 | 
         
            +
                    self.layer.append(
         
     | 
| 150 | 
         
            +
                        T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon)
         
     | 
| 151 | 
         
            +
                    )
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                def forward(
         
     | 
| 154 | 
         
            +
                    self,
         
     | 
| 155 | 
         
            +
                    hidden_states,
         
     | 
| 156 | 
         
            +
                    conditioning_emb=None,
         
     | 
| 157 | 
         
            +
                    attention_mask=None,
         
     | 
| 158 | 
         
            +
                    encoder_hidden_states=None,
         
     | 
| 159 | 
         
            +
                    encoder_attention_mask=None,
         
     | 
| 160 | 
         
            +
                    encoder_decoder_position_bias=None,
         
     | 
| 161 | 
         
            +
                ):
         
     | 
| 162 | 
         
            +
                    hidden_states = self.layer[0](
         
     | 
| 163 | 
         
            +
                        hidden_states,
         
     | 
| 164 | 
         
            +
                        conditioning_emb=conditioning_emb,
         
     | 
| 165 | 
         
            +
                        attention_mask=attention_mask,
         
     | 
| 166 | 
         
            +
                    )
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                    if encoder_hidden_states is not None:
         
     | 
| 169 | 
         
            +
                        encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to(
         
     | 
| 170 | 
         
            +
                            encoder_hidden_states.dtype
         
     | 
| 171 | 
         
            +
                        )
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                        hidden_states = self.layer[1](
         
     | 
| 174 | 
         
            +
                            hidden_states,
         
     | 
| 175 | 
         
            +
                            key_value_states=encoder_hidden_states,
         
     | 
| 176 | 
         
            +
                            attention_mask=encoder_extended_attention_mask,
         
     | 
| 177 | 
         
            +
                        )
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                    # Apply Film Conditional Feed Forward layer
         
     | 
| 180 | 
         
            +
                    hidden_states = self.layer[-1](hidden_states, conditioning_emb)
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                    return (hidden_states,)
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
            class T5LayerSelfAttentionCond(nn.Module):
         
     | 
| 186 | 
         
            +
                def __init__(self, d_model, d_kv, num_heads, dropout_rate):
         
     | 
| 187 | 
         
            +
                    super().__init__()
         
     | 
| 188 | 
         
            +
                    self.layer_norm = T5LayerNorm(d_model)
         
     | 
| 189 | 
         
            +
                    self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
         
     | 
| 190 | 
         
            +
                    self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
         
     | 
| 191 | 
         
            +
                    self.dropout = nn.Dropout(dropout_rate)
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                def forward(
         
     | 
| 194 | 
         
            +
                    self,
         
     | 
| 195 | 
         
            +
                    hidden_states,
         
     | 
| 196 | 
         
            +
                    conditioning_emb=None,
         
     | 
| 197 | 
         
            +
                    attention_mask=None,
         
     | 
| 198 | 
         
            +
                ):
         
     | 
| 199 | 
         
            +
                    # pre_self_attention_layer_norm
         
     | 
| 200 | 
         
            +
                    normed_hidden_states = self.layer_norm(hidden_states)
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                    if conditioning_emb is not None:
         
     | 
| 203 | 
         
            +
                        normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb)
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                    # Self-attention block
         
     | 
| 206 | 
         
            +
                    attention_output = self.attention(normed_hidden_states)
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                    hidden_states = hidden_states + self.dropout(attention_output)
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
                    return hidden_states
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
            class T5LayerCrossAttention(nn.Module):
         
     | 
| 214 | 
         
            +
                def __init__(self, d_model, d_kv, num_heads, dropout_rate, layer_norm_epsilon):
         
     | 
| 215 | 
         
            +
                    super().__init__()
         
     | 
| 216 | 
         
            +
                    self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
         
     | 
| 217 | 
         
            +
                    self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
         
     | 
| 218 | 
         
            +
                    self.dropout = nn.Dropout(dropout_rate)
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
                def forward(
         
     | 
| 221 | 
         
            +
                    self,
         
     | 
| 222 | 
         
            +
                    hidden_states,
         
     | 
| 223 | 
         
            +
                    key_value_states=None,
         
     | 
| 224 | 
         
            +
                    attention_mask=None,
         
     | 
| 225 | 
         
            +
                ):
         
     | 
| 226 | 
         
            +
                    normed_hidden_states = self.layer_norm(hidden_states)
         
     | 
| 227 | 
         
            +
                    attention_output = self.attention(
         
     | 
| 228 | 
         
            +
                        normed_hidden_states,
         
     | 
| 229 | 
         
            +
                        encoder_hidden_states=key_value_states,
         
     | 
| 230 | 
         
            +
                        attention_mask=attention_mask.squeeze(1),
         
     | 
| 231 | 
         
            +
                    )
         
     | 
| 232 | 
         
            +
                    layer_output = hidden_states + self.dropout(attention_output)
         
     | 
| 233 | 
         
            +
                    return layer_output
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
            class T5LayerFFCond(nn.Module):
         
     | 
| 237 | 
         
            +
                def __init__(self, d_model, d_ff, dropout_rate, layer_norm_epsilon):
         
     | 
| 238 | 
         
            +
                    super().__init__()
         
     | 
| 239 | 
         
            +
                    self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate)
         
     | 
| 240 | 
         
            +
                    self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
         
     | 
| 241 | 
         
            +
                    self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
         
     | 
| 242 | 
         
            +
                    self.dropout = nn.Dropout(dropout_rate)
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                def forward(self, hidden_states, conditioning_emb=None):
         
     | 
| 245 | 
         
            +
                    forwarded_states = self.layer_norm(hidden_states)
         
     | 
| 246 | 
         
            +
                    if conditioning_emb is not None:
         
     | 
| 247 | 
         
            +
                        forwarded_states = self.film(forwarded_states, conditioning_emb)
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
                    forwarded_states = self.DenseReluDense(forwarded_states)
         
     | 
| 250 | 
         
            +
                    hidden_states = hidden_states + self.dropout(forwarded_states)
         
     | 
| 251 | 
         
            +
                    return hidden_states
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
             
     | 
| 254 | 
         
            +
            class T5DenseGatedActDense(nn.Module):
         
     | 
| 255 | 
         
            +
                def __init__(self, d_model, d_ff, dropout_rate):
         
     | 
| 256 | 
         
            +
                    super().__init__()
         
     | 
| 257 | 
         
            +
                    self.wi_0 = nn.Linear(d_model, d_ff, bias=False)
         
     | 
| 258 | 
         
            +
                    self.wi_1 = nn.Linear(d_model, d_ff, bias=False)
         
     | 
| 259 | 
         
            +
                    self.wo = nn.Linear(d_ff, d_model, bias=False)
         
     | 
| 260 | 
         
            +
                    self.dropout = nn.Dropout(dropout_rate)
         
     | 
| 261 | 
         
            +
                    self.act = NewGELUActivation()
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                def forward(self, hidden_states):
         
     | 
| 264 | 
         
            +
                    hidden_gelu = self.act(self.wi_0(hidden_states))
         
     | 
| 265 | 
         
            +
                    hidden_linear = self.wi_1(hidden_states)
         
     | 
| 266 | 
         
            +
                    hidden_states = hidden_gelu * hidden_linear
         
     | 
| 267 | 
         
            +
                    hidden_states = self.dropout(hidden_states)
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                    hidden_states = self.wo(hidden_states)
         
     | 
| 270 | 
         
            +
                    return hidden_states
         
     | 
| 271 | 
         
            +
             
     | 
| 272 | 
         
            +
             
     | 
| 273 | 
         
            +
            class T5LayerNorm(nn.Module):
         
     | 
| 274 | 
         
            +
                def __init__(self, hidden_size, eps=1e-6):
         
     | 
| 275 | 
         
            +
                    """
         
     | 
| 276 | 
         
            +
                    Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
         
     | 
| 277 | 
         
            +
                    """
         
     | 
| 278 | 
         
            +
                    super().__init__()
         
     | 
| 279 | 
         
            +
                    self.weight = nn.Parameter(torch.ones(hidden_size))
         
     | 
| 280 | 
         
            +
                    self.variance_epsilon = eps
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
                def forward(self, hidden_states):
         
     | 
| 283 | 
         
            +
                    # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
         
     | 
| 284 | 
         
            +
                    # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
         
     | 
| 285 | 
         
            +
                    # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
         
     | 
| 286 | 
         
            +
                    # half-precision inputs is done in fp32
         
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
                    variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
         
     | 
| 289 | 
         
            +
                    hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                    # convert into half-precision if necessary
         
     | 
| 292 | 
         
            +
                    if self.weight.dtype in [torch.float16, torch.bfloat16]:
         
     | 
| 293 | 
         
            +
                        hidden_states = hidden_states.to(self.weight.dtype)
         
     | 
| 294 | 
         
            +
             
     | 
| 295 | 
         
            +
                    return self.weight * hidden_states
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
            class NewGELUActivation(nn.Module):
         
     | 
| 299 | 
         
            +
                """
         
     | 
| 300 | 
         
            +
                Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
         
     | 
| 301 | 
         
            +
                the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
         
     | 
| 302 | 
         
            +
                """
         
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
            +
                def forward(self, input: torch.Tensor) -> torch.Tensor:
         
     | 
| 305 | 
         
            +
                    return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
            class T5FiLMLayer(nn.Module):
         
     | 
| 309 | 
         
            +
                """
         
     | 
| 310 | 
         
            +
                FiLM Layer
         
     | 
| 311 | 
         
            +
                """
         
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
                def __init__(self, in_features, out_features):
         
     | 
| 314 | 
         
            +
                    super().__init__()
         
     | 
| 315 | 
         
            +
                    self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
         
     | 
| 316 | 
         
            +
             
     | 
| 317 | 
         
            +
                def forward(self, x, conditioning_emb):
         
     | 
| 318 | 
         
            +
                    emb = self.scale_bias(conditioning_emb)
         
     | 
| 319 | 
         
            +
                    scale, shift = torch.chunk(emb, 2, -1)
         
     | 
| 320 | 
         
            +
                    x = x * (1 + scale) + shift
         
     | 
| 321 | 
         
            +
                    return x
         
     | 
    	
        6DoF/diffusers/models/transformer_2d.py
    ADDED
    
    | 
         @@ -0,0 +1,343 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 15 | 
         
            +
            from typing import Any, Dict, Optional
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            import torch
         
     | 
| 18 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 19 | 
         
            +
            from torch import nn
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            from ..configuration_utils import ConfigMixin, register_to_config
         
     | 
| 22 | 
         
            +
            from ..models.embeddings import ImagePositionalEmbeddings
         
     | 
| 23 | 
         
            +
            from ..utils import BaseOutput, deprecate
         
     | 
| 24 | 
         
            +
            from .attention import BasicTransformerBlock
         
     | 
| 25 | 
         
            +
            from .embeddings import PatchEmbed
         
     | 
| 26 | 
         
            +
            from .modeling_utils import ModelMixin
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            @dataclass
         
     | 
| 30 | 
         
            +
            class Transformer2DModelOutput(BaseOutput):
         
     | 
| 31 | 
         
            +
                """
         
     | 
| 32 | 
         
            +
                The output of [`Transformer2DModel`].
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                Args:
         
     | 
| 35 | 
         
            +
                    sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
         
     | 
| 36 | 
         
            +
                        The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
         
     | 
| 37 | 
         
            +
                        distributions for the unnoised latent pixels.
         
     | 
| 38 | 
         
            +
                """
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                sample: torch.FloatTensor
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            class Transformer2DModel(ModelMixin, ConfigMixin):
         
     | 
| 44 | 
         
            +
                """
         
     | 
| 45 | 
         
            +
                A 2D Transformer model for image-like data.
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                Parameters:
         
     | 
| 48 | 
         
            +
                    num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
         
     | 
| 49 | 
         
            +
                    attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
         
     | 
| 50 | 
         
            +
                    in_channels (`int`, *optional*):
         
     | 
| 51 | 
         
            +
                        The number of channels in the input and output (specify if the input is **continuous**).
         
     | 
| 52 | 
         
            +
                    num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
         
     | 
| 53 | 
         
            +
                    dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
         
     | 
| 54 | 
         
            +
                    cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
         
     | 
| 55 | 
         
            +
                    sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
         
     | 
| 56 | 
         
            +
                        This is fixed during training since it is used to learn a number of position embeddings.
         
     | 
| 57 | 
         
            +
                    num_vector_embeds (`int`, *optional*):
         
     | 
| 58 | 
         
            +
                        The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
         
     | 
| 59 | 
         
            +
                        Includes the class for the masked latent pixel.
         
     | 
| 60 | 
         
            +
                    activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
         
     | 
| 61 | 
         
            +
                    num_embeds_ada_norm ( `int`, *optional*):
         
     | 
| 62 | 
         
            +
                        The number of diffusion steps used during training. Pass if at least one of the norm_layers is
         
     | 
| 63 | 
         
            +
                        `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
         
     | 
| 64 | 
         
            +
                        added to the hidden states.
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                        During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
         
     | 
| 67 | 
         
            +
                    attention_bias (`bool`, *optional*):
         
     | 
| 68 | 
         
            +
                        Configure if the `TransformerBlocks` attention should contain a bias parameter.
         
     | 
| 69 | 
         
            +
                """
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                @register_to_config
         
     | 
| 72 | 
         
            +
                def __init__(
         
     | 
| 73 | 
         
            +
                    self,
         
     | 
| 74 | 
         
            +
                    num_attention_heads: int = 16,
         
     | 
| 75 | 
         
            +
                    attention_head_dim: int = 88,
         
     | 
| 76 | 
         
            +
                    in_channels: Optional[int] = None,
         
     | 
| 77 | 
         
            +
                    out_channels: Optional[int] = None,
         
     | 
| 78 | 
         
            +
                    num_layers: int = 1,
         
     | 
| 79 | 
         
            +
                    dropout: float = 0.0,
         
     | 
| 80 | 
         
            +
                    norm_num_groups: int = 32,
         
     | 
| 81 | 
         
            +
                    cross_attention_dim: Optional[int] = None,
         
     | 
| 82 | 
         
            +
                    attention_bias: bool = False,
         
     | 
| 83 | 
         
            +
                    sample_size: Optional[int] = None,
         
     | 
| 84 | 
         
            +
                    num_vector_embeds: Optional[int] = None,
         
     | 
| 85 | 
         
            +
                    patch_size: Optional[int] = None,
         
     | 
| 86 | 
         
            +
                    activation_fn: str = "geglu",
         
     | 
| 87 | 
         
            +
                    num_embeds_ada_norm: Optional[int] = None,
         
     | 
| 88 | 
         
            +
                    use_linear_projection: bool = False,
         
     | 
| 89 | 
         
            +
                    only_cross_attention: bool = False,
         
     | 
| 90 | 
         
            +
                    upcast_attention: bool = False,
         
     | 
| 91 | 
         
            +
                    norm_type: str = "layer_norm",
         
     | 
| 92 | 
         
            +
                    norm_elementwise_affine: bool = True,
         
     | 
| 93 | 
         
            +
                ):
         
     | 
| 94 | 
         
            +
                    super().__init__()
         
     | 
| 95 | 
         
            +
                    self.use_linear_projection = use_linear_projection
         
     | 
| 96 | 
         
            +
                    self.num_attention_heads = num_attention_heads
         
     | 
| 97 | 
         
            +
                    self.attention_head_dim = attention_head_dim
         
     | 
| 98 | 
         
            +
                    inner_dim = num_attention_heads * attention_head_dim
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                    # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
         
     | 
| 101 | 
         
            +
                    # Define whether input is continuous or discrete depending on configuration
         
     | 
| 102 | 
         
            +
                    self.is_input_continuous = (in_channels is not None) and (patch_size is None)
         
     | 
| 103 | 
         
            +
                    self.is_input_vectorized = num_vector_embeds is not None
         
     | 
| 104 | 
         
            +
                    self.is_input_patches = in_channels is not None and patch_size is not None
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                    if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
         
     | 
| 107 | 
         
            +
                        deprecation_message = (
         
     | 
| 108 | 
         
            +
                            f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
         
     | 
| 109 | 
         
            +
                            " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
         
     | 
| 110 | 
         
            +
                            " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
         
     | 
| 111 | 
         
            +
                            " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
         
     | 
| 112 | 
         
            +
                            " would be very nice if you could open a Pull request for the `transformer/config.json` file"
         
     | 
| 113 | 
         
            +
                        )
         
     | 
| 114 | 
         
            +
                        deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
         
     | 
| 115 | 
         
            +
                        norm_type = "ada_norm"
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                    if self.is_input_continuous and self.is_input_vectorized:
         
     | 
| 118 | 
         
            +
                        raise ValueError(
         
     | 
| 119 | 
         
            +
                            f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
         
     | 
| 120 | 
         
            +
                            " sure that either `in_channels` or `num_vector_embeds` is None."
         
     | 
| 121 | 
         
            +
                        )
         
     | 
| 122 | 
         
            +
                    elif self.is_input_vectorized and self.is_input_patches:
         
     | 
| 123 | 
         
            +
                        raise ValueError(
         
     | 
| 124 | 
         
            +
                            f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
         
     | 
| 125 | 
         
            +
                            " sure that either `num_vector_embeds` or `num_patches` is None."
         
     | 
| 126 | 
         
            +
                        )
         
     | 
| 127 | 
         
            +
                    elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
         
     | 
| 128 | 
         
            +
                        raise ValueError(
         
     | 
| 129 | 
         
            +
                            f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
         
     | 
| 130 | 
         
            +
                            f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
         
     | 
| 131 | 
         
            +
                        )
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    # 2. Define input layers
         
     | 
| 134 | 
         
            +
                    if self.is_input_continuous:
         
     | 
| 135 | 
         
            +
                        self.in_channels = in_channels
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                        self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
         
     | 
| 138 | 
         
            +
                        if use_linear_projection:
         
     | 
| 139 | 
         
            +
                            self.proj_in = nn.Linear(in_channels, inner_dim)
         
     | 
| 140 | 
         
            +
                        else:
         
     | 
| 141 | 
         
            +
                            self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
         
     | 
| 142 | 
         
            +
                    elif self.is_input_vectorized:
         
     | 
| 143 | 
         
            +
                        assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
         
     | 
| 144 | 
         
            +
                        assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                        self.height = sample_size
         
     | 
| 147 | 
         
            +
                        self.width = sample_size
         
     | 
| 148 | 
         
            +
                        self.num_vector_embeds = num_vector_embeds
         
     | 
| 149 | 
         
            +
                        self.num_latent_pixels = self.height * self.width
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                        self.latent_image_embedding = ImagePositionalEmbeddings(
         
     | 
| 152 | 
         
            +
                            num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
         
     | 
| 153 | 
         
            +
                        )
         
     | 
| 154 | 
         
            +
                    elif self.is_input_patches:
         
     | 
| 155 | 
         
            +
                        assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                        self.height = sample_size
         
     | 
| 158 | 
         
            +
                        self.width = sample_size
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                        self.patch_size = patch_size
         
     | 
| 161 | 
         
            +
                        self.pos_embed = PatchEmbed(
         
     | 
| 162 | 
         
            +
                            height=sample_size,
         
     | 
| 163 | 
         
            +
                            width=sample_size,
         
     | 
| 164 | 
         
            +
                            patch_size=patch_size,
         
     | 
| 165 | 
         
            +
                            in_channels=in_channels,
         
     | 
| 166 | 
         
            +
                            embed_dim=inner_dim,
         
     | 
| 167 | 
         
            +
                        )
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                    # 3. Define transformers blocks
         
     | 
| 170 | 
         
            +
                    self.transformer_blocks = nn.ModuleList(
         
     | 
| 171 | 
         
            +
                        [
         
     | 
| 172 | 
         
            +
                            BasicTransformerBlock(
         
     | 
| 173 | 
         
            +
                                inner_dim,
         
     | 
| 174 | 
         
            +
                                num_attention_heads,
         
     | 
| 175 | 
         
            +
                                attention_head_dim,
         
     | 
| 176 | 
         
            +
                                dropout=dropout,
         
     | 
| 177 | 
         
            +
                                cross_attention_dim=cross_attention_dim,
         
     | 
| 178 | 
         
            +
                                activation_fn=activation_fn,
         
     | 
| 179 | 
         
            +
                                num_embeds_ada_norm=num_embeds_ada_norm,
         
     | 
| 180 | 
         
            +
                                attention_bias=attention_bias,
         
     | 
| 181 | 
         
            +
                                only_cross_attention=only_cross_attention,
         
     | 
| 182 | 
         
            +
                                upcast_attention=upcast_attention,
         
     | 
| 183 | 
         
            +
                                norm_type=norm_type,
         
     | 
| 184 | 
         
            +
                                norm_elementwise_affine=norm_elementwise_affine,
         
     | 
| 185 | 
         
            +
                            )
         
     | 
| 186 | 
         
            +
                            for d in range(num_layers)
         
     | 
| 187 | 
         
            +
                        ]
         
     | 
| 188 | 
         
            +
                    )
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
                    # 4. Define output layers
         
     | 
| 191 | 
         
            +
                    self.out_channels = in_channels if out_channels is None else out_channels
         
     | 
| 192 | 
         
            +
                    if self.is_input_continuous:
         
     | 
| 193 | 
         
            +
                        # TODO: should use out_channels for continuous projections
         
     | 
| 194 | 
         
            +
                        if use_linear_projection:
         
     | 
| 195 | 
         
            +
                            self.proj_out = nn.Linear(inner_dim, in_channels)
         
     | 
| 196 | 
         
            +
                        else:
         
     | 
| 197 | 
         
            +
                            self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
         
     | 
| 198 | 
         
            +
                    elif self.is_input_vectorized:
         
     | 
| 199 | 
         
            +
                        self.norm_out = nn.LayerNorm(inner_dim)
         
     | 
| 200 | 
         
            +
                        self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
         
     | 
| 201 | 
         
            +
                    elif self.is_input_patches:
         
     | 
| 202 | 
         
            +
                        self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
         
     | 
| 203 | 
         
            +
                        self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
         
     | 
| 204 | 
         
            +
                        self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                def forward(
         
     | 
| 207 | 
         
            +
                    self,
         
     | 
| 208 | 
         
            +
                    hidden_states: torch.Tensor,
         
     | 
| 209 | 
         
            +
                    encoder_hidden_states: Optional[torch.Tensor] = None,
         
     | 
| 210 | 
         
            +
                    timestep: Optional[torch.LongTensor] = None,
         
     | 
| 211 | 
         
            +
                    class_labels: Optional[torch.LongTensor] = None,
         
     | 
| 212 | 
         
            +
                    posemb: Optional = None,
         
     | 
| 213 | 
         
            +
                    cross_attention_kwargs: Dict[str, Any] = None,
         
     | 
| 214 | 
         
            +
                    attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 215 | 
         
            +
                    encoder_attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 216 | 
         
            +
                    return_dict: bool = True,
         
     | 
| 217 | 
         
            +
                ):
         
     | 
| 218 | 
         
            +
                    """
         
     | 
| 219 | 
         
            +
                    The [`Transformer2DModel`] forward method.
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                    Args:
         
     | 
| 222 | 
         
            +
                        hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
         
     | 
| 223 | 
         
            +
                            Input `hidden_states`.
         
     | 
| 224 | 
         
            +
                        encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
         
     | 
| 225 | 
         
            +
                            Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
         
     | 
| 226 | 
         
            +
                            self-attention.
         
     | 
| 227 | 
         
            +
                        timestep ( `torch.LongTensor`, *optional*):
         
     | 
| 228 | 
         
            +
                            Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
         
     | 
| 229 | 
         
            +
                        class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
         
     | 
| 230 | 
         
            +
                            Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
         
     | 
| 231 | 
         
            +
                            `AdaLayerZeroNorm`.
         
     | 
| 232 | 
         
            +
                        encoder_attention_mask ( `torch.Tensor`, *optional*):
         
     | 
| 233 | 
         
            +
                            Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                                * Mask `(batch, sequence_length)` True = keep, False = discard.
         
     | 
| 236 | 
         
            +
                                * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                            If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
         
     | 
| 239 | 
         
            +
                            above. This bias will be added to the cross-attention scores.
         
     | 
| 240 | 
         
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         
     | 
| 241 | 
         
            +
                            Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
         
     | 
| 242 | 
         
            +
                            tuple.
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                    Returns:
         
     | 
| 245 | 
         
            +
                        If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
         
     | 
| 246 | 
         
            +
                        `tuple` where the first element is the sample tensor.
         
     | 
| 247 | 
         
            +
                    """
         
     | 
| 248 | 
         
            +
                    # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
         
     | 
| 249 | 
         
            +
                    #   we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
         
     | 
| 250 | 
         
            +
                    #   we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
         
     | 
| 251 | 
         
            +
                    # expects mask of shape:
         
     | 
| 252 | 
         
            +
                    #   [batch, key_tokens]
         
     | 
| 253 | 
         
            +
                    # adds singleton query_tokens dimension:
         
     | 
| 254 | 
         
            +
                    #   [batch,                    1, key_tokens]
         
     | 
| 255 | 
         
            +
                    # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
         
     | 
| 256 | 
         
            +
                    #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)
         
     | 
| 257 | 
         
            +
                    #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
         
     | 
| 258 | 
         
            +
                    if attention_mask is not None and attention_mask.ndim == 2:
         
     | 
| 259 | 
         
            +
                        # assume that mask is expressed as:
         
     | 
| 260 | 
         
            +
                        #   (1 = keep,      0 = discard)
         
     | 
| 261 | 
         
            +
                        # convert mask into a bias that can be added to attention scores:
         
     | 
| 262 | 
         
            +
                        #       (keep = +0,     discard = -10000.0)
         
     | 
| 263 | 
         
            +
                        attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
         
     | 
| 264 | 
         
            +
                        attention_mask = attention_mask.unsqueeze(1)
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                    # convert encoder_attention_mask to a bias the same way we do for attention_mask
         
     | 
| 267 | 
         
            +
                    if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
         
     | 
| 268 | 
         
            +
                        encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
         
     | 
| 269 | 
         
            +
                        encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
                    # 1. Input
         
     | 
| 272 | 
         
            +
                    if self.is_input_continuous:
         
     | 
| 273 | 
         
            +
                        batch, _, height, width = hidden_states.shape
         
     | 
| 274 | 
         
            +
                        residual = hidden_states
         
     | 
| 275 | 
         
            +
             
     | 
| 276 | 
         
            +
                        hidden_states = self.norm(hidden_states)
         
     | 
| 277 | 
         
            +
                        if not self.use_linear_projection:
         
     | 
| 278 | 
         
            +
                            hidden_states = self.proj_in(hidden_states)
         
     | 
| 279 | 
         
            +
                            inner_dim = hidden_states.shape[1]
         
     | 
| 280 | 
         
            +
                            hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
         
     | 
| 281 | 
         
            +
                        else:
         
     | 
| 282 | 
         
            +
                            inner_dim = hidden_states.shape[1]
         
     | 
| 283 | 
         
            +
                            hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
         
     | 
| 284 | 
         
            +
                            hidden_states = self.proj_in(hidden_states)
         
     | 
| 285 | 
         
            +
                    elif self.is_input_vectorized:
         
     | 
| 286 | 
         
            +
                        hidden_states = self.latent_image_embedding(hidden_states)
         
     | 
| 287 | 
         
            +
                    elif self.is_input_patches:
         
     | 
| 288 | 
         
            +
                        hidden_states = self.pos_embed(hidden_states)
         
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
                    # 2. Blocks
         
     | 
| 291 | 
         
            +
                    for block in self.transformer_blocks:
         
     | 
| 292 | 
         
            +
                        hidden_states = block(
         
     | 
| 293 | 
         
            +
                            hidden_states,
         
     | 
| 294 | 
         
            +
                            attention_mask=attention_mask,
         
     | 
| 295 | 
         
            +
                            encoder_hidden_states=encoder_hidden_states,
         
     | 
| 296 | 
         
            +
                            encoder_attention_mask=encoder_attention_mask,
         
     | 
| 297 | 
         
            +
                            timestep=timestep,
         
     | 
| 298 | 
         
            +
                            posemb=posemb,
         
     | 
| 299 | 
         
            +
                            cross_attention_kwargs=cross_attention_kwargs,
         
     | 
| 300 | 
         
            +
                            class_labels=class_labels,
         
     | 
| 301 | 
         
            +
                        )
         
     | 
| 302 | 
         
            +
             
     | 
| 303 | 
         
            +
                    # 3. Output
         
     | 
| 304 | 
         
            +
                    if self.is_input_continuous:
         
     | 
| 305 | 
         
            +
                        if not self.use_linear_projection:
         
     | 
| 306 | 
         
            +
                            hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
         
     | 
| 307 | 
         
            +
                            hidden_states = self.proj_out(hidden_states)
         
     | 
| 308 | 
         
            +
                        else:
         
     | 
| 309 | 
         
            +
                            hidden_states = self.proj_out(hidden_states)
         
     | 
| 310 | 
         
            +
                            hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
         
     | 
| 311 | 
         
            +
             
     | 
| 312 | 
         
            +
                        output = hidden_states + residual
         
     | 
| 313 | 
         
            +
                    elif self.is_input_vectorized:
         
     | 
| 314 | 
         
            +
                        hidden_states = self.norm_out(hidden_states)
         
     | 
| 315 | 
         
            +
                        logits = self.out(hidden_states)
         
     | 
| 316 | 
         
            +
                        # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
         
     | 
| 317 | 
         
            +
                        logits = logits.permute(0, 2, 1)
         
     | 
| 318 | 
         
            +
             
     | 
| 319 | 
         
            +
                        # log(p(x_0))
         
     | 
| 320 | 
         
            +
                        output = F.log_softmax(logits.double(), dim=1).float()
         
     | 
| 321 | 
         
            +
                    elif self.is_input_patches:
         
     | 
| 322 | 
         
            +
                        # TODO: cleanup!
         
     | 
| 323 | 
         
            +
                        conditioning = self.transformer_blocks[0].norm1.emb(
         
     | 
| 324 | 
         
            +
                            timestep, class_labels, hidden_dtype=hidden_states.dtype
         
     | 
| 325 | 
         
            +
                        )
         
     | 
| 326 | 
         
            +
                        shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
         
     | 
| 327 | 
         
            +
                        hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
         
     | 
| 328 | 
         
            +
                        hidden_states = self.proj_out_2(hidden_states)
         
     | 
| 329 | 
         
            +
             
     | 
| 330 | 
         
            +
                        # unpatchify
         
     | 
| 331 | 
         
            +
                        height = width = int(hidden_states.shape[1] ** 0.5)
         
     | 
| 332 | 
         
            +
                        hidden_states = hidden_states.reshape(
         
     | 
| 333 | 
         
            +
                            shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
         
     | 
| 334 | 
         
            +
                        )
         
     | 
| 335 | 
         
            +
                        hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
         
     | 
| 336 | 
         
            +
                        output = hidden_states.reshape(
         
     | 
| 337 | 
         
            +
                            shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
         
     | 
| 338 | 
         
            +
                        )
         
     | 
| 339 | 
         
            +
             
     | 
| 340 | 
         
            +
                    if not return_dict:
         
     | 
| 341 | 
         
            +
                        return (output,)
         
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
                    return Transformer2DModelOutput(sample=output)
         
     | 
    	
        6DoF/diffusers/models/transformer_temporal.py
    ADDED
    
    | 
         @@ -0,0 +1,179 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 15 | 
         
            +
            from typing import Optional
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            import torch
         
     | 
| 18 | 
         
            +
            from torch import nn
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            from ..configuration_utils import ConfigMixin, register_to_config
         
     | 
| 21 | 
         
            +
            from ..utils import BaseOutput
         
     | 
| 22 | 
         
            +
            from .attention import BasicTransformerBlock
         
     | 
| 23 | 
         
            +
            from .modeling_utils import ModelMixin
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            @dataclass
         
     | 
| 27 | 
         
            +
            class TransformerTemporalModelOutput(BaseOutput):
         
     | 
| 28 | 
         
            +
                """
         
     | 
| 29 | 
         
            +
                The output of [`TransformerTemporalModel`].
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                Args:
         
     | 
| 32 | 
         
            +
                    sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
         
     | 
| 33 | 
         
            +
                        The hidden states output conditioned on `encoder_hidden_states` input.
         
     | 
| 34 | 
         
            +
                """
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                sample: torch.FloatTensor
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            class TransformerTemporalModel(ModelMixin, ConfigMixin):
         
     | 
| 40 | 
         
            +
                """
         
     | 
| 41 | 
         
            +
                A Transformer model for video-like data.
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                Parameters:
         
     | 
| 44 | 
         
            +
                    num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
         
     | 
| 45 | 
         
            +
                    attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
         
     | 
| 46 | 
         
            +
                    in_channels (`int`, *optional*):
         
     | 
| 47 | 
         
            +
                        The number of channels in the input and output (specify if the input is **continuous**).
         
     | 
| 48 | 
         
            +
                    num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
         
     | 
| 49 | 
         
            +
                    dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
         
     | 
| 50 | 
         
            +
                    cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
         
     | 
| 51 | 
         
            +
                    sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
         
     | 
| 52 | 
         
            +
                        This is fixed during training since it is used to learn a number of position embeddings.
         
     | 
| 53 | 
         
            +
                    activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
         
     | 
| 54 | 
         
            +
                    attention_bias (`bool`, *optional*):
         
     | 
| 55 | 
         
            +
                        Configure if the `TransformerBlock` attention should contain a bias parameter.
         
     | 
| 56 | 
         
            +
                    double_self_attention (`bool`, *optional*):
         
     | 
| 57 | 
         
            +
                        Configure if each `TransformerBlock` should contain two self-attention layers.
         
     | 
| 58 | 
         
            +
                """
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                @register_to_config
         
     | 
| 61 | 
         
            +
                def __init__(
         
     | 
| 62 | 
         
            +
                    self,
         
     | 
| 63 | 
         
            +
                    num_attention_heads: int = 16,
         
     | 
| 64 | 
         
            +
                    attention_head_dim: int = 88,
         
     | 
| 65 | 
         
            +
                    in_channels: Optional[int] = None,
         
     | 
| 66 | 
         
            +
                    out_channels: Optional[int] = None,
         
     | 
| 67 | 
         
            +
                    num_layers: int = 1,
         
     | 
| 68 | 
         
            +
                    dropout: float = 0.0,
         
     | 
| 69 | 
         
            +
                    norm_num_groups: int = 32,
         
     | 
| 70 | 
         
            +
                    cross_attention_dim: Optional[int] = None,
         
     | 
| 71 | 
         
            +
                    attention_bias: bool = False,
         
     | 
| 72 | 
         
            +
                    sample_size: Optional[int] = None,
         
     | 
| 73 | 
         
            +
                    activation_fn: str = "geglu",
         
     | 
| 74 | 
         
            +
                    norm_elementwise_affine: bool = True,
         
     | 
| 75 | 
         
            +
                    double_self_attention: bool = True,
         
     | 
| 76 | 
         
            +
                ):
         
     | 
| 77 | 
         
            +
                    super().__init__()
         
     | 
| 78 | 
         
            +
                    self.num_attention_heads = num_attention_heads
         
     | 
| 79 | 
         
            +
                    self.attention_head_dim = attention_head_dim
         
     | 
| 80 | 
         
            +
                    inner_dim = num_attention_heads * attention_head_dim
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                    self.in_channels = in_channels
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                    self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
         
     | 
| 85 | 
         
            +
                    self.proj_in = nn.Linear(in_channels, inner_dim)
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                    # 3. Define transformers blocks
         
     | 
| 88 | 
         
            +
                    self.transformer_blocks = nn.ModuleList(
         
     | 
| 89 | 
         
            +
                        [
         
     | 
| 90 | 
         
            +
                            BasicTransformerBlock(
         
     | 
| 91 | 
         
            +
                                inner_dim,
         
     | 
| 92 | 
         
            +
                                num_attention_heads,
         
     | 
| 93 | 
         
            +
                                attention_head_dim,
         
     | 
| 94 | 
         
            +
                                dropout=dropout,
         
     | 
| 95 | 
         
            +
                                cross_attention_dim=cross_attention_dim,
         
     | 
| 96 | 
         
            +
                                activation_fn=activation_fn,
         
     | 
| 97 | 
         
            +
                                attention_bias=attention_bias,
         
     | 
| 98 | 
         
            +
                                double_self_attention=double_self_attention,
         
     | 
| 99 | 
         
            +
                                norm_elementwise_affine=norm_elementwise_affine,
         
     | 
| 100 | 
         
            +
                            )
         
     | 
| 101 | 
         
            +
                            for d in range(num_layers)
         
     | 
| 102 | 
         
            +
                        ]
         
     | 
| 103 | 
         
            +
                    )
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                    self.proj_out = nn.Linear(inner_dim, in_channels)
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                def forward(
         
     | 
| 108 | 
         
            +
                    self,
         
     | 
| 109 | 
         
            +
                    hidden_states,
         
     | 
| 110 | 
         
            +
                    encoder_hidden_states=None,
         
     | 
| 111 | 
         
            +
                    timestep=None,
         
     | 
| 112 | 
         
            +
                    class_labels=None,
         
     | 
| 113 | 
         
            +
                    num_frames=1,
         
     | 
| 114 | 
         
            +
                    cross_attention_kwargs=None,
         
     | 
| 115 | 
         
            +
                    return_dict: bool = True,
         
     | 
| 116 | 
         
            +
                ):
         
     | 
| 117 | 
         
            +
                    """
         
     | 
| 118 | 
         
            +
                    The [`TransformerTemporal`] forward method.
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                    Args:
         
     | 
| 121 | 
         
            +
                        hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
         
     | 
| 122 | 
         
            +
                            Input hidden_states.
         
     | 
| 123 | 
         
            +
                        encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
         
     | 
| 124 | 
         
            +
                            Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
         
     | 
| 125 | 
         
            +
                            self-attention.
         
     | 
| 126 | 
         
            +
                        timestep ( `torch.long`, *optional*):
         
     | 
| 127 | 
         
            +
                            Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
         
     | 
| 128 | 
         
            +
                        class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
         
     | 
| 129 | 
         
            +
                            Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
         
     | 
| 130 | 
         
            +
                            `AdaLayerZeroNorm`.
         
     | 
| 131 | 
         
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         
     | 
| 132 | 
         
            +
                            Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
         
     | 
| 133 | 
         
            +
                            tuple.
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                    Returns:
         
     | 
| 136 | 
         
            +
                        [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
         
     | 
| 137 | 
         
            +
                            If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
         
     | 
| 138 | 
         
            +
                            returned, otherwise a `tuple` where the first element is the sample tensor.
         
     | 
| 139 | 
         
            +
                    """
         
     | 
| 140 | 
         
            +
                    # 1. Input
         
     | 
| 141 | 
         
            +
                    batch_frames, channel, height, width = hidden_states.shape
         
     | 
| 142 | 
         
            +
                    batch_size = batch_frames // num_frames
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                    residual = hidden_states
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                    hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
         
     | 
| 147 | 
         
            +
                    hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                    hidden_states = self.norm(hidden_states)
         
     | 
| 150 | 
         
            +
                    hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                    hidden_states = self.proj_in(hidden_states)
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                    # 2. Blocks
         
     | 
| 155 | 
         
            +
                    for block in self.transformer_blocks:
         
     | 
| 156 | 
         
            +
                        hidden_states = block(
         
     | 
| 157 | 
         
            +
                            hidden_states,
         
     | 
| 158 | 
         
            +
                            encoder_hidden_states=encoder_hidden_states,
         
     | 
| 159 | 
         
            +
                            timestep=timestep,
         
     | 
| 160 | 
         
            +
                            cross_attention_kwargs=cross_attention_kwargs,
         
     | 
| 161 | 
         
            +
                            class_labels=class_labels,
         
     | 
| 162 | 
         
            +
                        )
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                    # 3. Output
         
     | 
| 165 | 
         
            +
                    hidden_states = self.proj_out(hidden_states)
         
     | 
| 166 | 
         
            +
                    hidden_states = (
         
     | 
| 167 | 
         
            +
                        hidden_states[None, None, :]
         
     | 
| 168 | 
         
            +
                        .reshape(batch_size, height, width, channel, num_frames)
         
     | 
| 169 | 
         
            +
                        .permute(0, 3, 4, 1, 2)
         
     | 
| 170 | 
         
            +
                        .contiguous()
         
     | 
| 171 | 
         
            +
                    )
         
     | 
| 172 | 
         
            +
                    hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                    output = hidden_states + residual
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                    if not return_dict:
         
     | 
| 177 | 
         
            +
                        return (output,)
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                    return TransformerTemporalModelOutput(sample=output)
         
     | 
    	
        6DoF/diffusers/models/unet_1d.py
    ADDED
    
    | 
         @@ -0,0 +1,255 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 16 | 
         
            +
            from typing import Optional, Tuple, Union
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            import torch
         
     | 
| 19 | 
         
            +
            import torch.nn as nn
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            from ..configuration_utils import ConfigMixin, register_to_config
         
     | 
| 22 | 
         
            +
            from ..utils import BaseOutput
         
     | 
| 23 | 
         
            +
            from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
         
     | 
| 24 | 
         
            +
            from .modeling_utils import ModelMixin
         
     | 
| 25 | 
         
            +
            from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            @dataclass
         
     | 
| 29 | 
         
            +
            class UNet1DOutput(BaseOutput):
         
     | 
| 30 | 
         
            +
                """
         
     | 
| 31 | 
         
            +
                The output of [`UNet1DModel`].
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                Args:
         
     | 
| 34 | 
         
            +
                    sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
         
     | 
| 35 | 
         
            +
                        The hidden states output from the last layer of the model.
         
     | 
| 36 | 
         
            +
                """
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                sample: torch.FloatTensor
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            class UNet1DModel(ModelMixin, ConfigMixin):
         
     | 
| 42 | 
         
            +
                r"""
         
     | 
| 43 | 
         
            +
                A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
         
     | 
| 46 | 
         
            +
                for all models (such as downloading or saving).
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                Parameters:
         
     | 
| 49 | 
         
            +
                    sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
         
     | 
| 50 | 
         
            +
                    in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
         
     | 
| 51 | 
         
            +
                    out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
         
     | 
| 52 | 
         
            +
                    extra_in_channels (`int`, *optional*, defaults to 0):
         
     | 
| 53 | 
         
            +
                        Number of additional channels to be added to the input of the first down block. Useful for cases where the
         
     | 
| 54 | 
         
            +
                        input data has more channels than what the model was initially designed for.
         
     | 
| 55 | 
         
            +
                    time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
         
     | 
| 56 | 
         
            +
                    freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding.
         
     | 
| 57 | 
         
            +
                    flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
         
     | 
| 58 | 
         
            +
                        Whether to flip sin to cos for Fourier time embedding.
         
     | 
| 59 | 
         
            +
                    down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`):
         
     | 
| 60 | 
         
            +
                        Tuple of downsample block types.
         
     | 
| 61 | 
         
            +
                    up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`):
         
     | 
| 62 | 
         
            +
                        Tuple of upsample block types.
         
     | 
| 63 | 
         
            +
                    block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`):
         
     | 
| 64 | 
         
            +
                        Tuple of block output channels.
         
     | 
| 65 | 
         
            +
                    mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet.
         
     | 
| 66 | 
         
            +
                    out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet.
         
     | 
| 67 | 
         
            +
                    act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks.
         
     | 
| 68 | 
         
            +
                    norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization.
         
     | 
| 69 | 
         
            +
                    layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block.
         
     | 
| 70 | 
         
            +
                    downsample_each_block (`int`, *optional*, defaults to `False`):
         
     | 
| 71 | 
         
            +
                        Experimental feature for using a UNet without upsampling.
         
     | 
| 72 | 
         
            +
                """
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                @register_to_config
         
     | 
| 75 | 
         
            +
                def __init__(
         
     | 
| 76 | 
         
            +
                    self,
         
     | 
| 77 | 
         
            +
                    sample_size: int = 65536,
         
     | 
| 78 | 
         
            +
                    sample_rate: Optional[int] = None,
         
     | 
| 79 | 
         
            +
                    in_channels: int = 2,
         
     | 
| 80 | 
         
            +
                    out_channels: int = 2,
         
     | 
| 81 | 
         
            +
                    extra_in_channels: int = 0,
         
     | 
| 82 | 
         
            +
                    time_embedding_type: str = "fourier",
         
     | 
| 83 | 
         
            +
                    flip_sin_to_cos: bool = True,
         
     | 
| 84 | 
         
            +
                    use_timestep_embedding: bool = False,
         
     | 
| 85 | 
         
            +
                    freq_shift: float = 0.0,
         
     | 
| 86 | 
         
            +
                    down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
         
     | 
| 87 | 
         
            +
                    up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
         
     | 
| 88 | 
         
            +
                    mid_block_type: Tuple[str] = "UNetMidBlock1D",
         
     | 
| 89 | 
         
            +
                    out_block_type: str = None,
         
     | 
| 90 | 
         
            +
                    block_out_channels: Tuple[int] = (32, 32, 64),
         
     | 
| 91 | 
         
            +
                    act_fn: str = None,
         
     | 
| 92 | 
         
            +
                    norm_num_groups: int = 8,
         
     | 
| 93 | 
         
            +
                    layers_per_block: int = 1,
         
     | 
| 94 | 
         
            +
                    downsample_each_block: bool = False,
         
     | 
| 95 | 
         
            +
                ):
         
     | 
| 96 | 
         
            +
                    super().__init__()
         
     | 
| 97 | 
         
            +
                    self.sample_size = sample_size
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                    # time
         
     | 
| 100 | 
         
            +
                    if time_embedding_type == "fourier":
         
     | 
| 101 | 
         
            +
                        self.time_proj = GaussianFourierProjection(
         
     | 
| 102 | 
         
            +
                            embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
         
     | 
| 103 | 
         
            +
                        )
         
     | 
| 104 | 
         
            +
                        timestep_input_dim = 2 * block_out_channels[0]
         
     | 
| 105 | 
         
            +
                    elif time_embedding_type == "positional":
         
     | 
| 106 | 
         
            +
                        self.time_proj = Timesteps(
         
     | 
| 107 | 
         
            +
                            block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
         
     | 
| 108 | 
         
            +
                        )
         
     | 
| 109 | 
         
            +
                        timestep_input_dim = block_out_channels[0]
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                    if use_timestep_embedding:
         
     | 
| 112 | 
         
            +
                        time_embed_dim = block_out_channels[0] * 4
         
     | 
| 113 | 
         
            +
                        self.time_mlp = TimestepEmbedding(
         
     | 
| 114 | 
         
            +
                            in_channels=timestep_input_dim,
         
     | 
| 115 | 
         
            +
                            time_embed_dim=time_embed_dim,
         
     | 
| 116 | 
         
            +
                            act_fn=act_fn,
         
     | 
| 117 | 
         
            +
                            out_dim=block_out_channels[0],
         
     | 
| 118 | 
         
            +
                        )
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                    self.down_blocks = nn.ModuleList([])
         
     | 
| 121 | 
         
            +
                    self.mid_block = None
         
     | 
| 122 | 
         
            +
                    self.up_blocks = nn.ModuleList([])
         
     | 
| 123 | 
         
            +
                    self.out_block = None
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    # down
         
     | 
| 126 | 
         
            +
                    output_channel = in_channels
         
     | 
| 127 | 
         
            +
                    for i, down_block_type in enumerate(down_block_types):
         
     | 
| 128 | 
         
            +
                        input_channel = output_channel
         
     | 
| 129 | 
         
            +
                        output_channel = block_out_channels[i]
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                        if i == 0:
         
     | 
| 132 | 
         
            +
                            input_channel += extra_in_channels
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                        is_final_block = i == len(block_out_channels) - 1
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                        down_block = get_down_block(
         
     | 
| 137 | 
         
            +
                            down_block_type,
         
     | 
| 138 | 
         
            +
                            num_layers=layers_per_block,
         
     | 
| 139 | 
         
            +
                            in_channels=input_channel,
         
     | 
| 140 | 
         
            +
                            out_channels=output_channel,
         
     | 
| 141 | 
         
            +
                            temb_channels=block_out_channels[0],
         
     | 
| 142 | 
         
            +
                            add_downsample=not is_final_block or downsample_each_block,
         
     | 
| 143 | 
         
            +
                        )
         
     | 
| 144 | 
         
            +
                        self.down_blocks.append(down_block)
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                    # mid
         
     | 
| 147 | 
         
            +
                    self.mid_block = get_mid_block(
         
     | 
| 148 | 
         
            +
                        mid_block_type,
         
     | 
| 149 | 
         
            +
                        in_channels=block_out_channels[-1],
         
     | 
| 150 | 
         
            +
                        mid_channels=block_out_channels[-1],
         
     | 
| 151 | 
         
            +
                        out_channels=block_out_channels[-1],
         
     | 
| 152 | 
         
            +
                        embed_dim=block_out_channels[0],
         
     | 
| 153 | 
         
            +
                        num_layers=layers_per_block,
         
     | 
| 154 | 
         
            +
                        add_downsample=downsample_each_block,
         
     | 
| 155 | 
         
            +
                    )
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                    # up
         
     | 
| 158 | 
         
            +
                    reversed_block_out_channels = list(reversed(block_out_channels))
         
     | 
| 159 | 
         
            +
                    output_channel = reversed_block_out_channels[0]
         
     | 
| 160 | 
         
            +
                    if out_block_type is None:
         
     | 
| 161 | 
         
            +
                        final_upsample_channels = out_channels
         
     | 
| 162 | 
         
            +
                    else:
         
     | 
| 163 | 
         
            +
                        final_upsample_channels = block_out_channels[0]
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                    for i, up_block_type in enumerate(up_block_types):
         
     | 
| 166 | 
         
            +
                        prev_output_channel = output_channel
         
     | 
| 167 | 
         
            +
                        output_channel = (
         
     | 
| 168 | 
         
            +
                            reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
         
     | 
| 169 | 
         
            +
                        )
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                        is_final_block = i == len(block_out_channels) - 1
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                        up_block = get_up_block(
         
     | 
| 174 | 
         
            +
                            up_block_type,
         
     | 
| 175 | 
         
            +
                            num_layers=layers_per_block,
         
     | 
| 176 | 
         
            +
                            in_channels=prev_output_channel,
         
     | 
| 177 | 
         
            +
                            out_channels=output_channel,
         
     | 
| 178 | 
         
            +
                            temb_channels=block_out_channels[0],
         
     | 
| 179 | 
         
            +
                            add_upsample=not is_final_block,
         
     | 
| 180 | 
         
            +
                        )
         
     | 
| 181 | 
         
            +
                        self.up_blocks.append(up_block)
         
     | 
| 182 | 
         
            +
                        prev_output_channel = output_channel
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                    # out
         
     | 
| 185 | 
         
            +
                    num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
         
     | 
| 186 | 
         
            +
                    self.out_block = get_out_block(
         
     | 
| 187 | 
         
            +
                        out_block_type=out_block_type,
         
     | 
| 188 | 
         
            +
                        num_groups_out=num_groups_out,
         
     | 
| 189 | 
         
            +
                        embed_dim=block_out_channels[0],
         
     | 
| 190 | 
         
            +
                        out_channels=out_channels,
         
     | 
| 191 | 
         
            +
                        act_fn=act_fn,
         
     | 
| 192 | 
         
            +
                        fc_dim=block_out_channels[-1] // 4,
         
     | 
| 193 | 
         
            +
                    )
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                def forward(
         
     | 
| 196 | 
         
            +
                    self,
         
     | 
| 197 | 
         
            +
                    sample: torch.FloatTensor,
         
     | 
| 198 | 
         
            +
                    timestep: Union[torch.Tensor, float, int],
         
     | 
| 199 | 
         
            +
                    return_dict: bool = True,
         
     | 
| 200 | 
         
            +
                ) -> Union[UNet1DOutput, Tuple]:
         
     | 
| 201 | 
         
            +
                    r"""
         
     | 
| 202 | 
         
            +
                    The [`UNet1DModel`] forward method.
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                    Args:
         
     | 
| 205 | 
         
            +
                        sample (`torch.FloatTensor`):
         
     | 
| 206 | 
         
            +
                            The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
         
     | 
| 207 | 
         
            +
                        timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
         
     | 
| 208 | 
         
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         
     | 
| 209 | 
         
            +
                            Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
                    Returns:
         
     | 
| 212 | 
         
            +
                        [`~models.unet_1d.UNet1DOutput`] or `tuple`:
         
     | 
| 213 | 
         
            +
                            If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is
         
     | 
| 214 | 
         
            +
                            returned where the first element is the sample tensor.
         
     | 
| 215 | 
         
            +
                    """
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
                    # 1. time
         
     | 
| 218 | 
         
            +
                    timesteps = timestep
         
     | 
| 219 | 
         
            +
                    if not torch.is_tensor(timesteps):
         
     | 
| 220 | 
         
            +
                        timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
         
     | 
| 221 | 
         
            +
                    elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
         
     | 
| 222 | 
         
            +
                        timesteps = timesteps[None].to(sample.device)
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
                    timestep_embed = self.time_proj(timesteps)
         
     | 
| 225 | 
         
            +
                    if self.config.use_timestep_embedding:
         
     | 
| 226 | 
         
            +
                        timestep_embed = self.time_mlp(timestep_embed)
         
     | 
| 227 | 
         
            +
                    else:
         
     | 
| 228 | 
         
            +
                        timestep_embed = timestep_embed[..., None]
         
     | 
| 229 | 
         
            +
                        timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
         
     | 
| 230 | 
         
            +
                        timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))
         
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
                    # 2. down
         
     | 
| 233 | 
         
            +
                    down_block_res_samples = ()
         
     | 
| 234 | 
         
            +
                    for downsample_block in self.down_blocks:
         
     | 
| 235 | 
         
            +
                        sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed)
         
     | 
| 236 | 
         
            +
                        down_block_res_samples += res_samples
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                    # 3. mid
         
     | 
| 239 | 
         
            +
                    if self.mid_block:
         
     | 
| 240 | 
         
            +
                        sample = self.mid_block(sample, timestep_embed)
         
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
                    # 4. up
         
     | 
| 243 | 
         
            +
                    for i, upsample_block in enumerate(self.up_blocks):
         
     | 
| 244 | 
         
            +
                        res_samples = down_block_res_samples[-1:]
         
     | 
| 245 | 
         
            +
                        down_block_res_samples = down_block_res_samples[:-1]
         
     | 
| 246 | 
         
            +
                        sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                    # 5. post-process
         
     | 
| 249 | 
         
            +
                    if self.out_block:
         
     | 
| 250 | 
         
            +
                        sample = self.out_block(sample, timestep_embed)
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
                    if not return_dict:
         
     | 
| 253 | 
         
            +
                        return (sample,)
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
                    return UNet1DOutput(sample=sample)
         
     | 
    	
        6DoF/diffusers/models/unet_1d_blocks.py
    ADDED
    
    | 
         @@ -0,0 +1,656 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            import math
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            import torch
         
     | 
| 17 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 18 | 
         
            +
            from torch import nn
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            from .activations import get_activation
         
     | 
| 21 | 
         
            +
            from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            class DownResnetBlock1D(nn.Module):
         
     | 
| 25 | 
         
            +
                def __init__(
         
     | 
| 26 | 
         
            +
                    self,
         
     | 
| 27 | 
         
            +
                    in_channels,
         
     | 
| 28 | 
         
            +
                    out_channels=None,
         
     | 
| 29 | 
         
            +
                    num_layers=1,
         
     | 
| 30 | 
         
            +
                    conv_shortcut=False,
         
     | 
| 31 | 
         
            +
                    temb_channels=32,
         
     | 
| 32 | 
         
            +
                    groups=32,
         
     | 
| 33 | 
         
            +
                    groups_out=None,
         
     | 
| 34 | 
         
            +
                    non_linearity=None,
         
     | 
| 35 | 
         
            +
                    time_embedding_norm="default",
         
     | 
| 36 | 
         
            +
                    output_scale_factor=1.0,
         
     | 
| 37 | 
         
            +
                    add_downsample=True,
         
     | 
| 38 | 
         
            +
                ):
         
     | 
| 39 | 
         
            +
                    super().__init__()
         
     | 
| 40 | 
         
            +
                    self.in_channels = in_channels
         
     | 
| 41 | 
         
            +
                    out_channels = in_channels if out_channels is None else out_channels
         
     | 
| 42 | 
         
            +
                    self.out_channels = out_channels
         
     | 
| 43 | 
         
            +
                    self.use_conv_shortcut = conv_shortcut
         
     | 
| 44 | 
         
            +
                    self.time_embedding_norm = time_embedding_norm
         
     | 
| 45 | 
         
            +
                    self.add_downsample = add_downsample
         
     | 
| 46 | 
         
            +
                    self.output_scale_factor = output_scale_factor
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                    if groups_out is None:
         
     | 
| 49 | 
         
            +
                        groups_out = groups
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                    # there will always be at least one resnet
         
     | 
| 52 | 
         
            +
                    resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)]
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    for _ in range(num_layers):
         
     | 
| 55 | 
         
            +
                        resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                    self.resnets = nn.ModuleList(resnets)
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                    if non_linearity is None:
         
     | 
| 60 | 
         
            +
                        self.nonlinearity = None
         
     | 
| 61 | 
         
            +
                    else:
         
     | 
| 62 | 
         
            +
                        self.nonlinearity = get_activation(non_linearity)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    self.downsample = None
         
     | 
| 65 | 
         
            +
                    if add_downsample:
         
     | 
| 66 | 
         
            +
                        self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                def forward(self, hidden_states, temb=None):
         
     | 
| 69 | 
         
            +
                    output_states = ()
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    hidden_states = self.resnets[0](hidden_states, temb)
         
     | 
| 72 | 
         
            +
                    for resnet in self.resnets[1:]:
         
     | 
| 73 | 
         
            +
                        hidden_states = resnet(hidden_states, temb)
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                    output_states += (hidden_states,)
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                    if self.nonlinearity is not None:
         
     | 
| 78 | 
         
            +
                        hidden_states = self.nonlinearity(hidden_states)
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    if self.downsample is not None:
         
     | 
| 81 | 
         
            +
                        hidden_states = self.downsample(hidden_states)
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                    return hidden_states, output_states
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
            class UpResnetBlock1D(nn.Module):
         
     | 
| 87 | 
         
            +
                def __init__(
         
     | 
| 88 | 
         
            +
                    self,
         
     | 
| 89 | 
         
            +
                    in_channels,
         
     | 
| 90 | 
         
            +
                    out_channels=None,
         
     | 
| 91 | 
         
            +
                    num_layers=1,
         
     | 
| 92 | 
         
            +
                    temb_channels=32,
         
     | 
| 93 | 
         
            +
                    groups=32,
         
     | 
| 94 | 
         
            +
                    groups_out=None,
         
     | 
| 95 | 
         
            +
                    non_linearity=None,
         
     | 
| 96 | 
         
            +
                    time_embedding_norm="default",
         
     | 
| 97 | 
         
            +
                    output_scale_factor=1.0,
         
     | 
| 98 | 
         
            +
                    add_upsample=True,
         
     | 
| 99 | 
         
            +
                ):
         
     | 
| 100 | 
         
            +
                    super().__init__()
         
     | 
| 101 | 
         
            +
                    self.in_channels = in_channels
         
     | 
| 102 | 
         
            +
                    out_channels = in_channels if out_channels is None else out_channels
         
     | 
| 103 | 
         
            +
                    self.out_channels = out_channels
         
     | 
| 104 | 
         
            +
                    self.time_embedding_norm = time_embedding_norm
         
     | 
| 105 | 
         
            +
                    self.add_upsample = add_upsample
         
     | 
| 106 | 
         
            +
                    self.output_scale_factor = output_scale_factor
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                    if groups_out is None:
         
     | 
| 109 | 
         
            +
                        groups_out = groups
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                    # there will always be at least one resnet
         
     | 
| 112 | 
         
            +
                    resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                    for _ in range(num_layers):
         
     | 
| 115 | 
         
            +
                        resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                    self.resnets = nn.ModuleList(resnets)
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                    if non_linearity is None:
         
     | 
| 120 | 
         
            +
                        self.nonlinearity = None
         
     | 
| 121 | 
         
            +
                    else:
         
     | 
| 122 | 
         
            +
                        self.nonlinearity = get_activation(non_linearity)
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                    self.upsample = None
         
     | 
| 125 | 
         
            +
                    if add_upsample:
         
     | 
| 126 | 
         
            +
                        self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None):
         
     | 
| 129 | 
         
            +
                    if res_hidden_states_tuple is not None:
         
     | 
| 130 | 
         
            +
                        res_hidden_states = res_hidden_states_tuple[-1]
         
     | 
| 131 | 
         
            +
                        hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    hidden_states = self.resnets[0](hidden_states, temb)
         
     | 
| 134 | 
         
            +
                    for resnet in self.resnets[1:]:
         
     | 
| 135 | 
         
            +
                        hidden_states = resnet(hidden_states, temb)
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                    if self.nonlinearity is not None:
         
     | 
| 138 | 
         
            +
                        hidden_states = self.nonlinearity(hidden_states)
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                    if self.upsample is not None:
         
     | 
| 141 | 
         
            +
                        hidden_states = self.upsample(hidden_states)
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                    return hidden_states
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
            class ValueFunctionMidBlock1D(nn.Module):
         
     | 
| 147 | 
         
            +
                def __init__(self, in_channels, out_channels, embed_dim):
         
     | 
| 148 | 
         
            +
                    super().__init__()
         
     | 
| 149 | 
         
            +
                    self.in_channels = in_channels
         
     | 
| 150 | 
         
            +
                    self.out_channels = out_channels
         
     | 
| 151 | 
         
            +
                    self.embed_dim = embed_dim
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                    self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
         
     | 
| 154 | 
         
            +
                    self.down1 = Downsample1D(out_channels // 2, use_conv=True)
         
     | 
| 155 | 
         
            +
                    self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
         
     | 
| 156 | 
         
            +
                    self.down2 = Downsample1D(out_channels // 4, use_conv=True)
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                def forward(self, x, temb=None):
         
     | 
| 159 | 
         
            +
                    x = self.res1(x, temb)
         
     | 
| 160 | 
         
            +
                    x = self.down1(x)
         
     | 
| 161 | 
         
            +
                    x = self.res2(x, temb)
         
     | 
| 162 | 
         
            +
                    x = self.down2(x)
         
     | 
| 163 | 
         
            +
                    return x
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
            class MidResTemporalBlock1D(nn.Module):
         
     | 
| 167 | 
         
            +
                def __init__(
         
     | 
| 168 | 
         
            +
                    self,
         
     | 
| 169 | 
         
            +
                    in_channels,
         
     | 
| 170 | 
         
            +
                    out_channels,
         
     | 
| 171 | 
         
            +
                    embed_dim,
         
     | 
| 172 | 
         
            +
                    num_layers: int = 1,
         
     | 
| 173 | 
         
            +
                    add_downsample: bool = False,
         
     | 
| 174 | 
         
            +
                    add_upsample: bool = False,
         
     | 
| 175 | 
         
            +
                    non_linearity=None,
         
     | 
| 176 | 
         
            +
                ):
         
     | 
| 177 | 
         
            +
                    super().__init__()
         
     | 
| 178 | 
         
            +
                    self.in_channels = in_channels
         
     | 
| 179 | 
         
            +
                    self.out_channels = out_channels
         
     | 
| 180 | 
         
            +
                    self.add_downsample = add_downsample
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                    # there will always be at least one resnet
         
     | 
| 183 | 
         
            +
                    resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                    for _ in range(num_layers):
         
     | 
| 186 | 
         
            +
                        resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                    self.resnets = nn.ModuleList(resnets)
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
                    if non_linearity is None:
         
     | 
| 191 | 
         
            +
                        self.nonlinearity = None
         
     | 
| 192 | 
         
            +
                    else:
         
     | 
| 193 | 
         
            +
                        self.nonlinearity = get_activation(non_linearity)
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                    self.upsample = None
         
     | 
| 196 | 
         
            +
                    if add_upsample:
         
     | 
| 197 | 
         
            +
                        self.upsample = Downsample1D(out_channels, use_conv=True)
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                    self.downsample = None
         
     | 
| 200 | 
         
            +
                    if add_downsample:
         
     | 
| 201 | 
         
            +
                        self.downsample = Downsample1D(out_channels, use_conv=True)
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
                    if self.upsample and self.downsample:
         
     | 
| 204 | 
         
            +
                        raise ValueError("Block cannot downsample and upsample")
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                def forward(self, hidden_states, temb):
         
     | 
| 207 | 
         
            +
                    hidden_states = self.resnets[0](hidden_states, temb)
         
     | 
| 208 | 
         
            +
                    for resnet in self.resnets[1:]:
         
     | 
| 209 | 
         
            +
                        hidden_states = resnet(hidden_states, temb)
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
                    if self.upsample:
         
     | 
| 212 | 
         
            +
                        hidden_states = self.upsample(hidden_states)
         
     | 
| 213 | 
         
            +
                    if self.downsample:
         
     | 
| 214 | 
         
            +
                        self.downsample = self.downsample(hidden_states)
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                    return hidden_states
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
            class OutConv1DBlock(nn.Module):
         
     | 
| 220 | 
         
            +
                def __init__(self, num_groups_out, out_channels, embed_dim, act_fn):
         
     | 
| 221 | 
         
            +
                    super().__init__()
         
     | 
| 222 | 
         
            +
                    self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
         
     | 
| 223 | 
         
            +
                    self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
         
     | 
| 224 | 
         
            +
                    self.final_conv1d_act = get_activation(act_fn)
         
     | 
| 225 | 
         
            +
                    self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
                def forward(self, hidden_states, temb=None):
         
     | 
| 228 | 
         
            +
                    hidden_states = self.final_conv1d_1(hidden_states)
         
     | 
| 229 | 
         
            +
                    hidden_states = rearrange_dims(hidden_states)
         
     | 
| 230 | 
         
            +
                    hidden_states = self.final_conv1d_gn(hidden_states)
         
     | 
| 231 | 
         
            +
                    hidden_states = rearrange_dims(hidden_states)
         
     | 
| 232 | 
         
            +
                    hidden_states = self.final_conv1d_act(hidden_states)
         
     | 
| 233 | 
         
            +
                    hidden_states = self.final_conv1d_2(hidden_states)
         
     | 
| 234 | 
         
            +
                    return hidden_states
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
            class OutValueFunctionBlock(nn.Module):
         
     | 
| 238 | 
         
            +
                def __init__(self, fc_dim, embed_dim):
         
     | 
| 239 | 
         
            +
                    super().__init__()
         
     | 
| 240 | 
         
            +
                    self.final_block = nn.ModuleList(
         
     | 
| 241 | 
         
            +
                        [
         
     | 
| 242 | 
         
            +
                            nn.Linear(fc_dim + embed_dim, fc_dim // 2),
         
     | 
| 243 | 
         
            +
                            nn.Mish(),
         
     | 
| 244 | 
         
            +
                            nn.Linear(fc_dim // 2, 1),
         
     | 
| 245 | 
         
            +
                        ]
         
     | 
| 246 | 
         
            +
                    )
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                def forward(self, hidden_states, temb):
         
     | 
| 249 | 
         
            +
                    hidden_states = hidden_states.view(hidden_states.shape[0], -1)
         
     | 
| 250 | 
         
            +
                    hidden_states = torch.cat((hidden_states, temb), dim=-1)
         
     | 
| 251 | 
         
            +
                    for layer in self.final_block:
         
     | 
| 252 | 
         
            +
                        hidden_states = layer(hidden_states)
         
     | 
| 253 | 
         
            +
             
     | 
| 254 | 
         
            +
                    return hidden_states
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
            _kernels = {
         
     | 
| 258 | 
         
            +
                "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
         
     | 
| 259 | 
         
            +
                "cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875],
         
     | 
| 260 | 
         
            +
                "lanczos3": [
         
     | 
| 261 | 
         
            +
                    0.003689131001010537,
         
     | 
| 262 | 
         
            +
                    0.015056144446134567,
         
     | 
| 263 | 
         
            +
                    -0.03399861603975296,
         
     | 
| 264 | 
         
            +
                    -0.066637322306633,
         
     | 
| 265 | 
         
            +
                    0.13550527393817902,
         
     | 
| 266 | 
         
            +
                    0.44638532400131226,
         
     | 
| 267 | 
         
            +
                    0.44638532400131226,
         
     | 
| 268 | 
         
            +
                    0.13550527393817902,
         
     | 
| 269 | 
         
            +
                    -0.066637322306633,
         
     | 
| 270 | 
         
            +
                    -0.03399861603975296,
         
     | 
| 271 | 
         
            +
                    0.015056144446134567,
         
     | 
| 272 | 
         
            +
                    0.003689131001010537,
         
     | 
| 273 | 
         
            +
                ],
         
     | 
| 274 | 
         
            +
            }
         
     | 
| 275 | 
         
            +
             
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
            class Downsample1d(nn.Module):
         
     | 
| 278 | 
         
            +
                def __init__(self, kernel="linear", pad_mode="reflect"):
         
     | 
| 279 | 
         
            +
                    super().__init__()
         
     | 
| 280 | 
         
            +
                    self.pad_mode = pad_mode
         
     | 
| 281 | 
         
            +
                    kernel_1d = torch.tensor(_kernels[kernel])
         
     | 
| 282 | 
         
            +
                    self.pad = kernel_1d.shape[0] // 2 - 1
         
     | 
| 283 | 
         
            +
                    self.register_buffer("kernel", kernel_1d)
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
                def forward(self, hidden_states):
         
     | 
| 286 | 
         
            +
                    hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
         
     | 
| 287 | 
         
            +
                    weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
         
     | 
| 288 | 
         
            +
                    indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
         
     | 
| 289 | 
         
            +
                    kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
         
     | 
| 290 | 
         
            +
                    weight[indices, indices] = kernel
         
     | 
| 291 | 
         
            +
                    return F.conv1d(hidden_states, weight, stride=2)
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
            class Upsample1d(nn.Module):
         
     | 
| 295 | 
         
            +
                def __init__(self, kernel="linear", pad_mode="reflect"):
         
     | 
| 296 | 
         
            +
                    super().__init__()
         
     | 
| 297 | 
         
            +
                    self.pad_mode = pad_mode
         
     | 
| 298 | 
         
            +
                    kernel_1d = torch.tensor(_kernels[kernel]) * 2
         
     | 
| 299 | 
         
            +
                    self.pad = kernel_1d.shape[0] // 2 - 1
         
     | 
| 300 | 
         
            +
                    self.register_buffer("kernel", kernel_1d)
         
     | 
| 301 | 
         
            +
             
     | 
| 302 | 
         
            +
                def forward(self, hidden_states, temb=None):
         
     | 
| 303 | 
         
            +
                    hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
         
     | 
| 304 | 
         
            +
                    weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
         
     | 
| 305 | 
         
            +
                    indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
         
     | 
| 306 | 
         
            +
                    kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
         
     | 
| 307 | 
         
            +
                    weight[indices, indices] = kernel
         
     | 
| 308 | 
         
            +
                    return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)
         
     | 
| 309 | 
         
            +
             
     | 
| 310 | 
         
            +
             
     | 
| 311 | 
         
            +
            class SelfAttention1d(nn.Module):
         
     | 
| 312 | 
         
            +
                def __init__(self, in_channels, n_head=1, dropout_rate=0.0):
         
     | 
| 313 | 
         
            +
                    super().__init__()
         
     | 
| 314 | 
         
            +
                    self.channels = in_channels
         
     | 
| 315 | 
         
            +
                    self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
         
     | 
| 316 | 
         
            +
                    self.num_heads = n_head
         
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
            +
                    self.query = nn.Linear(self.channels, self.channels)
         
     | 
| 319 | 
         
            +
                    self.key = nn.Linear(self.channels, self.channels)
         
     | 
| 320 | 
         
            +
                    self.value = nn.Linear(self.channels, self.channels)
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
                    self.proj_attn = nn.Linear(self.channels, self.channels, bias=True)
         
     | 
| 323 | 
         
            +
             
     | 
| 324 | 
         
            +
                    self.dropout = nn.Dropout(dropout_rate, inplace=True)
         
     | 
| 325 | 
         
            +
             
     | 
| 326 | 
         
            +
                def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
         
     | 
| 327 | 
         
            +
                    new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
         
     | 
| 328 | 
         
            +
                    # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
         
     | 
| 329 | 
         
            +
                    new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
         
     | 
| 330 | 
         
            +
                    return new_projection
         
     | 
| 331 | 
         
            +
             
     | 
| 332 | 
         
            +
                def forward(self, hidden_states):
         
     | 
| 333 | 
         
            +
                    residual = hidden_states
         
     | 
| 334 | 
         
            +
                    batch, channel_dim, seq = hidden_states.shape
         
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
                    hidden_states = self.group_norm(hidden_states)
         
     | 
| 337 | 
         
            +
                    hidden_states = hidden_states.transpose(1, 2)
         
     | 
| 338 | 
         
            +
             
     | 
| 339 | 
         
            +
                    query_proj = self.query(hidden_states)
         
     | 
| 340 | 
         
            +
                    key_proj = self.key(hidden_states)
         
     | 
| 341 | 
         
            +
                    value_proj = self.value(hidden_states)
         
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
                    query_states = self.transpose_for_scores(query_proj)
         
     | 
| 344 | 
         
            +
                    key_states = self.transpose_for_scores(key_proj)
         
     | 
| 345 | 
         
            +
                    value_states = self.transpose_for_scores(value_proj)
         
     | 
| 346 | 
         
            +
             
     | 
| 347 | 
         
            +
                    scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1]))
         
     | 
| 348 | 
         
            +
             
     | 
| 349 | 
         
            +
                    attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
         
     | 
| 350 | 
         
            +
                    attention_probs = torch.softmax(attention_scores, dim=-1)
         
     | 
| 351 | 
         
            +
             
     | 
| 352 | 
         
            +
                    # compute attention output
         
     | 
| 353 | 
         
            +
                    hidden_states = torch.matmul(attention_probs, value_states)
         
     | 
| 354 | 
         
            +
             
     | 
| 355 | 
         
            +
                    hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
         
     | 
| 356 | 
         
            +
                    new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
         
     | 
| 357 | 
         
            +
                    hidden_states = hidden_states.view(new_hidden_states_shape)
         
     | 
| 358 | 
         
            +
             
     | 
| 359 | 
         
            +
                    # compute next hidden_states
         
     | 
| 360 | 
         
            +
                    hidden_states = self.proj_attn(hidden_states)
         
     | 
| 361 | 
         
            +
                    hidden_states = hidden_states.transpose(1, 2)
         
     | 
| 362 | 
         
            +
                    hidden_states = self.dropout(hidden_states)
         
     | 
| 363 | 
         
            +
             
     | 
| 364 | 
         
            +
                    output = hidden_states + residual
         
     | 
| 365 | 
         
            +
             
     | 
| 366 | 
         
            +
                    return output
         
     | 
| 367 | 
         
            +
             
     | 
| 368 | 
         
            +
             
     | 
| 369 | 
         
            +
            class ResConvBlock(nn.Module):
         
     | 
| 370 | 
         
            +
                def __init__(self, in_channels, mid_channels, out_channels, is_last=False):
         
     | 
| 371 | 
         
            +
                    super().__init__()
         
     | 
| 372 | 
         
            +
                    self.is_last = is_last
         
     | 
| 373 | 
         
            +
                    self.has_conv_skip = in_channels != out_channels
         
     | 
| 374 | 
         
            +
             
     | 
| 375 | 
         
            +
                    if self.has_conv_skip:
         
     | 
| 376 | 
         
            +
                        self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False)
         
     | 
| 377 | 
         
            +
             
     | 
| 378 | 
         
            +
                    self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2)
         
     | 
| 379 | 
         
            +
                    self.group_norm_1 = nn.GroupNorm(1, mid_channels)
         
     | 
| 380 | 
         
            +
                    self.gelu_1 = nn.GELU()
         
     | 
| 381 | 
         
            +
                    self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2)
         
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
                    if not self.is_last:
         
     | 
| 384 | 
         
            +
                        self.group_norm_2 = nn.GroupNorm(1, out_channels)
         
     | 
| 385 | 
         
            +
                        self.gelu_2 = nn.GELU()
         
     | 
| 386 | 
         
            +
             
     | 
| 387 | 
         
            +
                def forward(self, hidden_states):
         
     | 
| 388 | 
         
            +
                    residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
         
     | 
| 389 | 
         
            +
             
     | 
| 390 | 
         
            +
                    hidden_states = self.conv_1(hidden_states)
         
     | 
| 391 | 
         
            +
                    hidden_states = self.group_norm_1(hidden_states)
         
     | 
| 392 | 
         
            +
                    hidden_states = self.gelu_1(hidden_states)
         
     | 
| 393 | 
         
            +
                    hidden_states = self.conv_2(hidden_states)
         
     | 
| 394 | 
         
            +
             
     | 
| 395 | 
         
            +
                    if not self.is_last:
         
     | 
| 396 | 
         
            +
                        hidden_states = self.group_norm_2(hidden_states)
         
     | 
| 397 | 
         
            +
                        hidden_states = self.gelu_2(hidden_states)
         
     | 
| 398 | 
         
            +
             
     | 
| 399 | 
         
            +
                    output = hidden_states + residual
         
     | 
| 400 | 
         
            +
                    return output
         
     | 
| 401 | 
         
            +
             
     | 
| 402 | 
         
            +
             
     | 
| 403 | 
         
            +
            class UNetMidBlock1D(nn.Module):
         
     | 
| 404 | 
         
            +
                def __init__(self, mid_channels, in_channels, out_channels=None):
         
     | 
| 405 | 
         
            +
                    super().__init__()
         
     | 
| 406 | 
         
            +
             
     | 
| 407 | 
         
            +
                    out_channels = in_channels if out_channels is None else out_channels
         
     | 
| 408 | 
         
            +
             
     | 
| 409 | 
         
            +
                    # there is always at least one resnet
         
     | 
| 410 | 
         
            +
                    self.down = Downsample1d("cubic")
         
     | 
| 411 | 
         
            +
                    resnets = [
         
     | 
| 412 | 
         
            +
                        ResConvBlock(in_channels, mid_channels, mid_channels),
         
     | 
| 413 | 
         
            +
                        ResConvBlock(mid_channels, mid_channels, mid_channels),
         
     | 
| 414 | 
         
            +
                        ResConvBlock(mid_channels, mid_channels, mid_channels),
         
     | 
| 415 | 
         
            +
                        ResConvBlock(mid_channels, mid_channels, mid_channels),
         
     | 
| 416 | 
         
            +
                        ResConvBlock(mid_channels, mid_channels, mid_channels),
         
     | 
| 417 | 
         
            +
                        ResConvBlock(mid_channels, mid_channels, out_channels),
         
     | 
| 418 | 
         
            +
                    ]
         
     | 
| 419 | 
         
            +
                    attentions = [
         
     | 
| 420 | 
         
            +
                        SelfAttention1d(mid_channels, mid_channels // 32),
         
     | 
| 421 | 
         
            +
                        SelfAttention1d(mid_channels, mid_channels // 32),
         
     | 
| 422 | 
         
            +
                        SelfAttention1d(mid_channels, mid_channels // 32),
         
     | 
| 423 | 
         
            +
                        SelfAttention1d(mid_channels, mid_channels // 32),
         
     | 
| 424 | 
         
            +
                        SelfAttention1d(mid_channels, mid_channels // 32),
         
     | 
| 425 | 
         
            +
                        SelfAttention1d(out_channels, out_channels // 32),
         
     | 
| 426 | 
         
            +
                    ]
         
     | 
| 427 | 
         
            +
                    self.up = Upsample1d(kernel="cubic")
         
     | 
| 428 | 
         
            +
             
     | 
| 429 | 
         
            +
                    self.attentions = nn.ModuleList(attentions)
         
     | 
| 430 | 
         
            +
                    self.resnets = nn.ModuleList(resnets)
         
     | 
| 431 | 
         
            +
             
     | 
| 432 | 
         
            +
                def forward(self, hidden_states, temb=None):
         
     | 
| 433 | 
         
            +
                    hidden_states = self.down(hidden_states)
         
     | 
| 434 | 
         
            +
                    for attn, resnet in zip(self.attentions, self.resnets):
         
     | 
| 435 | 
         
            +
                        hidden_states = resnet(hidden_states)
         
     | 
| 436 | 
         
            +
                        hidden_states = attn(hidden_states)
         
     | 
| 437 | 
         
            +
             
     | 
| 438 | 
         
            +
                    hidden_states = self.up(hidden_states)
         
     | 
| 439 | 
         
            +
             
     | 
| 440 | 
         
            +
                    return hidden_states
         
     | 
| 441 | 
         
            +
             
     | 
| 442 | 
         
            +
             
     | 
| 443 | 
         
            +
            class AttnDownBlock1D(nn.Module):
         
     | 
| 444 | 
         
            +
                def __init__(self, out_channels, in_channels, mid_channels=None):
         
     | 
| 445 | 
         
            +
                    super().__init__()
         
     | 
| 446 | 
         
            +
                    mid_channels = out_channels if mid_channels is None else mid_channels
         
     | 
| 447 | 
         
            +
             
     | 
| 448 | 
         
            +
                    self.down = Downsample1d("cubic")
         
     | 
| 449 | 
         
            +
                    resnets = [
         
     | 
| 450 | 
         
            +
                        ResConvBlock(in_channels, mid_channels, mid_channels),
         
     | 
| 451 | 
         
            +
                        ResConvBlock(mid_channels, mid_channels, mid_channels),
         
     | 
| 452 | 
         
            +
                        ResConvBlock(mid_channels, mid_channels, out_channels),
         
     | 
| 453 | 
         
            +
                    ]
         
     | 
| 454 | 
         
            +
                    attentions = [
         
     | 
| 455 | 
         
            +
                        SelfAttention1d(mid_channels, mid_channels // 32),
         
     | 
| 456 | 
         
            +
                        SelfAttention1d(mid_channels, mid_channels // 32),
         
     | 
| 457 | 
         
            +
                        SelfAttention1d(out_channels, out_channels // 32),
         
     | 
| 458 | 
         
            +
                    ]
         
     | 
| 459 | 
         
            +
             
     | 
| 460 | 
         
            +
                    self.attentions = nn.ModuleList(attentions)
         
     | 
| 461 | 
         
            +
                    self.resnets = nn.ModuleList(resnets)
         
     | 
| 462 | 
         
            +
             
     | 
| 463 | 
         
            +
                def forward(self, hidden_states, temb=None):
         
     | 
| 464 | 
         
            +
                    hidden_states = self.down(hidden_states)
         
     | 
| 465 | 
         
            +
             
     | 
| 466 | 
         
            +
                    for resnet, attn in zip(self.resnets, self.attentions):
         
     | 
| 467 | 
         
            +
                        hidden_states = resnet(hidden_states)
         
     | 
| 468 | 
         
            +
                        hidden_states = attn(hidden_states)
         
     | 
| 469 | 
         
            +
             
     | 
| 470 | 
         
            +
                    return hidden_states, (hidden_states,)
         
     | 
| 471 | 
         
            +
             
     | 
| 472 | 
         
            +
             
     | 
| 473 | 
         
            +
            class DownBlock1D(nn.Module):
         
     | 
| 474 | 
         
            +
                def __init__(self, out_channels, in_channels, mid_channels=None):
         
     | 
| 475 | 
         
            +
                    super().__init__()
         
     | 
| 476 | 
         
            +
                    mid_channels = out_channels if mid_channels is None else mid_channels
         
     | 
| 477 | 
         
            +
             
     | 
| 478 | 
         
            +
                    self.down = Downsample1d("cubic")
         
     | 
| 479 | 
         
            +
                    resnets = [
         
     | 
| 480 | 
         
            +
                        ResConvBlock(in_channels, mid_channels, mid_channels),
         
     | 
| 481 | 
         
            +
                        ResConvBlock(mid_channels, mid_channels, mid_channels),
         
     | 
| 482 | 
         
            +
                        ResConvBlock(mid_channels, mid_channels, out_channels),
         
     | 
| 483 | 
         
            +
                    ]
         
     | 
| 484 | 
         
            +
             
     | 
| 485 | 
         
            +
                    self.resnets = nn.ModuleList(resnets)
         
     | 
| 486 | 
         
            +
             
     | 
| 487 | 
         
            +
                def forward(self, hidden_states, temb=None):
         
     | 
| 488 | 
         
            +
                    hidden_states = self.down(hidden_states)
         
     | 
| 489 | 
         
            +
             
     | 
| 490 | 
         
            +
                    for resnet in self.resnets:
         
     | 
| 491 | 
         
            +
                        hidden_states = resnet(hidden_states)
         
     | 
| 492 | 
         
            +
             
     | 
| 493 | 
         
            +
                    return hidden_states, (hidden_states,)
         
     | 
| 494 | 
         
            +
             
     | 
| 495 | 
         
            +
             
     | 
| 496 | 
         
            +
            class DownBlock1DNoSkip(nn.Module):
         
     | 
| 497 | 
         
            +
                def __init__(self, out_channels, in_channels, mid_channels=None):
         
     | 
| 498 | 
         
            +
                    super().__init__()
         
     | 
| 499 | 
         
            +
                    mid_channels = out_channels if mid_channels is None else mid_channels
         
     | 
| 500 | 
         
            +
             
     | 
| 501 | 
         
            +
                    resnets = [
         
     | 
| 502 | 
         
            +
                        ResConvBlock(in_channels, mid_channels, mid_channels),
         
     | 
| 503 | 
         
            +
                        ResConvBlock(mid_channels, mid_channels, mid_channels),
         
     | 
| 504 | 
         
            +
                        ResConvBlock(mid_channels, mid_channels, out_channels),
         
     | 
| 505 | 
         
            +
                    ]
         
     | 
| 506 | 
         
            +
             
     | 
| 507 | 
         
            +
                    self.resnets = nn.ModuleList(resnets)
         
     | 
| 508 | 
         
            +
             
     | 
| 509 | 
         
            +
                def forward(self, hidden_states, temb=None):
         
     | 
| 510 | 
         
            +
                    hidden_states = torch.cat([hidden_states, temb], dim=1)
         
     | 
| 511 | 
         
            +
                    for resnet in self.resnets:
         
     | 
| 512 | 
         
            +
                        hidden_states = resnet(hidden_states)
         
     | 
| 513 | 
         
            +
             
     | 
| 514 | 
         
            +
                    return hidden_states, (hidden_states,)
         
     | 
| 515 | 
         
            +
             
     | 
| 516 | 
         
            +
             
     | 
| 517 | 
         
            +
            class AttnUpBlock1D(nn.Module):
         
     | 
| 518 | 
         
            +
                def __init__(self, in_channels, out_channels, mid_channels=None):
         
     | 
| 519 | 
         
            +
                    super().__init__()
         
     | 
| 520 | 
         
            +
                    mid_channels = out_channels if mid_channels is None else mid_channels
         
     | 
| 521 | 
         
            +
             
     | 
| 522 | 
         
            +
                    resnets = [
         
     | 
| 523 | 
         
            +
                        ResConvBlock(2 * in_channels, mid_channels, mid_channels),
         
     | 
| 524 | 
         
            +
                        ResConvBlock(mid_channels, mid_channels, mid_channels),
         
     | 
| 525 | 
         
            +
                        ResConvBlock(mid_channels, mid_channels, out_channels),
         
     | 
| 526 | 
         
            +
                    ]
         
     | 
| 527 | 
         
            +
                    attentions = [
         
     | 
| 528 | 
         
            +
                        SelfAttention1d(mid_channels, mid_channels // 32),
         
     | 
| 529 | 
         
            +
                        SelfAttention1d(mid_channels, mid_channels // 32),
         
     | 
| 530 | 
         
            +
                        SelfAttention1d(out_channels, out_channels // 32),
         
     | 
| 531 | 
         
            +
                    ]
         
     | 
| 532 | 
         
            +
             
     | 
| 533 | 
         
            +
                    self.attentions = nn.ModuleList(attentions)
         
     | 
| 534 | 
         
            +
                    self.resnets = nn.ModuleList(resnets)
         
     | 
| 535 | 
         
            +
                    self.up = Upsample1d(kernel="cubic")
         
     | 
| 536 | 
         
            +
             
     | 
| 537 | 
         
            +
                def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
         
     | 
| 538 | 
         
            +
                    res_hidden_states = res_hidden_states_tuple[-1]
         
     | 
| 539 | 
         
            +
                    hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
         
     | 
| 540 | 
         
            +
             
     | 
| 541 | 
         
            +
                    for resnet, attn in zip(self.resnets, self.attentions):
         
     | 
| 542 | 
         
            +
                        hidden_states = resnet(hidden_states)
         
     | 
| 543 | 
         
            +
                        hidden_states = attn(hidden_states)
         
     | 
| 544 | 
         
            +
             
     | 
| 545 | 
         
            +
                    hidden_states = self.up(hidden_states)
         
     | 
| 546 | 
         
            +
             
     | 
| 547 | 
         
            +
                    return hidden_states
         
     | 
| 548 | 
         
            +
             
     | 
| 549 | 
         
            +
             
     | 
| 550 | 
         
            +
            class UpBlock1D(nn.Module):
         
     | 
| 551 | 
         
            +
                def __init__(self, in_channels, out_channels, mid_channels=None):
         
     | 
| 552 | 
         
            +
                    super().__init__()
         
     | 
| 553 | 
         
            +
                    mid_channels = in_channels if mid_channels is None else mid_channels
         
     | 
| 554 | 
         
            +
             
     | 
| 555 | 
         
            +
                    resnets = [
         
     | 
| 556 | 
         
            +
                        ResConvBlock(2 * in_channels, mid_channels, mid_channels),
         
     | 
| 557 | 
         
            +
                        ResConvBlock(mid_channels, mid_channels, mid_channels),
         
     | 
| 558 | 
         
            +
                        ResConvBlock(mid_channels, mid_channels, out_channels),
         
     | 
| 559 | 
         
            +
                    ]
         
     | 
| 560 | 
         
            +
             
     | 
| 561 | 
         
            +
                    self.resnets = nn.ModuleList(resnets)
         
     | 
| 562 | 
         
            +
                    self.up = Upsample1d(kernel="cubic")
         
     | 
| 563 | 
         
            +
             
     | 
| 564 | 
         
            +
                def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
         
     | 
| 565 | 
         
            +
                    res_hidden_states = res_hidden_states_tuple[-1]
         
     | 
| 566 | 
         
            +
                    hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
         
     | 
| 567 | 
         
            +
             
     | 
| 568 | 
         
            +
                    for resnet in self.resnets:
         
     | 
| 569 | 
         
            +
                        hidden_states = resnet(hidden_states)
         
     | 
| 570 | 
         
            +
             
     | 
| 571 | 
         
            +
                    hidden_states = self.up(hidden_states)
         
     | 
| 572 | 
         
            +
             
     | 
| 573 | 
         
            +
                    return hidden_states
         
     | 
| 574 | 
         
            +
             
     | 
| 575 | 
         
            +
             
     | 
| 576 | 
         
            +
            class UpBlock1DNoSkip(nn.Module):
         
     | 
| 577 | 
         
            +
                def __init__(self, in_channels, out_channels, mid_channels=None):
         
     | 
| 578 | 
         
            +
                    super().__init__()
         
     | 
| 579 | 
         
            +
                    mid_channels = in_channels if mid_channels is None else mid_channels
         
     | 
| 580 | 
         
            +
             
     | 
| 581 | 
         
            +
                    resnets = [
         
     | 
| 582 | 
         
            +
                        ResConvBlock(2 * in_channels, mid_channels, mid_channels),
         
     | 
| 583 | 
         
            +
                        ResConvBlock(mid_channels, mid_channels, mid_channels),
         
     | 
| 584 | 
         
            +
                        ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True),
         
     | 
| 585 | 
         
            +
                    ]
         
     | 
| 586 | 
         
            +
             
     | 
| 587 | 
         
            +
                    self.resnets = nn.ModuleList(resnets)
         
     | 
| 588 | 
         
            +
             
     | 
| 589 | 
         
            +
                def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
         
     | 
| 590 | 
         
            +
                    res_hidden_states = res_hidden_states_tuple[-1]
         
     | 
| 591 | 
         
            +
                    hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
         
     | 
| 592 | 
         
            +
             
     | 
| 593 | 
         
            +
                    for resnet in self.resnets:
         
     | 
| 594 | 
         
            +
                        hidden_states = resnet(hidden_states)
         
     | 
| 595 | 
         
            +
             
     | 
| 596 | 
         
            +
                    return hidden_states
         
     | 
| 597 | 
         
            +
             
     | 
| 598 | 
         
            +
             
     | 
| 599 | 
         
            +
            def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample):
         
     | 
| 600 | 
         
            +
                if down_block_type == "DownResnetBlock1D":
         
     | 
| 601 | 
         
            +
                    return DownResnetBlock1D(
         
     | 
| 602 | 
         
            +
                        in_channels=in_channels,
         
     | 
| 603 | 
         
            +
                        num_layers=num_layers,
         
     | 
| 604 | 
         
            +
                        out_channels=out_channels,
         
     | 
| 605 | 
         
            +
                        temb_channels=temb_channels,
         
     | 
| 606 | 
         
            +
                        add_downsample=add_downsample,
         
     | 
| 607 | 
         
            +
                    )
         
     | 
| 608 | 
         
            +
                elif down_block_type == "DownBlock1D":
         
     | 
| 609 | 
         
            +
                    return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
         
     | 
| 610 | 
         
            +
                elif down_block_type == "AttnDownBlock1D":
         
     | 
| 611 | 
         
            +
                    return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
         
     | 
| 612 | 
         
            +
                elif down_block_type == "DownBlock1DNoSkip":
         
     | 
| 613 | 
         
            +
                    return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
         
     | 
| 614 | 
         
            +
                raise ValueError(f"{down_block_type} does not exist.")
         
     | 
| 615 | 
         
            +
             
     | 
| 616 | 
         
            +
             
     | 
| 617 | 
         
            +
            def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_channels, add_upsample):
         
     | 
| 618 | 
         
            +
                if up_block_type == "UpResnetBlock1D":
         
     | 
| 619 | 
         
            +
                    return UpResnetBlock1D(
         
     | 
| 620 | 
         
            +
                        in_channels=in_channels,
         
     | 
| 621 | 
         
            +
                        num_layers=num_layers,
         
     | 
| 622 | 
         
            +
                        out_channels=out_channels,
         
     | 
| 623 | 
         
            +
                        temb_channels=temb_channels,
         
     | 
| 624 | 
         
            +
                        add_upsample=add_upsample,
         
     | 
| 625 | 
         
            +
                    )
         
     | 
| 626 | 
         
            +
                elif up_block_type == "UpBlock1D":
         
     | 
| 627 | 
         
            +
                    return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
         
     | 
| 628 | 
         
            +
                elif up_block_type == "AttnUpBlock1D":
         
     | 
| 629 | 
         
            +
                    return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
         
     | 
| 630 | 
         
            +
                elif up_block_type == "UpBlock1DNoSkip":
         
     | 
| 631 | 
         
            +
                    return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
         
     | 
| 632 | 
         
            +
                raise ValueError(f"{up_block_type} does not exist.")
         
     | 
| 633 | 
         
            +
             
     | 
| 634 | 
         
            +
             
     | 
| 635 | 
         
            +
            def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels, embed_dim, add_downsample):
         
     | 
| 636 | 
         
            +
                if mid_block_type == "MidResTemporalBlock1D":
         
     | 
| 637 | 
         
            +
                    return MidResTemporalBlock1D(
         
     | 
| 638 | 
         
            +
                        num_layers=num_layers,
         
     | 
| 639 | 
         
            +
                        in_channels=in_channels,
         
     | 
| 640 | 
         
            +
                        out_channels=out_channels,
         
     | 
| 641 | 
         
            +
                        embed_dim=embed_dim,
         
     | 
| 642 | 
         
            +
                        add_downsample=add_downsample,
         
     | 
| 643 | 
         
            +
                    )
         
     | 
| 644 | 
         
            +
                elif mid_block_type == "ValueFunctionMidBlock1D":
         
     | 
| 645 | 
         
            +
                    return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
         
     | 
| 646 | 
         
            +
                elif mid_block_type == "UNetMidBlock1D":
         
     | 
| 647 | 
         
            +
                    return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
         
     | 
| 648 | 
         
            +
                raise ValueError(f"{mid_block_type} does not exist.")
         
     | 
| 649 | 
         
            +
             
     | 
| 650 | 
         
            +
             
     | 
| 651 | 
         
            +
            def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim):
         
     | 
| 652 | 
         
            +
                if out_block_type == "OutConv1DBlock":
         
     | 
| 653 | 
         
            +
                    return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
         
     | 
| 654 | 
         
            +
                elif out_block_type == "ValueFunction":
         
     | 
| 655 | 
         
            +
                    return OutValueFunctionBlock(fc_dim, embed_dim)
         
     | 
| 656 | 
         
            +
                return None
         
     | 
    	
        6DoF/diffusers/models/unet_2d.py
    ADDED
    
    | 
         @@ -0,0 +1,329 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 15 | 
         
            +
            from typing import Optional, Tuple, Union
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            import torch
         
     | 
| 18 | 
         
            +
            import torch.nn as nn
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            from ..configuration_utils import ConfigMixin, register_to_config
         
     | 
| 21 | 
         
            +
            from ..utils import BaseOutput
         
     | 
| 22 | 
         
            +
            from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
         
     | 
| 23 | 
         
            +
            from .modeling_utils import ModelMixin
         
     | 
| 24 | 
         
            +
            from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            @dataclass
         
     | 
| 28 | 
         
            +
            class UNet2DOutput(BaseOutput):
         
     | 
| 29 | 
         
            +
                """
         
     | 
| 30 | 
         
            +
                The output of [`UNet2DModel`].
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                Args:
         
     | 
| 33 | 
         
            +
                    sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
         
     | 
| 34 | 
         
            +
                        The hidden states output from the last layer of the model.
         
     | 
| 35 | 
         
            +
                """
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                sample: torch.FloatTensor
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            class UNet2DModel(ModelMixin, ConfigMixin):
         
     | 
| 41 | 
         
            +
                r"""
         
     | 
| 42 | 
         
            +
                A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
         
     | 
| 45 | 
         
            +
                for all models (such as downloading or saving).
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                Parameters:
         
     | 
| 48 | 
         
            +
                    sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
         
     | 
| 49 | 
         
            +
                        Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
         
     | 
| 50 | 
         
            +
                        1)`.
         
     | 
| 51 | 
         
            +
                    in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
         
     | 
| 52 | 
         
            +
                    out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
         
     | 
| 53 | 
         
            +
                    center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
         
     | 
| 54 | 
         
            +
                    time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
         
     | 
| 55 | 
         
            +
                    freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
         
     | 
| 56 | 
         
            +
                    flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
         
     | 
| 57 | 
         
            +
                        Whether to flip sin to cos for Fourier time embedding.
         
     | 
| 58 | 
         
            +
                    down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
         
     | 
| 59 | 
         
            +
                        Tuple of downsample block types.
         
     | 
| 60 | 
         
            +
                    mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
         
     | 
| 61 | 
         
            +
                        Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
         
     | 
| 62 | 
         
            +
                    up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
         
     | 
| 63 | 
         
            +
                        Tuple of upsample block types.
         
     | 
| 64 | 
         
            +
                    block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
         
     | 
| 65 | 
         
            +
                        Tuple of block output channels.
         
     | 
| 66 | 
         
            +
                    layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
         
     | 
| 67 | 
         
            +
                    mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
         
     | 
| 68 | 
         
            +
                    downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
         
     | 
| 69 | 
         
            +
                    downsample_type (`str`, *optional*, defaults to `conv`):
         
     | 
| 70 | 
         
            +
                        The downsample type for downsampling layers. Choose between "conv" and "resnet"
         
     | 
| 71 | 
         
            +
                    upsample_type (`str`, *optional*, defaults to `conv`):
         
     | 
| 72 | 
         
            +
                        The upsample type for upsampling layers. Choose between "conv" and "resnet"
         
     | 
| 73 | 
         
            +
                    act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
         
     | 
| 74 | 
         
            +
                    attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
         
     | 
| 75 | 
         
            +
                    norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
         
     | 
| 76 | 
         
            +
                    norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
         
     | 
| 77 | 
         
            +
                    resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
         
     | 
| 78 | 
         
            +
                        for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
         
     | 
| 79 | 
         
            +
                    class_embed_type (`str`, *optional*, defaults to `None`):
         
     | 
| 80 | 
         
            +
                        The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
         
     | 
| 81 | 
         
            +
                        `"timestep"`, or `"identity"`.
         
     | 
| 82 | 
         
            +
                    num_class_embeds (`int`, *optional*, defaults to `None`):
         
     | 
| 83 | 
         
            +
                        Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
         
     | 
| 84 | 
         
            +
                        conditioning with `class_embed_type` equal to `None`.
         
     | 
| 85 | 
         
            +
                """
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                @register_to_config
         
     | 
| 88 | 
         
            +
                def __init__(
         
     | 
| 89 | 
         
            +
                    self,
         
     | 
| 90 | 
         
            +
                    sample_size: Optional[Union[int, Tuple[int, int]]] = None,
         
     | 
| 91 | 
         
            +
                    in_channels: int = 3,
         
     | 
| 92 | 
         
            +
                    out_channels: int = 3,
         
     | 
| 93 | 
         
            +
                    center_input_sample: bool = False,
         
     | 
| 94 | 
         
            +
                    time_embedding_type: str = "positional",
         
     | 
| 95 | 
         
            +
                    freq_shift: int = 0,
         
     | 
| 96 | 
         
            +
                    flip_sin_to_cos: bool = True,
         
     | 
| 97 | 
         
            +
                    down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
         
     | 
| 98 | 
         
            +
                    up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
         
     | 
| 99 | 
         
            +
                    block_out_channels: Tuple[int] = (224, 448, 672, 896),
         
     | 
| 100 | 
         
            +
                    layers_per_block: int = 2,
         
     | 
| 101 | 
         
            +
                    mid_block_scale_factor: float = 1,
         
     | 
| 102 | 
         
            +
                    downsample_padding: int = 1,
         
     | 
| 103 | 
         
            +
                    downsample_type: str = "conv",
         
     | 
| 104 | 
         
            +
                    upsample_type: str = "conv",
         
     | 
| 105 | 
         
            +
                    act_fn: str = "silu",
         
     | 
| 106 | 
         
            +
                    attention_head_dim: Optional[int] = 8,
         
     | 
| 107 | 
         
            +
                    norm_num_groups: int = 32,
         
     | 
| 108 | 
         
            +
                    norm_eps: float = 1e-5,
         
     | 
| 109 | 
         
            +
                    resnet_time_scale_shift: str = "default",
         
     | 
| 110 | 
         
            +
                    add_attention: bool = True,
         
     | 
| 111 | 
         
            +
                    class_embed_type: Optional[str] = None,
         
     | 
| 112 | 
         
            +
                    num_class_embeds: Optional[int] = None,
         
     | 
| 113 | 
         
            +
                ):
         
     | 
| 114 | 
         
            +
                    super().__init__()
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                    self.sample_size = sample_size
         
     | 
| 117 | 
         
            +
                    time_embed_dim = block_out_channels[0] * 4
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                    # Check inputs
         
     | 
| 120 | 
         
            +
                    if len(down_block_types) != len(up_block_types):
         
     | 
| 121 | 
         
            +
                        raise ValueError(
         
     | 
| 122 | 
         
            +
                            f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
         
     | 
| 123 | 
         
            +
                        )
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    if len(block_out_channels) != len(down_block_types):
         
     | 
| 126 | 
         
            +
                        raise ValueError(
         
     | 
| 127 | 
         
            +
                            f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
         
     | 
| 128 | 
         
            +
                        )
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                    # input
         
     | 
| 131 | 
         
            +
                    self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    # time
         
     | 
| 134 | 
         
            +
                    if time_embedding_type == "fourier":
         
     | 
| 135 | 
         
            +
                        self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
         
     | 
| 136 | 
         
            +
                        timestep_input_dim = 2 * block_out_channels[0]
         
     | 
| 137 | 
         
            +
                    elif time_embedding_type == "positional":
         
     | 
| 138 | 
         
            +
                        self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
         
     | 
| 139 | 
         
            +
                        timestep_input_dim = block_out_channels[0]
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                    self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                    # class embedding
         
     | 
| 144 | 
         
            +
                    if class_embed_type is None and num_class_embeds is not None:
         
     | 
| 145 | 
         
            +
                        self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
         
     | 
| 146 | 
         
            +
                    elif class_embed_type == "timestep":
         
     | 
| 147 | 
         
            +
                        self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
         
     | 
| 148 | 
         
            +
                    elif class_embed_type == "identity":
         
     | 
| 149 | 
         
            +
                        self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
         
     | 
| 150 | 
         
            +
                    else:
         
     | 
| 151 | 
         
            +
                        self.class_embedding = None
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                    self.down_blocks = nn.ModuleList([])
         
     | 
| 154 | 
         
            +
                    self.mid_block = None
         
     | 
| 155 | 
         
            +
                    self.up_blocks = nn.ModuleList([])
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                    # down
         
     | 
| 158 | 
         
            +
                    output_channel = block_out_channels[0]
         
     | 
| 159 | 
         
            +
                    for i, down_block_type in enumerate(down_block_types):
         
     | 
| 160 | 
         
            +
                        input_channel = output_channel
         
     | 
| 161 | 
         
            +
                        output_channel = block_out_channels[i]
         
     | 
| 162 | 
         
            +
                        is_final_block = i == len(block_out_channels) - 1
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                        down_block = get_down_block(
         
     | 
| 165 | 
         
            +
                            down_block_type,
         
     | 
| 166 | 
         
            +
                            num_layers=layers_per_block,
         
     | 
| 167 | 
         
            +
                            in_channels=input_channel,
         
     | 
| 168 | 
         
            +
                            out_channels=output_channel,
         
     | 
| 169 | 
         
            +
                            temb_channels=time_embed_dim,
         
     | 
| 170 | 
         
            +
                            add_downsample=not is_final_block,
         
     | 
| 171 | 
         
            +
                            resnet_eps=norm_eps,
         
     | 
| 172 | 
         
            +
                            resnet_act_fn=act_fn,
         
     | 
| 173 | 
         
            +
                            resnet_groups=norm_num_groups,
         
     | 
| 174 | 
         
            +
                            attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
         
     | 
| 175 | 
         
            +
                            downsample_padding=downsample_padding,
         
     | 
| 176 | 
         
            +
                            resnet_time_scale_shift=resnet_time_scale_shift,
         
     | 
| 177 | 
         
            +
                            downsample_type=downsample_type,
         
     | 
| 178 | 
         
            +
                        )
         
     | 
| 179 | 
         
            +
                        self.down_blocks.append(down_block)
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                    # mid
         
     | 
| 182 | 
         
            +
                    self.mid_block = UNetMidBlock2D(
         
     | 
| 183 | 
         
            +
                        in_channels=block_out_channels[-1],
         
     | 
| 184 | 
         
            +
                        temb_channels=time_embed_dim,
         
     | 
| 185 | 
         
            +
                        resnet_eps=norm_eps,
         
     | 
| 186 | 
         
            +
                        resnet_act_fn=act_fn,
         
     | 
| 187 | 
         
            +
                        output_scale_factor=mid_block_scale_factor,
         
     | 
| 188 | 
         
            +
                        resnet_time_scale_shift=resnet_time_scale_shift,
         
     | 
| 189 | 
         
            +
                        attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
         
     | 
| 190 | 
         
            +
                        resnet_groups=norm_num_groups,
         
     | 
| 191 | 
         
            +
                        add_attention=add_attention,
         
     | 
| 192 | 
         
            +
                    )
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                    # up
         
     | 
| 195 | 
         
            +
                    reversed_block_out_channels = list(reversed(block_out_channels))
         
     | 
| 196 | 
         
            +
                    output_channel = reversed_block_out_channels[0]
         
     | 
| 197 | 
         
            +
                    for i, up_block_type in enumerate(up_block_types):
         
     | 
| 198 | 
         
            +
                        prev_output_channel = output_channel
         
     | 
| 199 | 
         
            +
                        output_channel = reversed_block_out_channels[i]
         
     | 
| 200 | 
         
            +
                        input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                        is_final_block = i == len(block_out_channels) - 1
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                        up_block = get_up_block(
         
     | 
| 205 | 
         
            +
                            up_block_type,
         
     | 
| 206 | 
         
            +
                            num_layers=layers_per_block + 1,
         
     | 
| 207 | 
         
            +
                            in_channels=input_channel,
         
     | 
| 208 | 
         
            +
                            out_channels=output_channel,
         
     | 
| 209 | 
         
            +
                            prev_output_channel=prev_output_channel,
         
     | 
| 210 | 
         
            +
                            temb_channels=time_embed_dim,
         
     | 
| 211 | 
         
            +
                            add_upsample=not is_final_block,
         
     | 
| 212 | 
         
            +
                            resnet_eps=norm_eps,
         
     | 
| 213 | 
         
            +
                            resnet_act_fn=act_fn,
         
     | 
| 214 | 
         
            +
                            resnet_groups=norm_num_groups,
         
     | 
| 215 | 
         
            +
                            attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
         
     | 
| 216 | 
         
            +
                            resnet_time_scale_shift=resnet_time_scale_shift,
         
     | 
| 217 | 
         
            +
                            upsample_type=upsample_type,
         
     | 
| 218 | 
         
            +
                        )
         
     | 
| 219 | 
         
            +
                        self.up_blocks.append(up_block)
         
     | 
| 220 | 
         
            +
                        prev_output_channel = output_channel
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                    # out
         
     | 
| 223 | 
         
            +
                    num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
         
     | 
| 224 | 
         
            +
                    self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
         
     | 
| 225 | 
         
            +
                    self.conv_act = nn.SiLU()
         
     | 
| 226 | 
         
            +
                    self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                def forward(
         
     | 
| 229 | 
         
            +
                    self,
         
     | 
| 230 | 
         
            +
                    sample: torch.FloatTensor,
         
     | 
| 231 | 
         
            +
                    timestep: Union[torch.Tensor, float, int],
         
     | 
| 232 | 
         
            +
                    class_labels: Optional[torch.Tensor] = None,
         
     | 
| 233 | 
         
            +
                    return_dict: bool = True,
         
     | 
| 234 | 
         
            +
                ) -> Union[UNet2DOutput, Tuple]:
         
     | 
| 235 | 
         
            +
                    r"""
         
     | 
| 236 | 
         
            +
                    The [`UNet2DModel`] forward method.
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                    Args:
         
     | 
| 239 | 
         
            +
                        sample (`torch.FloatTensor`):
         
     | 
| 240 | 
         
            +
                            The noisy input tensor with the following shape `(batch, channel, height, width)`.
         
     | 
| 241 | 
         
            +
                        timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
         
     | 
| 242 | 
         
            +
                        class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
         
     | 
| 243 | 
         
            +
                            Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
         
     | 
| 244 | 
         
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         
     | 
| 245 | 
         
            +
                            Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                    Returns:
         
     | 
| 248 | 
         
            +
                        [`~models.unet_2d.UNet2DOutput`] or `tuple`:
         
     | 
| 249 | 
         
            +
                            If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
         
     | 
| 250 | 
         
            +
                            returned where the first element is the sample tensor.
         
     | 
| 251 | 
         
            +
                    """
         
     | 
| 252 | 
         
            +
                    # 0. center input if necessary
         
     | 
| 253 | 
         
            +
                    if self.config.center_input_sample:
         
     | 
| 254 | 
         
            +
                        sample = 2 * sample - 1.0
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                    # 1. time
         
     | 
| 257 | 
         
            +
                    timesteps = timestep
         
     | 
| 258 | 
         
            +
                    if not torch.is_tensor(timesteps):
         
     | 
| 259 | 
         
            +
                        timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
         
     | 
| 260 | 
         
            +
                    elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
         
     | 
| 261 | 
         
            +
                        timesteps = timesteps[None].to(sample.device)
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                    # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
         
     | 
| 264 | 
         
            +
                    timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                    t_emb = self.time_proj(timesteps)
         
     | 
| 267 | 
         
            +
             
     | 
| 268 | 
         
            +
                    # timesteps does not contain any weights and will always return f32 tensors
         
     | 
| 269 | 
         
            +
                    # but time_embedding might actually be running in fp16. so we need to cast here.
         
     | 
| 270 | 
         
            +
                    # there might be better ways to encapsulate this.
         
     | 
| 271 | 
         
            +
                    t_emb = t_emb.to(dtype=self.dtype)
         
     | 
| 272 | 
         
            +
                    emb = self.time_embedding(t_emb)
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
                    if self.class_embedding is not None:
         
     | 
| 275 | 
         
            +
                        if class_labels is None:
         
     | 
| 276 | 
         
            +
                            raise ValueError("class_labels should be provided when doing class conditioning")
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
                        if self.config.class_embed_type == "timestep":
         
     | 
| 279 | 
         
            +
                            class_labels = self.time_proj(class_labels)
         
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
            +
                        class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
         
     | 
| 282 | 
         
            +
                        emb = emb + class_emb
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
                    # 2. pre-process
         
     | 
| 285 | 
         
            +
                    skip_sample = sample
         
     | 
| 286 | 
         
            +
                    sample = self.conv_in(sample)
         
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
                    # 3. down
         
     | 
| 289 | 
         
            +
                    down_block_res_samples = (sample,)
         
     | 
| 290 | 
         
            +
                    for downsample_block in self.down_blocks:
         
     | 
| 291 | 
         
            +
                        if hasattr(downsample_block, "skip_conv"):
         
     | 
| 292 | 
         
            +
                            sample, res_samples, skip_sample = downsample_block(
         
     | 
| 293 | 
         
            +
                                hidden_states=sample, temb=emb, skip_sample=skip_sample
         
     | 
| 294 | 
         
            +
                            )
         
     | 
| 295 | 
         
            +
                        else:
         
     | 
| 296 | 
         
            +
                            sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
                        down_block_res_samples += res_samples
         
     | 
| 299 | 
         
            +
             
     | 
| 300 | 
         
            +
                    # 4. mid
         
     | 
| 301 | 
         
            +
                    sample = self.mid_block(sample, emb)
         
     | 
| 302 | 
         
            +
             
     | 
| 303 | 
         
            +
                    # 5. up
         
     | 
| 304 | 
         
            +
                    skip_sample = None
         
     | 
| 305 | 
         
            +
                    for upsample_block in self.up_blocks:
         
     | 
| 306 | 
         
            +
                        res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
         
     | 
| 307 | 
         
            +
                        down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
         
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
                        if hasattr(upsample_block, "skip_conv"):
         
     | 
| 310 | 
         
            +
                            sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
         
     | 
| 311 | 
         
            +
                        else:
         
     | 
| 312 | 
         
            +
                            sample = upsample_block(sample, res_samples, emb)
         
     | 
| 313 | 
         
            +
             
     | 
| 314 | 
         
            +
                    # 6. post-process
         
     | 
| 315 | 
         
            +
                    sample = self.conv_norm_out(sample)
         
     | 
| 316 | 
         
            +
                    sample = self.conv_act(sample)
         
     | 
| 317 | 
         
            +
                    sample = self.conv_out(sample)
         
     | 
| 318 | 
         
            +
             
     | 
| 319 | 
         
            +
                    if skip_sample is not None:
         
     | 
| 320 | 
         
            +
                        sample += skip_sample
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
                    if self.config.time_embedding_type == "fourier":
         
     | 
| 323 | 
         
            +
                        timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
         
     | 
| 324 | 
         
            +
                        sample = sample / timesteps
         
     | 
| 325 | 
         
            +
             
     | 
| 326 | 
         
            +
                    if not return_dict:
         
     | 
| 327 | 
         
            +
                        return (sample,)
         
     | 
| 328 | 
         
            +
             
     | 
| 329 | 
         
            +
                    return UNet2DOutput(sample=sample)
         
     | 
    	
        6DoF/diffusers/models/unet_2d_blocks.py
    ADDED
    
    | 
         The diff for this file is too large to render. 
		See raw diff 
     | 
| 
         | 
    	
        6DoF/diffusers/models/unet_2d_blocks_flax.py
    ADDED
    
    | 
         @@ -0,0 +1,377 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            import flax.linen as nn
         
     | 
| 16 | 
         
            +
            import jax.numpy as jnp
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            from .attention_flax import FlaxTransformer2DModel
         
     | 
| 19 | 
         
            +
            from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            class FlaxCrossAttnDownBlock2D(nn.Module):
         
     | 
| 23 | 
         
            +
                r"""
         
     | 
| 24 | 
         
            +
                Cross Attention 2D Downsizing block - original architecture from Unet transformers:
         
     | 
| 25 | 
         
            +
                https://arxiv.org/abs/2103.06104
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                Parameters:
         
     | 
| 28 | 
         
            +
                    in_channels (:obj:`int`):
         
     | 
| 29 | 
         
            +
                        Input channels
         
     | 
| 30 | 
         
            +
                    out_channels (:obj:`int`):
         
     | 
| 31 | 
         
            +
                        Output channels
         
     | 
| 32 | 
         
            +
                    dropout (:obj:`float`, *optional*, defaults to 0.0):
         
     | 
| 33 | 
         
            +
                        Dropout rate
         
     | 
| 34 | 
         
            +
                    num_layers (:obj:`int`, *optional*, defaults to 1):
         
     | 
| 35 | 
         
            +
                        Number of attention blocks layers
         
     | 
| 36 | 
         
            +
                    num_attention_heads (:obj:`int`, *optional*, defaults to 1):
         
     | 
| 37 | 
         
            +
                        Number of attention heads of each spatial transformer block
         
     | 
| 38 | 
         
            +
                    add_downsample (:obj:`bool`, *optional*, defaults to `True`):
         
     | 
| 39 | 
         
            +
                        Whether to add downsampling layer before each final output
         
     | 
| 40 | 
         
            +
                    use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
         
     | 
| 41 | 
         
            +
                        enable memory efficient attention https://arxiv.org/abs/2112.05682
         
     | 
| 42 | 
         
            +
                    dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
         
     | 
| 43 | 
         
            +
                        Parameters `dtype`
         
     | 
| 44 | 
         
            +
                """
         
     | 
| 45 | 
         
            +
                in_channels: int
         
     | 
| 46 | 
         
            +
                out_channels: int
         
     | 
| 47 | 
         
            +
                dropout: float = 0.0
         
     | 
| 48 | 
         
            +
                num_layers: int = 1
         
     | 
| 49 | 
         
            +
                num_attention_heads: int = 1
         
     | 
| 50 | 
         
            +
                add_downsample: bool = True
         
     | 
| 51 | 
         
            +
                use_linear_projection: bool = False
         
     | 
| 52 | 
         
            +
                only_cross_attention: bool = False
         
     | 
| 53 | 
         
            +
                use_memory_efficient_attention: bool = False
         
     | 
| 54 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                def setup(self):
         
     | 
| 57 | 
         
            +
                    resnets = []
         
     | 
| 58 | 
         
            +
                    attentions = []
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                    for i in range(self.num_layers):
         
     | 
| 61 | 
         
            +
                        in_channels = self.in_channels if i == 0 else self.out_channels
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                        res_block = FlaxResnetBlock2D(
         
     | 
| 64 | 
         
            +
                            in_channels=in_channels,
         
     | 
| 65 | 
         
            +
                            out_channels=self.out_channels,
         
     | 
| 66 | 
         
            +
                            dropout_prob=self.dropout,
         
     | 
| 67 | 
         
            +
                            dtype=self.dtype,
         
     | 
| 68 | 
         
            +
                        )
         
     | 
| 69 | 
         
            +
                        resnets.append(res_block)
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                        attn_block = FlaxTransformer2DModel(
         
     | 
| 72 | 
         
            +
                            in_channels=self.out_channels,
         
     | 
| 73 | 
         
            +
                            n_heads=self.num_attention_heads,
         
     | 
| 74 | 
         
            +
                            d_head=self.out_channels // self.num_attention_heads,
         
     | 
| 75 | 
         
            +
                            depth=1,
         
     | 
| 76 | 
         
            +
                            use_linear_projection=self.use_linear_projection,
         
     | 
| 77 | 
         
            +
                            only_cross_attention=self.only_cross_attention,
         
     | 
| 78 | 
         
            +
                            use_memory_efficient_attention=self.use_memory_efficient_attention,
         
     | 
| 79 | 
         
            +
                            dtype=self.dtype,
         
     | 
| 80 | 
         
            +
                        )
         
     | 
| 81 | 
         
            +
                        attentions.append(attn_block)
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                    self.resnets = resnets
         
     | 
| 84 | 
         
            +
                    self.attentions = attentions
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                    if self.add_downsample:
         
     | 
| 87 | 
         
            +
                        self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
         
     | 
| 90 | 
         
            +
                    output_states = ()
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                    for resnet, attn in zip(self.resnets, self.attentions):
         
     | 
| 93 | 
         
            +
                        hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
         
     | 
| 94 | 
         
            +
                        hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
         
     | 
| 95 | 
         
            +
                        output_states += (hidden_states,)
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                    if self.add_downsample:
         
     | 
| 98 | 
         
            +
                        hidden_states = self.downsamplers_0(hidden_states)
         
     | 
| 99 | 
         
            +
                        output_states += (hidden_states,)
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                    return hidden_states, output_states
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
            class FlaxDownBlock2D(nn.Module):
         
     | 
| 105 | 
         
            +
                r"""
         
     | 
| 106 | 
         
            +
                Flax 2D downsizing block
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                Parameters:
         
     | 
| 109 | 
         
            +
                    in_channels (:obj:`int`):
         
     | 
| 110 | 
         
            +
                        Input channels
         
     | 
| 111 | 
         
            +
                    out_channels (:obj:`int`):
         
     | 
| 112 | 
         
            +
                        Output channels
         
     | 
| 113 | 
         
            +
                    dropout (:obj:`float`, *optional*, defaults to 0.0):
         
     | 
| 114 | 
         
            +
                        Dropout rate
         
     | 
| 115 | 
         
            +
                    num_layers (:obj:`int`, *optional*, defaults to 1):
         
     | 
| 116 | 
         
            +
                        Number of attention blocks layers
         
     | 
| 117 | 
         
            +
                    add_downsample (:obj:`bool`, *optional*, defaults to `True`):
         
     | 
| 118 | 
         
            +
                        Whether to add downsampling layer before each final output
         
     | 
| 119 | 
         
            +
                    dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
         
     | 
| 120 | 
         
            +
                        Parameters `dtype`
         
     | 
| 121 | 
         
            +
                """
         
     | 
| 122 | 
         
            +
                in_channels: int
         
     | 
| 123 | 
         
            +
                out_channels: int
         
     | 
| 124 | 
         
            +
                dropout: float = 0.0
         
     | 
| 125 | 
         
            +
                num_layers: int = 1
         
     | 
| 126 | 
         
            +
                add_downsample: bool = True
         
     | 
| 127 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                def setup(self):
         
     | 
| 130 | 
         
            +
                    resnets = []
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    for i in range(self.num_layers):
         
     | 
| 133 | 
         
            +
                        in_channels = self.in_channels if i == 0 else self.out_channels
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                        res_block = FlaxResnetBlock2D(
         
     | 
| 136 | 
         
            +
                            in_channels=in_channels,
         
     | 
| 137 | 
         
            +
                            out_channels=self.out_channels,
         
     | 
| 138 | 
         
            +
                            dropout_prob=self.dropout,
         
     | 
| 139 | 
         
            +
                            dtype=self.dtype,
         
     | 
| 140 | 
         
            +
                        )
         
     | 
| 141 | 
         
            +
                        resnets.append(res_block)
         
     | 
| 142 | 
         
            +
                    self.resnets = resnets
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                    if self.add_downsample:
         
     | 
| 145 | 
         
            +
                        self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
                def __call__(self, hidden_states, temb, deterministic=True):
         
     | 
| 148 | 
         
            +
                    output_states = ()
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                    for resnet in self.resnets:
         
     | 
| 151 | 
         
            +
                        hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
         
     | 
| 152 | 
         
            +
                        output_states += (hidden_states,)
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                    if self.add_downsample:
         
     | 
| 155 | 
         
            +
                        hidden_states = self.downsamplers_0(hidden_states)
         
     | 
| 156 | 
         
            +
                        output_states += (hidden_states,)
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                    return hidden_states, output_states
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
            class FlaxCrossAttnUpBlock2D(nn.Module):
         
     | 
| 162 | 
         
            +
                r"""
         
     | 
| 163 | 
         
            +
                Cross Attention 2D Upsampling block - original architecture from Unet transformers:
         
     | 
| 164 | 
         
            +
                https://arxiv.org/abs/2103.06104
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                Parameters:
         
     | 
| 167 | 
         
            +
                    in_channels (:obj:`int`):
         
     | 
| 168 | 
         
            +
                        Input channels
         
     | 
| 169 | 
         
            +
                    out_channels (:obj:`int`):
         
     | 
| 170 | 
         
            +
                        Output channels
         
     | 
| 171 | 
         
            +
                    dropout (:obj:`float`, *optional*, defaults to 0.0):
         
     | 
| 172 | 
         
            +
                        Dropout rate
         
     | 
| 173 | 
         
            +
                    num_layers (:obj:`int`, *optional*, defaults to 1):
         
     | 
| 174 | 
         
            +
                        Number of attention blocks layers
         
     | 
| 175 | 
         
            +
                    num_attention_heads (:obj:`int`, *optional*, defaults to 1):
         
     | 
| 176 | 
         
            +
                        Number of attention heads of each spatial transformer block
         
     | 
| 177 | 
         
            +
                    add_upsample (:obj:`bool`, *optional*, defaults to `True`):
         
     | 
| 178 | 
         
            +
                        Whether to add upsampling layer before each final output
         
     | 
| 179 | 
         
            +
                    use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
         
     | 
| 180 | 
         
            +
                        enable memory efficient attention https://arxiv.org/abs/2112.05682
         
     | 
| 181 | 
         
            +
                    dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
         
     | 
| 182 | 
         
            +
                        Parameters `dtype`
         
     | 
| 183 | 
         
            +
                """
         
     | 
| 184 | 
         
            +
                in_channels: int
         
     | 
| 185 | 
         
            +
                out_channels: int
         
     | 
| 186 | 
         
            +
                prev_output_channel: int
         
     | 
| 187 | 
         
            +
                dropout: float = 0.0
         
     | 
| 188 | 
         
            +
                num_layers: int = 1
         
     | 
| 189 | 
         
            +
                num_attention_heads: int = 1
         
     | 
| 190 | 
         
            +
                add_upsample: bool = True
         
     | 
| 191 | 
         
            +
                use_linear_projection: bool = False
         
     | 
| 192 | 
         
            +
                only_cross_attention: bool = False
         
     | 
| 193 | 
         
            +
                use_memory_efficient_attention: bool = False
         
     | 
| 194 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                def setup(self):
         
     | 
| 197 | 
         
            +
                    resnets = []
         
     | 
| 198 | 
         
            +
                    attentions = []
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                    for i in range(self.num_layers):
         
     | 
| 201 | 
         
            +
                        res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
         
     | 
| 202 | 
         
            +
                        resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                        res_block = FlaxResnetBlock2D(
         
     | 
| 205 | 
         
            +
                            in_channels=resnet_in_channels + res_skip_channels,
         
     | 
| 206 | 
         
            +
                            out_channels=self.out_channels,
         
     | 
| 207 | 
         
            +
                            dropout_prob=self.dropout,
         
     | 
| 208 | 
         
            +
                            dtype=self.dtype,
         
     | 
| 209 | 
         
            +
                        )
         
     | 
| 210 | 
         
            +
                        resnets.append(res_block)
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                        attn_block = FlaxTransformer2DModel(
         
     | 
| 213 | 
         
            +
                            in_channels=self.out_channels,
         
     | 
| 214 | 
         
            +
                            n_heads=self.num_attention_heads,
         
     | 
| 215 | 
         
            +
                            d_head=self.out_channels // self.num_attention_heads,
         
     | 
| 216 | 
         
            +
                            depth=1,
         
     | 
| 217 | 
         
            +
                            use_linear_projection=self.use_linear_projection,
         
     | 
| 218 | 
         
            +
                            only_cross_attention=self.only_cross_attention,
         
     | 
| 219 | 
         
            +
                            use_memory_efficient_attention=self.use_memory_efficient_attention,
         
     | 
| 220 | 
         
            +
                            dtype=self.dtype,
         
     | 
| 221 | 
         
            +
                        )
         
     | 
| 222 | 
         
            +
                        attentions.append(attn_block)
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
                    self.resnets = resnets
         
     | 
| 225 | 
         
            +
                    self.attentions = attentions
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
                    if self.add_upsample:
         
     | 
| 228 | 
         
            +
                        self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True):
         
     | 
| 231 | 
         
            +
                    for resnet, attn in zip(self.resnets, self.attentions):
         
     | 
| 232 | 
         
            +
                        # pop res hidden states
         
     | 
| 233 | 
         
            +
                        res_hidden_states = res_hidden_states_tuple[-1]
         
     | 
| 234 | 
         
            +
                        res_hidden_states_tuple = res_hidden_states_tuple[:-1]
         
     | 
| 235 | 
         
            +
                        hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                        hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
         
     | 
| 238 | 
         
            +
                        hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                    if self.add_upsample:
         
     | 
| 241 | 
         
            +
                        hidden_states = self.upsamplers_0(hidden_states)
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                    return hidden_states
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
             
     | 
| 246 | 
         
            +
            class FlaxUpBlock2D(nn.Module):
         
     | 
| 247 | 
         
            +
                r"""
         
     | 
| 248 | 
         
            +
                Flax 2D upsampling block
         
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
                Parameters:
         
     | 
| 251 | 
         
            +
                    in_channels (:obj:`int`):
         
     | 
| 252 | 
         
            +
                        Input channels
         
     | 
| 253 | 
         
            +
                    out_channels (:obj:`int`):
         
     | 
| 254 | 
         
            +
                        Output channels
         
     | 
| 255 | 
         
            +
                    prev_output_channel (:obj:`int`):
         
     | 
| 256 | 
         
            +
                        Output channels from the previous block
         
     | 
| 257 | 
         
            +
                    dropout (:obj:`float`, *optional*, defaults to 0.0):
         
     | 
| 258 | 
         
            +
                        Dropout rate
         
     | 
| 259 | 
         
            +
                    num_layers (:obj:`int`, *optional*, defaults to 1):
         
     | 
| 260 | 
         
            +
                        Number of attention blocks layers
         
     | 
| 261 | 
         
            +
                    add_downsample (:obj:`bool`, *optional*, defaults to `True`):
         
     | 
| 262 | 
         
            +
                        Whether to add downsampling layer before each final output
         
     | 
| 263 | 
         
            +
                    dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
         
     | 
| 264 | 
         
            +
                        Parameters `dtype`
         
     | 
| 265 | 
         
            +
                """
         
     | 
| 266 | 
         
            +
                in_channels: int
         
     | 
| 267 | 
         
            +
                out_channels: int
         
     | 
| 268 | 
         
            +
                prev_output_channel: int
         
     | 
| 269 | 
         
            +
                dropout: float = 0.0
         
     | 
| 270 | 
         
            +
                num_layers: int = 1
         
     | 
| 271 | 
         
            +
                add_upsample: bool = True
         
     | 
| 272 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
                def setup(self):
         
     | 
| 275 | 
         
            +
                    resnets = []
         
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
                    for i in range(self.num_layers):
         
     | 
| 278 | 
         
            +
                        res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
         
     | 
| 279 | 
         
            +
                        resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
         
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
            +
                        res_block = FlaxResnetBlock2D(
         
     | 
| 282 | 
         
            +
                            in_channels=resnet_in_channels + res_skip_channels,
         
     | 
| 283 | 
         
            +
                            out_channels=self.out_channels,
         
     | 
| 284 | 
         
            +
                            dropout_prob=self.dropout,
         
     | 
| 285 | 
         
            +
                            dtype=self.dtype,
         
     | 
| 286 | 
         
            +
                        )
         
     | 
| 287 | 
         
            +
                        resnets.append(res_block)
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
                    self.resnets = resnets
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                    if self.add_upsample:
         
     | 
| 292 | 
         
            +
                        self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
         
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
                def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True):
         
     | 
| 295 | 
         
            +
                    for resnet in self.resnets:
         
     | 
| 296 | 
         
            +
                        # pop res hidden states
         
     | 
| 297 | 
         
            +
                        res_hidden_states = res_hidden_states_tuple[-1]
         
     | 
| 298 | 
         
            +
                        res_hidden_states_tuple = res_hidden_states_tuple[:-1]
         
     | 
| 299 | 
         
            +
                        hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                        hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
         
     | 
| 302 | 
         
            +
             
     | 
| 303 | 
         
            +
                    if self.add_upsample:
         
     | 
| 304 | 
         
            +
                        hidden_states = self.upsamplers_0(hidden_states)
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
                    return hidden_states
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
            class FlaxUNetMidBlock2DCrossAttn(nn.Module):
         
     | 
| 310 | 
         
            +
                r"""
         
     | 
| 311 | 
         
            +
                Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104
         
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
                Parameters:
         
     | 
| 314 | 
         
            +
                    in_channels (:obj:`int`):
         
     | 
| 315 | 
         
            +
                        Input channels
         
     | 
| 316 | 
         
            +
                    dropout (:obj:`float`, *optional*, defaults to 0.0):
         
     | 
| 317 | 
         
            +
                        Dropout rate
         
     | 
| 318 | 
         
            +
                    num_layers (:obj:`int`, *optional*, defaults to 1):
         
     | 
| 319 | 
         
            +
                        Number of attention blocks layers
         
     | 
| 320 | 
         
            +
                    num_attention_heads (:obj:`int`, *optional*, defaults to 1):
         
     | 
| 321 | 
         
            +
                        Number of attention heads of each spatial transformer block
         
     | 
| 322 | 
         
            +
                    use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
         
     | 
| 323 | 
         
            +
                        enable memory efficient attention https://arxiv.org/abs/2112.05682
         
     | 
| 324 | 
         
            +
                    dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
         
     | 
| 325 | 
         
            +
                        Parameters `dtype`
         
     | 
| 326 | 
         
            +
                """
         
     | 
| 327 | 
         
            +
                in_channels: int
         
     | 
| 328 | 
         
            +
                dropout: float = 0.0
         
     | 
| 329 | 
         
            +
                num_layers: int = 1
         
     | 
| 330 | 
         
            +
                num_attention_heads: int = 1
         
     | 
| 331 | 
         
            +
                use_linear_projection: bool = False
         
     | 
| 332 | 
         
            +
                use_memory_efficient_attention: bool = False
         
     | 
| 333 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 334 | 
         
            +
             
     | 
| 335 | 
         
            +
                def setup(self):
         
     | 
| 336 | 
         
            +
                    # there is always at least one resnet
         
     | 
| 337 | 
         
            +
                    resnets = [
         
     | 
| 338 | 
         
            +
                        FlaxResnetBlock2D(
         
     | 
| 339 | 
         
            +
                            in_channels=self.in_channels,
         
     | 
| 340 | 
         
            +
                            out_channels=self.in_channels,
         
     | 
| 341 | 
         
            +
                            dropout_prob=self.dropout,
         
     | 
| 342 | 
         
            +
                            dtype=self.dtype,
         
     | 
| 343 | 
         
            +
                        )
         
     | 
| 344 | 
         
            +
                    ]
         
     | 
| 345 | 
         
            +
             
     | 
| 346 | 
         
            +
                    attentions = []
         
     | 
| 347 | 
         
            +
             
     | 
| 348 | 
         
            +
                    for _ in range(self.num_layers):
         
     | 
| 349 | 
         
            +
                        attn_block = FlaxTransformer2DModel(
         
     | 
| 350 | 
         
            +
                            in_channels=self.in_channels,
         
     | 
| 351 | 
         
            +
                            n_heads=self.num_attention_heads,
         
     | 
| 352 | 
         
            +
                            d_head=self.in_channels // self.num_attention_heads,
         
     | 
| 353 | 
         
            +
                            depth=1,
         
     | 
| 354 | 
         
            +
                            use_linear_projection=self.use_linear_projection,
         
     | 
| 355 | 
         
            +
                            use_memory_efficient_attention=self.use_memory_efficient_attention,
         
     | 
| 356 | 
         
            +
                            dtype=self.dtype,
         
     | 
| 357 | 
         
            +
                        )
         
     | 
| 358 | 
         
            +
                        attentions.append(attn_block)
         
     | 
| 359 | 
         
            +
             
     | 
| 360 | 
         
            +
                        res_block = FlaxResnetBlock2D(
         
     | 
| 361 | 
         
            +
                            in_channels=self.in_channels,
         
     | 
| 362 | 
         
            +
                            out_channels=self.in_channels,
         
     | 
| 363 | 
         
            +
                            dropout_prob=self.dropout,
         
     | 
| 364 | 
         
            +
                            dtype=self.dtype,
         
     | 
| 365 | 
         
            +
                        )
         
     | 
| 366 | 
         
            +
                        resnets.append(res_block)
         
     | 
| 367 | 
         
            +
             
     | 
| 368 | 
         
            +
                    self.resnets = resnets
         
     | 
| 369 | 
         
            +
                    self.attentions = attentions
         
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
                def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
         
     | 
| 372 | 
         
            +
                    hidden_states = self.resnets[0](hidden_states, temb)
         
     | 
| 373 | 
         
            +
                    for attn, resnet in zip(self.attentions, self.resnets[1:]):
         
     | 
| 374 | 
         
            +
                        hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
         
     | 
| 375 | 
         
            +
                        hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
         
     | 
| 376 | 
         
            +
             
     | 
| 377 | 
         
            +
                    return hidden_states
         
     | 
    	
        6DoF/diffusers/models/unet_2d_condition.py
    ADDED
    
    | 
         @@ -0,0 +1,980 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 15 | 
         
            +
            from typing import Any, Dict, List, Optional, Tuple, Union
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            import torch
         
     | 
| 18 | 
         
            +
            import torch.nn as nn
         
     | 
| 19 | 
         
            +
            import torch.utils.checkpoint
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            from ..configuration_utils import ConfigMixin, register_to_config
         
     | 
| 22 | 
         
            +
            from ..loaders import UNet2DConditionLoadersMixin
         
     | 
| 23 | 
         
            +
            from ..utils import BaseOutput, logging
         
     | 
| 24 | 
         
            +
            from .activations import get_activation
         
     | 
| 25 | 
         
            +
            from .attention_processor import AttentionProcessor, AttnProcessor
         
     | 
| 26 | 
         
            +
            from .embeddings import (
         
     | 
| 27 | 
         
            +
                GaussianFourierProjection,
         
     | 
| 28 | 
         
            +
                ImageHintTimeEmbedding,
         
     | 
| 29 | 
         
            +
                ImageProjection,
         
     | 
| 30 | 
         
            +
                ImageTimeEmbedding,
         
     | 
| 31 | 
         
            +
                TextImageProjection,
         
     | 
| 32 | 
         
            +
                TextImageTimeEmbedding,
         
     | 
| 33 | 
         
            +
                TextTimeEmbedding,
         
     | 
| 34 | 
         
            +
                TimestepEmbedding,
         
     | 
| 35 | 
         
            +
                Timesteps,
         
     | 
| 36 | 
         
            +
            )
         
     | 
| 37 | 
         
            +
            from .modeling_utils import ModelMixin
         
     | 
| 38 | 
         
            +
            from .unet_2d_blocks import (
         
     | 
| 39 | 
         
            +
                CrossAttnDownBlock2D,
         
     | 
| 40 | 
         
            +
                CrossAttnUpBlock2D,
         
     | 
| 41 | 
         
            +
                DownBlock2D,
         
     | 
| 42 | 
         
            +
                UNetMidBlock2DCrossAttn,
         
     | 
| 43 | 
         
            +
                UNetMidBlock2DSimpleCrossAttn,
         
     | 
| 44 | 
         
            +
                UpBlock2D,
         
     | 
| 45 | 
         
            +
                get_down_block,
         
     | 
| 46 | 
         
            +
                get_up_block,
         
     | 
| 47 | 
         
            +
            )
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
            @dataclass
         
     | 
| 54 | 
         
            +
            class UNet2DConditionOutput(BaseOutput):
         
     | 
| 55 | 
         
            +
                """
         
     | 
| 56 | 
         
            +
                The output of [`UNet2DConditionModel`].
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                Args:
         
     | 
| 59 | 
         
            +
                    sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
         
     | 
| 60 | 
         
            +
                        The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
         
     | 
| 61 | 
         
            +
                """
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                sample: torch.FloatTensor = None
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
            class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
         
     | 
| 67 | 
         
            +
                r"""
         
     | 
| 68 | 
         
            +
                A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
         
     | 
| 69 | 
         
            +
                shaped output.
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
         
     | 
| 72 | 
         
            +
                for all models (such as downloading or saving).
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                Parameters:
         
     | 
| 75 | 
         
            +
                    sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
         
     | 
| 76 | 
         
            +
                        Height and width of input/output sample.
         
     | 
| 77 | 
         
            +
                    in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
         
     | 
| 78 | 
         
            +
                    out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
         
     | 
| 79 | 
         
            +
                    center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
         
     | 
| 80 | 
         
            +
                    flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
         
     | 
| 81 | 
         
            +
                        Whether to flip the sin to cos in the time embedding.
         
     | 
| 82 | 
         
            +
                    freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
         
     | 
| 83 | 
         
            +
                    down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
         
     | 
| 84 | 
         
            +
                        The tuple of downsample blocks to use.
         
     | 
| 85 | 
         
            +
                    mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
         
     | 
| 86 | 
         
            +
                        Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
         
     | 
| 87 | 
         
            +
                        `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
         
     | 
| 88 | 
         
            +
                    up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
         
     | 
| 89 | 
         
            +
                        The tuple of upsample blocks to use.
         
     | 
| 90 | 
         
            +
                    only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
         
     | 
| 91 | 
         
            +
                        Whether to include self-attention in the basic transformer blocks, see
         
     | 
| 92 | 
         
            +
                        [`~models.attention.BasicTransformerBlock`].
         
     | 
| 93 | 
         
            +
                    block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
         
     | 
| 94 | 
         
            +
                        The tuple of output channels for each block.
         
     | 
| 95 | 
         
            +
                    layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
         
     | 
| 96 | 
         
            +
                    downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
         
     | 
| 97 | 
         
            +
                    mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
         
     | 
| 98 | 
         
            +
                    act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
         
     | 
| 99 | 
         
            +
                    norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
         
     | 
| 100 | 
         
            +
                        If `None`, normalization and activation layers is skipped in post-processing.
         
     | 
| 101 | 
         
            +
                    norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
         
     | 
| 102 | 
         
            +
                    cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
         
     | 
| 103 | 
         
            +
                        The dimension of the cross attention features.
         
     | 
| 104 | 
         
            +
                    transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
         
     | 
| 105 | 
         
            +
                        The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
         
     | 
| 106 | 
         
            +
                        [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
         
     | 
| 107 | 
         
            +
                        [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
         
     | 
| 108 | 
         
            +
                    encoder_hid_dim (`int`, *optional*, defaults to None):
         
     | 
| 109 | 
         
            +
                        If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
         
     | 
| 110 | 
         
            +
                        dimension to `cross_attention_dim`.
         
     | 
| 111 | 
         
            +
                    encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
         
     | 
| 112 | 
         
            +
                        If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
         
     | 
| 113 | 
         
            +
                        embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
         
     | 
| 114 | 
         
            +
                    attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
         
     | 
| 115 | 
         
            +
                    num_attention_heads (`int`, *optional*):
         
     | 
| 116 | 
         
            +
                        The number of attention heads. If not defined, defaults to `attention_head_dim`
         
     | 
| 117 | 
         
            +
                    resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
         
     | 
| 118 | 
         
            +
                        for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
         
     | 
| 119 | 
         
            +
                    class_embed_type (`str`, *optional*, defaults to `None`):
         
     | 
| 120 | 
         
            +
                        The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
         
     | 
| 121 | 
         
            +
                        `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
         
     | 
| 122 | 
         
            +
                    addition_embed_type (`str`, *optional*, defaults to `None`):
         
     | 
| 123 | 
         
            +
                        Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
         
     | 
| 124 | 
         
            +
                        "text". "text" will use the `TextTimeEmbedding` layer.
         
     | 
| 125 | 
         
            +
                    addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
         
     | 
| 126 | 
         
            +
                        Dimension for the timestep embeddings.
         
     | 
| 127 | 
         
            +
                    num_class_embeds (`int`, *optional*, defaults to `None`):
         
     | 
| 128 | 
         
            +
                        Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
         
     | 
| 129 | 
         
            +
                        class conditioning with `class_embed_type` equal to `None`.
         
     | 
| 130 | 
         
            +
                    time_embedding_type (`str`, *optional*, defaults to `positional`):
         
     | 
| 131 | 
         
            +
                        The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
         
     | 
| 132 | 
         
            +
                    time_embedding_dim (`int`, *optional*, defaults to `None`):
         
     | 
| 133 | 
         
            +
                        An optional override for the dimension of the projected time embedding.
         
     | 
| 134 | 
         
            +
                    time_embedding_act_fn (`str`, *optional*, defaults to `None`):
         
     | 
| 135 | 
         
            +
                        Optional activation function to use only once on the time embeddings before they are passed to the rest of
         
     | 
| 136 | 
         
            +
                        the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
         
     | 
| 137 | 
         
            +
                    timestep_post_act (`str`, *optional*, defaults to `None`):
         
     | 
| 138 | 
         
            +
                        The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
         
     | 
| 139 | 
         
            +
                    time_cond_proj_dim (`int`, *optional*, defaults to `None`):
         
     | 
| 140 | 
         
            +
                        The dimension of `cond_proj` layer in the timestep embedding.
         
     | 
| 141 | 
         
            +
                    conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
         
     | 
| 142 | 
         
            +
                    conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
         
     | 
| 143 | 
         
            +
                    projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
         
     | 
| 144 | 
         
            +
                        `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
         
     | 
| 145 | 
         
            +
                    class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
         
     | 
| 146 | 
         
            +
                        embeddings with the class embeddings.
         
     | 
| 147 | 
         
            +
                    mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
         
     | 
| 148 | 
         
            +
                        Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
         
     | 
| 149 | 
         
            +
                        `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
         
     | 
| 150 | 
         
            +
                        `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
         
     | 
| 151 | 
         
            +
                        otherwise.
         
     | 
| 152 | 
         
            +
                """
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                _supports_gradient_checkpointing = True
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                @register_to_config
         
     | 
| 157 | 
         
            +
                def __init__(
         
     | 
| 158 | 
         
            +
                    self,
         
     | 
| 159 | 
         
            +
                    sample_size: Optional[int] = None,
         
     | 
| 160 | 
         
            +
                    in_channels: int = 4,
         
     | 
| 161 | 
         
            +
                    out_channels: int = 4,
         
     | 
| 162 | 
         
            +
                    center_input_sample: bool = False,
         
     | 
| 163 | 
         
            +
                    flip_sin_to_cos: bool = True,
         
     | 
| 164 | 
         
            +
                    freq_shift: int = 0,
         
     | 
| 165 | 
         
            +
                    down_block_types: Tuple[str] = (
         
     | 
| 166 | 
         
            +
                        "CrossAttnDownBlock2D",
         
     | 
| 167 | 
         
            +
                        "CrossAttnDownBlock2D",
         
     | 
| 168 | 
         
            +
                        "CrossAttnDownBlock2D",
         
     | 
| 169 | 
         
            +
                        "DownBlock2D",
         
     | 
| 170 | 
         
            +
                    ),
         
     | 
| 171 | 
         
            +
                    mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
         
     | 
| 172 | 
         
            +
                    up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
         
     | 
| 173 | 
         
            +
                    only_cross_attention: Union[bool, Tuple[bool]] = False,
         
     | 
| 174 | 
         
            +
                    block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
         
     | 
| 175 | 
         
            +
                    layers_per_block: Union[int, Tuple[int]] = 2,
         
     | 
| 176 | 
         
            +
                    downsample_padding: int = 1,
         
     | 
| 177 | 
         
            +
                    mid_block_scale_factor: float = 1,
         
     | 
| 178 | 
         
            +
                    act_fn: str = "silu",
         
     | 
| 179 | 
         
            +
                    norm_num_groups: Optional[int] = 32,
         
     | 
| 180 | 
         
            +
                    norm_eps: float = 1e-5,
         
     | 
| 181 | 
         
            +
                    cross_attention_dim: Union[int, Tuple[int]] = 1280,
         
     | 
| 182 | 
         
            +
                    transformer_layers_per_block: Union[int, Tuple[int]] = 1,
         
     | 
| 183 | 
         
            +
                    encoder_hid_dim: Optional[int] = None,
         
     | 
| 184 | 
         
            +
                    encoder_hid_dim_type: Optional[str] = None,
         
     | 
| 185 | 
         
            +
                    attention_head_dim: Union[int, Tuple[int]] = 8,
         
     | 
| 186 | 
         
            +
                    num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
         
     | 
| 187 | 
         
            +
                    dual_cross_attention: bool = False,
         
     | 
| 188 | 
         
            +
                    use_linear_projection: bool = False,
         
     | 
| 189 | 
         
            +
                    class_embed_type: Optional[str] = None,
         
     | 
| 190 | 
         
            +
                    addition_embed_type: Optional[str] = None,
         
     | 
| 191 | 
         
            +
                    addition_time_embed_dim: Optional[int] = None,
         
     | 
| 192 | 
         
            +
                    num_class_embeds: Optional[int] = None,
         
     | 
| 193 | 
         
            +
                    upcast_attention: bool = False,
         
     | 
| 194 | 
         
            +
                    resnet_time_scale_shift: str = "default",
         
     | 
| 195 | 
         
            +
                    resnet_skip_time_act: bool = False,
         
     | 
| 196 | 
         
            +
                    resnet_out_scale_factor: int = 1.0,
         
     | 
| 197 | 
         
            +
                    time_embedding_type: str = "positional",
         
     | 
| 198 | 
         
            +
                    time_embedding_dim: Optional[int] = None,
         
     | 
| 199 | 
         
            +
                    time_embedding_act_fn: Optional[str] = None,
         
     | 
| 200 | 
         
            +
                    timestep_post_act: Optional[str] = None,
         
     | 
| 201 | 
         
            +
                    time_cond_proj_dim: Optional[int] = None,
         
     | 
| 202 | 
         
            +
                    conv_in_kernel: int = 3,
         
     | 
| 203 | 
         
            +
                    conv_out_kernel: int = 3,
         
     | 
| 204 | 
         
            +
                    projection_class_embeddings_input_dim: Optional[int] = None,
         
     | 
| 205 | 
         
            +
                    class_embeddings_concat: bool = False,
         
     | 
| 206 | 
         
            +
                    mid_block_only_cross_attention: Optional[bool] = None,
         
     | 
| 207 | 
         
            +
                    cross_attention_norm: Optional[str] = None,
         
     | 
| 208 | 
         
            +
                    addition_embed_type_num_heads=64,
         
     | 
| 209 | 
         
            +
                ):
         
     | 
| 210 | 
         
            +
                    super().__init__()
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                    self.sample_size = sample_size
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                    if num_attention_heads is not None:
         
     | 
| 215 | 
         
            +
                        raise ValueError(
         
     | 
| 216 | 
         
            +
                            "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
         
     | 
| 217 | 
         
            +
                        )
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                    # If `num_attention_heads` is not defined (which is the case for most models)
         
     | 
| 220 | 
         
            +
                    # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
         
     | 
| 221 | 
         
            +
                    # The reason for this behavior is to correct for incorrectly named variables that were introduced
         
     | 
| 222 | 
         
            +
                    # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
         
     | 
| 223 | 
         
            +
                    # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
         
     | 
| 224 | 
         
            +
                    # which is why we correct for the naming here.
         
     | 
| 225 | 
         
            +
                    num_attention_heads = num_attention_heads or attention_head_dim
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
                    # Check inputs
         
     | 
| 228 | 
         
            +
                    if len(down_block_types) != len(up_block_types):
         
     | 
| 229 | 
         
            +
                        raise ValueError(
         
     | 
| 230 | 
         
            +
                            f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
         
     | 
| 231 | 
         
            +
                        )
         
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
                    if len(block_out_channels) != len(down_block_types):
         
     | 
| 234 | 
         
            +
                        raise ValueError(
         
     | 
| 235 | 
         
            +
                            f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
         
     | 
| 236 | 
         
            +
                        )
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                    if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
         
     | 
| 239 | 
         
            +
                        raise ValueError(
         
     | 
| 240 | 
         
            +
                            f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
         
     | 
| 241 | 
         
            +
                        )
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                    if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
         
     | 
| 244 | 
         
            +
                        raise ValueError(
         
     | 
| 245 | 
         
            +
                            f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
         
     | 
| 246 | 
         
            +
                        )
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                    if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
         
     | 
| 249 | 
         
            +
                        raise ValueError(
         
     | 
| 250 | 
         
            +
                            f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
         
     | 
| 251 | 
         
            +
                        )
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
                    if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
         
     | 
| 254 | 
         
            +
                        raise ValueError(
         
     | 
| 255 | 
         
            +
                            f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
         
     | 
| 256 | 
         
            +
                        )
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
                    if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
         
     | 
| 259 | 
         
            +
                        raise ValueError(
         
     | 
| 260 | 
         
            +
                            f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
         
     | 
| 261 | 
         
            +
                        )
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                    # input
         
     | 
| 264 | 
         
            +
                    conv_in_padding = (conv_in_kernel - 1) // 2
         
     | 
| 265 | 
         
            +
                    self.conv_in = nn.Conv2d(
         
     | 
| 266 | 
         
            +
                        in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
         
     | 
| 267 | 
         
            +
                    )
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                    # time
         
     | 
| 270 | 
         
            +
                    if time_embedding_type == "fourier":
         
     | 
| 271 | 
         
            +
                        time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
         
     | 
| 272 | 
         
            +
                        if time_embed_dim % 2 != 0:
         
     | 
| 273 | 
         
            +
                            raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
         
     | 
| 274 | 
         
            +
                        self.time_proj = GaussianFourierProjection(
         
     | 
| 275 | 
         
            +
                            time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
         
     | 
| 276 | 
         
            +
                        )
         
     | 
| 277 | 
         
            +
                        timestep_input_dim = time_embed_dim
         
     | 
| 278 | 
         
            +
                    elif time_embedding_type == "positional":
         
     | 
| 279 | 
         
            +
                        time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
         
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
            +
                        self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
         
     | 
| 282 | 
         
            +
                        timestep_input_dim = block_out_channels[0]
         
     | 
| 283 | 
         
            +
                    else:
         
     | 
| 284 | 
         
            +
                        raise ValueError(
         
     | 
| 285 | 
         
            +
                            f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
         
     | 
| 286 | 
         
            +
                        )
         
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
                    self.time_embedding = TimestepEmbedding(
         
     | 
| 289 | 
         
            +
                        timestep_input_dim,
         
     | 
| 290 | 
         
            +
                        time_embed_dim,
         
     | 
| 291 | 
         
            +
                        act_fn=act_fn,
         
     | 
| 292 | 
         
            +
                        post_act_fn=timestep_post_act,
         
     | 
| 293 | 
         
            +
                        cond_proj_dim=time_cond_proj_dim,
         
     | 
| 294 | 
         
            +
                    )
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
                    if encoder_hid_dim_type is None and encoder_hid_dim is not None:
         
     | 
| 297 | 
         
            +
                        encoder_hid_dim_type = "text_proj"
         
     | 
| 298 | 
         
            +
                        self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
         
     | 
| 299 | 
         
            +
                        logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                    if encoder_hid_dim is None and encoder_hid_dim_type is not None:
         
     | 
| 302 | 
         
            +
                        raise ValueError(
         
     | 
| 303 | 
         
            +
                            f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
         
     | 
| 304 | 
         
            +
                        )
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
                    if encoder_hid_dim_type == "text_proj":
         
     | 
| 307 | 
         
            +
                        self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
         
     | 
| 308 | 
         
            +
                    elif encoder_hid_dim_type == "text_image_proj":
         
     | 
| 309 | 
         
            +
                        # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
         
     | 
| 310 | 
         
            +
                        # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
         
     | 
| 311 | 
         
            +
                        # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
         
     | 
| 312 | 
         
            +
                        self.encoder_hid_proj = TextImageProjection(
         
     | 
| 313 | 
         
            +
                            text_embed_dim=encoder_hid_dim,
         
     | 
| 314 | 
         
            +
                            image_embed_dim=cross_attention_dim,
         
     | 
| 315 | 
         
            +
                            cross_attention_dim=cross_attention_dim,
         
     | 
| 316 | 
         
            +
                        )
         
     | 
| 317 | 
         
            +
                    elif encoder_hid_dim_type == "image_proj":
         
     | 
| 318 | 
         
            +
                        # Kandinsky 2.2
         
     | 
| 319 | 
         
            +
                        self.encoder_hid_proj = ImageProjection(
         
     | 
| 320 | 
         
            +
                            image_embed_dim=encoder_hid_dim,
         
     | 
| 321 | 
         
            +
                            cross_attention_dim=cross_attention_dim,
         
     | 
| 322 | 
         
            +
                        )
         
     | 
| 323 | 
         
            +
                    elif encoder_hid_dim_type is not None:
         
     | 
| 324 | 
         
            +
                        raise ValueError(
         
     | 
| 325 | 
         
            +
                            f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
         
     | 
| 326 | 
         
            +
                        )
         
     | 
| 327 | 
         
            +
                    else:
         
     | 
| 328 | 
         
            +
                        self.encoder_hid_proj = None
         
     | 
| 329 | 
         
            +
             
     | 
| 330 | 
         
            +
                    # class embedding
         
     | 
| 331 | 
         
            +
                    if class_embed_type is None and num_class_embeds is not None:
         
     | 
| 332 | 
         
            +
                        self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
         
     | 
| 333 | 
         
            +
                    elif class_embed_type == "timestep":
         
     | 
| 334 | 
         
            +
                        self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
         
     | 
| 335 | 
         
            +
                    elif class_embed_type == "identity":
         
     | 
| 336 | 
         
            +
                        self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
         
     | 
| 337 | 
         
            +
                    elif class_embed_type == "projection":
         
     | 
| 338 | 
         
            +
                        if projection_class_embeddings_input_dim is None:
         
     | 
| 339 | 
         
            +
                            raise ValueError(
         
     | 
| 340 | 
         
            +
                                "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
         
     | 
| 341 | 
         
            +
                            )
         
     | 
| 342 | 
         
            +
                        # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
         
     | 
| 343 | 
         
            +
                        # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
         
     | 
| 344 | 
         
            +
                        # 2. it projects from an arbitrary input dimension.
         
     | 
| 345 | 
         
            +
                        #
         
     | 
| 346 | 
         
            +
                        # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
         
     | 
| 347 | 
         
            +
                        # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
         
     | 
| 348 | 
         
            +
                        # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
         
     | 
| 349 | 
         
            +
                        self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
         
     | 
| 350 | 
         
            +
                    elif class_embed_type == "simple_projection":
         
     | 
| 351 | 
         
            +
                        if projection_class_embeddings_input_dim is None:
         
     | 
| 352 | 
         
            +
                            raise ValueError(
         
     | 
| 353 | 
         
            +
                                "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
         
     | 
| 354 | 
         
            +
                            )
         
     | 
| 355 | 
         
            +
                        self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
         
     | 
| 356 | 
         
            +
                    else:
         
     | 
| 357 | 
         
            +
                        self.class_embedding = None
         
     | 
| 358 | 
         
            +
             
     | 
| 359 | 
         
            +
                    if addition_embed_type == "text":
         
     | 
| 360 | 
         
            +
                        if encoder_hid_dim is not None:
         
     | 
| 361 | 
         
            +
                            text_time_embedding_from_dim = encoder_hid_dim
         
     | 
| 362 | 
         
            +
                        else:
         
     | 
| 363 | 
         
            +
                            text_time_embedding_from_dim = cross_attention_dim
         
     | 
| 364 | 
         
            +
             
     | 
| 365 | 
         
            +
                        self.add_embedding = TextTimeEmbedding(
         
     | 
| 366 | 
         
            +
                            text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
         
     | 
| 367 | 
         
            +
                        )
         
     | 
| 368 | 
         
            +
                    elif addition_embed_type == "text_image":
         
     | 
| 369 | 
         
            +
                        # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
         
     | 
| 370 | 
         
            +
                        # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
         
     | 
| 371 | 
         
            +
                        # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
         
     | 
| 372 | 
         
            +
                        self.add_embedding = TextImageTimeEmbedding(
         
     | 
| 373 | 
         
            +
                            text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
         
     | 
| 374 | 
         
            +
                        )
         
     | 
| 375 | 
         
            +
                    elif addition_embed_type == "text_time":
         
     | 
| 376 | 
         
            +
                        self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
         
     | 
| 377 | 
         
            +
                        self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
         
     | 
| 378 | 
         
            +
                    elif addition_embed_type == "image":
         
     | 
| 379 | 
         
            +
                        # Kandinsky 2.2
         
     | 
| 380 | 
         
            +
                        self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
         
     | 
| 381 | 
         
            +
                    elif addition_embed_type == "image_hint":
         
     | 
| 382 | 
         
            +
                        # Kandinsky 2.2 ControlNet
         
     | 
| 383 | 
         
            +
                        self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
         
     | 
| 384 | 
         
            +
                    elif addition_embed_type is not None:
         
     | 
| 385 | 
         
            +
                        raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
         
     | 
| 386 | 
         
            +
             
     | 
| 387 | 
         
            +
                    if time_embedding_act_fn is None:
         
     | 
| 388 | 
         
            +
                        self.time_embed_act = None
         
     | 
| 389 | 
         
            +
                    else:
         
     | 
| 390 | 
         
            +
                        self.time_embed_act = get_activation(time_embedding_act_fn)
         
     | 
| 391 | 
         
            +
             
     | 
| 392 | 
         
            +
                    self.down_blocks = nn.ModuleList([])
         
     | 
| 393 | 
         
            +
                    self.up_blocks = nn.ModuleList([])
         
     | 
| 394 | 
         
            +
             
     | 
| 395 | 
         
            +
                    if isinstance(only_cross_attention, bool):
         
     | 
| 396 | 
         
            +
                        if mid_block_only_cross_attention is None:
         
     | 
| 397 | 
         
            +
                            mid_block_only_cross_attention = only_cross_attention
         
     | 
| 398 | 
         
            +
             
     | 
| 399 | 
         
            +
                        only_cross_attention = [only_cross_attention] * len(down_block_types)
         
     | 
| 400 | 
         
            +
             
     | 
| 401 | 
         
            +
                    if mid_block_only_cross_attention is None:
         
     | 
| 402 | 
         
            +
                        mid_block_only_cross_attention = False
         
     | 
| 403 | 
         
            +
             
     | 
| 404 | 
         
            +
                    if isinstance(num_attention_heads, int):
         
     | 
| 405 | 
         
            +
                        num_attention_heads = (num_attention_heads,) * len(down_block_types)
         
     | 
| 406 | 
         
            +
             
     | 
| 407 | 
         
            +
                    if isinstance(attention_head_dim, int):
         
     | 
| 408 | 
         
            +
                        attention_head_dim = (attention_head_dim,) * len(down_block_types)
         
     | 
| 409 | 
         
            +
             
     | 
| 410 | 
         
            +
                    if isinstance(cross_attention_dim, int):
         
     | 
| 411 | 
         
            +
                        cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
         
     | 
| 412 | 
         
            +
             
     | 
| 413 | 
         
            +
                    if isinstance(layers_per_block, int):
         
     | 
| 414 | 
         
            +
                        layers_per_block = [layers_per_block] * len(down_block_types)
         
     | 
| 415 | 
         
            +
             
     | 
| 416 | 
         
            +
                    if isinstance(transformer_layers_per_block, int):
         
     | 
| 417 | 
         
            +
                        transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
         
     | 
| 418 | 
         
            +
             
     | 
| 419 | 
         
            +
                    if class_embeddings_concat:
         
     | 
| 420 | 
         
            +
                        # The time embeddings are concatenated with the class embeddings. The dimension of the
         
     | 
| 421 | 
         
            +
                        # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
         
     | 
| 422 | 
         
            +
                        # regular time embeddings
         
     | 
| 423 | 
         
            +
                        blocks_time_embed_dim = time_embed_dim * 2
         
     | 
| 424 | 
         
            +
                    else:
         
     | 
| 425 | 
         
            +
                        blocks_time_embed_dim = time_embed_dim
         
     | 
| 426 | 
         
            +
             
     | 
| 427 | 
         
            +
                    # down
         
     | 
| 428 | 
         
            +
                    output_channel = block_out_channels[0]
         
     | 
| 429 | 
         
            +
                    for i, down_block_type in enumerate(down_block_types):
         
     | 
| 430 | 
         
            +
                        input_channel = output_channel
         
     | 
| 431 | 
         
            +
                        output_channel = block_out_channels[i]
         
     | 
| 432 | 
         
            +
                        is_final_block = i == len(block_out_channels) - 1
         
     | 
| 433 | 
         
            +
             
     | 
| 434 | 
         
            +
                        down_block = get_down_block(
         
     | 
| 435 | 
         
            +
                            down_block_type,
         
     | 
| 436 | 
         
            +
                            num_layers=layers_per_block[i],
         
     | 
| 437 | 
         
            +
                            transformer_layers_per_block=transformer_layers_per_block[i],
         
     | 
| 438 | 
         
            +
                            in_channels=input_channel,
         
     | 
| 439 | 
         
            +
                            out_channels=output_channel,
         
     | 
| 440 | 
         
            +
                            temb_channels=blocks_time_embed_dim,
         
     | 
| 441 | 
         
            +
                            add_downsample=not is_final_block,
         
     | 
| 442 | 
         
            +
                            resnet_eps=norm_eps,
         
     | 
| 443 | 
         
            +
                            resnet_act_fn=act_fn,
         
     | 
| 444 | 
         
            +
                            resnet_groups=norm_num_groups,
         
     | 
| 445 | 
         
            +
                            cross_attention_dim=cross_attention_dim[i],
         
     | 
| 446 | 
         
            +
                            num_attention_heads=num_attention_heads[i],
         
     | 
| 447 | 
         
            +
                            downsample_padding=downsample_padding,
         
     | 
| 448 | 
         
            +
                            dual_cross_attention=dual_cross_attention,
         
     | 
| 449 | 
         
            +
                            use_linear_projection=use_linear_projection,
         
     | 
| 450 | 
         
            +
                            only_cross_attention=only_cross_attention[i],
         
     | 
| 451 | 
         
            +
                            upcast_attention=upcast_attention,
         
     | 
| 452 | 
         
            +
                            resnet_time_scale_shift=resnet_time_scale_shift,
         
     | 
| 453 | 
         
            +
                            resnet_skip_time_act=resnet_skip_time_act,
         
     | 
| 454 | 
         
            +
                            resnet_out_scale_factor=resnet_out_scale_factor,
         
     | 
| 455 | 
         
            +
                            cross_attention_norm=cross_attention_norm,
         
     | 
| 456 | 
         
            +
                            attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
         
     | 
| 457 | 
         
            +
                        )
         
     | 
| 458 | 
         
            +
                        self.down_blocks.append(down_block)
         
     | 
| 459 | 
         
            +
             
     | 
| 460 | 
         
            +
                    # mid
         
     | 
| 461 | 
         
            +
                    if mid_block_type == "UNetMidBlock2DCrossAttn":
         
     | 
| 462 | 
         
            +
                        self.mid_block = UNetMidBlock2DCrossAttn(
         
     | 
| 463 | 
         
            +
                            transformer_layers_per_block=transformer_layers_per_block[-1],
         
     | 
| 464 | 
         
            +
                            in_channels=block_out_channels[-1],
         
     | 
| 465 | 
         
            +
                            temb_channels=blocks_time_embed_dim,
         
     | 
| 466 | 
         
            +
                            resnet_eps=norm_eps,
         
     | 
| 467 | 
         
            +
                            resnet_act_fn=act_fn,
         
     | 
| 468 | 
         
            +
                            output_scale_factor=mid_block_scale_factor,
         
     | 
| 469 | 
         
            +
                            resnet_time_scale_shift=resnet_time_scale_shift,
         
     | 
| 470 | 
         
            +
                            cross_attention_dim=cross_attention_dim[-1],
         
     | 
| 471 | 
         
            +
                            num_attention_heads=num_attention_heads[-1],
         
     | 
| 472 | 
         
            +
                            resnet_groups=norm_num_groups,
         
     | 
| 473 | 
         
            +
                            dual_cross_attention=dual_cross_attention,
         
     | 
| 474 | 
         
            +
                            use_linear_projection=use_linear_projection,
         
     | 
| 475 | 
         
            +
                            upcast_attention=upcast_attention,
         
     | 
| 476 | 
         
            +
                        )
         
     | 
| 477 | 
         
            +
                    elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
         
     | 
| 478 | 
         
            +
                        self.mid_block = UNetMidBlock2DSimpleCrossAttn(
         
     | 
| 479 | 
         
            +
                            in_channels=block_out_channels[-1],
         
     | 
| 480 | 
         
            +
                            temb_channels=blocks_time_embed_dim,
         
     | 
| 481 | 
         
            +
                            resnet_eps=norm_eps,
         
     | 
| 482 | 
         
            +
                            resnet_act_fn=act_fn,
         
     | 
| 483 | 
         
            +
                            output_scale_factor=mid_block_scale_factor,
         
     | 
| 484 | 
         
            +
                            cross_attention_dim=cross_attention_dim[-1],
         
     | 
| 485 | 
         
            +
                            attention_head_dim=attention_head_dim[-1],
         
     | 
| 486 | 
         
            +
                            resnet_groups=norm_num_groups,
         
     | 
| 487 | 
         
            +
                            resnet_time_scale_shift=resnet_time_scale_shift,
         
     | 
| 488 | 
         
            +
                            skip_time_act=resnet_skip_time_act,
         
     | 
| 489 | 
         
            +
                            only_cross_attention=mid_block_only_cross_attention,
         
     | 
| 490 | 
         
            +
                            cross_attention_norm=cross_attention_norm,
         
     | 
| 491 | 
         
            +
                        )
         
     | 
| 492 | 
         
            +
                    elif mid_block_type is None:
         
     | 
| 493 | 
         
            +
                        self.mid_block = None
         
     | 
| 494 | 
         
            +
                    else:
         
     | 
| 495 | 
         
            +
                        raise ValueError(f"unknown mid_block_type : {mid_block_type}")
         
     | 
| 496 | 
         
            +
             
     | 
| 497 | 
         
            +
                    # count how many layers upsample the images
         
     | 
| 498 | 
         
            +
                    self.num_upsamplers = 0
         
     | 
| 499 | 
         
            +
             
     | 
| 500 | 
         
            +
                    # up
         
     | 
| 501 | 
         
            +
                    reversed_block_out_channels = list(reversed(block_out_channels))
         
     | 
| 502 | 
         
            +
                    reversed_num_attention_heads = list(reversed(num_attention_heads))
         
     | 
| 503 | 
         
            +
                    reversed_layers_per_block = list(reversed(layers_per_block))
         
     | 
| 504 | 
         
            +
                    reversed_cross_attention_dim = list(reversed(cross_attention_dim))
         
     | 
| 505 | 
         
            +
                    reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
         
     | 
| 506 | 
         
            +
                    only_cross_attention = list(reversed(only_cross_attention))
         
     | 
| 507 | 
         
            +
             
     | 
| 508 | 
         
            +
                    output_channel = reversed_block_out_channels[0]
         
     | 
| 509 | 
         
            +
                    for i, up_block_type in enumerate(up_block_types):
         
     | 
| 510 | 
         
            +
                        is_final_block = i == len(block_out_channels) - 1
         
     | 
| 511 | 
         
            +
             
     | 
| 512 | 
         
            +
                        prev_output_channel = output_channel
         
     | 
| 513 | 
         
            +
                        output_channel = reversed_block_out_channels[i]
         
     | 
| 514 | 
         
            +
                        input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
         
     | 
| 515 | 
         
            +
             
     | 
| 516 | 
         
            +
                        # add upsample block for all BUT final layer
         
     | 
| 517 | 
         
            +
                        if not is_final_block:
         
     | 
| 518 | 
         
            +
                            add_upsample = True
         
     | 
| 519 | 
         
            +
                            self.num_upsamplers += 1
         
     | 
| 520 | 
         
            +
                        else:
         
     | 
| 521 | 
         
            +
                            add_upsample = False
         
     | 
| 522 | 
         
            +
             
     | 
| 523 | 
         
            +
                        up_block = get_up_block(
         
     | 
| 524 | 
         
            +
                            up_block_type,
         
     | 
| 525 | 
         
            +
                            num_layers=reversed_layers_per_block[i] + 1,
         
     | 
| 526 | 
         
            +
                            transformer_layers_per_block=reversed_transformer_layers_per_block[i],
         
     | 
| 527 | 
         
            +
                            in_channels=input_channel,
         
     | 
| 528 | 
         
            +
                            out_channels=output_channel,
         
     | 
| 529 | 
         
            +
                            prev_output_channel=prev_output_channel,
         
     | 
| 530 | 
         
            +
                            temb_channels=blocks_time_embed_dim,
         
     | 
| 531 | 
         
            +
                            add_upsample=add_upsample,
         
     | 
| 532 | 
         
            +
                            resnet_eps=norm_eps,
         
     | 
| 533 | 
         
            +
                            resnet_act_fn=act_fn,
         
     | 
| 534 | 
         
            +
                            resnet_groups=norm_num_groups,
         
     | 
| 535 | 
         
            +
                            cross_attention_dim=reversed_cross_attention_dim[i],
         
     | 
| 536 | 
         
            +
                            num_attention_heads=reversed_num_attention_heads[i],
         
     | 
| 537 | 
         
            +
                            dual_cross_attention=dual_cross_attention,
         
     | 
| 538 | 
         
            +
                            use_linear_projection=use_linear_projection,
         
     | 
| 539 | 
         
            +
                            only_cross_attention=only_cross_attention[i],
         
     | 
| 540 | 
         
            +
                            upcast_attention=upcast_attention,
         
     | 
| 541 | 
         
            +
                            resnet_time_scale_shift=resnet_time_scale_shift,
         
     | 
| 542 | 
         
            +
                            resnet_skip_time_act=resnet_skip_time_act,
         
     | 
| 543 | 
         
            +
                            resnet_out_scale_factor=resnet_out_scale_factor,
         
     | 
| 544 | 
         
            +
                            cross_attention_norm=cross_attention_norm,
         
     | 
| 545 | 
         
            +
                            attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
         
     | 
| 546 | 
         
            +
                        )
         
     | 
| 547 | 
         
            +
                        self.up_blocks.append(up_block)
         
     | 
| 548 | 
         
            +
                        prev_output_channel = output_channel
         
     | 
| 549 | 
         
            +
             
     | 
| 550 | 
         
            +
                    # out
         
     | 
| 551 | 
         
            +
                    if norm_num_groups is not None:
         
     | 
| 552 | 
         
            +
                        self.conv_norm_out = nn.GroupNorm(
         
     | 
| 553 | 
         
            +
                            num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
         
     | 
| 554 | 
         
            +
                        )
         
     | 
| 555 | 
         
            +
             
     | 
| 556 | 
         
            +
                        self.conv_act = get_activation(act_fn)
         
     | 
| 557 | 
         
            +
             
     | 
| 558 | 
         
            +
                    else:
         
     | 
| 559 | 
         
            +
                        self.conv_norm_out = None
         
     | 
| 560 | 
         
            +
                        self.conv_act = None
         
     | 
| 561 | 
         
            +
             
     | 
| 562 | 
         
            +
                    conv_out_padding = (conv_out_kernel - 1) // 2
         
     | 
| 563 | 
         
            +
                    self.conv_out = nn.Conv2d(
         
     | 
| 564 | 
         
            +
                        block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
         
     | 
| 565 | 
         
            +
                    )
         
     | 
| 566 | 
         
            +
             
     | 
| 567 | 
         
            +
                @property
         
     | 
| 568 | 
         
            +
                def attn_processors(self) -> Dict[str, AttentionProcessor]:
         
     | 
| 569 | 
         
            +
                    r"""
         
     | 
| 570 | 
         
            +
                    Returns:
         
     | 
| 571 | 
         
            +
                        `dict` of attention processors: A dictionary containing all attention processors used in the model with
         
     | 
| 572 | 
         
            +
                        indexed by its weight name.
         
     | 
| 573 | 
         
            +
                    """
         
     | 
| 574 | 
         
            +
                    # set recursively
         
     | 
| 575 | 
         
            +
                    processors = {}
         
     | 
| 576 | 
         
            +
             
     | 
| 577 | 
         
            +
                    def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
         
     | 
| 578 | 
         
            +
                        if hasattr(module, "set_processor"):
         
     | 
| 579 | 
         
            +
                            processors[f"{name}.processor"] = module.processor
         
     | 
| 580 | 
         
            +
             
     | 
| 581 | 
         
            +
                        for sub_name, child in module.named_children():
         
     | 
| 582 | 
         
            +
                            fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
         
     | 
| 583 | 
         
            +
             
     | 
| 584 | 
         
            +
                        return processors
         
     | 
| 585 | 
         
            +
             
     | 
| 586 | 
         
            +
                    for name, module in self.named_children():
         
     | 
| 587 | 
         
            +
                        fn_recursive_add_processors(name, module, processors)
         
     | 
| 588 | 
         
            +
             
     | 
| 589 | 
         
            +
                    return processors
         
     | 
| 590 | 
         
            +
             
     | 
| 591 | 
         
            +
                def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
         
     | 
| 592 | 
         
            +
                    r"""
         
     | 
| 593 | 
         
            +
                    Sets the attention processor to use to compute attention.
         
     | 
| 594 | 
         
            +
             
     | 
| 595 | 
         
            +
                    Parameters:
         
     | 
| 596 | 
         
            +
                        processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
         
     | 
| 597 | 
         
            +
                            The instantiated processor class or a dictionary of processor classes that will be set as the processor
         
     | 
| 598 | 
         
            +
                            for **all** `Attention` layers.
         
     | 
| 599 | 
         
            +
             
     | 
| 600 | 
         
            +
                            If `processor` is a dict, the key needs to define the path to the corresponding cross attention
         
     | 
| 601 | 
         
            +
                            processor. This is strongly recommended when setting trainable attention processors.
         
     | 
| 602 | 
         
            +
             
     | 
| 603 | 
         
            +
                    """
         
     | 
| 604 | 
         
            +
                    count = len(self.attn_processors.keys())
         
     | 
| 605 | 
         
            +
             
     | 
| 606 | 
         
            +
                    if isinstance(processor, dict) and len(processor) != count:
         
     | 
| 607 | 
         
            +
                        raise ValueError(
         
     | 
| 608 | 
         
            +
                            f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
         
     | 
| 609 | 
         
            +
                            f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
         
     | 
| 610 | 
         
            +
                        )
         
     | 
| 611 | 
         
            +
             
     | 
| 612 | 
         
            +
                    def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
         
     | 
| 613 | 
         
            +
                        if hasattr(module, "set_processor"):
         
     | 
| 614 | 
         
            +
                            if not isinstance(processor, dict):
         
     | 
| 615 | 
         
            +
                                module.set_processor(processor)
         
     | 
| 616 | 
         
            +
                            else:
         
     | 
| 617 | 
         
            +
                                module.set_processor(processor.pop(f"{name}.processor"))
         
     | 
| 618 | 
         
            +
             
     | 
| 619 | 
         
            +
                        for sub_name, child in module.named_children():
         
     | 
| 620 | 
         
            +
                            fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
         
     | 
| 621 | 
         
            +
             
     | 
| 622 | 
         
            +
                    for name, module in self.named_children():
         
     | 
| 623 | 
         
            +
                        fn_recursive_attn_processor(name, module, processor)
         
     | 
| 624 | 
         
            +
             
     | 
| 625 | 
         
            +
                def set_default_attn_processor(self):
         
     | 
| 626 | 
         
            +
                    """
         
     | 
| 627 | 
         
            +
                    Disables custom attention processors and sets the default attention implementation.
         
     | 
| 628 | 
         
            +
                    """
         
     | 
| 629 | 
         
            +
                    self.set_attn_processor(AttnProcessor())
         
     | 
| 630 | 
         
            +
             
     | 
| 631 | 
         
            +
                def set_attention_slice(self, slice_size):
         
     | 
| 632 | 
         
            +
                    r"""
         
     | 
| 633 | 
         
            +
                    Enable sliced attention computation.
         
     | 
| 634 | 
         
            +
             
     | 
| 635 | 
         
            +
                    When this option is enabled, the attention module splits the input tensor in slices to compute attention in
         
     | 
| 636 | 
         
            +
                    several steps. This is useful for saving some memory in exchange for a small decrease in speed.
         
     | 
| 637 | 
         
            +
             
     | 
| 638 | 
         
            +
                    Args:
         
     | 
| 639 | 
         
            +
                        slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
         
     | 
| 640 | 
         
            +
                            When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
         
     | 
| 641 | 
         
            +
                            `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
         
     | 
| 642 | 
         
            +
                            provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
         
     | 
| 643 | 
         
            +
                            must be a multiple of `slice_size`.
         
     | 
| 644 | 
         
            +
                    """
         
     | 
| 645 | 
         
            +
                    sliceable_head_dims = []
         
     | 
| 646 | 
         
            +
             
     | 
| 647 | 
         
            +
                    def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
         
     | 
| 648 | 
         
            +
                        if hasattr(module, "set_attention_slice"):
         
     | 
| 649 | 
         
            +
                            sliceable_head_dims.append(module.sliceable_head_dim)
         
     | 
| 650 | 
         
            +
             
     | 
| 651 | 
         
            +
                        for child in module.children():
         
     | 
| 652 | 
         
            +
                            fn_recursive_retrieve_sliceable_dims(child)
         
     | 
| 653 | 
         
            +
             
     | 
| 654 | 
         
            +
                    # retrieve number of attention layers
         
     | 
| 655 | 
         
            +
                    for module in self.children():
         
     | 
| 656 | 
         
            +
                        fn_recursive_retrieve_sliceable_dims(module)
         
     | 
| 657 | 
         
            +
             
     | 
| 658 | 
         
            +
                    num_sliceable_layers = len(sliceable_head_dims)
         
     | 
| 659 | 
         
            +
             
     | 
| 660 | 
         
            +
                    if slice_size == "auto":
         
     | 
| 661 | 
         
            +
                        # half the attention head size is usually a good trade-off between
         
     | 
| 662 | 
         
            +
                        # speed and memory
         
     | 
| 663 | 
         
            +
                        slice_size = [dim // 2 for dim in sliceable_head_dims]
         
     | 
| 664 | 
         
            +
                    elif slice_size == "max":
         
     | 
| 665 | 
         
            +
                        # make smallest slice possible
         
     | 
| 666 | 
         
            +
                        slice_size = num_sliceable_layers * [1]
         
     | 
| 667 | 
         
            +
             
     | 
| 668 | 
         
            +
                    slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
         
     | 
| 669 | 
         
            +
             
     | 
| 670 | 
         
            +
                    if len(slice_size) != len(sliceable_head_dims):
         
     | 
| 671 | 
         
            +
                        raise ValueError(
         
     | 
| 672 | 
         
            +
                            f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
         
     | 
| 673 | 
         
            +
                            f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
         
     | 
| 674 | 
         
            +
                        )
         
     | 
| 675 | 
         
            +
             
     | 
| 676 | 
         
            +
                    for i in range(len(slice_size)):
         
     | 
| 677 | 
         
            +
                        size = slice_size[i]
         
     | 
| 678 | 
         
            +
                        dim = sliceable_head_dims[i]
         
     | 
| 679 | 
         
            +
                        if size is not None and size > dim:
         
     | 
| 680 | 
         
            +
                            raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
         
     | 
| 681 | 
         
            +
             
     | 
| 682 | 
         
            +
                    # Recursively walk through all the children.
         
     | 
| 683 | 
         
            +
                    # Any children which exposes the set_attention_slice method
         
     | 
| 684 | 
         
            +
                    # gets the message
         
     | 
| 685 | 
         
            +
                    def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
         
     | 
| 686 | 
         
            +
                        if hasattr(module, "set_attention_slice"):
         
     | 
| 687 | 
         
            +
                            module.set_attention_slice(slice_size.pop())
         
     | 
| 688 | 
         
            +
             
     | 
| 689 | 
         
            +
                        for child in module.children():
         
     | 
| 690 | 
         
            +
                            fn_recursive_set_attention_slice(child, slice_size)
         
     | 
| 691 | 
         
            +
             
     | 
| 692 | 
         
            +
                    reversed_slice_size = list(reversed(slice_size))
         
     | 
| 693 | 
         
            +
                    for module in self.children():
         
     | 
| 694 | 
         
            +
                        fn_recursive_set_attention_slice(module, reversed_slice_size)
         
     | 
| 695 | 
         
            +
             
     | 
| 696 | 
         
            +
                def _set_gradient_checkpointing(self, module, value=False):
         
     | 
| 697 | 
         
            +
                    if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
         
     | 
| 698 | 
         
            +
                        module.gradient_checkpointing = value
         
     | 
| 699 | 
         
            +
             
     | 
| 700 | 
         
            +
                def forward(
         
     | 
| 701 | 
         
            +
                    self,
         
     | 
| 702 | 
         
            +
                    sample: torch.FloatTensor,
         
     | 
| 703 | 
         
            +
                    timestep: Union[torch.Tensor, float, int],
         
     | 
| 704 | 
         
            +
                    encoder_hidden_states: torch.Tensor,
         
     | 
| 705 | 
         
            +
                    class_labels: Optional[torch.Tensor] = None,
         
     | 
| 706 | 
         
            +
                    timestep_cond: Optional[torch.Tensor] = None,
         
     | 
| 707 | 
         
            +
                    attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 708 | 
         
            +
                    cross_attention_kwargs: Optional[Dict[str, Any]] = None,
         
     | 
| 709 | 
         
            +
                    added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
         
     | 
| 710 | 
         
            +
                    down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
         
     | 
| 711 | 
         
            +
                    mid_block_additional_residual: Optional[torch.Tensor] = None,
         
     | 
| 712 | 
         
            +
                    encoder_attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 713 | 
         
            +
                    return_dict: bool = True,
         
     | 
| 714 | 
         
            +
                ) -> Union[UNet2DConditionOutput, Tuple]:
         
     | 
| 715 | 
         
            +
                    r"""
         
     | 
| 716 | 
         
            +
                    The [`UNet2DConditionModel`] forward method.
         
     | 
| 717 | 
         
            +
             
     | 
| 718 | 
         
            +
                    Args:
         
     | 
| 719 | 
         
            +
                        sample (`torch.FloatTensor`):
         
     | 
| 720 | 
         
            +
                            The noisy input tensor with the following shape `(batch, channel, height, width)`.
         
     | 
| 721 | 
         
            +
                        timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
         
     | 
| 722 | 
         
            +
                        encoder_hidden_states (`torch.FloatTensor`):
         
     | 
| 723 | 
         
            +
                            The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
         
     | 
| 724 | 
         
            +
                        encoder_attention_mask (`torch.Tensor`):
         
     | 
| 725 | 
         
            +
                            A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
         
     | 
| 726 | 
         
            +
                            `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
         
     | 
| 727 | 
         
            +
                            which adds large negative values to the attention scores corresponding to "discard" tokens.
         
     | 
| 728 | 
         
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         
     | 
| 729 | 
         
            +
                            Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
         
     | 
| 730 | 
         
            +
                            tuple.
         
     | 
| 731 | 
         
            +
                        cross_attention_kwargs (`dict`, *optional*):
         
     | 
| 732 | 
         
            +
                            A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
         
     | 
| 733 | 
         
            +
                        added_cond_kwargs: (`dict`, *optional*):
         
     | 
| 734 | 
         
            +
                            A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
         
     | 
| 735 | 
         
            +
                            are passed along to the UNet blocks.
         
     | 
| 736 | 
         
            +
             
     | 
| 737 | 
         
            +
                    Returns:
         
     | 
| 738 | 
         
            +
                        [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
         
     | 
| 739 | 
         
            +
                            If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
         
     | 
| 740 | 
         
            +
                            a `tuple` is returned where the first element is the sample tensor.
         
     | 
| 741 | 
         
            +
                    """
         
     | 
| 742 | 
         
            +
                    # By default samples have to be AT least a multiple of the overall upsampling factor.
         
     | 
| 743 | 
         
            +
                    # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
         
     | 
| 744 | 
         
            +
                    # However, the upsampling interpolation output size can be forced to fit any upsampling size
         
     | 
| 745 | 
         
            +
                    # on the fly if necessary.
         
     | 
| 746 | 
         
            +
                    default_overall_up_factor = 2**self.num_upsamplers
         
     | 
| 747 | 
         
            +
             
     | 
| 748 | 
         
            +
                    # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
         
     | 
| 749 | 
         
            +
                    forward_upsample_size = False
         
     | 
| 750 | 
         
            +
                    upsample_size = None
         
     | 
| 751 | 
         
            +
             
     | 
| 752 | 
         
            +
                    if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
         
     | 
| 753 | 
         
            +
                        logger.info("Forward upsample size to force interpolation output size.")
         
     | 
| 754 | 
         
            +
                        forward_upsample_size = True
         
     | 
| 755 | 
         
            +
             
     | 
| 756 | 
         
            +
                    # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
         
     | 
| 757 | 
         
            +
                    # expects mask of shape:
         
     | 
| 758 | 
         
            +
                    #   [batch, key_tokens]
         
     | 
| 759 | 
         
            +
                    # adds singleton query_tokens dimension:
         
     | 
| 760 | 
         
            +
                    #   [batch,                    1, key_tokens]
         
     | 
| 761 | 
         
            +
                    # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
         
     | 
| 762 | 
         
            +
                    #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)
         
     | 
| 763 | 
         
            +
                    #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
         
     | 
| 764 | 
         
            +
                    if attention_mask is not None:
         
     | 
| 765 | 
         
            +
                        # assume that mask is expressed as:
         
     | 
| 766 | 
         
            +
                        #   (1 = keep,      0 = discard)
         
     | 
| 767 | 
         
            +
                        # convert mask into a bias that can be added to attention scores:
         
     | 
| 768 | 
         
            +
                        #       (keep = +0,     discard = -10000.0)
         
     | 
| 769 | 
         
            +
                        attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
         
     | 
| 770 | 
         
            +
                        attention_mask = attention_mask.unsqueeze(1)
         
     | 
| 771 | 
         
            +
             
     | 
| 772 | 
         
            +
                    # convert encoder_attention_mask to a bias the same way we do for attention_mask
         
     | 
| 773 | 
         
            +
                    if encoder_attention_mask is not None:
         
     | 
| 774 | 
         
            +
                        encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
         
     | 
| 775 | 
         
            +
                        encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
         
     | 
| 776 | 
         
            +
             
     | 
| 777 | 
         
            +
                    # 0. center input if necessary
         
     | 
| 778 | 
         
            +
                    if self.config.center_input_sample:
         
     | 
| 779 | 
         
            +
                        sample = 2 * sample - 1.0
         
     | 
| 780 | 
         
            +
             
     | 
| 781 | 
         
            +
                    # 1. time
         
     | 
| 782 | 
         
            +
                    timesteps = timestep
         
     | 
| 783 | 
         
            +
                    if not torch.is_tensor(timesteps):
         
     | 
| 784 | 
         
            +
                        # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
         
     | 
| 785 | 
         
            +
                        # This would be a good case for the `match` statement (Python 3.10+)
         
     | 
| 786 | 
         
            +
                        is_mps = sample.device.type == "mps"
         
     | 
| 787 | 
         
            +
                        if isinstance(timestep, float):
         
     | 
| 788 | 
         
            +
                            dtype = torch.float32 if is_mps else torch.float64
         
     | 
| 789 | 
         
            +
                        else:
         
     | 
| 790 | 
         
            +
                            dtype = torch.int32 if is_mps else torch.int64
         
     | 
| 791 | 
         
            +
                        timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
         
     | 
| 792 | 
         
            +
                    elif len(timesteps.shape) == 0:
         
     | 
| 793 | 
         
            +
                        timesteps = timesteps[None].to(sample.device)
         
     | 
| 794 | 
         
            +
             
     | 
| 795 | 
         
            +
                    # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
         
     | 
| 796 | 
         
            +
                    timesteps = timesteps.expand(sample.shape[0])
         
     | 
| 797 | 
         
            +
             
     | 
| 798 | 
         
            +
                    t_emb = self.time_proj(timesteps)
         
     | 
| 799 | 
         
            +
             
     | 
| 800 | 
         
            +
                    # `Timesteps` does not contain any weights and will always return f32 tensors
         
     | 
| 801 | 
         
            +
                    # but time_embedding might actually be running in fp16. so we need to cast here.
         
     | 
| 802 | 
         
            +
                    # there might be better ways to encapsulate this.
         
     | 
| 803 | 
         
            +
                    t_emb = t_emb.to(dtype=sample.dtype)
         
     | 
| 804 | 
         
            +
             
     | 
| 805 | 
         
            +
                    emb = self.time_embedding(t_emb, timestep_cond)
         
     | 
| 806 | 
         
            +
                    aug_emb = None
         
     | 
| 807 | 
         
            +
             
     | 
| 808 | 
         
            +
                    if self.class_embedding is not None:
         
     | 
| 809 | 
         
            +
                        if class_labels is None:
         
     | 
| 810 | 
         
            +
                            raise ValueError("class_labels should be provided when num_class_embeds > 0")
         
     | 
| 811 | 
         
            +
             
     | 
| 812 | 
         
            +
                        if self.config.class_embed_type == "timestep":
         
     | 
| 813 | 
         
            +
                            class_labels = self.time_proj(class_labels)
         
     | 
| 814 | 
         
            +
             
     | 
| 815 | 
         
            +
                            # `Timesteps` does not contain any weights and will always return f32 tensors
         
     | 
| 816 | 
         
            +
                            # there might be better ways to encapsulate this.
         
     | 
| 817 | 
         
            +
                            class_labels = class_labels.to(dtype=sample.dtype)
         
     | 
| 818 | 
         
            +
             
     | 
| 819 | 
         
            +
                        class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
         
     | 
| 820 | 
         
            +
             
     | 
| 821 | 
         
            +
                        if self.config.class_embeddings_concat:
         
     | 
| 822 | 
         
            +
                            emb = torch.cat([emb, class_emb], dim=-1)
         
     | 
| 823 | 
         
            +
                        else:
         
     | 
| 824 | 
         
            +
                            emb = emb + class_emb
         
     | 
| 825 | 
         
            +
             
     | 
| 826 | 
         
            +
                    if self.config.addition_embed_type == "text":
         
     | 
| 827 | 
         
            +
                        aug_emb = self.add_embedding(encoder_hidden_states)
         
     | 
| 828 | 
         
            +
                    elif self.config.addition_embed_type == "text_image":
         
     | 
| 829 | 
         
            +
                        # Kandinsky 2.1 - style
         
     | 
| 830 | 
         
            +
                        if "image_embeds" not in added_cond_kwargs:
         
     | 
| 831 | 
         
            +
                            raise ValueError(
         
     | 
| 832 | 
         
            +
                                f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
         
     | 
| 833 | 
         
            +
                            )
         
     | 
| 834 | 
         
            +
             
     | 
| 835 | 
         
            +
                        image_embs = added_cond_kwargs.get("image_embeds")
         
     | 
| 836 | 
         
            +
                        text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
         
     | 
| 837 | 
         
            +
                        aug_emb = self.add_embedding(text_embs, image_embs)
         
     | 
| 838 | 
         
            +
                    elif self.config.addition_embed_type == "text_time":
         
     | 
| 839 | 
         
            +
                        if "text_embeds" not in added_cond_kwargs:
         
     | 
| 840 | 
         
            +
                            raise ValueError(
         
     | 
| 841 | 
         
            +
                                f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
         
     | 
| 842 | 
         
            +
                            )
         
     | 
| 843 | 
         
            +
                        text_embeds = added_cond_kwargs.get("text_embeds")
         
     | 
| 844 | 
         
            +
                        if "time_ids" not in added_cond_kwargs:
         
     | 
| 845 | 
         
            +
                            raise ValueError(
         
     | 
| 846 | 
         
            +
                                f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
         
     | 
| 847 | 
         
            +
                            )
         
     | 
| 848 | 
         
            +
                        time_ids = added_cond_kwargs.get("time_ids")
         
     | 
| 849 | 
         
            +
                        time_embeds = self.add_time_proj(time_ids.flatten())
         
     | 
| 850 | 
         
            +
                        time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
         
     | 
| 851 | 
         
            +
             
     | 
| 852 | 
         
            +
                        add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
         
     | 
| 853 | 
         
            +
                        add_embeds = add_embeds.to(emb.dtype)
         
     | 
| 854 | 
         
            +
                        aug_emb = self.add_embedding(add_embeds)
         
     | 
| 855 | 
         
            +
                    elif self.config.addition_embed_type == "image":
         
     | 
| 856 | 
         
            +
                        # Kandinsky 2.2 - style
         
     | 
| 857 | 
         
            +
                        if "image_embeds" not in added_cond_kwargs:
         
     | 
| 858 | 
         
            +
                            raise ValueError(
         
     | 
| 859 | 
         
            +
                                f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
         
     | 
| 860 | 
         
            +
                            )
         
     | 
| 861 | 
         
            +
                        image_embs = added_cond_kwargs.get("image_embeds")
         
     | 
| 862 | 
         
            +
                        aug_emb = self.add_embedding(image_embs)
         
     | 
| 863 | 
         
            +
                    elif self.config.addition_embed_type == "image_hint":
         
     | 
| 864 | 
         
            +
                        # Kandinsky 2.2 - style
         
     | 
| 865 | 
         
            +
                        if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
         
     | 
| 866 | 
         
            +
                            raise ValueError(
         
     | 
| 867 | 
         
            +
                                f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
         
     | 
| 868 | 
         
            +
                            )
         
     | 
| 869 | 
         
            +
                        image_embs = added_cond_kwargs.get("image_embeds")
         
     | 
| 870 | 
         
            +
                        hint = added_cond_kwargs.get("hint")
         
     | 
| 871 | 
         
            +
                        aug_emb, hint = self.add_embedding(image_embs, hint)
         
     | 
| 872 | 
         
            +
                        sample = torch.cat([sample, hint], dim=1)
         
     | 
| 873 | 
         
            +
             
     | 
| 874 | 
         
            +
                    emb = emb + aug_emb if aug_emb is not None else emb
         
     | 
| 875 | 
         
            +
             
     | 
| 876 | 
         
            +
                    if self.time_embed_act is not None:
         
     | 
| 877 | 
         
            +
                        emb = self.time_embed_act(emb)
         
     | 
| 878 | 
         
            +
             
     | 
| 879 | 
         
            +
                    if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
         
     | 
| 880 | 
         
            +
                        encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
         
     | 
| 881 | 
         
            +
                    elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
         
     | 
| 882 | 
         
            +
                        # Kadinsky 2.1 - style
         
     | 
| 883 | 
         
            +
                        if "image_embeds" not in added_cond_kwargs:
         
     | 
| 884 | 
         
            +
                            raise ValueError(
         
     | 
| 885 | 
         
            +
                                f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
         
     | 
| 886 | 
         
            +
                            )
         
     | 
| 887 | 
         
            +
             
     | 
| 888 | 
         
            +
                        image_embeds = added_cond_kwargs.get("image_embeds")
         
     | 
| 889 | 
         
            +
                        encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
         
     | 
| 890 | 
         
            +
                    elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
         
     | 
| 891 | 
         
            +
                        # Kandinsky 2.2 - style
         
     | 
| 892 | 
         
            +
                        if "image_embeds" not in added_cond_kwargs:
         
     | 
| 893 | 
         
            +
                            raise ValueError(
         
     | 
| 894 | 
         
            +
                                f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
         
     | 
| 895 | 
         
            +
                            )
         
     | 
| 896 | 
         
            +
                        image_embeds = added_cond_kwargs.get("image_embeds")
         
     | 
| 897 | 
         
            +
                        encoder_hidden_states = self.encoder_hid_proj(image_embeds)
         
     | 
| 898 | 
         
            +
                    # 2. pre-process
         
     | 
| 899 | 
         
            +
                    sample = self.conv_in(sample)
         
     | 
| 900 | 
         
            +
             
     | 
| 901 | 
         
            +
                    # 3. down
         
     | 
| 902 | 
         
            +
                    down_block_res_samples = (sample,)
         
     | 
| 903 | 
         
            +
                    for downsample_block in self.down_blocks:
         
     | 
| 904 | 
         
            +
                        if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
         
     | 
| 905 | 
         
            +
                            sample, res_samples = downsample_block(
         
     | 
| 906 | 
         
            +
                                hidden_states=sample,
         
     | 
| 907 | 
         
            +
                                temb=emb,
         
     | 
| 908 | 
         
            +
                                encoder_hidden_states=encoder_hidden_states,
         
     | 
| 909 | 
         
            +
                                attention_mask=attention_mask,
         
     | 
| 910 | 
         
            +
                                cross_attention_kwargs=cross_attention_kwargs,
         
     | 
| 911 | 
         
            +
                                encoder_attention_mask=encoder_attention_mask,
         
     | 
| 912 | 
         
            +
                            )
         
     | 
| 913 | 
         
            +
                        else:
         
     | 
| 914 | 
         
            +
                            sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
         
     | 
| 915 | 
         
            +
             
     | 
| 916 | 
         
            +
                        down_block_res_samples += res_samples
         
     | 
| 917 | 
         
            +
             
     | 
| 918 | 
         
            +
                    if down_block_additional_residuals is not None:
         
     | 
| 919 | 
         
            +
                        new_down_block_res_samples = ()
         
     | 
| 920 | 
         
            +
             
     | 
| 921 | 
         
            +
                        for down_block_res_sample, down_block_additional_residual in zip(
         
     | 
| 922 | 
         
            +
                            down_block_res_samples, down_block_additional_residuals
         
     | 
| 923 | 
         
            +
                        ):
         
     | 
| 924 | 
         
            +
                            down_block_res_sample = down_block_res_sample + down_block_additional_residual
         
     | 
| 925 | 
         
            +
                            new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
         
     | 
| 926 | 
         
            +
             
     | 
| 927 | 
         
            +
                        down_block_res_samples = new_down_block_res_samples
         
     | 
| 928 | 
         
            +
             
     | 
| 929 | 
         
            +
                    # 4. mid
         
     | 
| 930 | 
         
            +
                    if self.mid_block is not None:
         
     | 
| 931 | 
         
            +
                        sample = self.mid_block(
         
     | 
| 932 | 
         
            +
                            sample,
         
     | 
| 933 | 
         
            +
                            emb,
         
     | 
| 934 | 
         
            +
                            encoder_hidden_states=encoder_hidden_states,
         
     | 
| 935 | 
         
            +
                            attention_mask=attention_mask,
         
     | 
| 936 | 
         
            +
                            cross_attention_kwargs=cross_attention_kwargs,
         
     | 
| 937 | 
         
            +
                            encoder_attention_mask=encoder_attention_mask,
         
     | 
| 938 | 
         
            +
                        )
         
     | 
| 939 | 
         
            +
             
     | 
| 940 | 
         
            +
                    if mid_block_additional_residual is not None:
         
     | 
| 941 | 
         
            +
                        sample = sample + mid_block_additional_residual
         
     | 
| 942 | 
         
            +
             
     | 
| 943 | 
         
            +
                    # 5. up
         
     | 
| 944 | 
         
            +
                    for i, upsample_block in enumerate(self.up_blocks):
         
     | 
| 945 | 
         
            +
                        is_final_block = i == len(self.up_blocks) - 1
         
     | 
| 946 | 
         
            +
             
     | 
| 947 | 
         
            +
                        res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
         
     | 
| 948 | 
         
            +
                        down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
         
     | 
| 949 | 
         
            +
             
     | 
| 950 | 
         
            +
                        # if we have not reached the final block and need to forward the
         
     | 
| 951 | 
         
            +
                        # upsample size, we do it here
         
     | 
| 952 | 
         
            +
                        if not is_final_block and forward_upsample_size:
         
     | 
| 953 | 
         
            +
                            upsample_size = down_block_res_samples[-1].shape[2:]
         
     | 
| 954 | 
         
            +
             
     | 
| 955 | 
         
            +
                        if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
         
     | 
| 956 | 
         
            +
                            sample = upsample_block(
         
     | 
| 957 | 
         
            +
                                hidden_states=sample,
         
     | 
| 958 | 
         
            +
                                temb=emb,
         
     | 
| 959 | 
         
            +
                                res_hidden_states_tuple=res_samples,
         
     | 
| 960 | 
         
            +
                                encoder_hidden_states=encoder_hidden_states,
         
     | 
| 961 | 
         
            +
                                cross_attention_kwargs=cross_attention_kwargs,
         
     | 
| 962 | 
         
            +
                                upsample_size=upsample_size,
         
     | 
| 963 | 
         
            +
                                attention_mask=attention_mask,
         
     | 
| 964 | 
         
            +
                                encoder_attention_mask=encoder_attention_mask,
         
     | 
| 965 | 
         
            +
                            )
         
     | 
| 966 | 
         
            +
                        else:
         
     | 
| 967 | 
         
            +
                            sample = upsample_block(
         
     | 
| 968 | 
         
            +
                                hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
         
     | 
| 969 | 
         
            +
                            )
         
     | 
| 970 | 
         
            +
             
     | 
| 971 | 
         
            +
                    # 6. post-process
         
     | 
| 972 | 
         
            +
                    if self.conv_norm_out:
         
     | 
| 973 | 
         
            +
                        sample = self.conv_norm_out(sample)
         
     | 
| 974 | 
         
            +
                        sample = self.conv_act(sample)
         
     | 
| 975 | 
         
            +
                    sample = self.conv_out(sample)
         
     | 
| 976 | 
         
            +
             
     | 
| 977 | 
         
            +
                    if not return_dict:
         
     | 
| 978 | 
         
            +
                        return (sample,)
         
     | 
| 979 | 
         
            +
             
     | 
| 980 | 
         
            +
                    return UNet2DConditionOutput(sample=sample)
         
     | 
    	
        6DoF/diffusers/models/unet_2d_condition_flax.py
    ADDED
    
    | 
         @@ -0,0 +1,357 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            from typing import Optional, Tuple, Union
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            import flax
         
     | 
| 17 | 
         
            +
            import flax.linen as nn
         
     | 
| 18 | 
         
            +
            import jax
         
     | 
| 19 | 
         
            +
            import jax.numpy as jnp
         
     | 
| 20 | 
         
            +
            from flax.core.frozen_dict import FrozenDict
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            from ..configuration_utils import ConfigMixin, flax_register_to_config
         
     | 
| 23 | 
         
            +
            from ..utils import BaseOutput
         
     | 
| 24 | 
         
            +
            from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
         
     | 
| 25 | 
         
            +
            from .modeling_flax_utils import FlaxModelMixin
         
     | 
| 26 | 
         
            +
            from .unet_2d_blocks_flax import (
         
     | 
| 27 | 
         
            +
                FlaxCrossAttnDownBlock2D,
         
     | 
| 28 | 
         
            +
                FlaxCrossAttnUpBlock2D,
         
     | 
| 29 | 
         
            +
                FlaxDownBlock2D,
         
     | 
| 30 | 
         
            +
                FlaxUNetMidBlock2DCrossAttn,
         
     | 
| 31 | 
         
            +
                FlaxUpBlock2D,
         
     | 
| 32 | 
         
            +
            )
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            @flax.struct.dataclass
         
     | 
| 36 | 
         
            +
            class FlaxUNet2DConditionOutput(BaseOutput):
         
     | 
| 37 | 
         
            +
                """
         
     | 
| 38 | 
         
            +
                The output of [`FlaxUNet2DConditionModel`].
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                Args:
         
     | 
| 41 | 
         
            +
                    sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
         
     | 
| 42 | 
         
            +
                        The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
         
     | 
| 43 | 
         
            +
                """
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                sample: jnp.ndarray
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            @flax_register_to_config
         
     | 
| 49 | 
         
            +
            class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
         
     | 
| 50 | 
         
            +
                r"""
         
     | 
| 51 | 
         
            +
                A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
         
     | 
| 52 | 
         
            +
                shaped output.
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods
         
     | 
| 55 | 
         
            +
                implemented for all models (such as downloading or saving).
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
         
     | 
| 58 | 
         
            +
                subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
         
     | 
| 59 | 
         
            +
                general usage and behavior.
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                Inherent JAX features such as the following are supported:
         
     | 
| 62 | 
         
            +
                - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
         
     | 
| 63 | 
         
            +
                - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
         
     | 
| 64 | 
         
            +
                - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
         
     | 
| 65 | 
         
            +
                - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                Parameters:
         
     | 
| 68 | 
         
            +
                    sample_size (`int`, *optional*):
         
     | 
| 69 | 
         
            +
                        The size of the input sample.
         
     | 
| 70 | 
         
            +
                    in_channels (`int`, *optional*, defaults to 4):
         
     | 
| 71 | 
         
            +
                        The number of channels in the input sample.
         
     | 
| 72 | 
         
            +
                    out_channels (`int`, *optional*, defaults to 4):
         
     | 
| 73 | 
         
            +
                        The number of channels in the output.
         
     | 
| 74 | 
         
            +
                    down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
         
     | 
| 75 | 
         
            +
                        The tuple of downsample blocks to use.
         
     | 
| 76 | 
         
            +
                    up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
         
     | 
| 77 | 
         
            +
                        The tuple of upsample blocks to use.
         
     | 
| 78 | 
         
            +
                    block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
         
     | 
| 79 | 
         
            +
                        The tuple of output channels for each block.
         
     | 
| 80 | 
         
            +
                    layers_per_block (`int`, *optional*, defaults to 2):
         
     | 
| 81 | 
         
            +
                        The number of layers per block.
         
     | 
| 82 | 
         
            +
                    attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
         
     | 
| 83 | 
         
            +
                        The dimension of the attention heads.
         
     | 
| 84 | 
         
            +
                    num_attention_heads (`int` or `Tuple[int]`, *optional*):
         
     | 
| 85 | 
         
            +
                        The number of attention heads.
         
     | 
| 86 | 
         
            +
                    cross_attention_dim (`int`, *optional*, defaults to 768):
         
     | 
| 87 | 
         
            +
                        The dimension of the cross attention features.
         
     | 
| 88 | 
         
            +
                    dropout (`float`, *optional*, defaults to 0):
         
     | 
| 89 | 
         
            +
                        Dropout probability for down, up and bottleneck blocks.
         
     | 
| 90 | 
         
            +
                    flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
         
     | 
| 91 | 
         
            +
                        Whether to flip the sin to cos in the time embedding.
         
     | 
| 92 | 
         
            +
                    freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
         
     | 
| 93 | 
         
            +
                    use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
         
     | 
| 94 | 
         
            +
                        Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682).
         
     | 
| 95 | 
         
            +
                """
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                sample_size: int = 32
         
     | 
| 98 | 
         
            +
                in_channels: int = 4
         
     | 
| 99 | 
         
            +
                out_channels: int = 4
         
     | 
| 100 | 
         
            +
                down_block_types: Tuple[str] = (
         
     | 
| 101 | 
         
            +
                    "CrossAttnDownBlock2D",
         
     | 
| 102 | 
         
            +
                    "CrossAttnDownBlock2D",
         
     | 
| 103 | 
         
            +
                    "CrossAttnDownBlock2D",
         
     | 
| 104 | 
         
            +
                    "DownBlock2D",
         
     | 
| 105 | 
         
            +
                )
         
     | 
| 106 | 
         
            +
                up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
         
     | 
| 107 | 
         
            +
                only_cross_attention: Union[bool, Tuple[bool]] = False
         
     | 
| 108 | 
         
            +
                block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
         
     | 
| 109 | 
         
            +
                layers_per_block: int = 2
         
     | 
| 110 | 
         
            +
                attention_head_dim: Union[int, Tuple[int]] = 8
         
     | 
| 111 | 
         
            +
                num_attention_heads: Optional[Union[int, Tuple[int]]] = None
         
     | 
| 112 | 
         
            +
                cross_attention_dim: int = 1280
         
     | 
| 113 | 
         
            +
                dropout: float = 0.0
         
     | 
| 114 | 
         
            +
                use_linear_projection: bool = False
         
     | 
| 115 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 116 | 
         
            +
                flip_sin_to_cos: bool = True
         
     | 
| 117 | 
         
            +
                freq_shift: int = 0
         
     | 
| 118 | 
         
            +
                use_memory_efficient_attention: bool = False
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
         
     | 
| 121 | 
         
            +
                    # init input tensors
         
     | 
| 122 | 
         
            +
                    sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
         
     | 
| 123 | 
         
            +
                    sample = jnp.zeros(sample_shape, dtype=jnp.float32)
         
     | 
| 124 | 
         
            +
                    timesteps = jnp.ones((1,), dtype=jnp.int32)
         
     | 
| 125 | 
         
            +
                    encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    params_rng, dropout_rng = jax.random.split(rng)
         
     | 
| 128 | 
         
            +
                    rngs = {"params": params_rng, "dropout": dropout_rng}
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                    return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"]
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                def setup(self):
         
     | 
| 133 | 
         
            +
                    block_out_channels = self.block_out_channels
         
     | 
| 134 | 
         
            +
                    time_embed_dim = block_out_channels[0] * 4
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                    if self.num_attention_heads is not None:
         
     | 
| 137 | 
         
            +
                        raise ValueError(
         
     | 
| 138 | 
         
            +
                            "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
         
     | 
| 139 | 
         
            +
                        )
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                    # If `num_attention_heads` is not defined (which is the case for most models)
         
     | 
| 142 | 
         
            +
                    # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
         
     | 
| 143 | 
         
            +
                    # The reason for this behavior is to correct for incorrectly named variables that were introduced
         
     | 
| 144 | 
         
            +
                    # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
         
     | 
| 145 | 
         
            +
                    # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
         
     | 
| 146 | 
         
            +
                    # which is why we correct for the naming here.
         
     | 
| 147 | 
         
            +
                    num_attention_heads = self.num_attention_heads or self.attention_head_dim
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                    # input
         
     | 
| 150 | 
         
            +
                    self.conv_in = nn.Conv(
         
     | 
| 151 | 
         
            +
                        block_out_channels[0],
         
     | 
| 152 | 
         
            +
                        kernel_size=(3, 3),
         
     | 
| 153 | 
         
            +
                        strides=(1, 1),
         
     | 
| 154 | 
         
            +
                        padding=((1, 1), (1, 1)),
         
     | 
| 155 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 156 | 
         
            +
                    )
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                    # time
         
     | 
| 159 | 
         
            +
                    self.time_proj = FlaxTimesteps(
         
     | 
| 160 | 
         
            +
                        block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift
         
     | 
| 161 | 
         
            +
                    )
         
     | 
| 162 | 
         
            +
                    self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                    only_cross_attention = self.only_cross_attention
         
     | 
| 165 | 
         
            +
                    if isinstance(only_cross_attention, bool):
         
     | 
| 166 | 
         
            +
                        only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                    if isinstance(num_attention_heads, int):
         
     | 
| 169 | 
         
            +
                        num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                    # down
         
     | 
| 172 | 
         
            +
                    down_blocks = []
         
     | 
| 173 | 
         
            +
                    output_channel = block_out_channels[0]
         
     | 
| 174 | 
         
            +
                    for i, down_block_type in enumerate(self.down_block_types):
         
     | 
| 175 | 
         
            +
                        input_channel = output_channel
         
     | 
| 176 | 
         
            +
                        output_channel = block_out_channels[i]
         
     | 
| 177 | 
         
            +
                        is_final_block = i == len(block_out_channels) - 1
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                        if down_block_type == "CrossAttnDownBlock2D":
         
     | 
| 180 | 
         
            +
                            down_block = FlaxCrossAttnDownBlock2D(
         
     | 
| 181 | 
         
            +
                                in_channels=input_channel,
         
     | 
| 182 | 
         
            +
                                out_channels=output_channel,
         
     | 
| 183 | 
         
            +
                                dropout=self.dropout,
         
     | 
| 184 | 
         
            +
                                num_layers=self.layers_per_block,
         
     | 
| 185 | 
         
            +
                                num_attention_heads=num_attention_heads[i],
         
     | 
| 186 | 
         
            +
                                add_downsample=not is_final_block,
         
     | 
| 187 | 
         
            +
                                use_linear_projection=self.use_linear_projection,
         
     | 
| 188 | 
         
            +
                                only_cross_attention=only_cross_attention[i],
         
     | 
| 189 | 
         
            +
                                use_memory_efficient_attention=self.use_memory_efficient_attention,
         
     | 
| 190 | 
         
            +
                                dtype=self.dtype,
         
     | 
| 191 | 
         
            +
                            )
         
     | 
| 192 | 
         
            +
                        else:
         
     | 
| 193 | 
         
            +
                            down_block = FlaxDownBlock2D(
         
     | 
| 194 | 
         
            +
                                in_channels=input_channel,
         
     | 
| 195 | 
         
            +
                                out_channels=output_channel,
         
     | 
| 196 | 
         
            +
                                dropout=self.dropout,
         
     | 
| 197 | 
         
            +
                                num_layers=self.layers_per_block,
         
     | 
| 198 | 
         
            +
                                add_downsample=not is_final_block,
         
     | 
| 199 | 
         
            +
                                dtype=self.dtype,
         
     | 
| 200 | 
         
            +
                            )
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                        down_blocks.append(down_block)
         
     | 
| 203 | 
         
            +
                    self.down_blocks = down_blocks
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                    # mid
         
     | 
| 206 | 
         
            +
                    self.mid_block = FlaxUNetMidBlock2DCrossAttn(
         
     | 
| 207 | 
         
            +
                        in_channels=block_out_channels[-1],
         
     | 
| 208 | 
         
            +
                        dropout=self.dropout,
         
     | 
| 209 | 
         
            +
                        num_attention_heads=num_attention_heads[-1],
         
     | 
| 210 | 
         
            +
                        use_linear_projection=self.use_linear_projection,
         
     | 
| 211 | 
         
            +
                        use_memory_efficient_attention=self.use_memory_efficient_attention,
         
     | 
| 212 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 213 | 
         
            +
                    )
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
                    # up
         
     | 
| 216 | 
         
            +
                    up_blocks = []
         
     | 
| 217 | 
         
            +
                    reversed_block_out_channels = list(reversed(block_out_channels))
         
     | 
| 218 | 
         
            +
                    reversed_num_attention_heads = list(reversed(num_attention_heads))
         
     | 
| 219 | 
         
            +
                    only_cross_attention = list(reversed(only_cross_attention))
         
     | 
| 220 | 
         
            +
                    output_channel = reversed_block_out_channels[0]
         
     | 
| 221 | 
         
            +
                    for i, up_block_type in enumerate(self.up_block_types):
         
     | 
| 222 | 
         
            +
                        prev_output_channel = output_channel
         
     | 
| 223 | 
         
            +
                        output_channel = reversed_block_out_channels[i]
         
     | 
| 224 | 
         
            +
                        input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
                        is_final_block = i == len(block_out_channels) - 1
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                        if up_block_type == "CrossAttnUpBlock2D":
         
     | 
| 229 | 
         
            +
                            up_block = FlaxCrossAttnUpBlock2D(
         
     | 
| 230 | 
         
            +
                                in_channels=input_channel,
         
     | 
| 231 | 
         
            +
                                out_channels=output_channel,
         
     | 
| 232 | 
         
            +
                                prev_output_channel=prev_output_channel,
         
     | 
| 233 | 
         
            +
                                num_layers=self.layers_per_block + 1,
         
     | 
| 234 | 
         
            +
                                num_attention_heads=reversed_num_attention_heads[i],
         
     | 
| 235 | 
         
            +
                                add_upsample=not is_final_block,
         
     | 
| 236 | 
         
            +
                                dropout=self.dropout,
         
     | 
| 237 | 
         
            +
                                use_linear_projection=self.use_linear_projection,
         
     | 
| 238 | 
         
            +
                                only_cross_attention=only_cross_attention[i],
         
     | 
| 239 | 
         
            +
                                use_memory_efficient_attention=self.use_memory_efficient_attention,
         
     | 
| 240 | 
         
            +
                                dtype=self.dtype,
         
     | 
| 241 | 
         
            +
                            )
         
     | 
| 242 | 
         
            +
                        else:
         
     | 
| 243 | 
         
            +
                            up_block = FlaxUpBlock2D(
         
     | 
| 244 | 
         
            +
                                in_channels=input_channel,
         
     | 
| 245 | 
         
            +
                                out_channels=output_channel,
         
     | 
| 246 | 
         
            +
                                prev_output_channel=prev_output_channel,
         
     | 
| 247 | 
         
            +
                                num_layers=self.layers_per_block + 1,
         
     | 
| 248 | 
         
            +
                                add_upsample=not is_final_block,
         
     | 
| 249 | 
         
            +
                                dropout=self.dropout,
         
     | 
| 250 | 
         
            +
                                dtype=self.dtype,
         
     | 
| 251 | 
         
            +
                            )
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
                        up_blocks.append(up_block)
         
     | 
| 254 | 
         
            +
                        prev_output_channel = output_channel
         
     | 
| 255 | 
         
            +
                    self.up_blocks = up_blocks
         
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
                    # out
         
     | 
| 258 | 
         
            +
                    self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5)
         
     | 
| 259 | 
         
            +
                    self.conv_out = nn.Conv(
         
     | 
| 260 | 
         
            +
                        self.out_channels,
         
     | 
| 261 | 
         
            +
                        kernel_size=(3, 3),
         
     | 
| 262 | 
         
            +
                        strides=(1, 1),
         
     | 
| 263 | 
         
            +
                        padding=((1, 1), (1, 1)),
         
     | 
| 264 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 265 | 
         
            +
                    )
         
     | 
| 266 | 
         
            +
             
     | 
| 267 | 
         
            +
                def __call__(
         
     | 
| 268 | 
         
            +
                    self,
         
     | 
| 269 | 
         
            +
                    sample,
         
     | 
| 270 | 
         
            +
                    timesteps,
         
     | 
| 271 | 
         
            +
                    encoder_hidden_states,
         
     | 
| 272 | 
         
            +
                    down_block_additional_residuals=None,
         
     | 
| 273 | 
         
            +
                    mid_block_additional_residual=None,
         
     | 
| 274 | 
         
            +
                    return_dict: bool = True,
         
     | 
| 275 | 
         
            +
                    train: bool = False,
         
     | 
| 276 | 
         
            +
                ) -> Union[FlaxUNet2DConditionOutput, Tuple]:
         
     | 
| 277 | 
         
            +
                    r"""
         
     | 
| 278 | 
         
            +
                    Args:
         
     | 
| 279 | 
         
            +
                        sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
         
     | 
| 280 | 
         
            +
                        timestep (`jnp.ndarray` or `float` or `int`): timesteps
         
     | 
| 281 | 
         
            +
                        encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
         
     | 
| 282 | 
         
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         
     | 
| 283 | 
         
            +
                            Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
         
     | 
| 284 | 
         
            +
                            plain tuple.
         
     | 
| 285 | 
         
            +
                        train (`bool`, *optional*, defaults to `False`):
         
     | 
| 286 | 
         
            +
                            Use deterministic functions and disable dropout when not training.
         
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
                    Returns:
         
     | 
| 289 | 
         
            +
                        [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
         
     | 
| 290 | 
         
            +
                        [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
         
     | 
| 291 | 
         
            +
                        When returning a tuple, the first element is the sample tensor.
         
     | 
| 292 | 
         
            +
                    """
         
     | 
| 293 | 
         
            +
                    # 1. time
         
     | 
| 294 | 
         
            +
                    if not isinstance(timesteps, jnp.ndarray):
         
     | 
| 295 | 
         
            +
                        timesteps = jnp.array([timesteps], dtype=jnp.int32)
         
     | 
| 296 | 
         
            +
                    elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
         
     | 
| 297 | 
         
            +
                        timesteps = timesteps.astype(dtype=jnp.float32)
         
     | 
| 298 | 
         
            +
                        timesteps = jnp.expand_dims(timesteps, 0)
         
     | 
| 299 | 
         
            +
             
     | 
| 300 | 
         
            +
                    t_emb = self.time_proj(timesteps)
         
     | 
| 301 | 
         
            +
                    t_emb = self.time_embedding(t_emb)
         
     | 
| 302 | 
         
            +
             
     | 
| 303 | 
         
            +
                    # 2. pre-process
         
     | 
| 304 | 
         
            +
                    sample = jnp.transpose(sample, (0, 2, 3, 1))
         
     | 
| 305 | 
         
            +
                    sample = self.conv_in(sample)
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
            +
                    # 3. down
         
     | 
| 308 | 
         
            +
                    down_block_res_samples = (sample,)
         
     | 
| 309 | 
         
            +
                    for down_block in self.down_blocks:
         
     | 
| 310 | 
         
            +
                        if isinstance(down_block, FlaxCrossAttnDownBlock2D):
         
     | 
| 311 | 
         
            +
                            sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
         
     | 
| 312 | 
         
            +
                        else:
         
     | 
| 313 | 
         
            +
                            sample, res_samples = down_block(sample, t_emb, deterministic=not train)
         
     | 
| 314 | 
         
            +
                        down_block_res_samples += res_samples
         
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
                    if down_block_additional_residuals is not None:
         
     | 
| 317 | 
         
            +
                        new_down_block_res_samples = ()
         
     | 
| 318 | 
         
            +
             
     | 
| 319 | 
         
            +
                        for down_block_res_sample, down_block_additional_residual in zip(
         
     | 
| 320 | 
         
            +
                            down_block_res_samples, down_block_additional_residuals
         
     | 
| 321 | 
         
            +
                        ):
         
     | 
| 322 | 
         
            +
                            down_block_res_sample += down_block_additional_residual
         
     | 
| 323 | 
         
            +
                            new_down_block_res_samples += (down_block_res_sample,)
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
                        down_block_res_samples = new_down_block_res_samples
         
     | 
| 326 | 
         
            +
             
     | 
| 327 | 
         
            +
                    # 4. mid
         
     | 
| 328 | 
         
            +
                    sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
         
     | 
| 329 | 
         
            +
             
     | 
| 330 | 
         
            +
                    if mid_block_additional_residual is not None:
         
     | 
| 331 | 
         
            +
                        sample += mid_block_additional_residual
         
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
            +
                    # 5. up
         
     | 
| 334 | 
         
            +
                    for up_block in self.up_blocks:
         
     | 
| 335 | 
         
            +
                        res_samples = down_block_res_samples[-(self.layers_per_block + 1) :]
         
     | 
| 336 | 
         
            +
                        down_block_res_samples = down_block_res_samples[: -(self.layers_per_block + 1)]
         
     | 
| 337 | 
         
            +
                        if isinstance(up_block, FlaxCrossAttnUpBlock2D):
         
     | 
| 338 | 
         
            +
                            sample = up_block(
         
     | 
| 339 | 
         
            +
                                sample,
         
     | 
| 340 | 
         
            +
                                temb=t_emb,
         
     | 
| 341 | 
         
            +
                                encoder_hidden_states=encoder_hidden_states,
         
     | 
| 342 | 
         
            +
                                res_hidden_states_tuple=res_samples,
         
     | 
| 343 | 
         
            +
                                deterministic=not train,
         
     | 
| 344 | 
         
            +
                            )
         
     | 
| 345 | 
         
            +
                        else:
         
     | 
| 346 | 
         
            +
                            sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples, deterministic=not train)
         
     | 
| 347 | 
         
            +
             
     | 
| 348 | 
         
            +
                    # 6. post-process
         
     | 
| 349 | 
         
            +
                    sample = self.conv_norm_out(sample)
         
     | 
| 350 | 
         
            +
                    sample = nn.silu(sample)
         
     | 
| 351 | 
         
            +
                    sample = self.conv_out(sample)
         
     | 
| 352 | 
         
            +
                    sample = jnp.transpose(sample, (0, 3, 1, 2))
         
     | 
| 353 | 
         
            +
             
     | 
| 354 | 
         
            +
                    if not return_dict:
         
     | 
| 355 | 
         
            +
                        return (sample,)
         
     | 
| 356 | 
         
            +
             
     | 
| 357 | 
         
            +
                    return FlaxUNet2DConditionOutput(sample=sample)
         
     | 
    	
        6DoF/diffusers/models/unet_3d_blocks.py
    ADDED
    
    | 
         @@ -0,0 +1,679 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            import torch
         
     | 
| 16 | 
         
            +
            from torch import nn
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
         
     | 
| 19 | 
         
            +
            from .transformer_2d import Transformer2DModel
         
     | 
| 20 | 
         
            +
            from .transformer_temporal import TransformerTemporalModel
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            def get_down_block(
         
     | 
| 24 | 
         
            +
                down_block_type,
         
     | 
| 25 | 
         
            +
                num_layers,
         
     | 
| 26 | 
         
            +
                in_channels,
         
     | 
| 27 | 
         
            +
                out_channels,
         
     | 
| 28 | 
         
            +
                temb_channels,
         
     | 
| 29 | 
         
            +
                add_downsample,
         
     | 
| 30 | 
         
            +
                resnet_eps,
         
     | 
| 31 | 
         
            +
                resnet_act_fn,
         
     | 
| 32 | 
         
            +
                num_attention_heads,
         
     | 
| 33 | 
         
            +
                resnet_groups=None,
         
     | 
| 34 | 
         
            +
                cross_attention_dim=None,
         
     | 
| 35 | 
         
            +
                downsample_padding=None,
         
     | 
| 36 | 
         
            +
                dual_cross_attention=False,
         
     | 
| 37 | 
         
            +
                use_linear_projection=True,
         
     | 
| 38 | 
         
            +
                only_cross_attention=False,
         
     | 
| 39 | 
         
            +
                upcast_attention=False,
         
     | 
| 40 | 
         
            +
                resnet_time_scale_shift="default",
         
     | 
| 41 | 
         
            +
            ):
         
     | 
| 42 | 
         
            +
                if down_block_type == "DownBlock3D":
         
     | 
| 43 | 
         
            +
                    return DownBlock3D(
         
     | 
| 44 | 
         
            +
                        num_layers=num_layers,
         
     | 
| 45 | 
         
            +
                        in_channels=in_channels,
         
     | 
| 46 | 
         
            +
                        out_channels=out_channels,
         
     | 
| 47 | 
         
            +
                        temb_channels=temb_channels,
         
     | 
| 48 | 
         
            +
                        add_downsample=add_downsample,
         
     | 
| 49 | 
         
            +
                        resnet_eps=resnet_eps,
         
     | 
| 50 | 
         
            +
                        resnet_act_fn=resnet_act_fn,
         
     | 
| 51 | 
         
            +
                        resnet_groups=resnet_groups,
         
     | 
| 52 | 
         
            +
                        downsample_padding=downsample_padding,
         
     | 
| 53 | 
         
            +
                        resnet_time_scale_shift=resnet_time_scale_shift,
         
     | 
| 54 | 
         
            +
                    )
         
     | 
| 55 | 
         
            +
                elif down_block_type == "CrossAttnDownBlock3D":
         
     | 
| 56 | 
         
            +
                    if cross_attention_dim is None:
         
     | 
| 57 | 
         
            +
                        raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
         
     | 
| 58 | 
         
            +
                    return CrossAttnDownBlock3D(
         
     | 
| 59 | 
         
            +
                        num_layers=num_layers,
         
     | 
| 60 | 
         
            +
                        in_channels=in_channels,
         
     | 
| 61 | 
         
            +
                        out_channels=out_channels,
         
     | 
| 62 | 
         
            +
                        temb_channels=temb_channels,
         
     | 
| 63 | 
         
            +
                        add_downsample=add_downsample,
         
     | 
| 64 | 
         
            +
                        resnet_eps=resnet_eps,
         
     | 
| 65 | 
         
            +
                        resnet_act_fn=resnet_act_fn,
         
     | 
| 66 | 
         
            +
                        resnet_groups=resnet_groups,
         
     | 
| 67 | 
         
            +
                        downsample_padding=downsample_padding,
         
     | 
| 68 | 
         
            +
                        cross_attention_dim=cross_attention_dim,
         
     | 
| 69 | 
         
            +
                        num_attention_heads=num_attention_heads,
         
     | 
| 70 | 
         
            +
                        dual_cross_attention=dual_cross_attention,
         
     | 
| 71 | 
         
            +
                        use_linear_projection=use_linear_projection,
         
     | 
| 72 | 
         
            +
                        only_cross_attention=only_cross_attention,
         
     | 
| 73 | 
         
            +
                        upcast_attention=upcast_attention,
         
     | 
| 74 | 
         
            +
                        resnet_time_scale_shift=resnet_time_scale_shift,
         
     | 
| 75 | 
         
            +
                    )
         
     | 
| 76 | 
         
            +
                raise ValueError(f"{down_block_type} does not exist.")
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            def get_up_block(
         
     | 
| 80 | 
         
            +
                up_block_type,
         
     | 
| 81 | 
         
            +
                num_layers,
         
     | 
| 82 | 
         
            +
                in_channels,
         
     | 
| 83 | 
         
            +
                out_channels,
         
     | 
| 84 | 
         
            +
                prev_output_channel,
         
     | 
| 85 | 
         
            +
                temb_channels,
         
     | 
| 86 | 
         
            +
                add_upsample,
         
     | 
| 87 | 
         
            +
                resnet_eps,
         
     | 
| 88 | 
         
            +
                resnet_act_fn,
         
     | 
| 89 | 
         
            +
                num_attention_heads,
         
     | 
| 90 | 
         
            +
                resnet_groups=None,
         
     | 
| 91 | 
         
            +
                cross_attention_dim=None,
         
     | 
| 92 | 
         
            +
                dual_cross_attention=False,
         
     | 
| 93 | 
         
            +
                use_linear_projection=True,
         
     | 
| 94 | 
         
            +
                only_cross_attention=False,
         
     | 
| 95 | 
         
            +
                upcast_attention=False,
         
     | 
| 96 | 
         
            +
                resnet_time_scale_shift="default",
         
     | 
| 97 | 
         
            +
            ):
         
     | 
| 98 | 
         
            +
                if up_block_type == "UpBlock3D":
         
     | 
| 99 | 
         
            +
                    return UpBlock3D(
         
     | 
| 100 | 
         
            +
                        num_layers=num_layers,
         
     | 
| 101 | 
         
            +
                        in_channels=in_channels,
         
     | 
| 102 | 
         
            +
                        out_channels=out_channels,
         
     | 
| 103 | 
         
            +
                        prev_output_channel=prev_output_channel,
         
     | 
| 104 | 
         
            +
                        temb_channels=temb_channels,
         
     | 
| 105 | 
         
            +
                        add_upsample=add_upsample,
         
     | 
| 106 | 
         
            +
                        resnet_eps=resnet_eps,
         
     | 
| 107 | 
         
            +
                        resnet_act_fn=resnet_act_fn,
         
     | 
| 108 | 
         
            +
                        resnet_groups=resnet_groups,
         
     | 
| 109 | 
         
            +
                        resnet_time_scale_shift=resnet_time_scale_shift,
         
     | 
| 110 | 
         
            +
                    )
         
     | 
| 111 | 
         
            +
                elif up_block_type == "CrossAttnUpBlock3D":
         
     | 
| 112 | 
         
            +
                    if cross_attention_dim is None:
         
     | 
| 113 | 
         
            +
                        raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
         
     | 
| 114 | 
         
            +
                    return CrossAttnUpBlock3D(
         
     | 
| 115 | 
         
            +
                        num_layers=num_layers,
         
     | 
| 116 | 
         
            +
                        in_channels=in_channels,
         
     | 
| 117 | 
         
            +
                        out_channels=out_channels,
         
     | 
| 118 | 
         
            +
                        prev_output_channel=prev_output_channel,
         
     | 
| 119 | 
         
            +
                        temb_channels=temb_channels,
         
     | 
| 120 | 
         
            +
                        add_upsample=add_upsample,
         
     | 
| 121 | 
         
            +
                        resnet_eps=resnet_eps,
         
     | 
| 122 | 
         
            +
                        resnet_act_fn=resnet_act_fn,
         
     | 
| 123 | 
         
            +
                        resnet_groups=resnet_groups,
         
     | 
| 124 | 
         
            +
                        cross_attention_dim=cross_attention_dim,
         
     | 
| 125 | 
         
            +
                        num_attention_heads=num_attention_heads,
         
     | 
| 126 | 
         
            +
                        dual_cross_attention=dual_cross_attention,
         
     | 
| 127 | 
         
            +
                        use_linear_projection=use_linear_projection,
         
     | 
| 128 | 
         
            +
                        only_cross_attention=only_cross_attention,
         
     | 
| 129 | 
         
            +
                        upcast_attention=upcast_attention,
         
     | 
| 130 | 
         
            +
                        resnet_time_scale_shift=resnet_time_scale_shift,
         
     | 
| 131 | 
         
            +
                    )
         
     | 
| 132 | 
         
            +
                raise ValueError(f"{up_block_type} does not exist.")
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
            class UNetMidBlock3DCrossAttn(nn.Module):
         
     | 
| 136 | 
         
            +
                def __init__(
         
     | 
| 137 | 
         
            +
                    self,
         
     | 
| 138 | 
         
            +
                    in_channels: int,
         
     | 
| 139 | 
         
            +
                    temb_channels: int,
         
     | 
| 140 | 
         
            +
                    dropout: float = 0.0,
         
     | 
| 141 | 
         
            +
                    num_layers: int = 1,
         
     | 
| 142 | 
         
            +
                    resnet_eps: float = 1e-6,
         
     | 
| 143 | 
         
            +
                    resnet_time_scale_shift: str = "default",
         
     | 
| 144 | 
         
            +
                    resnet_act_fn: str = "swish",
         
     | 
| 145 | 
         
            +
                    resnet_groups: int = 32,
         
     | 
| 146 | 
         
            +
                    resnet_pre_norm: bool = True,
         
     | 
| 147 | 
         
            +
                    num_attention_heads=1,
         
     | 
| 148 | 
         
            +
                    output_scale_factor=1.0,
         
     | 
| 149 | 
         
            +
                    cross_attention_dim=1280,
         
     | 
| 150 | 
         
            +
                    dual_cross_attention=False,
         
     | 
| 151 | 
         
            +
                    use_linear_projection=True,
         
     | 
| 152 | 
         
            +
                    upcast_attention=False,
         
     | 
| 153 | 
         
            +
                ):
         
     | 
| 154 | 
         
            +
                    super().__init__()
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                    self.has_cross_attention = True
         
     | 
| 157 | 
         
            +
                    self.num_attention_heads = num_attention_heads
         
     | 
| 158 | 
         
            +
                    resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                    # there is always at least one resnet
         
     | 
| 161 | 
         
            +
                    resnets = [
         
     | 
| 162 | 
         
            +
                        ResnetBlock2D(
         
     | 
| 163 | 
         
            +
                            in_channels=in_channels,
         
     | 
| 164 | 
         
            +
                            out_channels=in_channels,
         
     | 
| 165 | 
         
            +
                            temb_channels=temb_channels,
         
     | 
| 166 | 
         
            +
                            eps=resnet_eps,
         
     | 
| 167 | 
         
            +
                            groups=resnet_groups,
         
     | 
| 168 | 
         
            +
                            dropout=dropout,
         
     | 
| 169 | 
         
            +
                            time_embedding_norm=resnet_time_scale_shift,
         
     | 
| 170 | 
         
            +
                            non_linearity=resnet_act_fn,
         
     | 
| 171 | 
         
            +
                            output_scale_factor=output_scale_factor,
         
     | 
| 172 | 
         
            +
                            pre_norm=resnet_pre_norm,
         
     | 
| 173 | 
         
            +
                        )
         
     | 
| 174 | 
         
            +
                    ]
         
     | 
| 175 | 
         
            +
                    temp_convs = [
         
     | 
| 176 | 
         
            +
                        TemporalConvLayer(
         
     | 
| 177 | 
         
            +
                            in_channels,
         
     | 
| 178 | 
         
            +
                            in_channels,
         
     | 
| 179 | 
         
            +
                            dropout=0.1,
         
     | 
| 180 | 
         
            +
                        )
         
     | 
| 181 | 
         
            +
                    ]
         
     | 
| 182 | 
         
            +
                    attentions = []
         
     | 
| 183 | 
         
            +
                    temp_attentions = []
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                    for _ in range(num_layers):
         
     | 
| 186 | 
         
            +
                        attentions.append(
         
     | 
| 187 | 
         
            +
                            Transformer2DModel(
         
     | 
| 188 | 
         
            +
                                in_channels // num_attention_heads,
         
     | 
| 189 | 
         
            +
                                num_attention_heads,
         
     | 
| 190 | 
         
            +
                                in_channels=in_channels,
         
     | 
| 191 | 
         
            +
                                num_layers=1,
         
     | 
| 192 | 
         
            +
                                cross_attention_dim=cross_attention_dim,
         
     | 
| 193 | 
         
            +
                                norm_num_groups=resnet_groups,
         
     | 
| 194 | 
         
            +
                                use_linear_projection=use_linear_projection,
         
     | 
| 195 | 
         
            +
                                upcast_attention=upcast_attention,
         
     | 
| 196 | 
         
            +
                            )
         
     | 
| 197 | 
         
            +
                        )
         
     | 
| 198 | 
         
            +
                        temp_attentions.append(
         
     | 
| 199 | 
         
            +
                            TransformerTemporalModel(
         
     | 
| 200 | 
         
            +
                                in_channels // num_attention_heads,
         
     | 
| 201 | 
         
            +
                                num_attention_heads,
         
     | 
| 202 | 
         
            +
                                in_channels=in_channels,
         
     | 
| 203 | 
         
            +
                                num_layers=1,
         
     | 
| 204 | 
         
            +
                                cross_attention_dim=cross_attention_dim,
         
     | 
| 205 | 
         
            +
                                norm_num_groups=resnet_groups,
         
     | 
| 206 | 
         
            +
                            )
         
     | 
| 207 | 
         
            +
                        )
         
     | 
| 208 | 
         
            +
                        resnets.append(
         
     | 
| 209 | 
         
            +
                            ResnetBlock2D(
         
     | 
| 210 | 
         
            +
                                in_channels=in_channels,
         
     | 
| 211 | 
         
            +
                                out_channels=in_channels,
         
     | 
| 212 | 
         
            +
                                temb_channels=temb_channels,
         
     | 
| 213 | 
         
            +
                                eps=resnet_eps,
         
     | 
| 214 | 
         
            +
                                groups=resnet_groups,
         
     | 
| 215 | 
         
            +
                                dropout=dropout,
         
     | 
| 216 | 
         
            +
                                time_embedding_norm=resnet_time_scale_shift,
         
     | 
| 217 | 
         
            +
                                non_linearity=resnet_act_fn,
         
     | 
| 218 | 
         
            +
                                output_scale_factor=output_scale_factor,
         
     | 
| 219 | 
         
            +
                                pre_norm=resnet_pre_norm,
         
     | 
| 220 | 
         
            +
                            )
         
     | 
| 221 | 
         
            +
                        )
         
     | 
| 222 | 
         
            +
                        temp_convs.append(
         
     | 
| 223 | 
         
            +
                            TemporalConvLayer(
         
     | 
| 224 | 
         
            +
                                in_channels,
         
     | 
| 225 | 
         
            +
                                in_channels,
         
     | 
| 226 | 
         
            +
                                dropout=0.1,
         
     | 
| 227 | 
         
            +
                            )
         
     | 
| 228 | 
         
            +
                        )
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                    self.resnets = nn.ModuleList(resnets)
         
     | 
| 231 | 
         
            +
                    self.temp_convs = nn.ModuleList(temp_convs)
         
     | 
| 232 | 
         
            +
                    self.attentions = nn.ModuleList(attentions)
         
     | 
| 233 | 
         
            +
                    self.temp_attentions = nn.ModuleList(temp_attentions)
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                def forward(
         
     | 
| 236 | 
         
            +
                    self,
         
     | 
| 237 | 
         
            +
                    hidden_states,
         
     | 
| 238 | 
         
            +
                    temb=None,
         
     | 
| 239 | 
         
            +
                    encoder_hidden_states=None,
         
     | 
| 240 | 
         
            +
                    attention_mask=None,
         
     | 
| 241 | 
         
            +
                    num_frames=1,
         
     | 
| 242 | 
         
            +
                    cross_attention_kwargs=None,
         
     | 
| 243 | 
         
            +
                ):
         
     | 
| 244 | 
         
            +
                    hidden_states = self.resnets[0](hidden_states, temb)
         
     | 
| 245 | 
         
            +
                    hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
         
     | 
| 246 | 
         
            +
                    for attn, temp_attn, resnet, temp_conv in zip(
         
     | 
| 247 | 
         
            +
                        self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
         
     | 
| 248 | 
         
            +
                    ):
         
     | 
| 249 | 
         
            +
                        hidden_states = attn(
         
     | 
| 250 | 
         
            +
                            hidden_states,
         
     | 
| 251 | 
         
            +
                            encoder_hidden_states=encoder_hidden_states,
         
     | 
| 252 | 
         
            +
                            cross_attention_kwargs=cross_attention_kwargs,
         
     | 
| 253 | 
         
            +
                            return_dict=False,
         
     | 
| 254 | 
         
            +
                        )[0]
         
     | 
| 255 | 
         
            +
                        hidden_states = temp_attn(
         
     | 
| 256 | 
         
            +
                            hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
         
     | 
| 257 | 
         
            +
                        )[0]
         
     | 
| 258 | 
         
            +
                        hidden_states = resnet(hidden_states, temb)
         
     | 
| 259 | 
         
            +
                        hidden_states = temp_conv(hidden_states, num_frames=num_frames)
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
                    return hidden_states
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
            class CrossAttnDownBlock3D(nn.Module):
         
     | 
| 265 | 
         
            +
                def __init__(
         
     | 
| 266 | 
         
            +
                    self,
         
     | 
| 267 | 
         
            +
                    in_channels: int,
         
     | 
| 268 | 
         
            +
                    out_channels: int,
         
     | 
| 269 | 
         
            +
                    temb_channels: int,
         
     | 
| 270 | 
         
            +
                    dropout: float = 0.0,
         
     | 
| 271 | 
         
            +
                    num_layers: int = 1,
         
     | 
| 272 | 
         
            +
                    resnet_eps: float = 1e-6,
         
     | 
| 273 | 
         
            +
                    resnet_time_scale_shift: str = "default",
         
     | 
| 274 | 
         
            +
                    resnet_act_fn: str = "swish",
         
     | 
| 275 | 
         
            +
                    resnet_groups: int = 32,
         
     | 
| 276 | 
         
            +
                    resnet_pre_norm: bool = True,
         
     | 
| 277 | 
         
            +
                    num_attention_heads=1,
         
     | 
| 278 | 
         
            +
                    cross_attention_dim=1280,
         
     | 
| 279 | 
         
            +
                    output_scale_factor=1.0,
         
     | 
| 280 | 
         
            +
                    downsample_padding=1,
         
     | 
| 281 | 
         
            +
                    add_downsample=True,
         
     | 
| 282 | 
         
            +
                    dual_cross_attention=False,
         
     | 
| 283 | 
         
            +
                    use_linear_projection=False,
         
     | 
| 284 | 
         
            +
                    only_cross_attention=False,
         
     | 
| 285 | 
         
            +
                    upcast_attention=False,
         
     | 
| 286 | 
         
            +
                ):
         
     | 
| 287 | 
         
            +
                    super().__init__()
         
     | 
| 288 | 
         
            +
                    resnets = []
         
     | 
| 289 | 
         
            +
                    attentions = []
         
     | 
| 290 | 
         
            +
                    temp_attentions = []
         
     | 
| 291 | 
         
            +
                    temp_convs = []
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
                    self.has_cross_attention = True
         
     | 
| 294 | 
         
            +
                    self.num_attention_heads = num_attention_heads
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
                    for i in range(num_layers):
         
     | 
| 297 | 
         
            +
                        in_channels = in_channels if i == 0 else out_channels
         
     | 
| 298 | 
         
            +
                        resnets.append(
         
     | 
| 299 | 
         
            +
                            ResnetBlock2D(
         
     | 
| 300 | 
         
            +
                                in_channels=in_channels,
         
     | 
| 301 | 
         
            +
                                out_channels=out_channels,
         
     | 
| 302 | 
         
            +
                                temb_channels=temb_channels,
         
     | 
| 303 | 
         
            +
                                eps=resnet_eps,
         
     | 
| 304 | 
         
            +
                                groups=resnet_groups,
         
     | 
| 305 | 
         
            +
                                dropout=dropout,
         
     | 
| 306 | 
         
            +
                                time_embedding_norm=resnet_time_scale_shift,
         
     | 
| 307 | 
         
            +
                                non_linearity=resnet_act_fn,
         
     | 
| 308 | 
         
            +
                                output_scale_factor=output_scale_factor,
         
     | 
| 309 | 
         
            +
                                pre_norm=resnet_pre_norm,
         
     | 
| 310 | 
         
            +
                            )
         
     | 
| 311 | 
         
            +
                        )
         
     | 
| 312 | 
         
            +
                        temp_convs.append(
         
     | 
| 313 | 
         
            +
                            TemporalConvLayer(
         
     | 
| 314 | 
         
            +
                                out_channels,
         
     | 
| 315 | 
         
            +
                                out_channels,
         
     | 
| 316 | 
         
            +
                                dropout=0.1,
         
     | 
| 317 | 
         
            +
                            )
         
     | 
| 318 | 
         
            +
                        )
         
     | 
| 319 | 
         
            +
                        attentions.append(
         
     | 
| 320 | 
         
            +
                            Transformer2DModel(
         
     | 
| 321 | 
         
            +
                                out_channels // num_attention_heads,
         
     | 
| 322 | 
         
            +
                                num_attention_heads,
         
     | 
| 323 | 
         
            +
                                in_channels=out_channels,
         
     | 
| 324 | 
         
            +
                                num_layers=1,
         
     | 
| 325 | 
         
            +
                                cross_attention_dim=cross_attention_dim,
         
     | 
| 326 | 
         
            +
                                norm_num_groups=resnet_groups,
         
     | 
| 327 | 
         
            +
                                use_linear_projection=use_linear_projection,
         
     | 
| 328 | 
         
            +
                                only_cross_attention=only_cross_attention,
         
     | 
| 329 | 
         
            +
                                upcast_attention=upcast_attention,
         
     | 
| 330 | 
         
            +
                            )
         
     | 
| 331 | 
         
            +
                        )
         
     | 
| 332 | 
         
            +
                        temp_attentions.append(
         
     | 
| 333 | 
         
            +
                            TransformerTemporalModel(
         
     | 
| 334 | 
         
            +
                                out_channels // num_attention_heads,
         
     | 
| 335 | 
         
            +
                                num_attention_heads,
         
     | 
| 336 | 
         
            +
                                in_channels=out_channels,
         
     | 
| 337 | 
         
            +
                                num_layers=1,
         
     | 
| 338 | 
         
            +
                                cross_attention_dim=cross_attention_dim,
         
     | 
| 339 | 
         
            +
                                norm_num_groups=resnet_groups,
         
     | 
| 340 | 
         
            +
                            )
         
     | 
| 341 | 
         
            +
                        )
         
     | 
| 342 | 
         
            +
                    self.resnets = nn.ModuleList(resnets)
         
     | 
| 343 | 
         
            +
                    self.temp_convs = nn.ModuleList(temp_convs)
         
     | 
| 344 | 
         
            +
                    self.attentions = nn.ModuleList(attentions)
         
     | 
| 345 | 
         
            +
                    self.temp_attentions = nn.ModuleList(temp_attentions)
         
     | 
| 346 | 
         
            +
             
     | 
| 347 | 
         
            +
                    if add_downsample:
         
     | 
| 348 | 
         
            +
                        self.downsamplers = nn.ModuleList(
         
     | 
| 349 | 
         
            +
                            [
         
     | 
| 350 | 
         
            +
                                Downsample2D(
         
     | 
| 351 | 
         
            +
                                    out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
         
     | 
| 352 | 
         
            +
                                )
         
     | 
| 353 | 
         
            +
                            ]
         
     | 
| 354 | 
         
            +
                        )
         
     | 
| 355 | 
         
            +
                    else:
         
     | 
| 356 | 
         
            +
                        self.downsamplers = None
         
     | 
| 357 | 
         
            +
             
     | 
| 358 | 
         
            +
                    self.gradient_checkpointing = False
         
     | 
| 359 | 
         
            +
             
     | 
| 360 | 
         
            +
                def forward(
         
     | 
| 361 | 
         
            +
                    self,
         
     | 
| 362 | 
         
            +
                    hidden_states,
         
     | 
| 363 | 
         
            +
                    temb=None,
         
     | 
| 364 | 
         
            +
                    encoder_hidden_states=None,
         
     | 
| 365 | 
         
            +
                    attention_mask=None,
         
     | 
| 366 | 
         
            +
                    num_frames=1,
         
     | 
| 367 | 
         
            +
                    cross_attention_kwargs=None,
         
     | 
| 368 | 
         
            +
                ):
         
     | 
| 369 | 
         
            +
                    # TODO(Patrick, William) - attention mask is not used
         
     | 
| 370 | 
         
            +
                    output_states = ()
         
     | 
| 371 | 
         
            +
             
     | 
| 372 | 
         
            +
                    for resnet, temp_conv, attn, temp_attn in zip(
         
     | 
| 373 | 
         
            +
                        self.resnets, self.temp_convs, self.attentions, self.temp_attentions
         
     | 
| 374 | 
         
            +
                    ):
         
     | 
| 375 | 
         
            +
                        hidden_states = resnet(hidden_states, temb)
         
     | 
| 376 | 
         
            +
                        hidden_states = temp_conv(hidden_states, num_frames=num_frames)
         
     | 
| 377 | 
         
            +
                        hidden_states = attn(
         
     | 
| 378 | 
         
            +
                            hidden_states,
         
     | 
| 379 | 
         
            +
                            encoder_hidden_states=encoder_hidden_states,
         
     | 
| 380 | 
         
            +
                            cross_attention_kwargs=cross_attention_kwargs,
         
     | 
| 381 | 
         
            +
                            return_dict=False,
         
     | 
| 382 | 
         
            +
                        )[0]
         
     | 
| 383 | 
         
            +
                        hidden_states = temp_attn(
         
     | 
| 384 | 
         
            +
                            hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
         
     | 
| 385 | 
         
            +
                        )[0]
         
     | 
| 386 | 
         
            +
             
     | 
| 387 | 
         
            +
                        output_states += (hidden_states,)
         
     | 
| 388 | 
         
            +
             
     | 
| 389 | 
         
            +
                    if self.downsamplers is not None:
         
     | 
| 390 | 
         
            +
                        for downsampler in self.downsamplers:
         
     | 
| 391 | 
         
            +
                            hidden_states = downsampler(hidden_states)
         
     | 
| 392 | 
         
            +
             
     | 
| 393 | 
         
            +
                        output_states += (hidden_states,)
         
     | 
| 394 | 
         
            +
             
     | 
| 395 | 
         
            +
                    return hidden_states, output_states
         
     | 
| 396 | 
         
            +
             
     | 
| 397 | 
         
            +
             
     | 
| 398 | 
         
            +
            class DownBlock3D(nn.Module):
         
     | 
| 399 | 
         
            +
                def __init__(
         
     | 
| 400 | 
         
            +
                    self,
         
     | 
| 401 | 
         
            +
                    in_channels: int,
         
     | 
| 402 | 
         
            +
                    out_channels: int,
         
     | 
| 403 | 
         
            +
                    temb_channels: int,
         
     | 
| 404 | 
         
            +
                    dropout: float = 0.0,
         
     | 
| 405 | 
         
            +
                    num_layers: int = 1,
         
     | 
| 406 | 
         
            +
                    resnet_eps: float = 1e-6,
         
     | 
| 407 | 
         
            +
                    resnet_time_scale_shift: str = "default",
         
     | 
| 408 | 
         
            +
                    resnet_act_fn: str = "swish",
         
     | 
| 409 | 
         
            +
                    resnet_groups: int = 32,
         
     | 
| 410 | 
         
            +
                    resnet_pre_norm: bool = True,
         
     | 
| 411 | 
         
            +
                    output_scale_factor=1.0,
         
     | 
| 412 | 
         
            +
                    add_downsample=True,
         
     | 
| 413 | 
         
            +
                    downsample_padding=1,
         
     | 
| 414 | 
         
            +
                ):
         
     | 
| 415 | 
         
            +
                    super().__init__()
         
     | 
| 416 | 
         
            +
                    resnets = []
         
     | 
| 417 | 
         
            +
                    temp_convs = []
         
     | 
| 418 | 
         
            +
             
     | 
| 419 | 
         
            +
                    for i in range(num_layers):
         
     | 
| 420 | 
         
            +
                        in_channels = in_channels if i == 0 else out_channels
         
     | 
| 421 | 
         
            +
                        resnets.append(
         
     | 
| 422 | 
         
            +
                            ResnetBlock2D(
         
     | 
| 423 | 
         
            +
                                in_channels=in_channels,
         
     | 
| 424 | 
         
            +
                                out_channels=out_channels,
         
     | 
| 425 | 
         
            +
                                temb_channels=temb_channels,
         
     | 
| 426 | 
         
            +
                                eps=resnet_eps,
         
     | 
| 427 | 
         
            +
                                groups=resnet_groups,
         
     | 
| 428 | 
         
            +
                                dropout=dropout,
         
     | 
| 429 | 
         
            +
                                time_embedding_norm=resnet_time_scale_shift,
         
     | 
| 430 | 
         
            +
                                non_linearity=resnet_act_fn,
         
     | 
| 431 | 
         
            +
                                output_scale_factor=output_scale_factor,
         
     | 
| 432 | 
         
            +
                                pre_norm=resnet_pre_norm,
         
     | 
| 433 | 
         
            +
                            )
         
     | 
| 434 | 
         
            +
                        )
         
     | 
| 435 | 
         
            +
                        temp_convs.append(
         
     | 
| 436 | 
         
            +
                            TemporalConvLayer(
         
     | 
| 437 | 
         
            +
                                out_channels,
         
     | 
| 438 | 
         
            +
                                out_channels,
         
     | 
| 439 | 
         
            +
                                dropout=0.1,
         
     | 
| 440 | 
         
            +
                            )
         
     | 
| 441 | 
         
            +
                        )
         
     | 
| 442 | 
         
            +
             
     | 
| 443 | 
         
            +
                    self.resnets = nn.ModuleList(resnets)
         
     | 
| 444 | 
         
            +
                    self.temp_convs = nn.ModuleList(temp_convs)
         
     | 
| 445 | 
         
            +
             
     | 
| 446 | 
         
            +
                    if add_downsample:
         
     | 
| 447 | 
         
            +
                        self.downsamplers = nn.ModuleList(
         
     | 
| 448 | 
         
            +
                            [
         
     | 
| 449 | 
         
            +
                                Downsample2D(
         
     | 
| 450 | 
         
            +
                                    out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
         
     | 
| 451 | 
         
            +
                                )
         
     | 
| 452 | 
         
            +
                            ]
         
     | 
| 453 | 
         
            +
                        )
         
     | 
| 454 | 
         
            +
                    else:
         
     | 
| 455 | 
         
            +
                        self.downsamplers = None
         
     | 
| 456 | 
         
            +
             
     | 
| 457 | 
         
            +
                    self.gradient_checkpointing = False
         
     | 
| 458 | 
         
            +
             
     | 
| 459 | 
         
            +
                def forward(self, hidden_states, temb=None, num_frames=1):
         
     | 
| 460 | 
         
            +
                    output_states = ()
         
     | 
| 461 | 
         
            +
             
     | 
| 462 | 
         
            +
                    for resnet, temp_conv in zip(self.resnets, self.temp_convs):
         
     | 
| 463 | 
         
            +
                        hidden_states = resnet(hidden_states, temb)
         
     | 
| 464 | 
         
            +
                        hidden_states = temp_conv(hidden_states, num_frames=num_frames)
         
     | 
| 465 | 
         
            +
             
     | 
| 466 | 
         
            +
                        output_states += (hidden_states,)
         
     | 
| 467 | 
         
            +
             
     | 
| 468 | 
         
            +
                    if self.downsamplers is not None:
         
     | 
| 469 | 
         
            +
                        for downsampler in self.downsamplers:
         
     | 
| 470 | 
         
            +
                            hidden_states = downsampler(hidden_states)
         
     | 
| 471 | 
         
            +
             
     | 
| 472 | 
         
            +
                        output_states += (hidden_states,)
         
     | 
| 473 | 
         
            +
             
     | 
| 474 | 
         
            +
                    return hidden_states, output_states
         
     | 
| 475 | 
         
            +
             
     | 
| 476 | 
         
            +
             
     | 
| 477 | 
         
            +
            class CrossAttnUpBlock3D(nn.Module):
         
     | 
| 478 | 
         
            +
                def __init__(
         
     | 
| 479 | 
         
            +
                    self,
         
     | 
| 480 | 
         
            +
                    in_channels: int,
         
     | 
| 481 | 
         
            +
                    out_channels: int,
         
     | 
| 482 | 
         
            +
                    prev_output_channel: int,
         
     | 
| 483 | 
         
            +
                    temb_channels: int,
         
     | 
| 484 | 
         
            +
                    dropout: float = 0.0,
         
     | 
| 485 | 
         
            +
                    num_layers: int = 1,
         
     | 
| 486 | 
         
            +
                    resnet_eps: float = 1e-6,
         
     | 
| 487 | 
         
            +
                    resnet_time_scale_shift: str = "default",
         
     | 
| 488 | 
         
            +
                    resnet_act_fn: str = "swish",
         
     | 
| 489 | 
         
            +
                    resnet_groups: int = 32,
         
     | 
| 490 | 
         
            +
                    resnet_pre_norm: bool = True,
         
     | 
| 491 | 
         
            +
                    num_attention_heads=1,
         
     | 
| 492 | 
         
            +
                    cross_attention_dim=1280,
         
     | 
| 493 | 
         
            +
                    output_scale_factor=1.0,
         
     | 
| 494 | 
         
            +
                    add_upsample=True,
         
     | 
| 495 | 
         
            +
                    dual_cross_attention=False,
         
     | 
| 496 | 
         
            +
                    use_linear_projection=False,
         
     | 
| 497 | 
         
            +
                    only_cross_attention=False,
         
     | 
| 498 | 
         
            +
                    upcast_attention=False,
         
     | 
| 499 | 
         
            +
                ):
         
     | 
| 500 | 
         
            +
                    super().__init__()
         
     | 
| 501 | 
         
            +
                    resnets = []
         
     | 
| 502 | 
         
            +
                    temp_convs = []
         
     | 
| 503 | 
         
            +
                    attentions = []
         
     | 
| 504 | 
         
            +
                    temp_attentions = []
         
     | 
| 505 | 
         
            +
             
     | 
| 506 | 
         
            +
                    self.has_cross_attention = True
         
     | 
| 507 | 
         
            +
                    self.num_attention_heads = num_attention_heads
         
     | 
| 508 | 
         
            +
             
     | 
| 509 | 
         
            +
                    for i in range(num_layers):
         
     | 
| 510 | 
         
            +
                        res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
         
     | 
| 511 | 
         
            +
                        resnet_in_channels = prev_output_channel if i == 0 else out_channels
         
     | 
| 512 | 
         
            +
             
     | 
| 513 | 
         
            +
                        resnets.append(
         
     | 
| 514 | 
         
            +
                            ResnetBlock2D(
         
     | 
| 515 | 
         
            +
                                in_channels=resnet_in_channels + res_skip_channels,
         
     | 
| 516 | 
         
            +
                                out_channels=out_channels,
         
     | 
| 517 | 
         
            +
                                temb_channels=temb_channels,
         
     | 
| 518 | 
         
            +
                                eps=resnet_eps,
         
     | 
| 519 | 
         
            +
                                groups=resnet_groups,
         
     | 
| 520 | 
         
            +
                                dropout=dropout,
         
     | 
| 521 | 
         
            +
                                time_embedding_norm=resnet_time_scale_shift,
         
     | 
| 522 | 
         
            +
                                non_linearity=resnet_act_fn,
         
     | 
| 523 | 
         
            +
                                output_scale_factor=output_scale_factor,
         
     | 
| 524 | 
         
            +
                                pre_norm=resnet_pre_norm,
         
     | 
| 525 | 
         
            +
                            )
         
     | 
| 526 | 
         
            +
                        )
         
     | 
| 527 | 
         
            +
                        temp_convs.append(
         
     | 
| 528 | 
         
            +
                            TemporalConvLayer(
         
     | 
| 529 | 
         
            +
                                out_channels,
         
     | 
| 530 | 
         
            +
                                out_channels,
         
     | 
| 531 | 
         
            +
                                dropout=0.1,
         
     | 
| 532 | 
         
            +
                            )
         
     | 
| 533 | 
         
            +
                        )
         
     | 
| 534 | 
         
            +
                        attentions.append(
         
     | 
| 535 | 
         
            +
                            Transformer2DModel(
         
     | 
| 536 | 
         
            +
                                out_channels // num_attention_heads,
         
     | 
| 537 | 
         
            +
                                num_attention_heads,
         
     | 
| 538 | 
         
            +
                                in_channels=out_channels,
         
     | 
| 539 | 
         
            +
                                num_layers=1,
         
     | 
| 540 | 
         
            +
                                cross_attention_dim=cross_attention_dim,
         
     | 
| 541 | 
         
            +
                                norm_num_groups=resnet_groups,
         
     | 
| 542 | 
         
            +
                                use_linear_projection=use_linear_projection,
         
     | 
| 543 | 
         
            +
                                only_cross_attention=only_cross_attention,
         
     | 
| 544 | 
         
            +
                                upcast_attention=upcast_attention,
         
     | 
| 545 | 
         
            +
                            )
         
     | 
| 546 | 
         
            +
                        )
         
     | 
| 547 | 
         
            +
                        temp_attentions.append(
         
     | 
| 548 | 
         
            +
                            TransformerTemporalModel(
         
     | 
| 549 | 
         
            +
                                out_channels // num_attention_heads,
         
     | 
| 550 | 
         
            +
                                num_attention_heads,
         
     | 
| 551 | 
         
            +
                                in_channels=out_channels,
         
     | 
| 552 | 
         
            +
                                num_layers=1,
         
     | 
| 553 | 
         
            +
                                cross_attention_dim=cross_attention_dim,
         
     | 
| 554 | 
         
            +
                                norm_num_groups=resnet_groups,
         
     | 
| 555 | 
         
            +
                            )
         
     | 
| 556 | 
         
            +
                        )
         
     | 
| 557 | 
         
            +
                    self.resnets = nn.ModuleList(resnets)
         
     | 
| 558 | 
         
            +
                    self.temp_convs = nn.ModuleList(temp_convs)
         
     | 
| 559 | 
         
            +
                    self.attentions = nn.ModuleList(attentions)
         
     | 
| 560 | 
         
            +
                    self.temp_attentions = nn.ModuleList(temp_attentions)
         
     | 
| 561 | 
         
            +
             
     | 
| 562 | 
         
            +
                    if add_upsample:
         
     | 
| 563 | 
         
            +
                        self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
         
     | 
| 564 | 
         
            +
                    else:
         
     | 
| 565 | 
         
            +
                        self.upsamplers = None
         
     | 
| 566 | 
         
            +
             
     | 
| 567 | 
         
            +
                    self.gradient_checkpointing = False
         
     | 
| 568 | 
         
            +
             
     | 
| 569 | 
         
            +
                def forward(
         
     | 
| 570 | 
         
            +
                    self,
         
     | 
| 571 | 
         
            +
                    hidden_states,
         
     | 
| 572 | 
         
            +
                    res_hidden_states_tuple,
         
     | 
| 573 | 
         
            +
                    temb=None,
         
     | 
| 574 | 
         
            +
                    encoder_hidden_states=None,
         
     | 
| 575 | 
         
            +
                    upsample_size=None,
         
     | 
| 576 | 
         
            +
                    attention_mask=None,
         
     | 
| 577 | 
         
            +
                    num_frames=1,
         
     | 
| 578 | 
         
            +
                    cross_attention_kwargs=None,
         
     | 
| 579 | 
         
            +
                ):
         
     | 
| 580 | 
         
            +
                    # TODO(Patrick, William) - attention mask is not used
         
     | 
| 581 | 
         
            +
                    for resnet, temp_conv, attn, temp_attn in zip(
         
     | 
| 582 | 
         
            +
                        self.resnets, self.temp_convs, self.attentions, self.temp_attentions
         
     | 
| 583 | 
         
            +
                    ):
         
     | 
| 584 | 
         
            +
                        # pop res hidden states
         
     | 
| 585 | 
         
            +
                        res_hidden_states = res_hidden_states_tuple[-1]
         
     | 
| 586 | 
         
            +
                        res_hidden_states_tuple = res_hidden_states_tuple[:-1]
         
     | 
| 587 | 
         
            +
                        hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
         
     | 
| 588 | 
         
            +
             
     | 
| 589 | 
         
            +
                        hidden_states = resnet(hidden_states, temb)
         
     | 
| 590 | 
         
            +
                        hidden_states = temp_conv(hidden_states, num_frames=num_frames)
         
     | 
| 591 | 
         
            +
                        hidden_states = attn(
         
     | 
| 592 | 
         
            +
                            hidden_states,
         
     | 
| 593 | 
         
            +
                            encoder_hidden_states=encoder_hidden_states,
         
     | 
| 594 | 
         
            +
                            cross_attention_kwargs=cross_attention_kwargs,
         
     | 
| 595 | 
         
            +
                            return_dict=False,
         
     | 
| 596 | 
         
            +
                        )[0]
         
     | 
| 597 | 
         
            +
                        hidden_states = temp_attn(
         
     | 
| 598 | 
         
            +
                            hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
         
     | 
| 599 | 
         
            +
                        )[0]
         
     | 
| 600 | 
         
            +
             
     | 
| 601 | 
         
            +
                    if self.upsamplers is not None:
         
     | 
| 602 | 
         
            +
                        for upsampler in self.upsamplers:
         
     | 
| 603 | 
         
            +
                            hidden_states = upsampler(hidden_states, upsample_size)
         
     | 
| 604 | 
         
            +
             
     | 
| 605 | 
         
            +
                    return hidden_states
         
     | 
| 606 | 
         
            +
             
     | 
| 607 | 
         
            +
             
     | 
| 608 | 
         
            +
            class UpBlock3D(nn.Module):
         
     | 
| 609 | 
         
            +
                def __init__(
         
     | 
| 610 | 
         
            +
                    self,
         
     | 
| 611 | 
         
            +
                    in_channels: int,
         
     | 
| 612 | 
         
            +
                    prev_output_channel: int,
         
     | 
| 613 | 
         
            +
                    out_channels: int,
         
     | 
| 614 | 
         
            +
                    temb_channels: int,
         
     | 
| 615 | 
         
            +
                    dropout: float = 0.0,
         
     | 
| 616 | 
         
            +
                    num_layers: int = 1,
         
     | 
| 617 | 
         
            +
                    resnet_eps: float = 1e-6,
         
     | 
| 618 | 
         
            +
                    resnet_time_scale_shift: str = "default",
         
     | 
| 619 | 
         
            +
                    resnet_act_fn: str = "swish",
         
     | 
| 620 | 
         
            +
                    resnet_groups: int = 32,
         
     | 
| 621 | 
         
            +
                    resnet_pre_norm: bool = True,
         
     | 
| 622 | 
         
            +
                    output_scale_factor=1.0,
         
     | 
| 623 | 
         
            +
                    add_upsample=True,
         
     | 
| 624 | 
         
            +
                ):
         
     | 
| 625 | 
         
            +
                    super().__init__()
         
     | 
| 626 | 
         
            +
                    resnets = []
         
     | 
| 627 | 
         
            +
                    temp_convs = []
         
     | 
| 628 | 
         
            +
             
     | 
| 629 | 
         
            +
                    for i in range(num_layers):
         
     | 
| 630 | 
         
            +
                        res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
         
     | 
| 631 | 
         
            +
                        resnet_in_channels = prev_output_channel if i == 0 else out_channels
         
     | 
| 632 | 
         
            +
             
     | 
| 633 | 
         
            +
                        resnets.append(
         
     | 
| 634 | 
         
            +
                            ResnetBlock2D(
         
     | 
| 635 | 
         
            +
                                in_channels=resnet_in_channels + res_skip_channels,
         
     | 
| 636 | 
         
            +
                                out_channels=out_channels,
         
     | 
| 637 | 
         
            +
                                temb_channels=temb_channels,
         
     | 
| 638 | 
         
            +
                                eps=resnet_eps,
         
     | 
| 639 | 
         
            +
                                groups=resnet_groups,
         
     | 
| 640 | 
         
            +
                                dropout=dropout,
         
     | 
| 641 | 
         
            +
                                time_embedding_norm=resnet_time_scale_shift,
         
     | 
| 642 | 
         
            +
                                non_linearity=resnet_act_fn,
         
     | 
| 643 | 
         
            +
                                output_scale_factor=output_scale_factor,
         
     | 
| 644 | 
         
            +
                                pre_norm=resnet_pre_norm,
         
     | 
| 645 | 
         
            +
                            )
         
     | 
| 646 | 
         
            +
                        )
         
     | 
| 647 | 
         
            +
                        temp_convs.append(
         
     | 
| 648 | 
         
            +
                            TemporalConvLayer(
         
     | 
| 649 | 
         
            +
                                out_channels,
         
     | 
| 650 | 
         
            +
                                out_channels,
         
     | 
| 651 | 
         
            +
                                dropout=0.1,
         
     | 
| 652 | 
         
            +
                            )
         
     | 
| 653 | 
         
            +
                        )
         
     | 
| 654 | 
         
            +
             
     | 
| 655 | 
         
            +
                    self.resnets = nn.ModuleList(resnets)
         
     | 
| 656 | 
         
            +
                    self.temp_convs = nn.ModuleList(temp_convs)
         
     | 
| 657 | 
         
            +
             
     | 
| 658 | 
         
            +
                    if add_upsample:
         
     | 
| 659 | 
         
            +
                        self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
         
     | 
| 660 | 
         
            +
                    else:
         
     | 
| 661 | 
         
            +
                        self.upsamplers = None
         
     | 
| 662 | 
         
            +
             
     | 
| 663 | 
         
            +
                    self.gradient_checkpointing = False
         
     | 
| 664 | 
         
            +
             
     | 
| 665 | 
         
            +
                def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
         
     | 
| 666 | 
         
            +
                    for resnet, temp_conv in zip(self.resnets, self.temp_convs):
         
     | 
| 667 | 
         
            +
                        # pop res hidden states
         
     | 
| 668 | 
         
            +
                        res_hidden_states = res_hidden_states_tuple[-1]
         
     | 
| 669 | 
         
            +
                        res_hidden_states_tuple = res_hidden_states_tuple[:-1]
         
     | 
| 670 | 
         
            +
                        hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
         
     | 
| 671 | 
         
            +
             
     | 
| 672 | 
         
            +
                        hidden_states = resnet(hidden_states, temb)
         
     | 
| 673 | 
         
            +
                        hidden_states = temp_conv(hidden_states, num_frames=num_frames)
         
     | 
| 674 | 
         
            +
             
     | 
| 675 | 
         
            +
                    if self.upsamplers is not None:
         
     | 
| 676 | 
         
            +
                        for upsampler in self.upsamplers:
         
     | 
| 677 | 
         
            +
                            hidden_states = upsampler(hidden_states, upsample_size)
         
     | 
| 678 | 
         
            +
             
     | 
| 679 | 
         
            +
                    return hidden_states
         
     | 
    	
        6DoF/diffusers/models/unet_3d_condition.py
    ADDED
    
    | 
         @@ -0,0 +1,627 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            # Copyright 2023 The ModelScope Team.
         
     | 
| 3 | 
         
            +
            #
         
     | 
| 4 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 5 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 6 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 7 | 
         
            +
            #
         
     | 
| 8 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 9 | 
         
            +
            #
         
     | 
| 10 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 11 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 12 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 13 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 14 | 
         
            +
            # limitations under the License.
         
     | 
| 15 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 16 | 
         
            +
            from typing import Any, Dict, List, Optional, Tuple, Union
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            import torch
         
     | 
| 19 | 
         
            +
            import torch.nn as nn
         
     | 
| 20 | 
         
            +
            import torch.utils.checkpoint
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            from ..configuration_utils import ConfigMixin, register_to_config
         
     | 
| 23 | 
         
            +
            from ..loaders import UNet2DConditionLoadersMixin
         
     | 
| 24 | 
         
            +
            from ..utils import BaseOutput, logging
         
     | 
| 25 | 
         
            +
            from .attention_processor import AttentionProcessor, AttnProcessor
         
     | 
| 26 | 
         
            +
            from .embeddings import TimestepEmbedding, Timesteps
         
     | 
| 27 | 
         
            +
            from .modeling_utils import ModelMixin
         
     | 
| 28 | 
         
            +
            from .transformer_temporal import TransformerTemporalModel
         
     | 
| 29 | 
         
            +
            from .unet_3d_blocks import (
         
     | 
| 30 | 
         
            +
                CrossAttnDownBlock3D,
         
     | 
| 31 | 
         
            +
                CrossAttnUpBlock3D,
         
     | 
| 32 | 
         
            +
                DownBlock3D,
         
     | 
| 33 | 
         
            +
                UNetMidBlock3DCrossAttn,
         
     | 
| 34 | 
         
            +
                UpBlock3D,
         
     | 
| 35 | 
         
            +
                get_down_block,
         
     | 
| 36 | 
         
            +
                get_up_block,
         
     | 
| 37 | 
         
            +
            )
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            @dataclass
         
     | 
| 44 | 
         
            +
            class UNet3DConditionOutput(BaseOutput):
         
     | 
| 45 | 
         
            +
                """
         
     | 
| 46 | 
         
            +
                The output of [`UNet3DConditionModel`].
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                Args:
         
     | 
| 49 | 
         
            +
                    sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
         
     | 
| 50 | 
         
            +
                        The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
         
     | 
| 51 | 
         
            +
                """
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                sample: torch.FloatTensor
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
            class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
         
     | 
| 57 | 
         
            +
                r"""
         
     | 
| 58 | 
         
            +
                A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
         
     | 
| 59 | 
         
            +
                shaped output.
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
         
     | 
| 62 | 
         
            +
                for all models (such as downloading or saving).
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                Parameters:
         
     | 
| 65 | 
         
            +
                    sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
         
     | 
| 66 | 
         
            +
                        Height and width of input/output sample.
         
     | 
| 67 | 
         
            +
                    in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
         
     | 
| 68 | 
         
            +
                    out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
         
     | 
| 69 | 
         
            +
                    down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
         
     | 
| 70 | 
         
            +
                        The tuple of downsample blocks to use.
         
     | 
| 71 | 
         
            +
                    up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
         
     | 
| 72 | 
         
            +
                        The tuple of upsample blocks to use.
         
     | 
| 73 | 
         
            +
                    block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
         
     | 
| 74 | 
         
            +
                        The tuple of output channels for each block.
         
     | 
| 75 | 
         
            +
                    layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
         
     | 
| 76 | 
         
            +
                    downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
         
     | 
| 77 | 
         
            +
                    mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
         
     | 
| 78 | 
         
            +
                    act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
         
     | 
| 79 | 
         
            +
                    norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
         
     | 
| 80 | 
         
            +
                        If `None`, normalization and activation layers is skipped in post-processing.
         
     | 
| 81 | 
         
            +
                    norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
         
     | 
| 82 | 
         
            +
                    cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
         
     | 
| 83 | 
         
            +
                    attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
         
     | 
| 84 | 
         
            +
                    num_attention_heads (`int`, *optional*): The number of attention heads.
         
     | 
| 85 | 
         
            +
                """
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                _supports_gradient_checkpointing = False
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                @register_to_config
         
     | 
| 90 | 
         
            +
                def __init__(
         
     | 
| 91 | 
         
            +
                    self,
         
     | 
| 92 | 
         
            +
                    sample_size: Optional[int] = None,
         
     | 
| 93 | 
         
            +
                    in_channels: int = 4,
         
     | 
| 94 | 
         
            +
                    out_channels: int = 4,
         
     | 
| 95 | 
         
            +
                    down_block_types: Tuple[str] = (
         
     | 
| 96 | 
         
            +
                        "CrossAttnDownBlock3D",
         
     | 
| 97 | 
         
            +
                        "CrossAttnDownBlock3D",
         
     | 
| 98 | 
         
            +
                        "CrossAttnDownBlock3D",
         
     | 
| 99 | 
         
            +
                        "DownBlock3D",
         
     | 
| 100 | 
         
            +
                    ),
         
     | 
| 101 | 
         
            +
                    up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
         
     | 
| 102 | 
         
            +
                    block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
         
     | 
| 103 | 
         
            +
                    layers_per_block: int = 2,
         
     | 
| 104 | 
         
            +
                    downsample_padding: int = 1,
         
     | 
| 105 | 
         
            +
                    mid_block_scale_factor: float = 1,
         
     | 
| 106 | 
         
            +
                    act_fn: str = "silu",
         
     | 
| 107 | 
         
            +
                    norm_num_groups: Optional[int] = 32,
         
     | 
| 108 | 
         
            +
                    norm_eps: float = 1e-5,
         
     | 
| 109 | 
         
            +
                    cross_attention_dim: int = 1024,
         
     | 
| 110 | 
         
            +
                    attention_head_dim: Union[int, Tuple[int]] = 64,
         
     | 
| 111 | 
         
            +
                    num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
         
     | 
| 112 | 
         
            +
                ):
         
     | 
| 113 | 
         
            +
                    super().__init__()
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                    self.sample_size = sample_size
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                    if num_attention_heads is not None:
         
     | 
| 118 | 
         
            +
                        raise NotImplementedError(
         
     | 
| 119 | 
         
            +
                            "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
         
     | 
| 120 | 
         
            +
                        )
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                    # If `num_attention_heads` is not defined (which is the case for most models)
         
     | 
| 123 | 
         
            +
                    # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
         
     | 
| 124 | 
         
            +
                    # The reason for this behavior is to correct for incorrectly named variables that were introduced
         
     | 
| 125 | 
         
            +
                    # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
         
     | 
| 126 | 
         
            +
                    # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
         
     | 
| 127 | 
         
            +
                    # which is why we correct for the naming here.
         
     | 
| 128 | 
         
            +
                    num_attention_heads = num_attention_heads or attention_head_dim
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                    # Check inputs
         
     | 
| 131 | 
         
            +
                    if len(down_block_types) != len(up_block_types):
         
     | 
| 132 | 
         
            +
                        raise ValueError(
         
     | 
| 133 | 
         
            +
                            f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
         
     | 
| 134 | 
         
            +
                        )
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                    if len(block_out_channels) != len(down_block_types):
         
     | 
| 137 | 
         
            +
                        raise ValueError(
         
     | 
| 138 | 
         
            +
                            f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
         
     | 
| 139 | 
         
            +
                        )
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                    if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
         
     | 
| 142 | 
         
            +
                        raise ValueError(
         
     | 
| 143 | 
         
            +
                            f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
         
     | 
| 144 | 
         
            +
                        )
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                    # input
         
     | 
| 147 | 
         
            +
                    conv_in_kernel = 3
         
     | 
| 148 | 
         
            +
                    conv_out_kernel = 3
         
     | 
| 149 | 
         
            +
                    conv_in_padding = (conv_in_kernel - 1) // 2
         
     | 
| 150 | 
         
            +
                    self.conv_in = nn.Conv2d(
         
     | 
| 151 | 
         
            +
                        in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
         
     | 
| 152 | 
         
            +
                    )
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                    # time
         
     | 
| 155 | 
         
            +
                    time_embed_dim = block_out_channels[0] * 4
         
     | 
| 156 | 
         
            +
                    self.time_proj = Timesteps(block_out_channels[0], True, 0)
         
     | 
| 157 | 
         
            +
                    timestep_input_dim = block_out_channels[0]
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                    self.time_embedding = TimestepEmbedding(
         
     | 
| 160 | 
         
            +
                        timestep_input_dim,
         
     | 
| 161 | 
         
            +
                        time_embed_dim,
         
     | 
| 162 | 
         
            +
                        act_fn=act_fn,
         
     | 
| 163 | 
         
            +
                    )
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                    self.transformer_in = TransformerTemporalModel(
         
     | 
| 166 | 
         
            +
                        num_attention_heads=8,
         
     | 
| 167 | 
         
            +
                        attention_head_dim=attention_head_dim,
         
     | 
| 168 | 
         
            +
                        in_channels=block_out_channels[0],
         
     | 
| 169 | 
         
            +
                        num_layers=1,
         
     | 
| 170 | 
         
            +
                    )
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                    # class embedding
         
     | 
| 173 | 
         
            +
                    self.down_blocks = nn.ModuleList([])
         
     | 
| 174 | 
         
            +
                    self.up_blocks = nn.ModuleList([])
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                    if isinstance(num_attention_heads, int):
         
     | 
| 177 | 
         
            +
                        num_attention_heads = (num_attention_heads,) * len(down_block_types)
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                    # down
         
     | 
| 180 | 
         
            +
                    output_channel = block_out_channels[0]
         
     | 
| 181 | 
         
            +
                    for i, down_block_type in enumerate(down_block_types):
         
     | 
| 182 | 
         
            +
                        input_channel = output_channel
         
     | 
| 183 | 
         
            +
                        output_channel = block_out_channels[i]
         
     | 
| 184 | 
         
            +
                        is_final_block = i == len(block_out_channels) - 1
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                        down_block = get_down_block(
         
     | 
| 187 | 
         
            +
                            down_block_type,
         
     | 
| 188 | 
         
            +
                            num_layers=layers_per_block,
         
     | 
| 189 | 
         
            +
                            in_channels=input_channel,
         
     | 
| 190 | 
         
            +
                            out_channels=output_channel,
         
     | 
| 191 | 
         
            +
                            temb_channels=time_embed_dim,
         
     | 
| 192 | 
         
            +
                            add_downsample=not is_final_block,
         
     | 
| 193 | 
         
            +
                            resnet_eps=norm_eps,
         
     | 
| 194 | 
         
            +
                            resnet_act_fn=act_fn,
         
     | 
| 195 | 
         
            +
                            resnet_groups=norm_num_groups,
         
     | 
| 196 | 
         
            +
                            cross_attention_dim=cross_attention_dim,
         
     | 
| 197 | 
         
            +
                            num_attention_heads=num_attention_heads[i],
         
     | 
| 198 | 
         
            +
                            downsample_padding=downsample_padding,
         
     | 
| 199 | 
         
            +
                            dual_cross_attention=False,
         
     | 
| 200 | 
         
            +
                        )
         
     | 
| 201 | 
         
            +
                        self.down_blocks.append(down_block)
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
                    # mid
         
     | 
| 204 | 
         
            +
                    self.mid_block = UNetMidBlock3DCrossAttn(
         
     | 
| 205 | 
         
            +
                        in_channels=block_out_channels[-1],
         
     | 
| 206 | 
         
            +
                        temb_channels=time_embed_dim,
         
     | 
| 207 | 
         
            +
                        resnet_eps=norm_eps,
         
     | 
| 208 | 
         
            +
                        resnet_act_fn=act_fn,
         
     | 
| 209 | 
         
            +
                        output_scale_factor=mid_block_scale_factor,
         
     | 
| 210 | 
         
            +
                        cross_attention_dim=cross_attention_dim,
         
     | 
| 211 | 
         
            +
                        num_attention_heads=num_attention_heads[-1],
         
     | 
| 212 | 
         
            +
                        resnet_groups=norm_num_groups,
         
     | 
| 213 | 
         
            +
                        dual_cross_attention=False,
         
     | 
| 214 | 
         
            +
                    )
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                    # count how many layers upsample the images
         
     | 
| 217 | 
         
            +
                    self.num_upsamplers = 0
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                    # up
         
     | 
| 220 | 
         
            +
                    reversed_block_out_channels = list(reversed(block_out_channels))
         
     | 
| 221 | 
         
            +
                    reversed_num_attention_heads = list(reversed(num_attention_heads))
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
                    output_channel = reversed_block_out_channels[0]
         
     | 
| 224 | 
         
            +
                    for i, up_block_type in enumerate(up_block_types):
         
     | 
| 225 | 
         
            +
                        is_final_block = i == len(block_out_channels) - 1
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
                        prev_output_channel = output_channel
         
     | 
| 228 | 
         
            +
                        output_channel = reversed_block_out_channels[i]
         
     | 
| 229 | 
         
            +
                        input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                        # add upsample block for all BUT final layer
         
     | 
| 232 | 
         
            +
                        if not is_final_block:
         
     | 
| 233 | 
         
            +
                            add_upsample = True
         
     | 
| 234 | 
         
            +
                            self.num_upsamplers += 1
         
     | 
| 235 | 
         
            +
                        else:
         
     | 
| 236 | 
         
            +
                            add_upsample = False
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                        up_block = get_up_block(
         
     | 
| 239 | 
         
            +
                            up_block_type,
         
     | 
| 240 | 
         
            +
                            num_layers=layers_per_block + 1,
         
     | 
| 241 | 
         
            +
                            in_channels=input_channel,
         
     | 
| 242 | 
         
            +
                            out_channels=output_channel,
         
     | 
| 243 | 
         
            +
                            prev_output_channel=prev_output_channel,
         
     | 
| 244 | 
         
            +
                            temb_channels=time_embed_dim,
         
     | 
| 245 | 
         
            +
                            add_upsample=add_upsample,
         
     | 
| 246 | 
         
            +
                            resnet_eps=norm_eps,
         
     | 
| 247 | 
         
            +
                            resnet_act_fn=act_fn,
         
     | 
| 248 | 
         
            +
                            resnet_groups=norm_num_groups,
         
     | 
| 249 | 
         
            +
                            cross_attention_dim=cross_attention_dim,
         
     | 
| 250 | 
         
            +
                            num_attention_heads=reversed_num_attention_heads[i],
         
     | 
| 251 | 
         
            +
                            dual_cross_attention=False,
         
     | 
| 252 | 
         
            +
                        )
         
     | 
| 253 | 
         
            +
                        self.up_blocks.append(up_block)
         
     | 
| 254 | 
         
            +
                        prev_output_channel = output_channel
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                    # out
         
     | 
| 257 | 
         
            +
                    if norm_num_groups is not None:
         
     | 
| 258 | 
         
            +
                        self.conv_norm_out = nn.GroupNorm(
         
     | 
| 259 | 
         
            +
                            num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
         
     | 
| 260 | 
         
            +
                        )
         
     | 
| 261 | 
         
            +
                        self.conv_act = nn.SiLU()
         
     | 
| 262 | 
         
            +
                    else:
         
     | 
| 263 | 
         
            +
                        self.conv_norm_out = None
         
     | 
| 264 | 
         
            +
                        self.conv_act = None
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                    conv_out_padding = (conv_out_kernel - 1) // 2
         
     | 
| 267 | 
         
            +
                    self.conv_out = nn.Conv2d(
         
     | 
| 268 | 
         
            +
                        block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
         
     | 
| 269 | 
         
            +
                    )
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
                @property
         
     | 
| 272 | 
         
            +
                # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
         
     | 
| 273 | 
         
            +
                def attn_processors(self) -> Dict[str, AttentionProcessor]:
         
     | 
| 274 | 
         
            +
                    r"""
         
     | 
| 275 | 
         
            +
                    Returns:
         
     | 
| 276 | 
         
            +
                        `dict` of attention processors: A dictionary containing all attention processors used in the model with
         
     | 
| 277 | 
         
            +
                        indexed by its weight name.
         
     | 
| 278 | 
         
            +
                    """
         
     | 
| 279 | 
         
            +
                    # set recursively
         
     | 
| 280 | 
         
            +
                    processors = {}
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
                    def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
         
     | 
| 283 | 
         
            +
                        if hasattr(module, "set_processor"):
         
     | 
| 284 | 
         
            +
                            processors[f"{name}.processor"] = module.processor
         
     | 
| 285 | 
         
            +
             
     | 
| 286 | 
         
            +
                        for sub_name, child in module.named_children():
         
     | 
| 287 | 
         
            +
                            fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
                        return processors
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                    for name, module in self.named_children():
         
     | 
| 292 | 
         
            +
                        fn_recursive_add_processors(name, module, processors)
         
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
                    return processors
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
                # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
         
     | 
| 297 | 
         
            +
                def set_attention_slice(self, slice_size):
         
     | 
| 298 | 
         
            +
                    r"""
         
     | 
| 299 | 
         
            +
                    Enable sliced attention computation.
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                    When this option is enabled, the attention module splits the input tensor in slices to compute attention in
         
     | 
| 302 | 
         
            +
                    several steps. This is useful for saving some memory in exchange for a small decrease in speed.
         
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
            +
                    Args:
         
     | 
| 305 | 
         
            +
                        slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
         
     | 
| 306 | 
         
            +
                            When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
         
     | 
| 307 | 
         
            +
                            `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
         
     | 
| 308 | 
         
            +
                            provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
         
     | 
| 309 | 
         
            +
                            must be a multiple of `slice_size`.
         
     | 
| 310 | 
         
            +
                    """
         
     | 
| 311 | 
         
            +
                    sliceable_head_dims = []
         
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
                    def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
         
     | 
| 314 | 
         
            +
                        if hasattr(module, "set_attention_slice"):
         
     | 
| 315 | 
         
            +
                            sliceable_head_dims.append(module.sliceable_head_dim)
         
     | 
| 316 | 
         
            +
             
     | 
| 317 | 
         
            +
                        for child in module.children():
         
     | 
| 318 | 
         
            +
                            fn_recursive_retrieve_sliceable_dims(child)
         
     | 
| 319 | 
         
            +
             
     | 
| 320 | 
         
            +
                    # retrieve number of attention layers
         
     | 
| 321 | 
         
            +
                    for module in self.children():
         
     | 
| 322 | 
         
            +
                        fn_recursive_retrieve_sliceable_dims(module)
         
     | 
| 323 | 
         
            +
             
     | 
| 324 | 
         
            +
                    num_sliceable_layers = len(sliceable_head_dims)
         
     | 
| 325 | 
         
            +
             
     | 
| 326 | 
         
            +
                    if slice_size == "auto":
         
     | 
| 327 | 
         
            +
                        # half the attention head size is usually a good trade-off between
         
     | 
| 328 | 
         
            +
                        # speed and memory
         
     | 
| 329 | 
         
            +
                        slice_size = [dim // 2 for dim in sliceable_head_dims]
         
     | 
| 330 | 
         
            +
                    elif slice_size == "max":
         
     | 
| 331 | 
         
            +
                        # make smallest slice possible
         
     | 
| 332 | 
         
            +
                        slice_size = num_sliceable_layers * [1]
         
     | 
| 333 | 
         
            +
             
     | 
| 334 | 
         
            +
                    slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
         
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
                    if len(slice_size) != len(sliceable_head_dims):
         
     | 
| 337 | 
         
            +
                        raise ValueError(
         
     | 
| 338 | 
         
            +
                            f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
         
     | 
| 339 | 
         
            +
                            f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
         
     | 
| 340 | 
         
            +
                        )
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
                    for i in range(len(slice_size)):
         
     | 
| 343 | 
         
            +
                        size = slice_size[i]
         
     | 
| 344 | 
         
            +
                        dim = sliceable_head_dims[i]
         
     | 
| 345 | 
         
            +
                        if size is not None and size > dim:
         
     | 
| 346 | 
         
            +
                            raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
         
     | 
| 347 | 
         
            +
             
     | 
| 348 | 
         
            +
                    # Recursively walk through all the children.
         
     | 
| 349 | 
         
            +
                    # Any children which exposes the set_attention_slice method
         
     | 
| 350 | 
         
            +
                    # gets the message
         
     | 
| 351 | 
         
            +
                    def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
         
     | 
| 352 | 
         
            +
                        if hasattr(module, "set_attention_slice"):
         
     | 
| 353 | 
         
            +
                            module.set_attention_slice(slice_size.pop())
         
     | 
| 354 | 
         
            +
             
     | 
| 355 | 
         
            +
                        for child in module.children():
         
     | 
| 356 | 
         
            +
                            fn_recursive_set_attention_slice(child, slice_size)
         
     | 
| 357 | 
         
            +
             
     | 
| 358 | 
         
            +
                    reversed_slice_size = list(reversed(slice_size))
         
     | 
| 359 | 
         
            +
                    for module in self.children():
         
     | 
| 360 | 
         
            +
                        fn_recursive_set_attention_slice(module, reversed_slice_size)
         
     | 
| 361 | 
         
            +
             
     | 
| 362 | 
         
            +
                # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
         
     | 
| 363 | 
         
            +
                def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
         
     | 
| 364 | 
         
            +
                    r"""
         
     | 
| 365 | 
         
            +
                    Sets the attention processor to use to compute attention.
         
     | 
| 366 | 
         
            +
             
     | 
| 367 | 
         
            +
                    Parameters:
         
     | 
| 368 | 
         
            +
                        processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
         
     | 
| 369 | 
         
            +
                            The instantiated processor class or a dictionary of processor classes that will be set as the processor
         
     | 
| 370 | 
         
            +
                            for **all** `Attention` layers.
         
     | 
| 371 | 
         
            +
             
     | 
| 372 | 
         
            +
                            If `processor` is a dict, the key needs to define the path to the corresponding cross attention
         
     | 
| 373 | 
         
            +
                            processor. This is strongly recommended when setting trainable attention processors.
         
     | 
| 374 | 
         
            +
             
     | 
| 375 | 
         
            +
                    """
         
     | 
| 376 | 
         
            +
                    count = len(self.attn_processors.keys())
         
     | 
| 377 | 
         
            +
             
     | 
| 378 | 
         
            +
                    if isinstance(processor, dict) and len(processor) != count:
         
     | 
| 379 | 
         
            +
                        raise ValueError(
         
     | 
| 380 | 
         
            +
                            f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
         
     | 
| 381 | 
         
            +
                            f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
         
     | 
| 382 | 
         
            +
                        )
         
     | 
| 383 | 
         
            +
             
     | 
| 384 | 
         
            +
                    def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
         
     | 
| 385 | 
         
            +
                        if hasattr(module, "set_processor"):
         
     | 
| 386 | 
         
            +
                            if not isinstance(processor, dict):
         
     | 
| 387 | 
         
            +
                                module.set_processor(processor)
         
     | 
| 388 | 
         
            +
                            else:
         
     | 
| 389 | 
         
            +
                                module.set_processor(processor.pop(f"{name}.processor"))
         
     | 
| 390 | 
         
            +
             
     | 
| 391 | 
         
            +
                        for sub_name, child in module.named_children():
         
     | 
| 392 | 
         
            +
                            fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
         
     | 
| 393 | 
         
            +
             
     | 
| 394 | 
         
            +
                    for name, module in self.named_children():
         
     | 
| 395 | 
         
            +
                        fn_recursive_attn_processor(name, module, processor)
         
     | 
| 396 | 
         
            +
             
     | 
| 397 | 
         
            +
                def enable_forward_chunking(self, chunk_size=None, dim=0):
         
     | 
| 398 | 
         
            +
                    """
         
     | 
| 399 | 
         
            +
                    Sets the attention processor to use [feed forward
         
     | 
| 400 | 
         
            +
                    chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
         
     | 
| 401 | 
         
            +
             
     | 
| 402 | 
         
            +
                    Parameters:
         
     | 
| 403 | 
         
            +
                        chunk_size (`int`, *optional*):
         
     | 
| 404 | 
         
            +
                            The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
         
     | 
| 405 | 
         
            +
                            over each tensor of dim=`dim`.
         
     | 
| 406 | 
         
            +
                        dim (`int`, *optional*, defaults to `0`):
         
     | 
| 407 | 
         
            +
                            The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
         
     | 
| 408 | 
         
            +
                            or dim=1 (sequence length).
         
     | 
| 409 | 
         
            +
                    """
         
     | 
| 410 | 
         
            +
                    if dim not in [0, 1]:
         
     | 
| 411 | 
         
            +
                        raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
         
     | 
| 412 | 
         
            +
             
     | 
| 413 | 
         
            +
                    # By default chunk size is 1
         
     | 
| 414 | 
         
            +
                    chunk_size = chunk_size or 1
         
     | 
| 415 | 
         
            +
             
     | 
| 416 | 
         
            +
                    def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
         
     | 
| 417 | 
         
            +
                        if hasattr(module, "set_chunk_feed_forward"):
         
     | 
| 418 | 
         
            +
                            module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
         
     | 
| 419 | 
         
            +
             
     | 
| 420 | 
         
            +
                        for child in module.children():
         
     | 
| 421 | 
         
            +
                            fn_recursive_feed_forward(child, chunk_size, dim)
         
     | 
| 422 | 
         
            +
             
     | 
| 423 | 
         
            +
                    for module in self.children():
         
     | 
| 424 | 
         
            +
                        fn_recursive_feed_forward(module, chunk_size, dim)
         
     | 
| 425 | 
         
            +
             
     | 
| 426 | 
         
            +
                def disable_forward_chunking(self):
         
     | 
| 427 | 
         
            +
                    def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
         
     | 
| 428 | 
         
            +
                        if hasattr(module, "set_chunk_feed_forward"):
         
     | 
| 429 | 
         
            +
                            module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
         
     | 
| 430 | 
         
            +
             
     | 
| 431 | 
         
            +
                        for child in module.children():
         
     | 
| 432 | 
         
            +
                            fn_recursive_feed_forward(child, chunk_size, dim)
         
     | 
| 433 | 
         
            +
             
     | 
| 434 | 
         
            +
                    for module in self.children():
         
     | 
| 435 | 
         
            +
                        fn_recursive_feed_forward(module, None, 0)
         
     | 
| 436 | 
         
            +
             
     | 
| 437 | 
         
            +
                # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
         
     | 
| 438 | 
         
            +
                def set_default_attn_processor(self):
         
     | 
| 439 | 
         
            +
                    """
         
     | 
| 440 | 
         
            +
                    Disables custom attention processors and sets the default attention implementation.
         
     | 
| 441 | 
         
            +
                    """
         
     | 
| 442 | 
         
            +
                    self.set_attn_processor(AttnProcessor())
         
     | 
| 443 | 
         
            +
             
     | 
| 444 | 
         
            +
                def _set_gradient_checkpointing(self, module, value=False):
         
     | 
| 445 | 
         
            +
                    if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
         
     | 
| 446 | 
         
            +
                        module.gradient_checkpointing = value
         
     | 
| 447 | 
         
            +
             
     | 
| 448 | 
         
            +
                def forward(
         
     | 
| 449 | 
         
            +
                    self,
         
     | 
| 450 | 
         
            +
                    sample: torch.FloatTensor,
         
     | 
| 451 | 
         
            +
                    timestep: Union[torch.Tensor, float, int],
         
     | 
| 452 | 
         
            +
                    encoder_hidden_states: torch.Tensor,
         
     | 
| 453 | 
         
            +
                    class_labels: Optional[torch.Tensor] = None,
         
     | 
| 454 | 
         
            +
                    timestep_cond: Optional[torch.Tensor] = None,
         
     | 
| 455 | 
         
            +
                    attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 456 | 
         
            +
                    cross_attention_kwargs: Optional[Dict[str, Any]] = None,
         
     | 
| 457 | 
         
            +
                    down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
         
     | 
| 458 | 
         
            +
                    mid_block_additional_residual: Optional[torch.Tensor] = None,
         
     | 
| 459 | 
         
            +
                    return_dict: bool = True,
         
     | 
| 460 | 
         
            +
                ) -> Union[UNet3DConditionOutput, Tuple]:
         
     | 
| 461 | 
         
            +
                    r"""
         
     | 
| 462 | 
         
            +
                    The [`UNet3DConditionModel`] forward method.
         
     | 
| 463 | 
         
            +
             
     | 
| 464 | 
         
            +
                    Args:
         
     | 
| 465 | 
         
            +
                        sample (`torch.FloatTensor`):
         
     | 
| 466 | 
         
            +
                            The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`.
         
     | 
| 467 | 
         
            +
                        timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
         
     | 
| 468 | 
         
            +
                        encoder_hidden_states (`torch.FloatTensor`):
         
     | 
| 469 | 
         
            +
                            The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
         
     | 
| 470 | 
         
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         
     | 
| 471 | 
         
            +
                            Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain
         
     | 
| 472 | 
         
            +
                            tuple.
         
     | 
| 473 | 
         
            +
                        cross_attention_kwargs (`dict`, *optional*):
         
     | 
| 474 | 
         
            +
                            A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
         
     | 
| 475 | 
         
            +
             
     | 
| 476 | 
         
            +
                    Returns:
         
     | 
| 477 | 
         
            +
                        [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`:
         
     | 
| 478 | 
         
            +
                            If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise
         
     | 
| 479 | 
         
            +
                            a `tuple` is returned where the first element is the sample tensor.
         
     | 
| 480 | 
         
            +
                    """
         
     | 
| 481 | 
         
            +
                    # By default samples have to be AT least a multiple of the overall upsampling factor.
         
     | 
| 482 | 
         
            +
                    # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
         
     | 
| 483 | 
         
            +
                    # However, the upsampling interpolation output size can be forced to fit any upsampling size
         
     | 
| 484 | 
         
            +
                    # on the fly if necessary.
         
     | 
| 485 | 
         
            +
                    default_overall_up_factor = 2**self.num_upsamplers
         
     | 
| 486 | 
         
            +
             
     | 
| 487 | 
         
            +
                    # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
         
     | 
| 488 | 
         
            +
                    forward_upsample_size = False
         
     | 
| 489 | 
         
            +
                    upsample_size = None
         
     | 
| 490 | 
         
            +
             
     | 
| 491 | 
         
            +
                    if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
         
     | 
| 492 | 
         
            +
                        logger.info("Forward upsample size to force interpolation output size.")
         
     | 
| 493 | 
         
            +
                        forward_upsample_size = True
         
     | 
| 494 | 
         
            +
             
     | 
| 495 | 
         
            +
                    # prepare attention_mask
         
     | 
| 496 | 
         
            +
                    if attention_mask is not None:
         
     | 
| 497 | 
         
            +
                        attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
         
     | 
| 498 | 
         
            +
                        attention_mask = attention_mask.unsqueeze(1)
         
     | 
| 499 | 
         
            +
             
     | 
| 500 | 
         
            +
                    # 1. time
         
     | 
| 501 | 
         
            +
                    timesteps = timestep
         
     | 
| 502 | 
         
            +
                    if not torch.is_tensor(timesteps):
         
     | 
| 503 | 
         
            +
                        # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
         
     | 
| 504 | 
         
            +
                        # This would be a good case for the `match` statement (Python 3.10+)
         
     | 
| 505 | 
         
            +
                        is_mps = sample.device.type == "mps"
         
     | 
| 506 | 
         
            +
                        if isinstance(timestep, float):
         
     | 
| 507 | 
         
            +
                            dtype = torch.float32 if is_mps else torch.float64
         
     | 
| 508 | 
         
            +
                        else:
         
     | 
| 509 | 
         
            +
                            dtype = torch.int32 if is_mps else torch.int64
         
     | 
| 510 | 
         
            +
                        timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
         
     | 
| 511 | 
         
            +
                    elif len(timesteps.shape) == 0:
         
     | 
| 512 | 
         
            +
                        timesteps = timesteps[None].to(sample.device)
         
     | 
| 513 | 
         
            +
             
     | 
| 514 | 
         
            +
                    # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
         
     | 
| 515 | 
         
            +
                    num_frames = sample.shape[2]
         
     | 
| 516 | 
         
            +
                    timesteps = timesteps.expand(sample.shape[0])
         
     | 
| 517 | 
         
            +
             
     | 
| 518 | 
         
            +
                    t_emb = self.time_proj(timesteps)
         
     | 
| 519 | 
         
            +
             
     | 
| 520 | 
         
            +
                    # timesteps does not contain any weights and will always return f32 tensors
         
     | 
| 521 | 
         
            +
                    # but time_embedding might actually be running in fp16. so we need to cast here.
         
     | 
| 522 | 
         
            +
                    # there might be better ways to encapsulate this.
         
     | 
| 523 | 
         
            +
                    t_emb = t_emb.to(dtype=self.dtype)
         
     | 
| 524 | 
         
            +
             
     | 
| 525 | 
         
            +
                    emb = self.time_embedding(t_emb, timestep_cond)
         
     | 
| 526 | 
         
            +
                    emb = emb.repeat_interleave(repeats=num_frames, dim=0)
         
     | 
| 527 | 
         
            +
                    encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
         
     | 
| 528 | 
         
            +
             
     | 
| 529 | 
         
            +
                    # 2. pre-process
         
     | 
| 530 | 
         
            +
                    sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
         
     | 
| 531 | 
         
            +
                    sample = self.conv_in(sample)
         
     | 
| 532 | 
         
            +
             
     | 
| 533 | 
         
            +
                    sample = self.transformer_in(
         
     | 
| 534 | 
         
            +
                        sample,
         
     | 
| 535 | 
         
            +
                        num_frames=num_frames,
         
     | 
| 536 | 
         
            +
                        cross_attention_kwargs=cross_attention_kwargs,
         
     | 
| 537 | 
         
            +
                        return_dict=False,
         
     | 
| 538 | 
         
            +
                    )[0]
         
     | 
| 539 | 
         
            +
             
     | 
| 540 | 
         
            +
                    # 3. down
         
     | 
| 541 | 
         
            +
                    down_block_res_samples = (sample,)
         
     | 
| 542 | 
         
            +
                    for downsample_block in self.down_blocks:
         
     | 
| 543 | 
         
            +
                        if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
         
     | 
| 544 | 
         
            +
                            sample, res_samples = downsample_block(
         
     | 
| 545 | 
         
            +
                                hidden_states=sample,
         
     | 
| 546 | 
         
            +
                                temb=emb,
         
     | 
| 547 | 
         
            +
                                encoder_hidden_states=encoder_hidden_states,
         
     | 
| 548 | 
         
            +
                                attention_mask=attention_mask,
         
     | 
| 549 | 
         
            +
                                num_frames=num_frames,
         
     | 
| 550 | 
         
            +
                                cross_attention_kwargs=cross_attention_kwargs,
         
     | 
| 551 | 
         
            +
                            )
         
     | 
| 552 | 
         
            +
                        else:
         
     | 
| 553 | 
         
            +
                            sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
         
     | 
| 554 | 
         
            +
             
     | 
| 555 | 
         
            +
                        down_block_res_samples += res_samples
         
     | 
| 556 | 
         
            +
             
     | 
| 557 | 
         
            +
                    if down_block_additional_residuals is not None:
         
     | 
| 558 | 
         
            +
                        new_down_block_res_samples = ()
         
     | 
| 559 | 
         
            +
             
     | 
| 560 | 
         
            +
                        for down_block_res_sample, down_block_additional_residual in zip(
         
     | 
| 561 | 
         
            +
                            down_block_res_samples, down_block_additional_residuals
         
     | 
| 562 | 
         
            +
                        ):
         
     | 
| 563 | 
         
            +
                            down_block_res_sample = down_block_res_sample + down_block_additional_residual
         
     | 
| 564 | 
         
            +
                            new_down_block_res_samples += (down_block_res_sample,)
         
     | 
| 565 | 
         
            +
             
     | 
| 566 | 
         
            +
                        down_block_res_samples = new_down_block_res_samples
         
     | 
| 567 | 
         
            +
             
     | 
| 568 | 
         
            +
                    # 4. mid
         
     | 
| 569 | 
         
            +
                    if self.mid_block is not None:
         
     | 
| 570 | 
         
            +
                        sample = self.mid_block(
         
     | 
| 571 | 
         
            +
                            sample,
         
     | 
| 572 | 
         
            +
                            emb,
         
     | 
| 573 | 
         
            +
                            encoder_hidden_states=encoder_hidden_states,
         
     | 
| 574 | 
         
            +
                            attention_mask=attention_mask,
         
     | 
| 575 | 
         
            +
                            num_frames=num_frames,
         
     | 
| 576 | 
         
            +
                            cross_attention_kwargs=cross_attention_kwargs,
         
     | 
| 577 | 
         
            +
                        )
         
     | 
| 578 | 
         
            +
             
     | 
| 579 | 
         
            +
                    if mid_block_additional_residual is not None:
         
     | 
| 580 | 
         
            +
                        sample = sample + mid_block_additional_residual
         
     | 
| 581 | 
         
            +
             
     | 
| 582 | 
         
            +
                    # 5. up
         
     | 
| 583 | 
         
            +
                    for i, upsample_block in enumerate(self.up_blocks):
         
     | 
| 584 | 
         
            +
                        is_final_block = i == len(self.up_blocks) - 1
         
     | 
| 585 | 
         
            +
             
     | 
| 586 | 
         
            +
                        res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
         
     | 
| 587 | 
         
            +
                        down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
         
     | 
| 588 | 
         
            +
             
     | 
| 589 | 
         
            +
                        # if we have not reached the final block and need to forward the
         
     | 
| 590 | 
         
            +
                        # upsample size, we do it here
         
     | 
| 591 | 
         
            +
                        if not is_final_block and forward_upsample_size:
         
     | 
| 592 | 
         
            +
                            upsample_size = down_block_res_samples[-1].shape[2:]
         
     | 
| 593 | 
         
            +
             
     | 
| 594 | 
         
            +
                        if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
         
     | 
| 595 | 
         
            +
                            sample = upsample_block(
         
     | 
| 596 | 
         
            +
                                hidden_states=sample,
         
     | 
| 597 | 
         
            +
                                temb=emb,
         
     | 
| 598 | 
         
            +
                                res_hidden_states_tuple=res_samples,
         
     | 
| 599 | 
         
            +
                                encoder_hidden_states=encoder_hidden_states,
         
     | 
| 600 | 
         
            +
                                upsample_size=upsample_size,
         
     | 
| 601 | 
         
            +
                                attention_mask=attention_mask,
         
     | 
| 602 | 
         
            +
                                num_frames=num_frames,
         
     | 
| 603 | 
         
            +
                                cross_attention_kwargs=cross_attention_kwargs,
         
     | 
| 604 | 
         
            +
                            )
         
     | 
| 605 | 
         
            +
                        else:
         
     | 
| 606 | 
         
            +
                            sample = upsample_block(
         
     | 
| 607 | 
         
            +
                                hidden_states=sample,
         
     | 
| 608 | 
         
            +
                                temb=emb,
         
     | 
| 609 | 
         
            +
                                res_hidden_states_tuple=res_samples,
         
     | 
| 610 | 
         
            +
                                upsample_size=upsample_size,
         
     | 
| 611 | 
         
            +
                                num_frames=num_frames,
         
     | 
| 612 | 
         
            +
                            )
         
     | 
| 613 | 
         
            +
             
     | 
| 614 | 
         
            +
                    # 6. post-process
         
     | 
| 615 | 
         
            +
                    if self.conv_norm_out:
         
     | 
| 616 | 
         
            +
                        sample = self.conv_norm_out(sample)
         
     | 
| 617 | 
         
            +
                        sample = self.conv_act(sample)
         
     | 
| 618 | 
         
            +
             
     | 
| 619 | 
         
            +
                    sample = self.conv_out(sample)
         
     | 
| 620 | 
         
            +
             
     | 
| 621 | 
         
            +
                    # reshape to (batch, channel, framerate, width, height)
         
     | 
| 622 | 
         
            +
                    sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4)
         
     | 
| 623 | 
         
            +
             
     | 
| 624 | 
         
            +
                    if not return_dict:
         
     | 
| 625 | 
         
            +
                        return (sample,)
         
     | 
| 626 | 
         
            +
             
     | 
| 627 | 
         
            +
                    return UNet3DConditionOutput(sample=sample)
         
     | 
    	
        6DoF/diffusers/models/vae.py
    ADDED
    
    | 
         @@ -0,0 +1,441 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 15 | 
         
            +
            from typing import Optional
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            import numpy as np
         
     | 
| 18 | 
         
            +
            import torch
         
     | 
| 19 | 
         
            +
            import torch.nn as nn
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            from ..utils import BaseOutput, is_torch_version, randn_tensor
         
     | 
| 22 | 
         
            +
            from .attention_processor import SpatialNorm
         
     | 
| 23 | 
         
            +
            from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            @dataclass
         
     | 
| 27 | 
         
            +
            class DecoderOutput(BaseOutput):
         
     | 
| 28 | 
         
            +
                """
         
     | 
| 29 | 
         
            +
                Output of decoding method.
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                Args:
         
     | 
| 32 | 
         
            +
                    sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
         
     | 
| 33 | 
         
            +
                        The decoded output sample from the last layer of the model.
         
     | 
| 34 | 
         
            +
                """
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                sample: torch.FloatTensor
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            class Encoder(nn.Module):
         
     | 
| 40 | 
         
            +
                def __init__(
         
     | 
| 41 | 
         
            +
                    self,
         
     | 
| 42 | 
         
            +
                    in_channels=3,
         
     | 
| 43 | 
         
            +
                    out_channels=3,
         
     | 
| 44 | 
         
            +
                    down_block_types=("DownEncoderBlock2D",),
         
     | 
| 45 | 
         
            +
                    block_out_channels=(64,),
         
     | 
| 46 | 
         
            +
                    layers_per_block=2,
         
     | 
| 47 | 
         
            +
                    norm_num_groups=32,
         
     | 
| 48 | 
         
            +
                    act_fn="silu",
         
     | 
| 49 | 
         
            +
                    double_z=True,
         
     | 
| 50 | 
         
            +
                ):
         
     | 
| 51 | 
         
            +
                    super().__init__()
         
     | 
| 52 | 
         
            +
                    self.layers_per_block = layers_per_block
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    self.conv_in = torch.nn.Conv2d(
         
     | 
| 55 | 
         
            +
                        in_channels,
         
     | 
| 56 | 
         
            +
                        block_out_channels[0],
         
     | 
| 57 | 
         
            +
                        kernel_size=3,
         
     | 
| 58 | 
         
            +
                        stride=1,
         
     | 
| 59 | 
         
            +
                        padding=1,
         
     | 
| 60 | 
         
            +
                    )
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                    self.mid_block = None
         
     | 
| 63 | 
         
            +
                    self.down_blocks = nn.ModuleList([])
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                    # down
         
     | 
| 66 | 
         
            +
                    output_channel = block_out_channels[0]
         
     | 
| 67 | 
         
            +
                    for i, down_block_type in enumerate(down_block_types):
         
     | 
| 68 | 
         
            +
                        input_channel = output_channel
         
     | 
| 69 | 
         
            +
                        output_channel = block_out_channels[i]
         
     | 
| 70 | 
         
            +
                        is_final_block = i == len(block_out_channels) - 1
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                        down_block = get_down_block(
         
     | 
| 73 | 
         
            +
                            down_block_type,
         
     | 
| 74 | 
         
            +
                            num_layers=self.layers_per_block,
         
     | 
| 75 | 
         
            +
                            in_channels=input_channel,
         
     | 
| 76 | 
         
            +
                            out_channels=output_channel,
         
     | 
| 77 | 
         
            +
                            add_downsample=not is_final_block,
         
     | 
| 78 | 
         
            +
                            resnet_eps=1e-6,
         
     | 
| 79 | 
         
            +
                            downsample_padding=0,
         
     | 
| 80 | 
         
            +
                            resnet_act_fn=act_fn,
         
     | 
| 81 | 
         
            +
                            resnet_groups=norm_num_groups,
         
     | 
| 82 | 
         
            +
                            attention_head_dim=output_channel,
         
     | 
| 83 | 
         
            +
                            temb_channels=None,
         
     | 
| 84 | 
         
            +
                        )
         
     | 
| 85 | 
         
            +
                        self.down_blocks.append(down_block)
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                    # mid
         
     | 
| 88 | 
         
            +
                    self.mid_block = UNetMidBlock2D(
         
     | 
| 89 | 
         
            +
                        in_channels=block_out_channels[-1],
         
     | 
| 90 | 
         
            +
                        resnet_eps=1e-6,
         
     | 
| 91 | 
         
            +
                        resnet_act_fn=act_fn,
         
     | 
| 92 | 
         
            +
                        output_scale_factor=1,
         
     | 
| 93 | 
         
            +
                        resnet_time_scale_shift="default",
         
     | 
| 94 | 
         
            +
                        attention_head_dim=block_out_channels[-1],
         
     | 
| 95 | 
         
            +
                        resnet_groups=norm_num_groups,
         
     | 
| 96 | 
         
            +
                        temb_channels=None,
         
     | 
| 97 | 
         
            +
                    )
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                    # out
         
     | 
| 100 | 
         
            +
                    self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
         
     | 
| 101 | 
         
            +
                    self.conv_act = nn.SiLU()
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                    conv_out_channels = 2 * out_channels if double_z else out_channels
         
     | 
| 104 | 
         
            +
                    self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                    self.gradient_checkpointing = False
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                def forward(self, x):
         
     | 
| 109 | 
         
            +
                    sample = x
         
     | 
| 110 | 
         
            +
                    sample = self.conv_in(sample)
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    if self.training and self.gradient_checkpointing:
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                        def create_custom_forward(module):
         
     | 
| 115 | 
         
            +
                            def custom_forward(*inputs):
         
     | 
| 116 | 
         
            +
                                return module(*inputs)
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                            return custom_forward
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                        # down
         
     | 
| 121 | 
         
            +
                        if is_torch_version(">=", "1.11.0"):
         
     | 
| 122 | 
         
            +
                            for down_block in self.down_blocks:
         
     | 
| 123 | 
         
            +
                                sample = torch.utils.checkpoint.checkpoint(
         
     | 
| 124 | 
         
            +
                                    create_custom_forward(down_block), sample, use_reentrant=False
         
     | 
| 125 | 
         
            +
                                )
         
     | 
| 126 | 
         
            +
                            # middle
         
     | 
| 127 | 
         
            +
                            sample = torch.utils.checkpoint.checkpoint(
         
     | 
| 128 | 
         
            +
                                create_custom_forward(self.mid_block), sample, use_reentrant=False
         
     | 
| 129 | 
         
            +
                            )
         
     | 
| 130 | 
         
            +
                        else:
         
     | 
| 131 | 
         
            +
                            for down_block in self.down_blocks:
         
     | 
| 132 | 
         
            +
                                sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
         
     | 
| 133 | 
         
            +
                            # middle
         
     | 
| 134 | 
         
            +
                            sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                    else:
         
     | 
| 137 | 
         
            +
                        # down
         
     | 
| 138 | 
         
            +
                        for down_block in self.down_blocks:
         
     | 
| 139 | 
         
            +
                            sample = down_block(sample)
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                        # middle
         
     | 
| 142 | 
         
            +
                        sample = self.mid_block(sample)
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                    # post-process
         
     | 
| 145 | 
         
            +
                    sample = self.conv_norm_out(sample)
         
     | 
| 146 | 
         
            +
                    sample = self.conv_act(sample)
         
     | 
| 147 | 
         
            +
                    sample = self.conv_out(sample)
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                    return sample
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
            class Decoder(nn.Module):
         
     | 
| 153 | 
         
            +
                def __init__(
         
     | 
| 154 | 
         
            +
                    self,
         
     | 
| 155 | 
         
            +
                    in_channels=3,
         
     | 
| 156 | 
         
            +
                    out_channels=3,
         
     | 
| 157 | 
         
            +
                    up_block_types=("UpDecoderBlock2D",),
         
     | 
| 158 | 
         
            +
                    block_out_channels=(64,),
         
     | 
| 159 | 
         
            +
                    layers_per_block=2,
         
     | 
| 160 | 
         
            +
                    norm_num_groups=32,
         
     | 
| 161 | 
         
            +
                    act_fn="silu",
         
     | 
| 162 | 
         
            +
                    norm_type="group",  # group, spatial
         
     | 
| 163 | 
         
            +
                ):
         
     | 
| 164 | 
         
            +
                    super().__init__()
         
     | 
| 165 | 
         
            +
                    self.layers_per_block = layers_per_block
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                    self.conv_in = nn.Conv2d(
         
     | 
| 168 | 
         
            +
                        in_channels,
         
     | 
| 169 | 
         
            +
                        block_out_channels[-1],
         
     | 
| 170 | 
         
            +
                        kernel_size=3,
         
     | 
| 171 | 
         
            +
                        stride=1,
         
     | 
| 172 | 
         
            +
                        padding=1,
         
     | 
| 173 | 
         
            +
                    )
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                    self.mid_block = None
         
     | 
| 176 | 
         
            +
                    self.up_blocks = nn.ModuleList([])
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                    temb_channels = in_channels if norm_type == "spatial" else None
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                    # mid
         
     | 
| 181 | 
         
            +
                    self.mid_block = UNetMidBlock2D(
         
     | 
| 182 | 
         
            +
                        in_channels=block_out_channels[-1],
         
     | 
| 183 | 
         
            +
                        resnet_eps=1e-6,
         
     | 
| 184 | 
         
            +
                        resnet_act_fn=act_fn,
         
     | 
| 185 | 
         
            +
                        output_scale_factor=1,
         
     | 
| 186 | 
         
            +
                        resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
         
     | 
| 187 | 
         
            +
                        attention_head_dim=block_out_channels[-1],
         
     | 
| 188 | 
         
            +
                        resnet_groups=norm_num_groups,
         
     | 
| 189 | 
         
            +
                        temb_channels=temb_channels,
         
     | 
| 190 | 
         
            +
                    )
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                    # up
         
     | 
| 193 | 
         
            +
                    reversed_block_out_channels = list(reversed(block_out_channels))
         
     | 
| 194 | 
         
            +
                    output_channel = reversed_block_out_channels[0]
         
     | 
| 195 | 
         
            +
                    for i, up_block_type in enumerate(up_block_types):
         
     | 
| 196 | 
         
            +
                        prev_output_channel = output_channel
         
     | 
| 197 | 
         
            +
                        output_channel = reversed_block_out_channels[i]
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                        is_final_block = i == len(block_out_channels) - 1
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                        up_block = get_up_block(
         
     | 
| 202 | 
         
            +
                            up_block_type,
         
     | 
| 203 | 
         
            +
                            num_layers=self.layers_per_block + 1,
         
     | 
| 204 | 
         
            +
                            in_channels=prev_output_channel,
         
     | 
| 205 | 
         
            +
                            out_channels=output_channel,
         
     | 
| 206 | 
         
            +
                            prev_output_channel=None,
         
     | 
| 207 | 
         
            +
                            add_upsample=not is_final_block,
         
     | 
| 208 | 
         
            +
                            resnet_eps=1e-6,
         
     | 
| 209 | 
         
            +
                            resnet_act_fn=act_fn,
         
     | 
| 210 | 
         
            +
                            resnet_groups=norm_num_groups,
         
     | 
| 211 | 
         
            +
                            attention_head_dim=output_channel,
         
     | 
| 212 | 
         
            +
                            temb_channels=temb_channels,
         
     | 
| 213 | 
         
            +
                            resnet_time_scale_shift=norm_type,
         
     | 
| 214 | 
         
            +
                        )
         
     | 
| 215 | 
         
            +
                        self.up_blocks.append(up_block)
         
     | 
| 216 | 
         
            +
                        prev_output_channel = output_channel
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
                    # out
         
     | 
| 219 | 
         
            +
                    if norm_type == "spatial":
         
     | 
| 220 | 
         
            +
                        self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
         
     | 
| 221 | 
         
            +
                    else:
         
     | 
| 222 | 
         
            +
                        self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
         
     | 
| 223 | 
         
            +
                    self.conv_act = nn.SiLU()
         
     | 
| 224 | 
         
            +
                    self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
                    self.gradient_checkpointing = False
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                def forward(self, z, latent_embeds=None):
         
     | 
| 229 | 
         
            +
                    sample = z
         
     | 
| 230 | 
         
            +
                    sample = self.conv_in(sample)
         
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
                    upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
         
     | 
| 233 | 
         
            +
                    if self.training and self.gradient_checkpointing:
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                        def create_custom_forward(module):
         
     | 
| 236 | 
         
            +
                            def custom_forward(*inputs):
         
     | 
| 237 | 
         
            +
                                return module(*inputs)
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
                            return custom_forward
         
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
                        if is_torch_version(">=", "1.11.0"):
         
     | 
| 242 | 
         
            +
                            # middle
         
     | 
| 243 | 
         
            +
                            sample = torch.utils.checkpoint.checkpoint(
         
     | 
| 244 | 
         
            +
                                create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
         
     | 
| 245 | 
         
            +
                            )
         
     | 
| 246 | 
         
            +
                            sample = sample.to(upscale_dtype)
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                            # up
         
     | 
| 249 | 
         
            +
                            for up_block in self.up_blocks:
         
     | 
| 250 | 
         
            +
                                sample = torch.utils.checkpoint.checkpoint(
         
     | 
| 251 | 
         
            +
                                    create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
         
     | 
| 252 | 
         
            +
                                )
         
     | 
| 253 | 
         
            +
                        else:
         
     | 
| 254 | 
         
            +
                            # middle
         
     | 
| 255 | 
         
            +
                            sample = torch.utils.checkpoint.checkpoint(
         
     | 
| 256 | 
         
            +
                                create_custom_forward(self.mid_block), sample, latent_embeds
         
     | 
| 257 | 
         
            +
                            )
         
     | 
| 258 | 
         
            +
                            sample = sample.to(upscale_dtype)
         
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
                            # up
         
     | 
| 261 | 
         
            +
                            for up_block in self.up_blocks:
         
     | 
| 262 | 
         
            +
                                sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
         
     | 
| 263 | 
         
            +
                    else:
         
     | 
| 264 | 
         
            +
                        # middle
         
     | 
| 265 | 
         
            +
                        sample = self.mid_block(sample, latent_embeds)
         
     | 
| 266 | 
         
            +
                        sample = sample.to(upscale_dtype)
         
     | 
| 267 | 
         
            +
             
     | 
| 268 | 
         
            +
                        # up
         
     | 
| 269 | 
         
            +
                        for up_block in self.up_blocks:
         
     | 
| 270 | 
         
            +
                            sample = up_block(sample, latent_embeds)
         
     | 
| 271 | 
         
            +
             
     | 
| 272 | 
         
            +
                    # post-process
         
     | 
| 273 | 
         
            +
                    if latent_embeds is None:
         
     | 
| 274 | 
         
            +
                        sample = self.conv_norm_out(sample)
         
     | 
| 275 | 
         
            +
                    else:
         
     | 
| 276 | 
         
            +
                        sample = self.conv_norm_out(sample, latent_embeds)
         
     | 
| 277 | 
         
            +
                    sample = self.conv_act(sample)
         
     | 
| 278 | 
         
            +
                    sample = self.conv_out(sample)
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                    return sample
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
            class VectorQuantizer(nn.Module):
         
     | 
| 284 | 
         
            +
                """
         
     | 
| 285 | 
         
            +
                Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
         
     | 
| 286 | 
         
            +
                multiplications and allows for post-hoc remapping of indices.
         
     | 
| 287 | 
         
            +
                """
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
                # NOTE: due to a bug the beta term was applied to the wrong term. for
         
     | 
| 290 | 
         
            +
                # backwards compatibility we use the buggy version by default, but you can
         
     | 
| 291 | 
         
            +
                # specify legacy=False to fix it.
         
     | 
| 292 | 
         
            +
                def __init__(
         
     | 
| 293 | 
         
            +
                    self, n_e, vq_embed_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True
         
     | 
| 294 | 
         
            +
                ):
         
     | 
| 295 | 
         
            +
                    super().__init__()
         
     | 
| 296 | 
         
            +
                    self.n_e = n_e
         
     | 
| 297 | 
         
            +
                    self.vq_embed_dim = vq_embed_dim
         
     | 
| 298 | 
         
            +
                    self.beta = beta
         
     | 
| 299 | 
         
            +
                    self.legacy = legacy
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                    self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim)
         
     | 
| 302 | 
         
            +
                    self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
         
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
            +
                    self.remap = remap
         
     | 
| 305 | 
         
            +
                    if self.remap is not None:
         
     | 
| 306 | 
         
            +
                        self.register_buffer("used", torch.tensor(np.load(self.remap)))
         
     | 
| 307 | 
         
            +
                        self.re_embed = self.used.shape[0]
         
     | 
| 308 | 
         
            +
                        self.unknown_index = unknown_index  # "random" or "extra" or integer
         
     | 
| 309 | 
         
            +
                        if self.unknown_index == "extra":
         
     | 
| 310 | 
         
            +
                            self.unknown_index = self.re_embed
         
     | 
| 311 | 
         
            +
                            self.re_embed = self.re_embed + 1
         
     | 
| 312 | 
         
            +
                        print(
         
     | 
| 313 | 
         
            +
                            f"Remapping {self.n_e} indices to {self.re_embed} indices. "
         
     | 
| 314 | 
         
            +
                            f"Using {self.unknown_index} for unknown indices."
         
     | 
| 315 | 
         
            +
                        )
         
     | 
| 316 | 
         
            +
                    else:
         
     | 
| 317 | 
         
            +
                        self.re_embed = n_e
         
     | 
| 318 | 
         
            +
             
     | 
| 319 | 
         
            +
                    self.sane_index_shape = sane_index_shape
         
     | 
| 320 | 
         
            +
             
     | 
| 321 | 
         
            +
                def remap_to_used(self, inds):
         
     | 
| 322 | 
         
            +
                    ishape = inds.shape
         
     | 
| 323 | 
         
            +
                    assert len(ishape) > 1
         
     | 
| 324 | 
         
            +
                    inds = inds.reshape(ishape[0], -1)
         
     | 
| 325 | 
         
            +
                    used = self.used.to(inds)
         
     | 
| 326 | 
         
            +
                    match = (inds[:, :, None] == used[None, None, ...]).long()
         
     | 
| 327 | 
         
            +
                    new = match.argmax(-1)
         
     | 
| 328 | 
         
            +
                    unknown = match.sum(2) < 1
         
     | 
| 329 | 
         
            +
                    if self.unknown_index == "random":
         
     | 
| 330 | 
         
            +
                        new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
         
     | 
| 331 | 
         
            +
                    else:
         
     | 
| 332 | 
         
            +
                        new[unknown] = self.unknown_index
         
     | 
| 333 | 
         
            +
                    return new.reshape(ishape)
         
     | 
| 334 | 
         
            +
             
     | 
| 335 | 
         
            +
                def unmap_to_all(self, inds):
         
     | 
| 336 | 
         
            +
                    ishape = inds.shape
         
     | 
| 337 | 
         
            +
                    assert len(ishape) > 1
         
     | 
| 338 | 
         
            +
                    inds = inds.reshape(ishape[0], -1)
         
     | 
| 339 | 
         
            +
                    used = self.used.to(inds)
         
     | 
| 340 | 
         
            +
                    if self.re_embed > self.used.shape[0]:  # extra token
         
     | 
| 341 | 
         
            +
                        inds[inds >= self.used.shape[0]] = 0  # simply set to zero
         
     | 
| 342 | 
         
            +
                    back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
         
     | 
| 343 | 
         
            +
                    return back.reshape(ishape)
         
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
                def forward(self, z):
         
     | 
| 346 | 
         
            +
                    # reshape z -> (batch, height, width, channel) and flatten
         
     | 
| 347 | 
         
            +
                    z = z.permute(0, 2, 3, 1).contiguous()
         
     | 
| 348 | 
         
            +
                    z_flattened = z.view(-1, self.vq_embed_dim)
         
     | 
| 349 | 
         
            +
             
     | 
| 350 | 
         
            +
                    # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
         
     | 
| 351 | 
         
            +
                    min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1)
         
     | 
| 352 | 
         
            +
             
     | 
| 353 | 
         
            +
                    z_q = self.embedding(min_encoding_indices).view(z.shape)
         
     | 
| 354 | 
         
            +
                    perplexity = None
         
     | 
| 355 | 
         
            +
                    min_encodings = None
         
     | 
| 356 | 
         
            +
             
     | 
| 357 | 
         
            +
                    # compute loss for embedding
         
     | 
| 358 | 
         
            +
                    if not self.legacy:
         
     | 
| 359 | 
         
            +
                        loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
         
     | 
| 360 | 
         
            +
                    else:
         
     | 
| 361 | 
         
            +
                        loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
         
     | 
| 362 | 
         
            +
             
     | 
| 363 | 
         
            +
                    # preserve gradients
         
     | 
| 364 | 
         
            +
                    z_q = z + (z_q - z).detach()
         
     | 
| 365 | 
         
            +
             
     | 
| 366 | 
         
            +
                    # reshape back to match original input shape
         
     | 
| 367 | 
         
            +
                    z_q = z_q.permute(0, 3, 1, 2).contiguous()
         
     | 
| 368 | 
         
            +
             
     | 
| 369 | 
         
            +
                    if self.remap is not None:
         
     | 
| 370 | 
         
            +
                        min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1)  # add batch axis
         
     | 
| 371 | 
         
            +
                        min_encoding_indices = self.remap_to_used(min_encoding_indices)
         
     | 
| 372 | 
         
            +
                        min_encoding_indices = min_encoding_indices.reshape(-1, 1)  # flatten
         
     | 
| 373 | 
         
            +
             
     | 
| 374 | 
         
            +
                    if self.sane_index_shape:
         
     | 
| 375 | 
         
            +
                        min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
         
     | 
| 376 | 
         
            +
             
     | 
| 377 | 
         
            +
                    return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
         
     | 
| 378 | 
         
            +
             
     | 
| 379 | 
         
            +
                def get_codebook_entry(self, indices, shape):
         
     | 
| 380 | 
         
            +
                    # shape specifying (batch, height, width, channel)
         
     | 
| 381 | 
         
            +
                    if self.remap is not None:
         
     | 
| 382 | 
         
            +
                        indices = indices.reshape(shape[0], -1)  # add batch axis
         
     | 
| 383 | 
         
            +
                        indices = self.unmap_to_all(indices)
         
     | 
| 384 | 
         
            +
                        indices = indices.reshape(-1)  # flatten again
         
     | 
| 385 | 
         
            +
             
     | 
| 386 | 
         
            +
                    # get quantized latent vectors
         
     | 
| 387 | 
         
            +
                    z_q = self.embedding(indices)
         
     | 
| 388 | 
         
            +
             
     | 
| 389 | 
         
            +
                    if shape is not None:
         
     | 
| 390 | 
         
            +
                        z_q = z_q.view(shape)
         
     | 
| 391 | 
         
            +
                        # reshape back to match original input shape
         
     | 
| 392 | 
         
            +
                        z_q = z_q.permute(0, 3, 1, 2).contiguous()
         
     | 
| 393 | 
         
            +
             
     | 
| 394 | 
         
            +
                    return z_q
         
     | 
| 395 | 
         
            +
             
     | 
| 396 | 
         
            +
             
     | 
| 397 | 
         
            +
            class DiagonalGaussianDistribution(object):
         
     | 
| 398 | 
         
            +
                def __init__(self, parameters, deterministic=False):
         
     | 
| 399 | 
         
            +
                    self.parameters = parameters
         
     | 
| 400 | 
         
            +
                    self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
         
     | 
| 401 | 
         
            +
                    self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
         
     | 
| 402 | 
         
            +
                    self.deterministic = deterministic
         
     | 
| 403 | 
         
            +
                    self.std = torch.exp(0.5 * self.logvar)
         
     | 
| 404 | 
         
            +
                    self.var = torch.exp(self.logvar)
         
     | 
| 405 | 
         
            +
                    if self.deterministic:
         
     | 
| 406 | 
         
            +
                        self.var = self.std = torch.zeros_like(
         
     | 
| 407 | 
         
            +
                            self.mean, device=self.parameters.device, dtype=self.parameters.dtype
         
     | 
| 408 | 
         
            +
                        )
         
     | 
| 409 | 
         
            +
             
     | 
| 410 | 
         
            +
                def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
         
     | 
| 411 | 
         
            +
                    # make sure sample is on the same device as the parameters and has same dtype
         
     | 
| 412 | 
         
            +
                    sample = randn_tensor(
         
     | 
| 413 | 
         
            +
                        self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype
         
     | 
| 414 | 
         
            +
                    )
         
     | 
| 415 | 
         
            +
                    x = self.mean + self.std * sample
         
     | 
| 416 | 
         
            +
                    return x
         
     | 
| 417 | 
         
            +
             
     | 
| 418 | 
         
            +
                def kl(self, other=None):
         
     | 
| 419 | 
         
            +
                    if self.deterministic:
         
     | 
| 420 | 
         
            +
                        return torch.Tensor([0.0])
         
     | 
| 421 | 
         
            +
                    else:
         
     | 
| 422 | 
         
            +
                        if other is None:
         
     | 
| 423 | 
         
            +
                            return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
         
     | 
| 424 | 
         
            +
                        else:
         
     | 
| 425 | 
         
            +
                            return 0.5 * torch.sum(
         
     | 
| 426 | 
         
            +
                                torch.pow(self.mean - other.mean, 2) / other.var
         
     | 
| 427 | 
         
            +
                                + self.var / other.var
         
     | 
| 428 | 
         
            +
                                - 1.0
         
     | 
| 429 | 
         
            +
                                - self.logvar
         
     | 
| 430 | 
         
            +
                                + other.logvar,
         
     | 
| 431 | 
         
            +
                                dim=[1, 2, 3],
         
     | 
| 432 | 
         
            +
                            )
         
     | 
| 433 | 
         
            +
             
     | 
| 434 | 
         
            +
                def nll(self, sample, dims=[1, 2, 3]):
         
     | 
| 435 | 
         
            +
                    if self.deterministic:
         
     | 
| 436 | 
         
            +
                        return torch.Tensor([0.0])
         
     | 
| 437 | 
         
            +
                    logtwopi = np.log(2.0 * np.pi)
         
     | 
| 438 | 
         
            +
                    return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
         
     | 
| 439 | 
         
            +
             
     | 
| 440 | 
         
            +
                def mode(self):
         
     | 
| 441 | 
         
            +
                    return self.mean
         
     | 
    	
        6DoF/diffusers/models/vae_flax.py
    ADDED
    
    | 
         @@ -0,0 +1,869 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            # JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            import math
         
     | 
| 18 | 
         
            +
            from functools import partial
         
     | 
| 19 | 
         
            +
            from typing import Tuple
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            import flax
         
     | 
| 22 | 
         
            +
            import flax.linen as nn
         
     | 
| 23 | 
         
            +
            import jax
         
     | 
| 24 | 
         
            +
            import jax.numpy as jnp
         
     | 
| 25 | 
         
            +
            from flax.core.frozen_dict import FrozenDict
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            from ..configuration_utils import ConfigMixin, flax_register_to_config
         
     | 
| 28 | 
         
            +
            from ..utils import BaseOutput
         
     | 
| 29 | 
         
            +
            from .modeling_flax_utils import FlaxModelMixin
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            @flax.struct.dataclass
         
     | 
| 33 | 
         
            +
            class FlaxDecoderOutput(BaseOutput):
         
     | 
| 34 | 
         
            +
                """
         
     | 
| 35 | 
         
            +
                Output of decoding method.
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                Args:
         
     | 
| 38 | 
         
            +
                    sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
         
     | 
| 39 | 
         
            +
                        The decoded output sample from the last layer of the model.
         
     | 
| 40 | 
         
            +
                    dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
         
     | 
| 41 | 
         
            +
                        The `dtype` of the parameters.
         
     | 
| 42 | 
         
            +
                """
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                sample: jnp.ndarray
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            @flax.struct.dataclass
         
     | 
| 48 | 
         
            +
            class FlaxAutoencoderKLOutput(BaseOutput):
         
     | 
| 49 | 
         
            +
                """
         
     | 
| 50 | 
         
            +
                Output of AutoencoderKL encoding method.
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                Args:
         
     | 
| 53 | 
         
            +
                    latent_dist (`FlaxDiagonalGaussianDistribution`):
         
     | 
| 54 | 
         
            +
                        Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution`.
         
     | 
| 55 | 
         
            +
                        `FlaxDiagonalGaussianDistribution` allows for sampling latents from the distribution.
         
     | 
| 56 | 
         
            +
                """
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                latent_dist: "FlaxDiagonalGaussianDistribution"
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
            class FlaxUpsample2D(nn.Module):
         
     | 
| 62 | 
         
            +
                """
         
     | 
| 63 | 
         
            +
                Flax implementation of 2D Upsample layer
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                Args:
         
     | 
| 66 | 
         
            +
                    in_channels (`int`):
         
     | 
| 67 | 
         
            +
                        Input channels
         
     | 
| 68 | 
         
            +
                    dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
         
     | 
| 69 | 
         
            +
                        Parameters `dtype`
         
     | 
| 70 | 
         
            +
                """
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                in_channels: int
         
     | 
| 73 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                def setup(self):
         
     | 
| 76 | 
         
            +
                    self.conv = nn.Conv(
         
     | 
| 77 | 
         
            +
                        self.in_channels,
         
     | 
| 78 | 
         
            +
                        kernel_size=(3, 3),
         
     | 
| 79 | 
         
            +
                        strides=(1, 1),
         
     | 
| 80 | 
         
            +
                        padding=((1, 1), (1, 1)),
         
     | 
| 81 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 82 | 
         
            +
                    )
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                def __call__(self, hidden_states):
         
     | 
| 85 | 
         
            +
                    batch, height, width, channels = hidden_states.shape
         
     | 
| 86 | 
         
            +
                    hidden_states = jax.image.resize(
         
     | 
| 87 | 
         
            +
                        hidden_states,
         
     | 
| 88 | 
         
            +
                        shape=(batch, height * 2, width * 2, channels),
         
     | 
| 89 | 
         
            +
                        method="nearest",
         
     | 
| 90 | 
         
            +
                    )
         
     | 
| 91 | 
         
            +
                    hidden_states = self.conv(hidden_states)
         
     | 
| 92 | 
         
            +
                    return hidden_states
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
            class FlaxDownsample2D(nn.Module):
         
     | 
| 96 | 
         
            +
                """
         
     | 
| 97 | 
         
            +
                Flax implementation of 2D Downsample layer
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                Args:
         
     | 
| 100 | 
         
            +
                    in_channels (`int`):
         
     | 
| 101 | 
         
            +
                        Input channels
         
     | 
| 102 | 
         
            +
                    dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
         
     | 
| 103 | 
         
            +
                        Parameters `dtype`
         
     | 
| 104 | 
         
            +
                """
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                in_channels: int
         
     | 
| 107 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                def setup(self):
         
     | 
| 110 | 
         
            +
                    self.conv = nn.Conv(
         
     | 
| 111 | 
         
            +
                        self.in_channels,
         
     | 
| 112 | 
         
            +
                        kernel_size=(3, 3),
         
     | 
| 113 | 
         
            +
                        strides=(2, 2),
         
     | 
| 114 | 
         
            +
                        padding="VALID",
         
     | 
| 115 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 116 | 
         
            +
                    )
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                def __call__(self, hidden_states):
         
     | 
| 119 | 
         
            +
                    pad = ((0, 0), (0, 1), (0, 1), (0, 0))  # pad height and width dim
         
     | 
| 120 | 
         
            +
                    hidden_states = jnp.pad(hidden_states, pad_width=pad)
         
     | 
| 121 | 
         
            +
                    hidden_states = self.conv(hidden_states)
         
     | 
| 122 | 
         
            +
                    return hidden_states
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
            class FlaxResnetBlock2D(nn.Module):
         
     | 
| 126 | 
         
            +
                """
         
     | 
| 127 | 
         
            +
                Flax implementation of 2D Resnet Block.
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                Args:
         
     | 
| 130 | 
         
            +
                    in_channels (`int`):
         
     | 
| 131 | 
         
            +
                        Input channels
         
     | 
| 132 | 
         
            +
                    out_channels (`int`):
         
     | 
| 133 | 
         
            +
                        Output channels
         
     | 
| 134 | 
         
            +
                    dropout (:obj:`float`, *optional*, defaults to 0.0):
         
     | 
| 135 | 
         
            +
                        Dropout rate
         
     | 
| 136 | 
         
            +
                    groups (:obj:`int`, *optional*, defaults to `32`):
         
     | 
| 137 | 
         
            +
                        The number of groups to use for group norm.
         
     | 
| 138 | 
         
            +
                    use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`):
         
     | 
| 139 | 
         
            +
                        Whether to use `nin_shortcut`. This activates a new layer inside ResNet block
         
     | 
| 140 | 
         
            +
                    dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
         
     | 
| 141 | 
         
            +
                        Parameters `dtype`
         
     | 
| 142 | 
         
            +
                """
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                in_channels: int
         
     | 
| 145 | 
         
            +
                out_channels: int = None
         
     | 
| 146 | 
         
            +
                dropout: float = 0.0
         
     | 
| 147 | 
         
            +
                groups: int = 32
         
     | 
| 148 | 
         
            +
                use_nin_shortcut: bool = None
         
     | 
| 149 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                def setup(self):
         
     | 
| 152 | 
         
            +
                    out_channels = self.in_channels if self.out_channels is None else self.out_channels
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                    self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
         
     | 
| 155 | 
         
            +
                    self.conv1 = nn.Conv(
         
     | 
| 156 | 
         
            +
                        out_channels,
         
     | 
| 157 | 
         
            +
                        kernel_size=(3, 3),
         
     | 
| 158 | 
         
            +
                        strides=(1, 1),
         
     | 
| 159 | 
         
            +
                        padding=((1, 1), (1, 1)),
         
     | 
| 160 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 161 | 
         
            +
                    )
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                    self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
         
     | 
| 164 | 
         
            +
                    self.dropout_layer = nn.Dropout(self.dropout)
         
     | 
| 165 | 
         
            +
                    self.conv2 = nn.Conv(
         
     | 
| 166 | 
         
            +
                        out_channels,
         
     | 
| 167 | 
         
            +
                        kernel_size=(3, 3),
         
     | 
| 168 | 
         
            +
                        strides=(1, 1),
         
     | 
| 169 | 
         
            +
                        padding=((1, 1), (1, 1)),
         
     | 
| 170 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 171 | 
         
            +
                    )
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                    use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                    self.conv_shortcut = None
         
     | 
| 176 | 
         
            +
                    if use_nin_shortcut:
         
     | 
| 177 | 
         
            +
                        self.conv_shortcut = nn.Conv(
         
     | 
| 178 | 
         
            +
                            out_channels,
         
     | 
| 179 | 
         
            +
                            kernel_size=(1, 1),
         
     | 
| 180 | 
         
            +
                            strides=(1, 1),
         
     | 
| 181 | 
         
            +
                            padding="VALID",
         
     | 
| 182 | 
         
            +
                            dtype=self.dtype,
         
     | 
| 183 | 
         
            +
                        )
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                def __call__(self, hidden_states, deterministic=True):
         
     | 
| 186 | 
         
            +
                    residual = hidden_states
         
     | 
| 187 | 
         
            +
                    hidden_states = self.norm1(hidden_states)
         
     | 
| 188 | 
         
            +
                    hidden_states = nn.swish(hidden_states)
         
     | 
| 189 | 
         
            +
                    hidden_states = self.conv1(hidden_states)
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                    hidden_states = self.norm2(hidden_states)
         
     | 
| 192 | 
         
            +
                    hidden_states = nn.swish(hidden_states)
         
     | 
| 193 | 
         
            +
                    hidden_states = self.dropout_layer(hidden_states, deterministic)
         
     | 
| 194 | 
         
            +
                    hidden_states = self.conv2(hidden_states)
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                    if self.conv_shortcut is not None:
         
     | 
| 197 | 
         
            +
                        residual = self.conv_shortcut(residual)
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                    return hidden_states + residual
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
            class FlaxAttentionBlock(nn.Module):
         
     | 
| 203 | 
         
            +
                r"""
         
     | 
| 204 | 
         
            +
                Flax Convolutional based multi-head attention block for diffusion-based VAE.
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                Parameters:
         
     | 
| 207 | 
         
            +
                    channels (:obj:`int`):
         
     | 
| 208 | 
         
            +
                        Input channels
         
     | 
| 209 | 
         
            +
                    num_head_channels (:obj:`int`, *optional*, defaults to `None`):
         
     | 
| 210 | 
         
            +
                        Number of attention heads
         
     | 
| 211 | 
         
            +
                    num_groups (:obj:`int`, *optional*, defaults to `32`):
         
     | 
| 212 | 
         
            +
                        The number of groups to use for group norm
         
     | 
| 213 | 
         
            +
                    dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
         
     | 
| 214 | 
         
            +
                        Parameters `dtype`
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                """
         
     | 
| 217 | 
         
            +
                channels: int
         
     | 
| 218 | 
         
            +
                num_head_channels: int = None
         
     | 
| 219 | 
         
            +
                num_groups: int = 32
         
     | 
| 220 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                def setup(self):
         
     | 
| 223 | 
         
            +
                    self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                    dense = partial(nn.Dense, self.channels, dtype=self.dtype)
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
                    self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6)
         
     | 
| 228 | 
         
            +
                    self.query, self.key, self.value = dense(), dense(), dense()
         
     | 
| 229 | 
         
            +
                    self.proj_attn = dense()
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                def transpose_for_scores(self, projection):
         
     | 
| 232 | 
         
            +
                    new_projection_shape = projection.shape[:-1] + (self.num_heads, -1)
         
     | 
| 233 | 
         
            +
                    # move heads to 2nd position (B, T, H * D) -> (B, T, H, D)
         
     | 
| 234 | 
         
            +
                    new_projection = projection.reshape(new_projection_shape)
         
     | 
| 235 | 
         
            +
                    # (B, T, H, D) -> (B, H, T, D)
         
     | 
| 236 | 
         
            +
                    new_projection = jnp.transpose(new_projection, (0, 2, 1, 3))
         
     | 
| 237 | 
         
            +
                    return new_projection
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
                def __call__(self, hidden_states):
         
     | 
| 240 | 
         
            +
                    residual = hidden_states
         
     | 
| 241 | 
         
            +
                    batch, height, width, channels = hidden_states.shape
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                    hidden_states = self.group_norm(hidden_states)
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                    hidden_states = hidden_states.reshape((batch, height * width, channels))
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                    query = self.query(hidden_states)
         
     | 
| 248 | 
         
            +
                    key = self.key(hidden_states)
         
     | 
| 249 | 
         
            +
                    value = self.value(hidden_states)
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
                    # transpose
         
     | 
| 252 | 
         
            +
                    query = self.transpose_for_scores(query)
         
     | 
| 253 | 
         
            +
                    key = self.transpose_for_scores(key)
         
     | 
| 254 | 
         
            +
                    value = self.transpose_for_scores(value)
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                    # compute attentions
         
     | 
| 257 | 
         
            +
                    scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
         
     | 
| 258 | 
         
            +
                    attn_weights = jnp.einsum("...qc,...kc->...qk", query * scale, key * scale)
         
     | 
| 259 | 
         
            +
                    attn_weights = nn.softmax(attn_weights, axis=-1)
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
                    # attend to values
         
     | 
| 262 | 
         
            +
                    hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights)
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                    hidden_states = jnp.transpose(hidden_states, (0, 2, 1, 3))
         
     | 
| 265 | 
         
            +
                    new_hidden_states_shape = hidden_states.shape[:-2] + (self.channels,)
         
     | 
| 266 | 
         
            +
                    hidden_states = hidden_states.reshape(new_hidden_states_shape)
         
     | 
| 267 | 
         
            +
             
     | 
| 268 | 
         
            +
                    hidden_states = self.proj_attn(hidden_states)
         
     | 
| 269 | 
         
            +
                    hidden_states = hidden_states.reshape((batch, height, width, channels))
         
     | 
| 270 | 
         
            +
                    hidden_states = hidden_states + residual
         
     | 
| 271 | 
         
            +
                    return hidden_states
         
     | 
| 272 | 
         
            +
             
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
            class FlaxDownEncoderBlock2D(nn.Module):
         
     | 
| 275 | 
         
            +
                r"""
         
     | 
| 276 | 
         
            +
                Flax Resnet blocks-based Encoder block for diffusion-based VAE.
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
                Parameters:
         
     | 
| 279 | 
         
            +
                    in_channels (:obj:`int`):
         
     | 
| 280 | 
         
            +
                        Input channels
         
     | 
| 281 | 
         
            +
                    out_channels (:obj:`int`):
         
     | 
| 282 | 
         
            +
                        Output channels
         
     | 
| 283 | 
         
            +
                    dropout (:obj:`float`, *optional*, defaults to 0.0):
         
     | 
| 284 | 
         
            +
                        Dropout rate
         
     | 
| 285 | 
         
            +
                    num_layers (:obj:`int`, *optional*, defaults to 1):
         
     | 
| 286 | 
         
            +
                        Number of Resnet layer block
         
     | 
| 287 | 
         
            +
                    resnet_groups (:obj:`int`, *optional*, defaults to `32`):
         
     | 
| 288 | 
         
            +
                        The number of groups to use for the Resnet block group norm
         
     | 
| 289 | 
         
            +
                    add_downsample (:obj:`bool`, *optional*, defaults to `True`):
         
     | 
| 290 | 
         
            +
                        Whether to add downsample layer
         
     | 
| 291 | 
         
            +
                    dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
         
     | 
| 292 | 
         
            +
                        Parameters `dtype`
         
     | 
| 293 | 
         
            +
                """
         
     | 
| 294 | 
         
            +
                in_channels: int
         
     | 
| 295 | 
         
            +
                out_channels: int
         
     | 
| 296 | 
         
            +
                dropout: float = 0.0
         
     | 
| 297 | 
         
            +
                num_layers: int = 1
         
     | 
| 298 | 
         
            +
                resnet_groups: int = 32
         
     | 
| 299 | 
         
            +
                add_downsample: bool = True
         
     | 
| 300 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 301 | 
         
            +
             
     | 
| 302 | 
         
            +
                def setup(self):
         
     | 
| 303 | 
         
            +
                    resnets = []
         
     | 
| 304 | 
         
            +
                    for i in range(self.num_layers):
         
     | 
| 305 | 
         
            +
                        in_channels = self.in_channels if i == 0 else self.out_channels
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
            +
                        res_block = FlaxResnetBlock2D(
         
     | 
| 308 | 
         
            +
                            in_channels=in_channels,
         
     | 
| 309 | 
         
            +
                            out_channels=self.out_channels,
         
     | 
| 310 | 
         
            +
                            dropout=self.dropout,
         
     | 
| 311 | 
         
            +
                            groups=self.resnet_groups,
         
     | 
| 312 | 
         
            +
                            dtype=self.dtype,
         
     | 
| 313 | 
         
            +
                        )
         
     | 
| 314 | 
         
            +
                        resnets.append(res_block)
         
     | 
| 315 | 
         
            +
                    self.resnets = resnets
         
     | 
| 316 | 
         
            +
             
     | 
| 317 | 
         
            +
                    if self.add_downsample:
         
     | 
| 318 | 
         
            +
                        self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
         
     | 
| 319 | 
         
            +
             
     | 
| 320 | 
         
            +
                def __call__(self, hidden_states, deterministic=True):
         
     | 
| 321 | 
         
            +
                    for resnet in self.resnets:
         
     | 
| 322 | 
         
            +
                        hidden_states = resnet(hidden_states, deterministic=deterministic)
         
     | 
| 323 | 
         
            +
             
     | 
| 324 | 
         
            +
                    if self.add_downsample:
         
     | 
| 325 | 
         
            +
                        hidden_states = self.downsamplers_0(hidden_states)
         
     | 
| 326 | 
         
            +
             
     | 
| 327 | 
         
            +
                    return hidden_states
         
     | 
| 328 | 
         
            +
             
     | 
| 329 | 
         
            +
             
     | 
| 330 | 
         
            +
            class FlaxUpDecoderBlock2D(nn.Module):
         
     | 
| 331 | 
         
            +
                r"""
         
     | 
| 332 | 
         
            +
                Flax Resnet blocks-based Decoder block for diffusion-based VAE.
         
     | 
| 333 | 
         
            +
             
     | 
| 334 | 
         
            +
                Parameters:
         
     | 
| 335 | 
         
            +
                    in_channels (:obj:`int`):
         
     | 
| 336 | 
         
            +
                        Input channels
         
     | 
| 337 | 
         
            +
                    out_channels (:obj:`int`):
         
     | 
| 338 | 
         
            +
                        Output channels
         
     | 
| 339 | 
         
            +
                    dropout (:obj:`float`, *optional*, defaults to 0.0):
         
     | 
| 340 | 
         
            +
                        Dropout rate
         
     | 
| 341 | 
         
            +
                    num_layers (:obj:`int`, *optional*, defaults to 1):
         
     | 
| 342 | 
         
            +
                        Number of Resnet layer block
         
     | 
| 343 | 
         
            +
                    resnet_groups (:obj:`int`, *optional*, defaults to `32`):
         
     | 
| 344 | 
         
            +
                        The number of groups to use for the Resnet block group norm
         
     | 
| 345 | 
         
            +
                    add_upsample (:obj:`bool`, *optional*, defaults to `True`):
         
     | 
| 346 | 
         
            +
                        Whether to add upsample layer
         
     | 
| 347 | 
         
            +
                    dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
         
     | 
| 348 | 
         
            +
                        Parameters `dtype`
         
     | 
| 349 | 
         
            +
                """
         
     | 
| 350 | 
         
            +
                in_channels: int
         
     | 
| 351 | 
         
            +
                out_channels: int
         
     | 
| 352 | 
         
            +
                dropout: float = 0.0
         
     | 
| 353 | 
         
            +
                num_layers: int = 1
         
     | 
| 354 | 
         
            +
                resnet_groups: int = 32
         
     | 
| 355 | 
         
            +
                add_upsample: bool = True
         
     | 
| 356 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 357 | 
         
            +
             
     | 
| 358 | 
         
            +
                def setup(self):
         
     | 
| 359 | 
         
            +
                    resnets = []
         
     | 
| 360 | 
         
            +
                    for i in range(self.num_layers):
         
     | 
| 361 | 
         
            +
                        in_channels = self.in_channels if i == 0 else self.out_channels
         
     | 
| 362 | 
         
            +
                        res_block = FlaxResnetBlock2D(
         
     | 
| 363 | 
         
            +
                            in_channels=in_channels,
         
     | 
| 364 | 
         
            +
                            out_channels=self.out_channels,
         
     | 
| 365 | 
         
            +
                            dropout=self.dropout,
         
     | 
| 366 | 
         
            +
                            groups=self.resnet_groups,
         
     | 
| 367 | 
         
            +
                            dtype=self.dtype,
         
     | 
| 368 | 
         
            +
                        )
         
     | 
| 369 | 
         
            +
                        resnets.append(res_block)
         
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
                    self.resnets = resnets
         
     | 
| 372 | 
         
            +
             
     | 
| 373 | 
         
            +
                    if self.add_upsample:
         
     | 
| 374 | 
         
            +
                        self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
         
     | 
| 375 | 
         
            +
             
     | 
| 376 | 
         
            +
                def __call__(self, hidden_states, deterministic=True):
         
     | 
| 377 | 
         
            +
                    for resnet in self.resnets:
         
     | 
| 378 | 
         
            +
                        hidden_states = resnet(hidden_states, deterministic=deterministic)
         
     | 
| 379 | 
         
            +
             
     | 
| 380 | 
         
            +
                    if self.add_upsample:
         
     | 
| 381 | 
         
            +
                        hidden_states = self.upsamplers_0(hidden_states)
         
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
                    return hidden_states
         
     | 
| 384 | 
         
            +
             
     | 
| 385 | 
         
            +
             
     | 
| 386 | 
         
            +
            class FlaxUNetMidBlock2D(nn.Module):
         
     | 
| 387 | 
         
            +
                r"""
         
     | 
| 388 | 
         
            +
                Flax Unet Mid-Block module.
         
     | 
| 389 | 
         
            +
             
     | 
| 390 | 
         
            +
                Parameters:
         
     | 
| 391 | 
         
            +
                    in_channels (:obj:`int`):
         
     | 
| 392 | 
         
            +
                        Input channels
         
     | 
| 393 | 
         
            +
                    dropout (:obj:`float`, *optional*, defaults to 0.0):
         
     | 
| 394 | 
         
            +
                        Dropout rate
         
     | 
| 395 | 
         
            +
                    num_layers (:obj:`int`, *optional*, defaults to 1):
         
     | 
| 396 | 
         
            +
                        Number of Resnet layer block
         
     | 
| 397 | 
         
            +
                    resnet_groups (:obj:`int`, *optional*, defaults to `32`):
         
     | 
| 398 | 
         
            +
                        The number of groups to use for the Resnet and Attention block group norm
         
     | 
| 399 | 
         
            +
                    num_attention_heads (:obj:`int`, *optional*, defaults to `1`):
         
     | 
| 400 | 
         
            +
                        Number of attention heads for each attention block
         
     | 
| 401 | 
         
            +
                    dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
         
     | 
| 402 | 
         
            +
                        Parameters `dtype`
         
     | 
| 403 | 
         
            +
                """
         
     | 
| 404 | 
         
            +
                in_channels: int
         
     | 
| 405 | 
         
            +
                dropout: float = 0.0
         
     | 
| 406 | 
         
            +
                num_layers: int = 1
         
     | 
| 407 | 
         
            +
                resnet_groups: int = 32
         
     | 
| 408 | 
         
            +
                num_attention_heads: int = 1
         
     | 
| 409 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 410 | 
         
            +
             
     | 
| 411 | 
         
            +
                def setup(self):
         
     | 
| 412 | 
         
            +
                    resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32)
         
     | 
| 413 | 
         
            +
             
     | 
| 414 | 
         
            +
                    # there is always at least one resnet
         
     | 
| 415 | 
         
            +
                    resnets = [
         
     | 
| 416 | 
         
            +
                        FlaxResnetBlock2D(
         
     | 
| 417 | 
         
            +
                            in_channels=self.in_channels,
         
     | 
| 418 | 
         
            +
                            out_channels=self.in_channels,
         
     | 
| 419 | 
         
            +
                            dropout=self.dropout,
         
     | 
| 420 | 
         
            +
                            groups=resnet_groups,
         
     | 
| 421 | 
         
            +
                            dtype=self.dtype,
         
     | 
| 422 | 
         
            +
                        )
         
     | 
| 423 | 
         
            +
                    ]
         
     | 
| 424 | 
         
            +
             
     | 
| 425 | 
         
            +
                    attentions = []
         
     | 
| 426 | 
         
            +
             
     | 
| 427 | 
         
            +
                    for _ in range(self.num_layers):
         
     | 
| 428 | 
         
            +
                        attn_block = FlaxAttentionBlock(
         
     | 
| 429 | 
         
            +
                            channels=self.in_channels,
         
     | 
| 430 | 
         
            +
                            num_head_channels=self.num_attention_heads,
         
     | 
| 431 | 
         
            +
                            num_groups=resnet_groups,
         
     | 
| 432 | 
         
            +
                            dtype=self.dtype,
         
     | 
| 433 | 
         
            +
                        )
         
     | 
| 434 | 
         
            +
                        attentions.append(attn_block)
         
     | 
| 435 | 
         
            +
             
     | 
| 436 | 
         
            +
                        res_block = FlaxResnetBlock2D(
         
     | 
| 437 | 
         
            +
                            in_channels=self.in_channels,
         
     | 
| 438 | 
         
            +
                            out_channels=self.in_channels,
         
     | 
| 439 | 
         
            +
                            dropout=self.dropout,
         
     | 
| 440 | 
         
            +
                            groups=resnet_groups,
         
     | 
| 441 | 
         
            +
                            dtype=self.dtype,
         
     | 
| 442 | 
         
            +
                        )
         
     | 
| 443 | 
         
            +
                        resnets.append(res_block)
         
     | 
| 444 | 
         
            +
             
     | 
| 445 | 
         
            +
                    self.resnets = resnets
         
     | 
| 446 | 
         
            +
                    self.attentions = attentions
         
     | 
| 447 | 
         
            +
             
     | 
| 448 | 
         
            +
                def __call__(self, hidden_states, deterministic=True):
         
     | 
| 449 | 
         
            +
                    hidden_states = self.resnets[0](hidden_states, deterministic=deterministic)
         
     | 
| 450 | 
         
            +
                    for attn, resnet in zip(self.attentions, self.resnets[1:]):
         
     | 
| 451 | 
         
            +
                        hidden_states = attn(hidden_states)
         
     | 
| 452 | 
         
            +
                        hidden_states = resnet(hidden_states, deterministic=deterministic)
         
     | 
| 453 | 
         
            +
             
     | 
| 454 | 
         
            +
                    return hidden_states
         
     | 
| 455 | 
         
            +
             
     | 
| 456 | 
         
            +
             
     | 
| 457 | 
         
            +
            class FlaxEncoder(nn.Module):
         
     | 
| 458 | 
         
            +
                r"""
         
     | 
| 459 | 
         
            +
                Flax Implementation of VAE Encoder.
         
     | 
| 460 | 
         
            +
             
     | 
| 461 | 
         
            +
                This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
         
     | 
| 462 | 
         
            +
                subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
         
     | 
| 463 | 
         
            +
                general usage and behavior.
         
     | 
| 464 | 
         
            +
             
     | 
| 465 | 
         
            +
                Finally, this model supports inherent JAX features such as:
         
     | 
| 466 | 
         
            +
                - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
         
     | 
| 467 | 
         
            +
                - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
         
     | 
| 468 | 
         
            +
                - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
         
     | 
| 469 | 
         
            +
                - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
         
     | 
| 470 | 
         
            +
             
     | 
| 471 | 
         
            +
                Parameters:
         
     | 
| 472 | 
         
            +
                    in_channels (:obj:`int`, *optional*, defaults to 3):
         
     | 
| 473 | 
         
            +
                        Input channels
         
     | 
| 474 | 
         
            +
                    out_channels (:obj:`int`, *optional*, defaults to 3):
         
     | 
| 475 | 
         
            +
                        Output channels
         
     | 
| 476 | 
         
            +
                    down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`):
         
     | 
| 477 | 
         
            +
                        DownEncoder block type
         
     | 
| 478 | 
         
            +
                    block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`):
         
     | 
| 479 | 
         
            +
                        Tuple containing the number of output channels for each block
         
     | 
| 480 | 
         
            +
                    layers_per_block (:obj:`int`, *optional*, defaults to `2`):
         
     | 
| 481 | 
         
            +
                        Number of Resnet layer for each block
         
     | 
| 482 | 
         
            +
                    norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
         
     | 
| 483 | 
         
            +
                        norm num group
         
     | 
| 484 | 
         
            +
                    act_fn (:obj:`str`, *optional*, defaults to `silu`):
         
     | 
| 485 | 
         
            +
                        Activation function
         
     | 
| 486 | 
         
            +
                    double_z (:obj:`bool`, *optional*, defaults to `False`):
         
     | 
| 487 | 
         
            +
                        Whether to double the last output channels
         
     | 
| 488 | 
         
            +
                    dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
         
     | 
| 489 | 
         
            +
                        Parameters `dtype`
         
     | 
| 490 | 
         
            +
                """
         
     | 
| 491 | 
         
            +
                in_channels: int = 3
         
     | 
| 492 | 
         
            +
                out_channels: int = 3
         
     | 
| 493 | 
         
            +
                down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
         
     | 
| 494 | 
         
            +
                block_out_channels: Tuple[int] = (64,)
         
     | 
| 495 | 
         
            +
                layers_per_block: int = 2
         
     | 
| 496 | 
         
            +
                norm_num_groups: int = 32
         
     | 
| 497 | 
         
            +
                act_fn: str = "silu"
         
     | 
| 498 | 
         
            +
                double_z: bool = False
         
     | 
| 499 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 500 | 
         
            +
             
     | 
| 501 | 
         
            +
                def setup(self):
         
     | 
| 502 | 
         
            +
                    block_out_channels = self.block_out_channels
         
     | 
| 503 | 
         
            +
                    # in
         
     | 
| 504 | 
         
            +
                    self.conv_in = nn.Conv(
         
     | 
| 505 | 
         
            +
                        block_out_channels[0],
         
     | 
| 506 | 
         
            +
                        kernel_size=(3, 3),
         
     | 
| 507 | 
         
            +
                        strides=(1, 1),
         
     | 
| 508 | 
         
            +
                        padding=((1, 1), (1, 1)),
         
     | 
| 509 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 510 | 
         
            +
                    )
         
     | 
| 511 | 
         
            +
             
     | 
| 512 | 
         
            +
                    # downsampling
         
     | 
| 513 | 
         
            +
                    down_blocks = []
         
     | 
| 514 | 
         
            +
                    output_channel = block_out_channels[0]
         
     | 
| 515 | 
         
            +
                    for i, _ in enumerate(self.down_block_types):
         
     | 
| 516 | 
         
            +
                        input_channel = output_channel
         
     | 
| 517 | 
         
            +
                        output_channel = block_out_channels[i]
         
     | 
| 518 | 
         
            +
                        is_final_block = i == len(block_out_channels) - 1
         
     | 
| 519 | 
         
            +
             
     | 
| 520 | 
         
            +
                        down_block = FlaxDownEncoderBlock2D(
         
     | 
| 521 | 
         
            +
                            in_channels=input_channel,
         
     | 
| 522 | 
         
            +
                            out_channels=output_channel,
         
     | 
| 523 | 
         
            +
                            num_layers=self.layers_per_block,
         
     | 
| 524 | 
         
            +
                            resnet_groups=self.norm_num_groups,
         
     | 
| 525 | 
         
            +
                            add_downsample=not is_final_block,
         
     | 
| 526 | 
         
            +
                            dtype=self.dtype,
         
     | 
| 527 | 
         
            +
                        )
         
     | 
| 528 | 
         
            +
                        down_blocks.append(down_block)
         
     | 
| 529 | 
         
            +
                    self.down_blocks = down_blocks
         
     | 
| 530 | 
         
            +
             
     | 
| 531 | 
         
            +
                    # middle
         
     | 
| 532 | 
         
            +
                    self.mid_block = FlaxUNetMidBlock2D(
         
     | 
| 533 | 
         
            +
                        in_channels=block_out_channels[-1],
         
     | 
| 534 | 
         
            +
                        resnet_groups=self.norm_num_groups,
         
     | 
| 535 | 
         
            +
                        num_attention_heads=None,
         
     | 
| 536 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 537 | 
         
            +
                    )
         
     | 
| 538 | 
         
            +
             
     | 
| 539 | 
         
            +
                    # end
         
     | 
| 540 | 
         
            +
                    conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels
         
     | 
| 541 | 
         
            +
                    self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
         
     | 
| 542 | 
         
            +
                    self.conv_out = nn.Conv(
         
     | 
| 543 | 
         
            +
                        conv_out_channels,
         
     | 
| 544 | 
         
            +
                        kernel_size=(3, 3),
         
     | 
| 545 | 
         
            +
                        strides=(1, 1),
         
     | 
| 546 | 
         
            +
                        padding=((1, 1), (1, 1)),
         
     | 
| 547 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 548 | 
         
            +
                    )
         
     | 
| 549 | 
         
            +
             
     | 
| 550 | 
         
            +
                def __call__(self, sample, deterministic: bool = True):
         
     | 
| 551 | 
         
            +
                    # in
         
     | 
| 552 | 
         
            +
                    sample = self.conv_in(sample)
         
     | 
| 553 | 
         
            +
             
     | 
| 554 | 
         
            +
                    # downsampling
         
     | 
| 555 | 
         
            +
                    for block in self.down_blocks:
         
     | 
| 556 | 
         
            +
                        sample = block(sample, deterministic=deterministic)
         
     | 
| 557 | 
         
            +
             
     | 
| 558 | 
         
            +
                    # middle
         
     | 
| 559 | 
         
            +
                    sample = self.mid_block(sample, deterministic=deterministic)
         
     | 
| 560 | 
         
            +
             
     | 
| 561 | 
         
            +
                    # end
         
     | 
| 562 | 
         
            +
                    sample = self.conv_norm_out(sample)
         
     | 
| 563 | 
         
            +
                    sample = nn.swish(sample)
         
     | 
| 564 | 
         
            +
                    sample = self.conv_out(sample)
         
     | 
| 565 | 
         
            +
             
     | 
| 566 | 
         
            +
                    return sample
         
     | 
| 567 | 
         
            +
             
     | 
| 568 | 
         
            +
             
     | 
| 569 | 
         
            +
            class FlaxDecoder(nn.Module):
         
     | 
| 570 | 
         
            +
                r"""
         
     | 
| 571 | 
         
            +
                Flax Implementation of VAE Decoder.
         
     | 
| 572 | 
         
            +
             
     | 
| 573 | 
         
            +
                This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
         
     | 
| 574 | 
         
            +
                subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
         
     | 
| 575 | 
         
            +
                general usage and behavior.
         
     | 
| 576 | 
         
            +
             
     | 
| 577 | 
         
            +
                Finally, this model supports inherent JAX features such as:
         
     | 
| 578 | 
         
            +
                - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
         
     | 
| 579 | 
         
            +
                - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
         
     | 
| 580 | 
         
            +
                - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
         
     | 
| 581 | 
         
            +
                - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
         
     | 
| 582 | 
         
            +
             
     | 
| 583 | 
         
            +
                Parameters:
         
     | 
| 584 | 
         
            +
                    in_channels (:obj:`int`, *optional*, defaults to 3):
         
     | 
| 585 | 
         
            +
                        Input channels
         
     | 
| 586 | 
         
            +
                    out_channels (:obj:`int`, *optional*, defaults to 3):
         
     | 
| 587 | 
         
            +
                        Output channels
         
     | 
| 588 | 
         
            +
                    up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`):
         
     | 
| 589 | 
         
            +
                        UpDecoder block type
         
     | 
| 590 | 
         
            +
                    block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`):
         
     | 
| 591 | 
         
            +
                        Tuple containing the number of output channels for each block
         
     | 
| 592 | 
         
            +
                    layers_per_block (:obj:`int`, *optional*, defaults to `2`):
         
     | 
| 593 | 
         
            +
                        Number of Resnet layer for each block
         
     | 
| 594 | 
         
            +
                    norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
         
     | 
| 595 | 
         
            +
                        norm num group
         
     | 
| 596 | 
         
            +
                    act_fn (:obj:`str`, *optional*, defaults to `silu`):
         
     | 
| 597 | 
         
            +
                        Activation function
         
     | 
| 598 | 
         
            +
                    double_z (:obj:`bool`, *optional*, defaults to `False`):
         
     | 
| 599 | 
         
            +
                        Whether to double the last output channels
         
     | 
| 600 | 
         
            +
                    dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
         
     | 
| 601 | 
         
            +
                        parameters `dtype`
         
     | 
| 602 | 
         
            +
                """
         
     | 
| 603 | 
         
            +
                in_channels: int = 3
         
     | 
| 604 | 
         
            +
                out_channels: int = 3
         
     | 
| 605 | 
         
            +
                up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
         
     | 
| 606 | 
         
            +
                block_out_channels: int = (64,)
         
     | 
| 607 | 
         
            +
                layers_per_block: int = 2
         
     | 
| 608 | 
         
            +
                norm_num_groups: int = 32
         
     | 
| 609 | 
         
            +
                act_fn: str = "silu"
         
     | 
| 610 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 611 | 
         
            +
             
     | 
| 612 | 
         
            +
                def setup(self):
         
     | 
| 613 | 
         
            +
                    block_out_channels = self.block_out_channels
         
     | 
| 614 | 
         
            +
             
     | 
| 615 | 
         
            +
                    # z to block_in
         
     | 
| 616 | 
         
            +
                    self.conv_in = nn.Conv(
         
     | 
| 617 | 
         
            +
                        block_out_channels[-1],
         
     | 
| 618 | 
         
            +
                        kernel_size=(3, 3),
         
     | 
| 619 | 
         
            +
                        strides=(1, 1),
         
     | 
| 620 | 
         
            +
                        padding=((1, 1), (1, 1)),
         
     | 
| 621 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 622 | 
         
            +
                    )
         
     | 
| 623 | 
         
            +
             
     | 
| 624 | 
         
            +
                    # middle
         
     | 
| 625 | 
         
            +
                    self.mid_block = FlaxUNetMidBlock2D(
         
     | 
| 626 | 
         
            +
                        in_channels=block_out_channels[-1],
         
     | 
| 627 | 
         
            +
                        resnet_groups=self.norm_num_groups,
         
     | 
| 628 | 
         
            +
                        num_attention_heads=None,
         
     | 
| 629 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 630 | 
         
            +
                    )
         
     | 
| 631 | 
         
            +
             
     | 
| 632 | 
         
            +
                    # upsampling
         
     | 
| 633 | 
         
            +
                    reversed_block_out_channels = list(reversed(block_out_channels))
         
     | 
| 634 | 
         
            +
                    output_channel = reversed_block_out_channels[0]
         
     | 
| 635 | 
         
            +
                    up_blocks = []
         
     | 
| 636 | 
         
            +
                    for i, _ in enumerate(self.up_block_types):
         
     | 
| 637 | 
         
            +
                        prev_output_channel = output_channel
         
     | 
| 638 | 
         
            +
                        output_channel = reversed_block_out_channels[i]
         
     | 
| 639 | 
         
            +
             
     | 
| 640 | 
         
            +
                        is_final_block = i == len(block_out_channels) - 1
         
     | 
| 641 | 
         
            +
             
     | 
| 642 | 
         
            +
                        up_block = FlaxUpDecoderBlock2D(
         
     | 
| 643 | 
         
            +
                            in_channels=prev_output_channel,
         
     | 
| 644 | 
         
            +
                            out_channels=output_channel,
         
     | 
| 645 | 
         
            +
                            num_layers=self.layers_per_block + 1,
         
     | 
| 646 | 
         
            +
                            resnet_groups=self.norm_num_groups,
         
     | 
| 647 | 
         
            +
                            add_upsample=not is_final_block,
         
     | 
| 648 | 
         
            +
                            dtype=self.dtype,
         
     | 
| 649 | 
         
            +
                        )
         
     | 
| 650 | 
         
            +
                        up_blocks.append(up_block)
         
     | 
| 651 | 
         
            +
                        prev_output_channel = output_channel
         
     | 
| 652 | 
         
            +
             
     | 
| 653 | 
         
            +
                    self.up_blocks = up_blocks
         
     | 
| 654 | 
         
            +
             
     | 
| 655 | 
         
            +
                    # end
         
     | 
| 656 | 
         
            +
                    self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
         
     | 
| 657 | 
         
            +
                    self.conv_out = nn.Conv(
         
     | 
| 658 | 
         
            +
                        self.out_channels,
         
     | 
| 659 | 
         
            +
                        kernel_size=(3, 3),
         
     | 
| 660 | 
         
            +
                        strides=(1, 1),
         
     | 
| 661 | 
         
            +
                        padding=((1, 1), (1, 1)),
         
     | 
| 662 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 663 | 
         
            +
                    )
         
     | 
| 664 | 
         
            +
             
     | 
| 665 | 
         
            +
                def __call__(self, sample, deterministic: bool = True):
         
     | 
| 666 | 
         
            +
                    # z to block_in
         
     | 
| 667 | 
         
            +
                    sample = self.conv_in(sample)
         
     | 
| 668 | 
         
            +
             
     | 
| 669 | 
         
            +
                    # middle
         
     | 
| 670 | 
         
            +
                    sample = self.mid_block(sample, deterministic=deterministic)
         
     | 
| 671 | 
         
            +
             
     | 
| 672 | 
         
            +
                    # upsampling
         
     | 
| 673 | 
         
            +
                    for block in self.up_blocks:
         
     | 
| 674 | 
         
            +
                        sample = block(sample, deterministic=deterministic)
         
     | 
| 675 | 
         
            +
             
     | 
| 676 | 
         
            +
                    sample = self.conv_norm_out(sample)
         
     | 
| 677 | 
         
            +
                    sample = nn.swish(sample)
         
     | 
| 678 | 
         
            +
                    sample = self.conv_out(sample)
         
     | 
| 679 | 
         
            +
             
     | 
| 680 | 
         
            +
                    return sample
         
     | 
| 681 | 
         
            +
             
     | 
| 682 | 
         
            +
             
     | 
| 683 | 
         
            +
            class FlaxDiagonalGaussianDistribution(object):
         
     | 
| 684 | 
         
            +
                def __init__(self, parameters, deterministic=False):
         
     | 
| 685 | 
         
            +
                    # Last axis to account for channels-last
         
     | 
| 686 | 
         
            +
                    self.mean, self.logvar = jnp.split(parameters, 2, axis=-1)
         
     | 
| 687 | 
         
            +
                    self.logvar = jnp.clip(self.logvar, -30.0, 20.0)
         
     | 
| 688 | 
         
            +
                    self.deterministic = deterministic
         
     | 
| 689 | 
         
            +
                    self.std = jnp.exp(0.5 * self.logvar)
         
     | 
| 690 | 
         
            +
                    self.var = jnp.exp(self.logvar)
         
     | 
| 691 | 
         
            +
                    if self.deterministic:
         
     | 
| 692 | 
         
            +
                        self.var = self.std = jnp.zeros_like(self.mean)
         
     | 
| 693 | 
         
            +
             
     | 
| 694 | 
         
            +
                def sample(self, key):
         
     | 
| 695 | 
         
            +
                    return self.mean + self.std * jax.random.normal(key, self.mean.shape)
         
     | 
| 696 | 
         
            +
             
     | 
| 697 | 
         
            +
                def kl(self, other=None):
         
     | 
| 698 | 
         
            +
                    if self.deterministic:
         
     | 
| 699 | 
         
            +
                        return jnp.array([0.0])
         
     | 
| 700 | 
         
            +
             
     | 
| 701 | 
         
            +
                    if other is None:
         
     | 
| 702 | 
         
            +
                        return 0.5 * jnp.sum(self.mean**2 + self.var - 1.0 - self.logvar, axis=[1, 2, 3])
         
     | 
| 703 | 
         
            +
             
     | 
| 704 | 
         
            +
                    return 0.5 * jnp.sum(
         
     | 
| 705 | 
         
            +
                        jnp.square(self.mean - other.mean) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar,
         
     | 
| 706 | 
         
            +
                        axis=[1, 2, 3],
         
     | 
| 707 | 
         
            +
                    )
         
     | 
| 708 | 
         
            +
             
     | 
| 709 | 
         
            +
                def nll(self, sample, axis=[1, 2, 3]):
         
     | 
| 710 | 
         
            +
                    if self.deterministic:
         
     | 
| 711 | 
         
            +
                        return jnp.array([0.0])
         
     | 
| 712 | 
         
            +
             
     | 
| 713 | 
         
            +
                    logtwopi = jnp.log(2.0 * jnp.pi)
         
     | 
| 714 | 
         
            +
                    return 0.5 * jnp.sum(logtwopi + self.logvar + jnp.square(sample - self.mean) / self.var, axis=axis)
         
     | 
| 715 | 
         
            +
             
     | 
| 716 | 
         
            +
                def mode(self):
         
     | 
| 717 | 
         
            +
                    return self.mean
         
     | 
| 718 | 
         
            +
             
     | 
| 719 | 
         
            +
             
     | 
| 720 | 
         
            +
            @flax_register_to_config
         
     | 
| 721 | 
         
            +
            class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
         
     | 
| 722 | 
         
            +
                r"""
         
     | 
| 723 | 
         
            +
                Flax implementation of a VAE model with KL loss for decoding latent representations.
         
     | 
| 724 | 
         
            +
             
     | 
| 725 | 
         
            +
                This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods
         
     | 
| 726 | 
         
            +
                implemented for all models (such as downloading or saving).
         
     | 
| 727 | 
         
            +
             
     | 
| 728 | 
         
            +
                This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
         
     | 
| 729 | 
         
            +
                subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matter related to its
         
     | 
| 730 | 
         
            +
                general usage and behavior.
         
     | 
| 731 | 
         
            +
             
     | 
| 732 | 
         
            +
                Inherent JAX features such as the following are supported:
         
     | 
| 733 | 
         
            +
             
     | 
| 734 | 
         
            +
                - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
         
     | 
| 735 | 
         
            +
                - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
         
     | 
| 736 | 
         
            +
                - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
         
     | 
| 737 | 
         
            +
                - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
         
     | 
| 738 | 
         
            +
             
     | 
| 739 | 
         
            +
                Parameters:
         
     | 
| 740 | 
         
            +
                    in_channels (`int`, *optional*, defaults to 3):
         
     | 
| 741 | 
         
            +
                        Number of channels in the input image.
         
     | 
| 742 | 
         
            +
                    out_channels (`int`, *optional*, defaults to 3):
         
     | 
| 743 | 
         
            +
                        Number of channels in the output.
         
     | 
| 744 | 
         
            +
                    down_block_types (`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`):
         
     | 
| 745 | 
         
            +
                        Tuple of downsample block types.
         
     | 
| 746 | 
         
            +
                    up_block_types (`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`):
         
     | 
| 747 | 
         
            +
                        Tuple of upsample block types.
         
     | 
| 748 | 
         
            +
                    block_out_channels (`Tuple[str]`, *optional*, defaults to `(64,)`):
         
     | 
| 749 | 
         
            +
                        Tuple of block output channels.
         
     | 
| 750 | 
         
            +
                    layers_per_block (`int`, *optional*, defaults to `2`):
         
     | 
| 751 | 
         
            +
                        Number of ResNet layer for each block.
         
     | 
| 752 | 
         
            +
                    act_fn (`str`, *optional*, defaults to `silu`):
         
     | 
| 753 | 
         
            +
                        The activation function to use.
         
     | 
| 754 | 
         
            +
                    latent_channels (`int`, *optional*, defaults to `4`):
         
     | 
| 755 | 
         
            +
                        Number of channels in the latent space.
         
     | 
| 756 | 
         
            +
                    norm_num_groups (`int`, *optional*, defaults to `32`):
         
     | 
| 757 | 
         
            +
                        The number of groups for normalization.
         
     | 
| 758 | 
         
            +
                    sample_size (`int`, *optional*, defaults to 32):
         
     | 
| 759 | 
         
            +
                        Sample input size.
         
     | 
| 760 | 
         
            +
                    scaling_factor (`float`, *optional*, defaults to 0.18215):
         
     | 
| 761 | 
         
            +
                        The component-wise standard deviation of the trained latent space computed using the first batch of the
         
     | 
| 762 | 
         
            +
                        training set. This is used to scale the latent space to have unit variance when training the diffusion
         
     | 
| 763 | 
         
            +
                        model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
         
     | 
| 764 | 
         
            +
                        diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
         
     | 
| 765 | 
         
            +
                        / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
         
     | 
| 766 | 
         
            +
                        Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
         
     | 
| 767 | 
         
            +
                    dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
         
     | 
| 768 | 
         
            +
                        The `dtype` of the parameters.
         
     | 
| 769 | 
         
            +
                """
         
     | 
| 770 | 
         
            +
                in_channels: int = 3
         
     | 
| 771 | 
         
            +
                out_channels: int = 3
         
     | 
| 772 | 
         
            +
                down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
         
     | 
| 773 | 
         
            +
                up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
         
     | 
| 774 | 
         
            +
                block_out_channels: Tuple[int] = (64,)
         
     | 
| 775 | 
         
            +
                layers_per_block: int = 1
         
     | 
| 776 | 
         
            +
                act_fn: str = "silu"
         
     | 
| 777 | 
         
            +
                latent_channels: int = 4
         
     | 
| 778 | 
         
            +
                norm_num_groups: int = 32
         
     | 
| 779 | 
         
            +
                sample_size: int = 32
         
     | 
| 780 | 
         
            +
                scaling_factor: float = 0.18215
         
     | 
| 781 | 
         
            +
                dtype: jnp.dtype = jnp.float32
         
     | 
| 782 | 
         
            +
             
     | 
| 783 | 
         
            +
                def setup(self):
         
     | 
| 784 | 
         
            +
                    self.encoder = FlaxEncoder(
         
     | 
| 785 | 
         
            +
                        in_channels=self.config.in_channels,
         
     | 
| 786 | 
         
            +
                        out_channels=self.config.latent_channels,
         
     | 
| 787 | 
         
            +
                        down_block_types=self.config.down_block_types,
         
     | 
| 788 | 
         
            +
                        block_out_channels=self.config.block_out_channels,
         
     | 
| 789 | 
         
            +
                        layers_per_block=self.config.layers_per_block,
         
     | 
| 790 | 
         
            +
                        act_fn=self.config.act_fn,
         
     | 
| 791 | 
         
            +
                        norm_num_groups=self.config.norm_num_groups,
         
     | 
| 792 | 
         
            +
                        double_z=True,
         
     | 
| 793 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 794 | 
         
            +
                    )
         
     | 
| 795 | 
         
            +
                    self.decoder = FlaxDecoder(
         
     | 
| 796 | 
         
            +
                        in_channels=self.config.latent_channels,
         
     | 
| 797 | 
         
            +
                        out_channels=self.config.out_channels,
         
     | 
| 798 | 
         
            +
                        up_block_types=self.config.up_block_types,
         
     | 
| 799 | 
         
            +
                        block_out_channels=self.config.block_out_channels,
         
     | 
| 800 | 
         
            +
                        layers_per_block=self.config.layers_per_block,
         
     | 
| 801 | 
         
            +
                        norm_num_groups=self.config.norm_num_groups,
         
     | 
| 802 | 
         
            +
                        act_fn=self.config.act_fn,
         
     | 
| 803 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 804 | 
         
            +
                    )
         
     | 
| 805 | 
         
            +
                    self.quant_conv = nn.Conv(
         
     | 
| 806 | 
         
            +
                        2 * self.config.latent_channels,
         
     | 
| 807 | 
         
            +
                        kernel_size=(1, 1),
         
     | 
| 808 | 
         
            +
                        strides=(1, 1),
         
     | 
| 809 | 
         
            +
                        padding="VALID",
         
     | 
| 810 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 811 | 
         
            +
                    )
         
     | 
| 812 | 
         
            +
                    self.post_quant_conv = nn.Conv(
         
     | 
| 813 | 
         
            +
                        self.config.latent_channels,
         
     | 
| 814 | 
         
            +
                        kernel_size=(1, 1),
         
     | 
| 815 | 
         
            +
                        strides=(1, 1),
         
     | 
| 816 | 
         
            +
                        padding="VALID",
         
     | 
| 817 | 
         
            +
                        dtype=self.dtype,
         
     | 
| 818 | 
         
            +
                    )
         
     | 
| 819 | 
         
            +
             
     | 
| 820 | 
         
            +
                def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
         
     | 
| 821 | 
         
            +
                    # init input tensors
         
     | 
| 822 | 
         
            +
                    sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
         
     | 
| 823 | 
         
            +
                    sample = jnp.zeros(sample_shape, dtype=jnp.float32)
         
     | 
| 824 | 
         
            +
             
     | 
| 825 | 
         
            +
                    params_rng, dropout_rng, gaussian_rng = jax.random.split(rng, 3)
         
     | 
| 826 | 
         
            +
                    rngs = {"params": params_rng, "dropout": dropout_rng, "gaussian": gaussian_rng}
         
     | 
| 827 | 
         
            +
             
     | 
| 828 | 
         
            +
                    return self.init(rngs, sample)["params"]
         
     | 
| 829 | 
         
            +
             
     | 
| 830 | 
         
            +
                def encode(self, sample, deterministic: bool = True, return_dict: bool = True):
         
     | 
| 831 | 
         
            +
                    sample = jnp.transpose(sample, (0, 2, 3, 1))
         
     | 
| 832 | 
         
            +
             
     | 
| 833 | 
         
            +
                    hidden_states = self.encoder(sample, deterministic=deterministic)
         
     | 
| 834 | 
         
            +
                    moments = self.quant_conv(hidden_states)
         
     | 
| 835 | 
         
            +
                    posterior = FlaxDiagonalGaussianDistribution(moments)
         
     | 
| 836 | 
         
            +
             
     | 
| 837 | 
         
            +
                    if not return_dict:
         
     | 
| 838 | 
         
            +
                        return (posterior,)
         
     | 
| 839 | 
         
            +
             
     | 
| 840 | 
         
            +
                    return FlaxAutoencoderKLOutput(latent_dist=posterior)
         
     | 
| 841 | 
         
            +
             
     | 
| 842 | 
         
            +
                def decode(self, latents, deterministic: bool = True, return_dict: bool = True):
         
     | 
| 843 | 
         
            +
                    if latents.shape[-1] != self.config.latent_channels:
         
     | 
| 844 | 
         
            +
                        latents = jnp.transpose(latents, (0, 2, 3, 1))
         
     | 
| 845 | 
         
            +
             
     | 
| 846 | 
         
            +
                    hidden_states = self.post_quant_conv(latents)
         
     | 
| 847 | 
         
            +
                    hidden_states = self.decoder(hidden_states, deterministic=deterministic)
         
     | 
| 848 | 
         
            +
             
     | 
| 849 | 
         
            +
                    hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2))
         
     | 
| 850 | 
         
            +
             
     | 
| 851 | 
         
            +
                    if not return_dict:
         
     | 
| 852 | 
         
            +
                        return (hidden_states,)
         
     | 
| 853 | 
         
            +
             
     | 
| 854 | 
         
            +
                    return FlaxDecoderOutput(sample=hidden_states)
         
     | 
| 855 | 
         
            +
             
     | 
| 856 | 
         
            +
                def __call__(self, sample, sample_posterior=False, deterministic: bool = True, return_dict: bool = True):
         
     | 
| 857 | 
         
            +
                    posterior = self.encode(sample, deterministic=deterministic, return_dict=return_dict)
         
     | 
| 858 | 
         
            +
                    if sample_posterior:
         
     | 
| 859 | 
         
            +
                        rng = self.make_rng("gaussian")
         
     | 
| 860 | 
         
            +
                        hidden_states = posterior.latent_dist.sample(rng)
         
     | 
| 861 | 
         
            +
                    else:
         
     | 
| 862 | 
         
            +
                        hidden_states = posterior.latent_dist.mode()
         
     | 
| 863 | 
         
            +
             
     | 
| 864 | 
         
            +
                    sample = self.decode(hidden_states, return_dict=return_dict).sample
         
     | 
| 865 | 
         
            +
             
     | 
| 866 | 
         
            +
                    if not return_dict:
         
     | 
| 867 | 
         
            +
                        return (sample,)
         
     | 
| 868 | 
         
            +
             
     | 
| 869 | 
         
            +
                    return FlaxDecoderOutput(sample=sample)
         
     | 
    	
        6DoF/diffusers/models/vq_model.py
    ADDED
    
    | 
         @@ -0,0 +1,167 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
            # limitations under the License.
         
     | 
| 14 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 15 | 
         
            +
            from typing import Optional, Tuple, Union
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            import torch
         
     | 
| 18 | 
         
            +
            import torch.nn as nn
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            from ..configuration_utils import ConfigMixin, register_to_config
         
     | 
| 21 | 
         
            +
            from ..utils import BaseOutput, apply_forward_hook
         
     | 
| 22 | 
         
            +
            from .modeling_utils import ModelMixin
         
     | 
| 23 | 
         
            +
            from .vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            @dataclass
         
     | 
| 27 | 
         
            +
            class VQEncoderOutput(BaseOutput):
         
     | 
| 28 | 
         
            +
                """
         
     | 
| 29 | 
         
            +
                Output of VQModel encoding method.
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                Args:
         
     | 
| 32 | 
         
            +
                    latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
         
     | 
| 33 | 
         
            +
                        The encoded output sample from the last layer of the model.
         
     | 
| 34 | 
         
            +
                """
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                latents: torch.FloatTensor
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            class VQModel(ModelMixin, ConfigMixin):
         
     | 
| 40 | 
         
            +
                r"""
         
     | 
| 41 | 
         
            +
                A VQ-VAE model for decoding latent representations.
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
         
     | 
| 44 | 
         
            +
                for all models (such as downloading or saving).
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                Parameters:
         
     | 
| 47 | 
         
            +
                    in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
         
     | 
| 48 | 
         
            +
                    out_channels (int,  *optional*, defaults to 3): Number of channels in the output.
         
     | 
| 49 | 
         
            +
                    down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
         
     | 
| 50 | 
         
            +
                        Tuple of downsample block types.
         
     | 
| 51 | 
         
            +
                    up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
         
     | 
| 52 | 
         
            +
                        Tuple of upsample block types.
         
     | 
| 53 | 
         
            +
                    block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
         
     | 
| 54 | 
         
            +
                        Tuple of block output channels.
         
     | 
| 55 | 
         
            +
                    act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
         
     | 
| 56 | 
         
            +
                    latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
         
     | 
| 57 | 
         
            +
                    sample_size (`int`, *optional*, defaults to `32`): Sample input size.
         
     | 
| 58 | 
         
            +
                    num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
         
     | 
| 59 | 
         
            +
                    vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
         
     | 
| 60 | 
         
            +
                    scaling_factor (`float`, *optional*, defaults to `0.18215`):
         
     | 
| 61 | 
         
            +
                        The component-wise standard deviation of the trained latent space computed using the first batch of the
         
     | 
| 62 | 
         
            +
                        training set. This is used to scale the latent space to have unit variance when training the diffusion
         
     | 
| 63 | 
         
            +
                        model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
         
     | 
| 64 | 
         
            +
                        diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
         
     | 
| 65 | 
         
            +
                        / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
         
     | 
| 66 | 
         
            +
                        Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
         
     | 
| 67 | 
         
            +
                """
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                @register_to_config
         
     | 
| 70 | 
         
            +
                def __init__(
         
     | 
| 71 | 
         
            +
                    self,
         
     | 
| 72 | 
         
            +
                    in_channels: int = 3,
         
     | 
| 73 | 
         
            +
                    out_channels: int = 3,
         
     | 
| 74 | 
         
            +
                    down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
         
     | 
| 75 | 
         
            +
                    up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
         
     | 
| 76 | 
         
            +
                    block_out_channels: Tuple[int] = (64,),
         
     | 
| 77 | 
         
            +
                    layers_per_block: int = 1,
         
     | 
| 78 | 
         
            +
                    act_fn: str = "silu",
         
     | 
| 79 | 
         
            +
                    latent_channels: int = 3,
         
     | 
| 80 | 
         
            +
                    sample_size: int = 32,
         
     | 
| 81 | 
         
            +
                    num_vq_embeddings: int = 256,
         
     | 
| 82 | 
         
            +
                    norm_num_groups: int = 32,
         
     | 
| 83 | 
         
            +
                    vq_embed_dim: Optional[int] = None,
         
     | 
| 84 | 
         
            +
                    scaling_factor: float = 0.18215,
         
     | 
| 85 | 
         
            +
                    norm_type: str = "group",  # group, spatial
         
     | 
| 86 | 
         
            +
                ):
         
     | 
| 87 | 
         
            +
                    super().__init__()
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    # pass init params to Encoder
         
     | 
| 90 | 
         
            +
                    self.encoder = Encoder(
         
     | 
| 91 | 
         
            +
                        in_channels=in_channels,
         
     | 
| 92 | 
         
            +
                        out_channels=latent_channels,
         
     | 
| 93 | 
         
            +
                        down_block_types=down_block_types,
         
     | 
| 94 | 
         
            +
                        block_out_channels=block_out_channels,
         
     | 
| 95 | 
         
            +
                        layers_per_block=layers_per_block,
         
     | 
| 96 | 
         
            +
                        act_fn=act_fn,
         
     | 
| 97 | 
         
            +
                        norm_num_groups=norm_num_groups,
         
     | 
| 98 | 
         
            +
                        double_z=False,
         
     | 
| 99 | 
         
            +
                    )
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                    vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                    self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1)
         
     | 
| 104 | 
         
            +
                    self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
         
     | 
| 105 | 
         
            +
                    self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1)
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                    # pass init params to Decoder
         
     | 
| 108 | 
         
            +
                    self.decoder = Decoder(
         
     | 
| 109 | 
         
            +
                        in_channels=latent_channels,
         
     | 
| 110 | 
         
            +
                        out_channels=out_channels,
         
     | 
| 111 | 
         
            +
                        up_block_types=up_block_types,
         
     | 
| 112 | 
         
            +
                        block_out_channels=block_out_channels,
         
     | 
| 113 | 
         
            +
                        layers_per_block=layers_per_block,
         
     | 
| 114 | 
         
            +
                        act_fn=act_fn,
         
     | 
| 115 | 
         
            +
                        norm_num_groups=norm_num_groups,
         
     | 
| 116 | 
         
            +
                        norm_type=norm_type,
         
     | 
| 117 | 
         
            +
                    )
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                @apply_forward_hook
         
     | 
| 120 | 
         
            +
                def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
         
     | 
| 121 | 
         
            +
                    h = self.encoder(x)
         
     | 
| 122 | 
         
            +
                    h = self.quant_conv(h)
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                    if not return_dict:
         
     | 
| 125 | 
         
            +
                        return (h,)
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    return VQEncoderOutput(latents=h)
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                @apply_forward_hook
         
     | 
| 130 | 
         
            +
                def decode(
         
     | 
| 131 | 
         
            +
                    self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
         
     | 
| 132 | 
         
            +
                ) -> Union[DecoderOutput, torch.FloatTensor]:
         
     | 
| 133 | 
         
            +
                    # also go through quantization layer
         
     | 
| 134 | 
         
            +
                    if not force_not_quantize:
         
     | 
| 135 | 
         
            +
                        quant, emb_loss, info = self.quantize(h)
         
     | 
| 136 | 
         
            +
                    else:
         
     | 
| 137 | 
         
            +
                        quant = h
         
     | 
| 138 | 
         
            +
                    quant2 = self.post_quant_conv(quant)
         
     | 
| 139 | 
         
            +
                    dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None)
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                    if not return_dict:
         
     | 
| 142 | 
         
            +
                        return (dec,)
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                    return DecoderOutput(sample=dec)
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
         
     | 
| 147 | 
         
            +
                    r"""
         
     | 
| 148 | 
         
            +
                    The [`VQModel`] forward method.
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                    Args:
         
     | 
| 151 | 
         
            +
                        sample (`torch.FloatTensor`): Input sample.
         
     | 
| 152 | 
         
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         
     | 
| 153 | 
         
            +
                            Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple.
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                    Returns:
         
     | 
| 156 | 
         
            +
                        [`~models.vq_model.VQEncoderOutput`] or `tuple`:
         
     | 
| 157 | 
         
            +
                            If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple`
         
     | 
| 158 | 
         
            +
                            is returned.
         
     | 
| 159 | 
         
            +
                    """
         
     | 
| 160 | 
         
            +
                    x = sample
         
     | 
| 161 | 
         
            +
                    h = self.encode(x).latents
         
     | 
| 162 | 
         
            +
                    dec = self.decode(h).sample
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                    if not return_dict:
         
     | 
| 165 | 
         
            +
                        return (dec,)
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                    return DecoderOutput(sample=dec)
         
     | 
    	
        6DoF/diffusers/optimization.py
    ADDED
    
    | 
         @@ -0,0 +1,354 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # coding=utf-8
         
     | 
| 2 | 
         
            +
            # Copyright 2023 The HuggingFace Inc. team.
         
     | 
| 3 | 
         
            +
            #
         
     | 
| 4 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 5 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 6 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 7 | 
         
            +
            #
         
     | 
| 8 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 9 | 
         
            +
            #
         
     | 
| 10 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 11 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 12 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 13 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 14 | 
         
            +
            # limitations under the License.
         
     | 
| 15 | 
         
            +
            """PyTorch optimization for diffusion models."""
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            import math
         
     | 
| 18 | 
         
            +
            from enum import Enum
         
     | 
| 19 | 
         
            +
            from typing import Optional, Union
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            from torch.optim import Optimizer
         
     | 
| 22 | 
         
            +
            from torch.optim.lr_scheduler import LambdaLR
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            from .utils import logging
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            logger = logging.get_logger(__name__)
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            class SchedulerType(Enum):
         
     | 
| 31 | 
         
            +
                LINEAR = "linear"
         
     | 
| 32 | 
         
            +
                COSINE = "cosine"
         
     | 
| 33 | 
         
            +
                COSINE_WITH_RESTARTS = "cosine_with_restarts"
         
     | 
| 34 | 
         
            +
                POLYNOMIAL = "polynomial"
         
     | 
| 35 | 
         
            +
                CONSTANT = "constant"
         
     | 
| 36 | 
         
            +
                CONSTANT_WITH_WARMUP = "constant_with_warmup"
         
     | 
| 37 | 
         
            +
                PIECEWISE_CONSTANT = "piecewise_constant"
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
         
     | 
| 41 | 
         
            +
                """
         
     | 
| 42 | 
         
            +
                Create a schedule with a constant learning rate, using the learning rate set in optimizer.
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                Args:
         
     | 
| 45 | 
         
            +
                    optimizer ([`~torch.optim.Optimizer`]):
         
     | 
| 46 | 
         
            +
                        The optimizer for which to schedule the learning rate.
         
     | 
| 47 | 
         
            +
                    last_epoch (`int`, *optional*, defaults to -1):
         
     | 
| 48 | 
         
            +
                        The index of the last epoch when resuming training.
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                Return:
         
     | 
| 51 | 
         
            +
                    `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
         
     | 
| 52 | 
         
            +
                """
         
     | 
| 53 | 
         
            +
                return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
            def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
         
     | 
| 57 | 
         
            +
                """
         
     | 
| 58 | 
         
            +
                Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
         
     | 
| 59 | 
         
            +
                increases linearly between 0 and the initial lr set in the optimizer.
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                Args:
         
     | 
| 62 | 
         
            +
                    optimizer ([`~torch.optim.Optimizer`]):
         
     | 
| 63 | 
         
            +
                        The optimizer for which to schedule the learning rate.
         
     | 
| 64 | 
         
            +
                    num_warmup_steps (`int`):
         
     | 
| 65 | 
         
            +
                        The number of steps for the warmup phase.
         
     | 
| 66 | 
         
            +
                    last_epoch (`int`, *optional*, defaults to -1):
         
     | 
| 67 | 
         
            +
                        The index of the last epoch when resuming training.
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                Return:
         
     | 
| 70 | 
         
            +
                    `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
         
     | 
| 71 | 
         
            +
                """
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                def lr_lambda(current_step: int):
         
     | 
| 74 | 
         
            +
                    if current_step < num_warmup_steps:
         
     | 
| 75 | 
         
            +
                        return float(current_step) / float(max(1.0, num_warmup_steps))
         
     | 
| 76 | 
         
            +
                    return 1.0
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
            def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1):
         
     | 
| 82 | 
         
            +
                """
         
     | 
| 83 | 
         
            +
                Create a schedule with a constant learning rate, using the learning rate set in optimizer.
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                Args:
         
     | 
| 86 | 
         
            +
                    optimizer ([`~torch.optim.Optimizer`]):
         
     | 
| 87 | 
         
            +
                        The optimizer for which to schedule the learning rate.
         
     | 
| 88 | 
         
            +
                    step_rules (`string`):
         
     | 
| 89 | 
         
            +
                        The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
         
     | 
| 90 | 
         
            +
                        if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
         
     | 
| 91 | 
         
            +
                        steps and multiple 0.005 for the other steps.
         
     | 
| 92 | 
         
            +
                    last_epoch (`int`, *optional*, defaults to -1):
         
     | 
| 93 | 
         
            +
                        The index of the last epoch when resuming training.
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                Return:
         
     | 
| 96 | 
         
            +
                    `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
         
     | 
| 97 | 
         
            +
                """
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                rules_dict = {}
         
     | 
| 100 | 
         
            +
                rule_list = step_rules.split(",")
         
     | 
| 101 | 
         
            +
                for rule_str in rule_list[:-1]:
         
     | 
| 102 | 
         
            +
                    value_str, steps_str = rule_str.split(":")
         
     | 
| 103 | 
         
            +
                    steps = int(steps_str)
         
     | 
| 104 | 
         
            +
                    value = float(value_str)
         
     | 
| 105 | 
         
            +
                    rules_dict[steps] = value
         
     | 
| 106 | 
         
            +
                last_lr_multiple = float(rule_list[-1])
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                def create_rules_function(rules_dict, last_lr_multiple):
         
     | 
| 109 | 
         
            +
                    def rule_func(steps: int) -> float:
         
     | 
| 110 | 
         
            +
                        sorted_steps = sorted(rules_dict.keys())
         
     | 
| 111 | 
         
            +
                        for i, sorted_step in enumerate(sorted_steps):
         
     | 
| 112 | 
         
            +
                            if steps < sorted_step:
         
     | 
| 113 | 
         
            +
                                return rules_dict[sorted_steps[i]]
         
     | 
| 114 | 
         
            +
                        return last_lr_multiple
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                    return rule_func
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                rules_func = create_rules_function(rules_dict, last_lr_multiple)
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                return LambdaLR(optimizer, rules_func, last_epoch=last_epoch)
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
            def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
         
     | 
| 124 | 
         
            +
                """
         
     | 
| 125 | 
         
            +
                Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
         
     | 
| 126 | 
         
            +
                a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                Args:
         
     | 
| 129 | 
         
            +
                    optimizer ([`~torch.optim.Optimizer`]):
         
     | 
| 130 | 
         
            +
                        The optimizer for which to schedule the learning rate.
         
     | 
| 131 | 
         
            +
                    num_warmup_steps (`int`):
         
     | 
| 132 | 
         
            +
                        The number of steps for the warmup phase.
         
     | 
| 133 | 
         
            +
                    num_training_steps (`int`):
         
     | 
| 134 | 
         
            +
                        The total number of training steps.
         
     | 
| 135 | 
         
            +
                    last_epoch (`int`, *optional*, defaults to -1):
         
     | 
| 136 | 
         
            +
                        The index of the last epoch when resuming training.
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                Return:
         
     | 
| 139 | 
         
            +
                    `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
         
     | 
| 140 | 
         
            +
                """
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                def lr_lambda(current_step: int):
         
     | 
| 143 | 
         
            +
                    if current_step < num_warmup_steps:
         
     | 
| 144 | 
         
            +
                        return float(current_step) / float(max(1, num_warmup_steps))
         
     | 
| 145 | 
         
            +
                    return max(
         
     | 
| 146 | 
         
            +
                        0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
         
     | 
| 147 | 
         
            +
                    )
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                return LambdaLR(optimizer, lr_lambda, last_epoch)
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
            def get_cosine_schedule_with_warmup(
         
     | 
| 153 | 
         
            +
                optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
         
     | 
| 154 | 
         
            +
            ):
         
     | 
| 155 | 
         
            +
                """
         
     | 
| 156 | 
         
            +
                Create a schedule with a learning rate that decreases following the values of the cosine function between the
         
     | 
| 157 | 
         
            +
                initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
         
     | 
| 158 | 
         
            +
                initial lr set in the optimizer.
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                Args:
         
     | 
| 161 | 
         
            +
                    optimizer ([`~torch.optim.Optimizer`]):
         
     | 
| 162 | 
         
            +
                        The optimizer for which to schedule the learning rate.
         
     | 
| 163 | 
         
            +
                    num_warmup_steps (`int`):
         
     | 
| 164 | 
         
            +
                        The number of steps for the warmup phase.
         
     | 
| 165 | 
         
            +
                    num_training_steps (`int`):
         
     | 
| 166 | 
         
            +
                        The total number of training steps.
         
     | 
| 167 | 
         
            +
                    num_periods (`float`, *optional*, defaults to 0.5):
         
     | 
| 168 | 
         
            +
                        The number of periods of the cosine function in a schedule (the default is to just decrease from the max
         
     | 
| 169 | 
         
            +
                        value to 0 following a half-cosine).
         
     | 
| 170 | 
         
            +
                    last_epoch (`int`, *optional*, defaults to -1):
         
     | 
| 171 | 
         
            +
                        The index of the last epoch when resuming training.
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                Return:
         
     | 
| 174 | 
         
            +
                    `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
         
     | 
| 175 | 
         
            +
                """
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                def lr_lambda(current_step):
         
     | 
| 178 | 
         
            +
                    if current_step < num_warmup_steps:
         
     | 
| 179 | 
         
            +
                        return float(current_step) / float(max(1, num_warmup_steps))
         
     | 
| 180 | 
         
            +
                    progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
         
     | 
| 181 | 
         
            +
                    return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                return LambdaLR(optimizer, lr_lambda, last_epoch)
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
            def get_cosine_with_hard_restarts_schedule_with_warmup(
         
     | 
| 187 | 
         
            +
                optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
         
     | 
| 188 | 
         
            +
            ):
         
     | 
| 189 | 
         
            +
                """
         
     | 
| 190 | 
         
            +
                Create a schedule with a learning rate that decreases following the values of the cosine function between the
         
     | 
| 191 | 
         
            +
                initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
         
     | 
| 192 | 
         
            +
                linearly between 0 and the initial lr set in the optimizer.
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                Args:
         
     | 
| 195 | 
         
            +
                    optimizer ([`~torch.optim.Optimizer`]):
         
     | 
| 196 | 
         
            +
                        The optimizer for which to schedule the learning rate.
         
     | 
| 197 | 
         
            +
                    num_warmup_steps (`int`):
         
     | 
| 198 | 
         
            +
                        The number of steps for the warmup phase.
         
     | 
| 199 | 
         
            +
                    num_training_steps (`int`):
         
     | 
| 200 | 
         
            +
                        The total number of training steps.
         
     | 
| 201 | 
         
            +
                    num_cycles (`int`, *optional*, defaults to 1):
         
     | 
| 202 | 
         
            +
                        The number of hard restarts to use.
         
     | 
| 203 | 
         
            +
                    last_epoch (`int`, *optional*, defaults to -1):
         
     | 
| 204 | 
         
            +
                        The index of the last epoch when resuming training.
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                Return:
         
     | 
| 207 | 
         
            +
                    `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
         
     | 
| 208 | 
         
            +
                """
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
                def lr_lambda(current_step):
         
     | 
| 211 | 
         
            +
                    if current_step < num_warmup_steps:
         
     | 
| 212 | 
         
            +
                        return float(current_step) / float(max(1, num_warmup_steps))
         
     | 
| 213 | 
         
            +
                    progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
         
     | 
| 214 | 
         
            +
                    if progress >= 1.0:
         
     | 
| 215 | 
         
            +
                        return 0.0
         
     | 
| 216 | 
         
            +
                    return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
                return LambdaLR(optimizer, lr_lambda, last_epoch)
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
            def get_polynomial_decay_schedule_with_warmup(
         
     | 
| 222 | 
         
            +
                optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
         
     | 
| 223 | 
         
            +
            ):
         
     | 
| 224 | 
         
            +
                """
         
     | 
| 225 | 
         
            +
                Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
         
     | 
| 226 | 
         
            +
                optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
         
     | 
| 227 | 
         
            +
                initial lr set in the optimizer.
         
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
                Args:
         
     | 
| 230 | 
         
            +
                    optimizer ([`~torch.optim.Optimizer`]):
         
     | 
| 231 | 
         
            +
                        The optimizer for which to schedule the learning rate.
         
     | 
| 232 | 
         
            +
                    num_warmup_steps (`int`):
         
     | 
| 233 | 
         
            +
                        The number of steps for the warmup phase.
         
     | 
| 234 | 
         
            +
                    num_training_steps (`int`):
         
     | 
| 235 | 
         
            +
                        The total number of training steps.
         
     | 
| 236 | 
         
            +
                    lr_end (`float`, *optional*, defaults to 1e-7):
         
     | 
| 237 | 
         
            +
                        The end LR.
         
     | 
| 238 | 
         
            +
                    power (`float`, *optional*, defaults to 1.0):
         
     | 
| 239 | 
         
            +
                        Power factor.
         
     | 
| 240 | 
         
            +
                    last_epoch (`int`, *optional*, defaults to -1):
         
     | 
| 241 | 
         
            +
                        The index of the last epoch when resuming training.
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
         
     | 
| 244 | 
         
            +
                implementation at
         
     | 
| 245 | 
         
            +
                https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                Return:
         
     | 
| 248 | 
         
            +
                    `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
         
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
                """
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
                lr_init = optimizer.defaults["lr"]
         
     | 
| 253 | 
         
            +
                if not (lr_init > lr_end):
         
     | 
| 254 | 
         
            +
                    raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                def lr_lambda(current_step: int):
         
     | 
| 257 | 
         
            +
                    if current_step < num_warmup_steps:
         
     | 
| 258 | 
         
            +
                        return float(current_step) / float(max(1, num_warmup_steps))
         
     | 
| 259 | 
         
            +
                    elif current_step > num_training_steps:
         
     | 
| 260 | 
         
            +
                        return lr_end / lr_init  # as LambdaLR multiplies by lr_init
         
     | 
| 261 | 
         
            +
                    else:
         
     | 
| 262 | 
         
            +
                        lr_range = lr_init - lr_end
         
     | 
| 263 | 
         
            +
                        decay_steps = num_training_steps - num_warmup_steps
         
     | 
| 264 | 
         
            +
                        pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
         
     | 
| 265 | 
         
            +
                        decay = lr_range * pct_remaining**power + lr_end
         
     | 
| 266 | 
         
            +
                        return decay / lr_init  # as LambdaLR multiplies by lr_init
         
     | 
| 267 | 
         
            +
             
     | 
| 268 | 
         
            +
                return LambdaLR(optimizer, lr_lambda, last_epoch)
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
            TYPE_TO_SCHEDULER_FUNCTION = {
         
     | 
| 272 | 
         
            +
                SchedulerType.LINEAR: get_linear_schedule_with_warmup,
         
     | 
| 273 | 
         
            +
                SchedulerType.COSINE: get_cosine_schedule_with_warmup,
         
     | 
| 274 | 
         
            +
                SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
         
     | 
| 275 | 
         
            +
                SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
         
     | 
| 276 | 
         
            +
                SchedulerType.CONSTANT: get_constant_schedule,
         
     | 
| 277 | 
         
            +
                SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
         
     | 
| 278 | 
         
            +
                SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule,
         
     | 
| 279 | 
         
            +
            }
         
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
            def get_scheduler(
         
     | 
| 283 | 
         
            +
                name: Union[str, SchedulerType],
         
     | 
| 284 | 
         
            +
                optimizer: Optimizer,
         
     | 
| 285 | 
         
            +
                step_rules: Optional[str] = None,
         
     | 
| 286 | 
         
            +
                num_warmup_steps: Optional[int] = None,
         
     | 
| 287 | 
         
            +
                num_training_steps: Optional[int] = None,
         
     | 
| 288 | 
         
            +
                num_cycles: int = 1,
         
     | 
| 289 | 
         
            +
                power: float = 1.0,
         
     | 
| 290 | 
         
            +
                last_epoch: int = -1,
         
     | 
| 291 | 
         
            +
            ):
         
     | 
| 292 | 
         
            +
                """
         
     | 
| 293 | 
         
            +
                Unified API to get any scheduler from its name.
         
     | 
| 294 | 
         
            +
             
     | 
| 295 | 
         
            +
                Args:
         
     | 
| 296 | 
         
            +
                    name (`str` or `SchedulerType`):
         
     | 
| 297 | 
         
            +
                        The name of the scheduler to use.
         
     | 
| 298 | 
         
            +
                    optimizer (`torch.optim.Optimizer`):
         
     | 
| 299 | 
         
            +
                        The optimizer that will be used during training.
         
     | 
| 300 | 
         
            +
                    step_rules (`str`, *optional*):
         
     | 
| 301 | 
         
            +
                        A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler.
         
     | 
| 302 | 
         
            +
                    num_warmup_steps (`int`, *optional*):
         
     | 
| 303 | 
         
            +
                        The number of warmup steps to do. This is not required by all schedulers (hence the argument being
         
     | 
| 304 | 
         
            +
                        optional), the function will raise an error if it's unset and the scheduler type requires it.
         
     | 
| 305 | 
         
            +
                    num_training_steps (`int``, *optional*):
         
     | 
| 306 | 
         
            +
                        The number of training steps to do. This is not required by all schedulers (hence the argument being
         
     | 
| 307 | 
         
            +
                        optional), the function will raise an error if it's unset and the scheduler type requires it.
         
     | 
| 308 | 
         
            +
                    num_cycles (`int`, *optional*):
         
     | 
| 309 | 
         
            +
                        The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
         
     | 
| 310 | 
         
            +
                    power (`float`, *optional*, defaults to 1.0):
         
     | 
| 311 | 
         
            +
                        Power factor. See `POLYNOMIAL` scheduler
         
     | 
| 312 | 
         
            +
                    last_epoch (`int`, *optional*, defaults to -1):
         
     | 
| 313 | 
         
            +
                        The index of the last epoch when resuming training.
         
     | 
| 314 | 
         
            +
                """
         
     | 
| 315 | 
         
            +
                name = SchedulerType(name)
         
     | 
| 316 | 
         
            +
                schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
         
     | 
| 317 | 
         
            +
                if name == SchedulerType.CONSTANT:
         
     | 
| 318 | 
         
            +
                    return schedule_func(optimizer, last_epoch=last_epoch)
         
     | 
| 319 | 
         
            +
             
     | 
| 320 | 
         
            +
                if name == SchedulerType.PIECEWISE_CONSTANT:
         
     | 
| 321 | 
         
            +
                    return schedule_func(optimizer, step_rules=step_rules, last_epoch=last_epoch)
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
                # All other schedulers require `num_warmup_steps`
         
     | 
| 324 | 
         
            +
                if num_warmup_steps is None:
         
     | 
| 325 | 
         
            +
                    raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
         
     | 
| 326 | 
         
            +
             
     | 
| 327 | 
         
            +
                if name == SchedulerType.CONSTANT_WITH_WARMUP:
         
     | 
| 328 | 
         
            +
                    return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, last_epoch=last_epoch)
         
     | 
| 329 | 
         
            +
             
     | 
| 330 | 
         
            +
                # All other schedulers require `num_training_steps`
         
     | 
| 331 | 
         
            +
                if num_training_steps is None:
         
     | 
| 332 | 
         
            +
                    raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
         
     | 
| 333 | 
         
            +
             
     | 
| 334 | 
         
            +
                if name == SchedulerType.COSINE_WITH_RESTARTS:
         
     | 
| 335 | 
         
            +
                    return schedule_func(
         
     | 
| 336 | 
         
            +
                        optimizer,
         
     | 
| 337 | 
         
            +
                        num_warmup_steps=num_warmup_steps,
         
     | 
| 338 | 
         
            +
                        num_training_steps=num_training_steps,
         
     | 
| 339 | 
         
            +
                        num_cycles=num_cycles,
         
     | 
| 340 | 
         
            +
                        last_epoch=last_epoch,
         
     | 
| 341 | 
         
            +
                    )
         
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
                if name == SchedulerType.POLYNOMIAL:
         
     | 
| 344 | 
         
            +
                    return schedule_func(
         
     | 
| 345 | 
         
            +
                        optimizer,
         
     | 
| 346 | 
         
            +
                        num_warmup_steps=num_warmup_steps,
         
     | 
| 347 | 
         
            +
                        num_training_steps=num_training_steps,
         
     | 
| 348 | 
         
            +
                        power=power,
         
     | 
| 349 | 
         
            +
                        last_epoch=last_epoch,
         
     | 
| 350 | 
         
            +
                    )
         
     | 
| 351 | 
         
            +
             
     | 
| 352 | 
         
            +
                return schedule_func(
         
     | 
| 353 | 
         
            +
                    optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, last_epoch=last_epoch
         
     | 
| 354 | 
         
            +
                )
         
     | 
    	
        6DoF/diffusers/pipeline_utils.py
    ADDED
    
    | 
         @@ -0,0 +1,29 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 4 | 
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 5 | 
         
            +
            # You may obtain a copy of the License at
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 10 | 
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 11 | 
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            # limitations under the License.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            # NOTE: This file is deprecated and will be removed in a future version.
         
     | 
| 17 | 
         
            +
            # It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            from .pipelines import DiffusionPipeline, ImagePipelineOutput  # noqa: F401
         
     | 
| 20 | 
         
            +
            from .utils import deprecate
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            deprecate(
         
     | 
| 24 | 
         
            +
                "pipelines_utils",
         
     | 
| 25 | 
         
            +
                "0.22.0",
         
     | 
| 26 | 
         
            +
                "Importing `DiffusionPipeline` or `ImagePipelineOutput` from diffusers.pipeline_utils is deprecated. Please import from diffusers.pipelines.pipeline_utils instead.",
         
     | 
| 27 | 
         
            +
                standard_warn=False,
         
     | 
| 28 | 
         
            +
                stacklevel=3,
         
     | 
| 29 | 
         
            +
            )
         
     |