RelightVid / misc_utils /ptp_utils.py
aleafy's picture
Start fresh
0a63786
from dataclasses import dataclass
import torch
import numpy as np
@dataclass
class Edit:
old: str
new: str
weight: float = 1.0
@dataclass
class Insert:
text: str
weight: float = 1.0
@property
def old(self):
return ""
@property
def new(self):
return self.text
@dataclass
class Delete:
text: str
weight: float = 1.0
@property
def old(self):
return self.text
@property
def new(self):
return ""
@dataclass
class Text:
text: str
weight: float = 1.0
@property
def old(self):
return self.text
@property
def new(self):
return self.text
@torch.inference_mode()
def get_text_embedding(prompt, tokenizer, text_encoder):
text_input_ids = tokenizer(
prompt,
padding="max_length",
truncation=True,
max_length=tokenizer.model_max_length,
return_tensors="pt",
).input_ids
text_embeddings = text_encoder(text_input_ids.to(text_encoder.device))[0]
return text_embeddings
@torch.inference_mode()
def encode_text(text_pieces, tokenizer, text_encoder):
n_old_tokens = 0
n_new_tokens = 0
new_id_to_old_id = []
weights = []
for piece in text_pieces:
old, new = piece.old, piece.new
old_tokens = tokenizer.tokenize(old)
new_tokens = tokenizer.tokenize(new)
if len(old_tokens) == 0 and len(new_tokens) == 0:
continue
elif old == new:
n_old_tokens += len(old_tokens)
n_new_tokens += len(new_tokens)
new_id_to_old_id.extend(range(n_old_tokens - len(old_tokens), n_old_tokens))
elif len(old_tokens) == 0:
# insert
new_id_to_old_id.extend([-1] * len(new_tokens))
n_new_tokens += len(new_tokens)
elif len(new_tokens) == 0:
# delete
n_old_tokens += len(old_tokens)
else:
# replace
n_old_tokens += len(old_tokens)
n_new_tokens += len(new_tokens)
start = n_old_tokens - len(old_tokens)
end = n_old_tokens
ids = np.linspace(start, end, len(new_tokens), endpoint=False).astype(int)
new_id_to_old_id.extend(list(ids))
weights.extend([piece.weight] * len(new_tokens))
old_prompt = " ".join([piece.old for piece in text_pieces])
new_prompt = " ".join([piece.new for piece in text_pieces])
old_text_input_ids = tokenizer(
old_prompt,
padding="max_length",
truncation=True,
max_length=tokenizer.model_max_length,
return_tensors="pt",
).input_ids
new_text_input_ids = tokenizer(
new_prompt,
padding="max_length",
truncation=True,
max_length=tokenizer.model_max_length,
return_tensors="pt",
).input_ids
old_text_embeddings = text_encoder(old_text_input_ids.to(text_encoder.device))[0]
new_text_embeddings = text_encoder(new_text_input_ids.to(text_encoder.device))[0]
value = new_text_embeddings.clone() # batch (1), seq, dim
key = new_text_embeddings.clone()
for i, (j, weight) in enumerate(zip(new_id_to_old_id, weights)):
if 0 <= j < old_text_embeddings.shape[1]:
key[0, i] = old_text_embeddings[0, j]
value[0, i] *= weight
return key, value
@torch.inference_mode()
def get_text_embedding_openclip(prompt, text_encoder, device='cuda'):
import open_clip
text_input_ids = open_clip.tokenize(prompt)
text_embeddings = text_encoder(text_input_ids.to(device))
return text_embeddings
@torch.inference_mode()
def encode_text_openclip(text_pieces, text_encoder, device='cuda'):
import open_clip
n_old_tokens = 0
n_new_tokens = 0
new_id_to_old_id = []
weights = []
for piece in text_pieces:
old, new = piece.old, piece.new
old_tokens = open_clip.tokenize(old)
new_tokens = open_clip.tokenize(new)
if len(old_tokens) == 0 and len(new_tokens) == 0:
continue
elif old == new:
n_old_tokens += len(old_tokens)
n_new_tokens += len(new_tokens)
new_id_to_old_id.extend(range(n_old_tokens - len(old_tokens), n_old_tokens))
elif len(old_tokens) == 0:
# insert
new_id_to_old_id.extend([-1] * len(new_tokens))
n_new_tokens += len(new_tokens)
elif len(new_tokens) == 0:
# delete
n_old_tokens += len(old_tokens)
else:
# replace
n_old_tokens += len(old_tokens)
n_new_tokens += len(new_tokens)
start = n_old_tokens - len(old_tokens)
end = n_old_tokens
ids = np.linspace(start, end, len(new_tokens), endpoint=False).astype(int)
new_id_to_old_id.extend(list(ids))
weights.extend([piece.weight] * len(new_tokens))
old_prompt = " ".join([piece.old for piece in text_pieces])
new_prompt = " ".join([piece.new for piece in text_pieces])
old_text_input_ids = open_clip.tokenize(old_prompt)
new_text_input_ids = open_clip.tokenize(new_prompt)
old_text_embeddings = text_encoder(old_text_input_ids.to(device))
new_text_embeddings = text_encoder(new_text_input_ids.to(device))
value = new_text_embeddings.clone() # batch (1), seq, dim
key = new_text_embeddings.clone()
for i, (j, weight) in enumerate(zip(new_id_to_old_id, weights)):
if 0 <= j < old_text_embeddings.shape[1]:
key[0, i] = old_text_embeddings[0, j]
value[0, i] *= weight
return key, value