Spaces:
Running
on
Zero
Running
on
Zero
from dataclasses import dataclass | |
import torch | |
import numpy as np | |
class Edit: | |
old: str | |
new: str | |
weight: float = 1.0 | |
class Insert: | |
text: str | |
weight: float = 1.0 | |
def old(self): | |
return "" | |
def new(self): | |
return self.text | |
class Delete: | |
text: str | |
weight: float = 1.0 | |
def old(self): | |
return self.text | |
def new(self): | |
return "" | |
class Text: | |
text: str | |
weight: float = 1.0 | |
def old(self): | |
return self.text | |
def new(self): | |
return self.text | |
def get_text_embedding(prompt, tokenizer, text_encoder): | |
text_input_ids = tokenizer( | |
prompt, | |
padding="max_length", | |
truncation=True, | |
max_length=tokenizer.model_max_length, | |
return_tensors="pt", | |
).input_ids | |
text_embeddings = text_encoder(text_input_ids.to(text_encoder.device))[0] | |
return text_embeddings | |
def encode_text(text_pieces, tokenizer, text_encoder): | |
n_old_tokens = 0 | |
n_new_tokens = 0 | |
new_id_to_old_id = [] | |
weights = [] | |
for piece in text_pieces: | |
old, new = piece.old, piece.new | |
old_tokens = tokenizer.tokenize(old) | |
new_tokens = tokenizer.tokenize(new) | |
if len(old_tokens) == 0 and len(new_tokens) == 0: | |
continue | |
elif old == new: | |
n_old_tokens += len(old_tokens) | |
n_new_tokens += len(new_tokens) | |
new_id_to_old_id.extend(range(n_old_tokens - len(old_tokens), n_old_tokens)) | |
elif len(old_tokens) == 0: | |
# insert | |
new_id_to_old_id.extend([-1] * len(new_tokens)) | |
n_new_tokens += len(new_tokens) | |
elif len(new_tokens) == 0: | |
# delete | |
n_old_tokens += len(old_tokens) | |
else: | |
# replace | |
n_old_tokens += len(old_tokens) | |
n_new_tokens += len(new_tokens) | |
start = n_old_tokens - len(old_tokens) | |
end = n_old_tokens | |
ids = np.linspace(start, end, len(new_tokens), endpoint=False).astype(int) | |
new_id_to_old_id.extend(list(ids)) | |
weights.extend([piece.weight] * len(new_tokens)) | |
old_prompt = " ".join([piece.old for piece in text_pieces]) | |
new_prompt = " ".join([piece.new for piece in text_pieces]) | |
old_text_input_ids = tokenizer( | |
old_prompt, | |
padding="max_length", | |
truncation=True, | |
max_length=tokenizer.model_max_length, | |
return_tensors="pt", | |
).input_ids | |
new_text_input_ids = tokenizer( | |
new_prompt, | |
padding="max_length", | |
truncation=True, | |
max_length=tokenizer.model_max_length, | |
return_tensors="pt", | |
).input_ids | |
old_text_embeddings = text_encoder(old_text_input_ids.to(text_encoder.device))[0] | |
new_text_embeddings = text_encoder(new_text_input_ids.to(text_encoder.device))[0] | |
value = new_text_embeddings.clone() # batch (1), seq, dim | |
key = new_text_embeddings.clone() | |
for i, (j, weight) in enumerate(zip(new_id_to_old_id, weights)): | |
if 0 <= j < old_text_embeddings.shape[1]: | |
key[0, i] = old_text_embeddings[0, j] | |
value[0, i] *= weight | |
return key, value | |
def get_text_embedding_openclip(prompt, text_encoder, device='cuda'): | |
import open_clip | |
text_input_ids = open_clip.tokenize(prompt) | |
text_embeddings = text_encoder(text_input_ids.to(device)) | |
return text_embeddings | |
def encode_text_openclip(text_pieces, text_encoder, device='cuda'): | |
import open_clip | |
n_old_tokens = 0 | |
n_new_tokens = 0 | |
new_id_to_old_id = [] | |
weights = [] | |
for piece in text_pieces: | |
old, new = piece.old, piece.new | |
old_tokens = open_clip.tokenize(old) | |
new_tokens = open_clip.tokenize(new) | |
if len(old_tokens) == 0 and len(new_tokens) == 0: | |
continue | |
elif old == new: | |
n_old_tokens += len(old_tokens) | |
n_new_tokens += len(new_tokens) | |
new_id_to_old_id.extend(range(n_old_tokens - len(old_tokens), n_old_tokens)) | |
elif len(old_tokens) == 0: | |
# insert | |
new_id_to_old_id.extend([-1] * len(new_tokens)) | |
n_new_tokens += len(new_tokens) | |
elif len(new_tokens) == 0: | |
# delete | |
n_old_tokens += len(old_tokens) | |
else: | |
# replace | |
n_old_tokens += len(old_tokens) | |
n_new_tokens += len(new_tokens) | |
start = n_old_tokens - len(old_tokens) | |
end = n_old_tokens | |
ids = np.linspace(start, end, len(new_tokens), endpoint=False).astype(int) | |
new_id_to_old_id.extend(list(ids)) | |
weights.extend([piece.weight] * len(new_tokens)) | |
old_prompt = " ".join([piece.old for piece in text_pieces]) | |
new_prompt = " ".join([piece.new for piece in text_pieces]) | |
old_text_input_ids = open_clip.tokenize(old_prompt) | |
new_text_input_ids = open_clip.tokenize(new_prompt) | |
old_text_embeddings = text_encoder(old_text_input_ids.to(device)) | |
new_text_embeddings = text_encoder(new_text_input_ids.to(device)) | |
value = new_text_embeddings.clone() # batch (1), seq, dim | |
key = new_text_embeddings.clone() | |
for i, (j, weight) in enumerate(zip(new_id_to_old_id, weights)): | |
if 0 <= j < old_text_embeddings.shape[1]: | |
key[0, i] = old_text_embeddings[0, j] | |
value[0, i] *= weight | |
return key, value |