from dataclasses import dataclass from itertools import product import re from typing import Union, List, Tuple import numpy as np import open_clip from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWordsBase as CLIP from modules import prompt_parser, shared from scripts.cutofflib.utils import log class ClipWrapper: def __init__(self, te: CLIP): self.te = te self.v1 = hasattr(te.wrapped, 'tokenizer') self.t = ( te.wrapped.tokenizer if self.v1 else open_clip.tokenizer._tokenizer ) def token_to_id(self, token: str) -> int: if self.v1: return self.t._convert_token_to_id(token) # type: ignore else: return self.t.encoder[token] def id_to_token(self, id: int) -> str: if self.v1: return self.t.convert_ids_to_tokens(id) # type: ignore else: return self.t.decoder[id] def ids_to_tokens(self, ids: List[int]) -> List[str]: if self.v1: return self.t.convert_ids_to_tokens(ids) # type: ignore else: return [self.t.decoder[id] for id in ids] def token(self, token: Union[int,str]): if isinstance(token, int): return Token(token, self.id_to_token(token)) else: return Token(self.token_to_id(token), token) @dataclass class Token: id: int token: str class CutoffPrompt: @staticmethod def _cutoff(prompt: str, clip: CLIP, tokens: List[str], padding: str): def token_count(text: str): tt = token_to_block(clip, text) # tt[0] == clip.id_start (<|startoftext|>) for index, (t, _) in enumerate(tt): if t.id == clip.id_end: # <|endoftext|> return index - 1 return 0 # must not happen... re_targets = [ re.compile(r'\b' + re.escape(x) + r'\b') for x in tokens ] replacer = [ ' ' + ' '.join([padding] * token_count(x)) + ' ' for x in tokens ] rows: List[Tuple[str,str]] = [] for block in prompt.split(','): b0 = block for r, p in zip(re_targets, replacer): block = r.sub(p, block) b1 = block rows.append((b0, b1)) return rows def __init__(self, prompt: str, clip: CLIP, tokens: List[str], padding: str): self.prompt = prompt rows = CutoffPrompt._cutoff(prompt, clip, tokens, padding) self.base = np.array([x[0] for x in rows]) self.cut = np.array([x[1] for x in rows]) self.sw = np.array([False] * len(rows)) @property def block_count(self): return self.base.shape[0] def switch(self, block_index: int, to: Union[bool,None] = None): if to is None: to = not self.sw[block_index] self.sw[block_index] = to return to def text(self, sw=None): if sw is None: sw = self.sw blocks = np.where(sw, self.cut, self.base) return ','.join(blocks) def active_blocks(self) -> np.ndarray: indices, = (self.base != self.cut).nonzero() return indices def generate(self): indices = self.active_blocks() for diff_sw in product([False, True], repeat=indices.shape[0]): sw = np.full_like(self.sw, False) sw[indices] = diff_sw yield diff_sw, self.text(sw) def generate_prompts( clip: CLIP, prompt: str, targets: List[str], padding: Union[str,int,Token], ) -> CutoffPrompt: te = ClipWrapper(clip) if not isinstance(padding, Token): o_pad = padding padding = te.token(padding) if padding.id == clip.id_end: raise ValueError(f'`{o_pad}` is not a valid token.') result = CutoffPrompt(prompt, clip, targets, padding.token.replace('', '')) log(f'[Cutoff] replace: {", ".join(targets)}') log(f'[Cutoff] to: {padding.token} ({padding.id})') log(f'[Cutoff] original: {prompt}') for i, (_, pp) in enumerate(result.generate()): log(f'[Cutoff] #{i}: {pp}') return result def token_to_block(clip: CLIP, prompt: str): te = ClipWrapper(clip) # cf. sd_hijack_clip.py parsed = prompt_parser.parse_prompt_attention(prompt) tokenized: List[List[int]] = clip.tokenize([text for text, _ in parsed]) CHUNK_LENGTH = 75 id_start = te.token(clip.id_start) # type: ignore id_end = te.token(clip.id_end) # type: ignore comma = te.token(',') last_comma = -1 current_block = 0 current_tokens: List[Tuple[Token,int]] = [] result: List[Tuple[Token,int]] = [] def next_chunk(): nonlocal current_tokens, last_comma to_add = CHUNK_LENGTH - len(current_tokens) if 0 < to_add: current_tokens += [(id_end, -1)] * to_add current_tokens = [(id_start, -1)] + current_tokens + [(id_end, -1)] last_comma = -1 result.extend(current_tokens) current_tokens = [] for tokens, (text, weight) in zip(tokenized, parsed): if text == 'BREAK' and weight == -1: next_chunk() continue p = 0 while p < len(tokens): token = tokens[p] if token == comma.id: last_comma = len(current_tokens) current_block += 1 elif ( shared.opts.comma_padding_backtrack != 0 and len(current_tokens) == CHUNK_LENGTH and last_comma != -1 and len(current_tokens) - last_comma <= shared.opts.comma_padding_backtrack ): break_location = last_comma + 1 reloc_tokens = current_tokens[break_location:] current_tokens = current_tokens[:break_location] next_chunk() current_tokens = reloc_tokens if len(current_tokens) == CHUNK_LENGTH: next_chunk() embedding, embedding_length_in_tokens = clip.hijack.embedding_db.find_embedding_at_position(tokens, p) if embedding is None: if token == comma.id: current_tokens.append((te.token(token), -1)) else: current_tokens.append((te.token(token), current_block)) p += 1 continue emb_len = int(embedding.vec.shape[0]) if len(current_tokens) + emb_len > CHUNK_LENGTH: next_chunk() current_tokens += [(te.token(0), current_block)] * emb_len p += embedding_length_in_tokens if len(current_tokens) > 0: next_chunk() return result