File size: 2,666 Bytes
ef9fd1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from modules import sd_hijack_clip, sd_hijack, shared
from modules.sd_hijack import StableDiffusionModelHijack, EmbeddingsWithFixes, apply_optimizations
try:
    from modules.sd_hijack import fix_checkpoint
    def clear_any_hijacks():
        StableDiffusionModelHijack.hijack = default_hijack
except (ModuleNotFoundError, ImportError):
    from modules.sd_hijack_checkpoint import add, remove
    def fix_checkpoint():
        add()

    def clear_any_hijacks():
        remove()
        StableDiffusionModelHijack.hijack = default_hijack


import ldm.modules.encoders.modules

default_hijack = StableDiffusionModelHijack.hijack

def trigger_sd_hijack(enabled, pretrained_key):
    clear_any_hijacks()
    if not enabled or pretrained_key == '':
        pretrained_key = 'openai/clip-vit-large-patch14'
    StableDiffusionModelHijack.hijack = create_lambda(pretrained_key)
    print("Hijacked clip text model!")
    sd_hijack.model_hijack.undo_hijack(shared.sd_model)
    sd_hijack.model_hijack.hijack(shared.sd_model)
    if not enabled:
        StableDiffusionModelHijack.hijack = default_hijack




def create_lambda(model):
    def hijack_lambda(self, m):
        if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
            from transformers import CLIPTextModel, CLIPTokenizer
            print(f"Changing CLIP model to {model}")
            try:
                m.cond_stage_model.transformer = CLIPTextModel.from_pretrained(
                    model).to(m.cond_stage_model.transformer.device)
                m.cond_stage_model.transformer.requires_grad_(False)
                m.cond_stage_model.tokenizer = CLIPTokenizer.from_pretrained(
                    model)
            except:
                print(f"Cannot initiate from given model key {model}!")

            model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
            model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
            m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)

            self.optimization_method = apply_optimizations()

            self.clip = m.cond_stage_model

            fix_checkpoint()


            def flatten(el):
                flattened = [flatten(children) for children in el.children()]
                res = [el]
                for c in flattened:
                    res += c
                return res

            self.layers = flatten(m)
        else:
            print("CLIP change can be only applied to FrozenCLIPEmbedder class")
            return default_hijack(self, m)
    return hijack_lambda