Spaces:
Runtime error
Runtime error
import torch | |
from typing import List | |
def _encode_prompt_with_t5( | |
text_encoder, | |
tokenizer, | |
max_sequence_length, | |
prompt=None, | |
num_images_per_prompt=1, | |
device=None, | |
): | |
prompt = [prompt] if isinstance(prompt, str) else prompt | |
batch_size = len(prompt) | |
text_inputs = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=max_sequence_length, | |
truncation=True, | |
add_special_tokens=True, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
prompt_embeds = text_encoder(text_input_ids.to(device))[0] | |
dtype = text_encoder.dtype | |
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) | |
_, seq_len, _ = prompt_embeds.shape | |
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method | |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) | |
return prompt_embeds | |
def _encode_prompt_with_clip( | |
text_encoder, | |
tokenizer, | |
prompt: str, | |
device=None, | |
text_input_ids=None, | |
num_images_per_prompt: int = 1, | |
): | |
prompt = [prompt] if isinstance(prompt, str) else prompt | |
batch_size = len(prompt) | |
if tokenizer is not None: | |
text_inputs = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=77, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
else: | |
if text_input_ids is None: | |
raise ValueError( | |
"text_input_ids must be provided when the tokenizer is not specified" | |
) | |
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) | |
pooled_prompt_embeds = prompt_embeds[0] | |
prompt_embeds = prompt_embeds.hidden_states[-2] | |
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) | |
_, seq_len, _ = prompt_embeds.shape | |
# duplicate text embeddings for each generation per prompt, using mps friendly method | |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) | |
return prompt_embeds, pooled_prompt_embeds | |
def encode_prompt( | |
text_encoders, | |
tokenizers, | |
prompt: str|List, | |
max_sequence_length, | |
device=None, | |
num_images_per_prompt: int = 1, | |
text_input_ids_list=None, | |
only_positive_t5=False, | |
): | |
prompt = [prompt] if isinstance(prompt, str) else prompt | |
clip_tokenizers = tokenizers[:2] | |
clip_text_encoders = text_encoders[:2] | |
clip_prompt_embeds_list = [] | |
clip_pooled_prompt_embeds_list = [] | |
for i, (tokenizer, text_encoder) in enumerate( | |
zip(clip_tokenizers, clip_text_encoders) | |
): | |
prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip( | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
prompt=prompt if not only_positive_t5 else [""] * len(prompt), | |
device=device if device is not None else text_encoder.device, | |
num_images_per_prompt=num_images_per_prompt, | |
text_input_ids=text_input_ids_list[i] if text_input_ids_list else None, | |
) | |
clip_prompt_embeds_list.append(prompt_embeds) | |
clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds) | |
clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1) | |
pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1) | |
t5_prompt_embed = _encode_prompt_with_t5( | |
text_encoders[-1], | |
tokenizers[-1], | |
max_sequence_length, | |
prompt=prompt, | |
num_images_per_prompt=num_images_per_prompt, | |
device=device if device is not None else text_encoders[-1].device, | |
) | |
clip_prompt_embeds = torch.nn.functional.pad( | |
clip_prompt_embeds, | |
(0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]), | |
) | |
t5_prompt_embed = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) | |
return clip_prompt_embeds, t5_prompt_embed, pooled_prompt_embeds | |