|
from dataclasses import dataclass |
|
from itertools import product |
|
import re |
|
from typing import Union, List, Tuple |
|
import numpy as np |
|
import open_clip |
|
from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWordsBase as CLIP |
|
from modules import prompt_parser, shared |
|
from scripts.cutofflib.utils import log |
|
|
|
class ClipWrapper: |
|
def __init__(self, te: CLIP): |
|
self.te = te |
|
self.v1 = hasattr(te.wrapped, 'tokenizer') |
|
self.t = ( |
|
te.wrapped.tokenizer if self.v1 |
|
else open_clip.tokenizer._tokenizer |
|
) |
|
|
|
def token_to_id(self, token: str) -> int: |
|
if self.v1: |
|
return self.t._convert_token_to_id(token) |
|
else: |
|
return self.t.encoder[token] |
|
|
|
def id_to_token(self, id: int) -> str: |
|
if self.v1: |
|
return self.t.convert_ids_to_tokens(id) |
|
else: |
|
return self.t.decoder[id] |
|
|
|
def ids_to_tokens(self, ids: List[int]) -> List[str]: |
|
if self.v1: |
|
return self.t.convert_ids_to_tokens(ids) |
|
else: |
|
return [self.t.decoder[id] for id in ids] |
|
|
|
def token(self, token: Union[int,str]): |
|
if isinstance(token, int): |
|
return Token(token, self.id_to_token(token)) |
|
else: |
|
return Token(self.token_to_id(token), token) |
|
|
|
|
|
@dataclass |
|
class Token: |
|
id: int |
|
token: str |
|
|
|
class CutoffPrompt: |
|
|
|
@staticmethod |
|
def _cutoff(prompt: str, clip: CLIP, tokens: List[str], padding: str): |
|
def token_count(text: str): |
|
tt = token_to_block(clip, text) |
|
|
|
for index, (t, _) in enumerate(tt): |
|
if t.id == clip.id_end: |
|
return index - 1 |
|
return 0 |
|
|
|
re_targets = [ re.compile(r'\b' + re.escape(x) + r'\b') for x in tokens ] |
|
replacer = [ ' ' + ' '.join([padding] * token_count(x)) + ' ' for x in tokens ] |
|
|
|
rows: List[Tuple[str,str]] = [] |
|
for block in prompt.split(','): |
|
b0 = block |
|
for r, p in zip(re_targets, replacer): |
|
block = r.sub(p, block) |
|
b1 = block |
|
rows.append((b0, b1)) |
|
|
|
return rows |
|
|
|
def __init__(self, prompt: str, clip: CLIP, tokens: List[str], padding: str): |
|
self.prompt = prompt |
|
rows = CutoffPrompt._cutoff(prompt, clip, tokens, padding) |
|
self.base = np.array([x[0] for x in rows]) |
|
self.cut = np.array([x[1] for x in rows]) |
|
self.sw = np.array([False] * len(rows)) |
|
|
|
@property |
|
def block_count(self): |
|
return self.base.shape[0] |
|
|
|
def switch(self, block_index: int, to: Union[bool,None] = None): |
|
if to is None: |
|
to = not self.sw[block_index] |
|
self.sw[block_index] = to |
|
return to |
|
|
|
def text(self, sw=None): |
|
if sw is None: |
|
sw = self.sw |
|
blocks = np.where(sw, self.cut, self.base) |
|
return ','.join(blocks) |
|
|
|
def active_blocks(self) -> np.ndarray: |
|
indices, = (self.base != self.cut).nonzero() |
|
return indices |
|
|
|
def generate(self): |
|
indices = self.active_blocks() |
|
for diff_sw in product([False, True], repeat=indices.shape[0]): |
|
sw = np.full_like(self.sw, False) |
|
sw[indices] = diff_sw |
|
yield diff_sw, self.text(sw) |
|
|
|
|
|
def generate_prompts( |
|
clip: CLIP, |
|
prompt: str, |
|
targets: List[str], |
|
padding: Union[str,int,Token], |
|
) -> CutoffPrompt: |
|
|
|
te = ClipWrapper(clip) |
|
|
|
if not isinstance(padding, Token): |
|
o_pad = padding |
|
padding = te.token(padding) |
|
if padding.id == clip.id_end: |
|
raise ValueError(f'`{o_pad}` is not a valid token.') |
|
|
|
result = CutoffPrompt(prompt, clip, targets, padding.token.replace('</w>', '')) |
|
|
|
log(f'[Cutoff] replace: {", ".join(targets)}') |
|
log(f'[Cutoff] to: {padding.token} ({padding.id})') |
|
log(f'[Cutoff] original: {prompt}') |
|
for i, (_, pp) in enumerate(result.generate()): |
|
log(f'[Cutoff] #{i}: {pp}') |
|
|
|
return result |
|
|
|
|
|
def token_to_block(clip: CLIP, prompt: str): |
|
te = ClipWrapper(clip) |
|
|
|
|
|
|
|
parsed = prompt_parser.parse_prompt_attention(prompt) |
|
tokenized: List[List[int]] = clip.tokenize([text for text, _ in parsed]) |
|
|
|
CHUNK_LENGTH = 75 |
|
id_start = te.token(clip.id_start) |
|
id_end = te.token(clip.id_end) |
|
comma = te.token(',</w>') |
|
|
|
last_comma = -1 |
|
current_block = 0 |
|
current_tokens: List[Tuple[Token,int]] = [] |
|
result: List[Tuple[Token,int]] = [] |
|
|
|
def next_chunk(): |
|
nonlocal current_tokens, last_comma |
|
|
|
to_add = CHUNK_LENGTH - len(current_tokens) |
|
if 0 < to_add: |
|
current_tokens += [(id_end, -1)] * to_add |
|
|
|
current_tokens = [(id_start, -1)] + current_tokens + [(id_end, -1)] |
|
|
|
last_comma = -1 |
|
result.extend(current_tokens) |
|
current_tokens = [] |
|
|
|
for tokens, (text, weight) in zip(tokenized, parsed): |
|
if text == 'BREAK' and weight == -1: |
|
next_chunk() |
|
continue |
|
|
|
p = 0 |
|
while p < len(tokens): |
|
token = tokens[p] |
|
|
|
if token == comma.id: |
|
last_comma = len(current_tokens) |
|
current_block += 1 |
|
|
|
elif ( |
|
shared.opts.comma_padding_backtrack != 0 |
|
and len(current_tokens) == CHUNK_LENGTH |
|
and last_comma != -1 |
|
and len(current_tokens) - last_comma <= shared.opts.comma_padding_backtrack |
|
): |
|
break_location = last_comma + 1 |
|
reloc_tokens = current_tokens[break_location:] |
|
current_tokens = current_tokens[:break_location] |
|
next_chunk() |
|
current_tokens = reloc_tokens |
|
|
|
if len(current_tokens) == CHUNK_LENGTH: |
|
next_chunk() |
|
|
|
embedding, embedding_length_in_tokens = clip.hijack.embedding_db.find_embedding_at_position(tokens, p) |
|
if embedding is None: |
|
if token == comma.id: |
|
current_tokens.append((te.token(token), -1)) |
|
else: |
|
current_tokens.append((te.token(token), current_block)) |
|
p += 1 |
|
continue |
|
|
|
emb_len = int(embedding.vec.shape[0]) |
|
if len(current_tokens) + emb_len > CHUNK_LENGTH: |
|
next_chunk() |
|
|
|
current_tokens += [(te.token(0), current_block)] * emb_len |
|
p += embedding_length_in_tokens |
|
|
|
if len(current_tokens) > 0: |
|
next_chunk() |
|
|
|
return result |
|
|