Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	File size: 2,666 Bytes
			
			ef9fd1f  | 
								1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72  | 
								from modules import sd_hijack_clip, sd_hijack, shared
from modules.sd_hijack import StableDiffusionModelHijack, EmbeddingsWithFixes, apply_optimizations
try:
    from modules.sd_hijack import fix_checkpoint
    def clear_any_hijacks():
        StableDiffusionModelHijack.hijack = default_hijack
except (ModuleNotFoundError, ImportError):
    from modules.sd_hijack_checkpoint import add, remove
    def fix_checkpoint():
        add()
    def clear_any_hijacks():
        remove()
        StableDiffusionModelHijack.hijack = default_hijack
import ldm.modules.encoders.modules
default_hijack = StableDiffusionModelHijack.hijack
def trigger_sd_hijack(enabled, pretrained_key):
    clear_any_hijacks()
    if not enabled or pretrained_key == '':
        pretrained_key = 'openai/clip-vit-large-patch14'
    StableDiffusionModelHijack.hijack = create_lambda(pretrained_key)
    print("Hijacked clip text model!")
    sd_hijack.model_hijack.undo_hijack(shared.sd_model)
    sd_hijack.model_hijack.hijack(shared.sd_model)
    if not enabled:
        StableDiffusionModelHijack.hijack = default_hijack
def create_lambda(model):
    def hijack_lambda(self, m):
        if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
            from transformers import CLIPTextModel, CLIPTokenizer
            print(f"Changing CLIP model to {model}")
            try:
                m.cond_stage_model.transformer = CLIPTextModel.from_pretrained(
                    model).to(m.cond_stage_model.transformer.device)
                m.cond_stage_model.transformer.requires_grad_(False)
                m.cond_stage_model.tokenizer = CLIPTokenizer.from_pretrained(
                    model)
            except:
                print(f"Cannot initiate from given model key {model}!")
            model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
            model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
            m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
            self.optimization_method = apply_optimizations()
            self.clip = m.cond_stage_model
            fix_checkpoint()
            def flatten(el):
                flattened = [flatten(children) for children in el.children()]
                res = [el]
                for c in flattened:
                    res += c
                return res
            self.layers = flatten(m)
        else:
            print("CLIP change can be only applied to FrozenCLIPEmbedder class")
            return default_hijack(self, m)
    return hijack_lambda
 |