Spaces:
Runtime error
Runtime error
File size: 2,666 Bytes
ef9fd1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
from modules import sd_hijack_clip, sd_hijack, shared
from modules.sd_hijack import StableDiffusionModelHijack, EmbeddingsWithFixes, apply_optimizations
try:
from modules.sd_hijack import fix_checkpoint
def clear_any_hijacks():
StableDiffusionModelHijack.hijack = default_hijack
except (ModuleNotFoundError, ImportError):
from modules.sd_hijack_checkpoint import add, remove
def fix_checkpoint():
add()
def clear_any_hijacks():
remove()
StableDiffusionModelHijack.hijack = default_hijack
import ldm.modules.encoders.modules
default_hijack = StableDiffusionModelHijack.hijack
def trigger_sd_hijack(enabled, pretrained_key):
clear_any_hijacks()
if not enabled or pretrained_key == '':
pretrained_key = 'openai/clip-vit-large-patch14'
StableDiffusionModelHijack.hijack = create_lambda(pretrained_key)
print("Hijacked clip text model!")
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
sd_hijack.model_hijack.hijack(shared.sd_model)
if not enabled:
StableDiffusionModelHijack.hijack = default_hijack
def create_lambda(model):
def hijack_lambda(self, m):
if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
from transformers import CLIPTextModel, CLIPTokenizer
print(f"Changing CLIP model to {model}")
try:
m.cond_stage_model.transformer = CLIPTextModel.from_pretrained(
model).to(m.cond_stage_model.transformer.device)
m.cond_stage_model.transformer.requires_grad_(False)
m.cond_stage_model.tokenizer = CLIPTokenizer.from_pretrained(
model)
except:
print(f"Cannot initiate from given model key {model}!")
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
self.optimization_method = apply_optimizations()
self.clip = m.cond_stage_model
fix_checkpoint()
def flatten(el):
flattened = [flatten(children) for children in el.children()]
res = [el]
for c in flattened:
res += c
return res
self.layers = flatten(m)
else:
print("CLIP change can be only applied to FrozenCLIPEmbedder class")
return default_hijack(self, m)
return hijack_lambda
|