File size: 4,345 Bytes
2e5e07d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
# import argparse
# from omegaconf import OmegaConf
# from models import get_models
# import sys
# import os
# from PIL import Image
# from copy import deepcopy


def tca_transform_model(model):
    for down_block in model.down_blocks:
        try:
            for attention in down_block.attentions:
                attention.transformer_blocks[0].tca_transform()
                attention.transformer_blocks[0].tca_transform()
        except:
            continue
    for attention in model.mid_block.attentions:
        attention.transformer_blocks[0].tca_transform()
        attention.transformer_blocks[0].tca_transform()
    for up_block in model.up_blocks:
        try:
            for attention in up_block.attentions:
                attention.transformer_blocks[0].tca_transform()
                attention.transformer_blocks[0].tca_transform()
        except:
            continue
    return model


class ImageProjModel(torch.nn.Module):
    """Projection Model"""
    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
        super().__init__()
        
        self.cross_attention_dim = cross_attention_dim
        self.clip_extra_context_tokens = clip_extra_context_tokens
        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
        self.norm = torch.nn.LayerNorm(cross_attention_dim)
        
    def forward(self, image_embeds):
        embeds = image_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens
    

def ip_transform_model(model):
    model.image_proj_model = ImageProjModel(cross_attention_dim=768, clip_embeddings_dim=1024,
                                            clip_extra_context_tokens=4).to(model.device)
    for down_block in model.down_blocks:
        try:
            for attention in down_block.attentions:
                attention.transformer_blocks[0].attn2.ip_transform()
                attention.transformer_blocks[0].attn2.ip_transform()
        except:
            continue
    for attention in model.mid_block.attentions:
        attention.transformer_blocks[0].attn2.ip_transform()
        attention.transformer_blocks[0].attn2.ip_transform()
    for up_block in model.up_blocks:
        try:
            for attention in up_block.attentions:
                attention.transformer_blocks[0].attn2.ip_transform()
                attention.transformer_blocks[0].attn2.ip_transform()
        except:
            continue
    return model


def ip_scale_set(model, scale):
    for down_block in model.down_blocks:
        try:
            for attention in down_block.attentions:
                attention.transformer_blocks[0].attn2.set_scale(scale)
                attention.transformer_blocks[0].attn2.set_scale(scale)
        except:
            continue
    for attention in model.mid_block.attentions:
        attention.transformer_blocks[0].attn2.set_scale(scale)
        attention.transformer_blocks[0].attn2.set_scale(scale)
    for up_block in model.up_blocks:
        try:
            for attention in up_block.attentions:
                attention.transformer_blocks[0].attn2.set_scale(scale)
                attention.transformer_blocks[0].attn2.set_scale(scale)
        except:
            continue
    return model


def ip_train_set(model):
    model.requires_grad_(False)
    model.image_proj_model.requires_grad_(True)
    for down_block in model.down_blocks:
        try:
            for attention in down_block.attentions:
                attention.transformer_blocks[0].attn2.ip_train_set()
                attention.transformer_blocks[0].attn2.ip_train_set()
        except:
            continue
    for attention in model.mid_block.attentions:
        attention.transformer_blocks[0].attn2.ip_train_set()
        attention.transformer_blocks[0].attn2.ip_train_set()
    for up_block in model.up_blocks:
        try:
            for attention in up_block.attentions:
                attention.transformer_blocks[0].attn2.ip_train_set()
                attention.transformer_blocks[0].attn2.ip_train_set()
        except:
            continue
    return model