Spaces:
Running
on
Zero
Running
on
Zero
import io | |
import os | |
from pathlib import Path | |
from tempfile import TemporaryDirectory | |
import torch | |
import torchaudio | |
import random | |
import numpy as np | |
from PIL import Image | |
from urllib.parse import urlparse | |
from os.path import exists | |
import re | |
from num2words import num2words | |
import uuid | |
from typing import List, Optional, Dict, Union, Tuple, Iterable | |
from src.utils.image_utils import is_valid_image | |
IMAGENET_STANDARD_MEAN = [0.5, 0.5, 0.5] | |
IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5] | |
def is_local(url): | |
url_parsed = urlparse(url) | |
if url_parsed.scheme in ("file", ""): | |
return exists(url_parsed.path) | |
return False | |
def replace_numbers_with_words(sentence): | |
sentence = re.sub(r"(\d+)", r" \1 ", sentence) | |
def replace_with_words(match): | |
num = match.group(0) | |
try: | |
return num2words(num) | |
except: | |
return num | |
return re.sub(r"\b\d+\b", replace_with_words, sentence) | |
def save_to_buffer(audio_tensors, codec_audio_sr): | |
result = torch.cat(audio_tensors, 1) | |
buffer = io.BytesIO() | |
torchaudio.save(buffer, result, int(codec_audio_sr), format="wav") | |
buffer.seek(0) | |
return buffer.read() | |
def save_to_file(audio_tensors, codec_audio_sr): | |
generated_audio_dir = f"media/voicecraft/generated" | |
Path(generated_audio_dir).mkdir(parents=True, exist_ok=True) | |
filename = f"{generated_audio_dir}/{str(uuid.uuid4())}.wav" | |
tensors = torch.cat(audio_tensors, 1) | |
torchaudio.save(filename, tensors, int(codec_audio_sr), format="wav") | |
return filename | |
def split_line_to_sentences(line): | |
line = line.strip().capitalize() | |
line = line + "." if line and line[-1] not in (".", "!", "?") else line | |
sentences = re.findall(r"\w+.*?[.?!]", line.replace("\n", " "), flags=re.S) | |
return sentences | |
def seed_everything(seed=1): | |
os.environ["PYTHONHASHSEED"] = str(seed) | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.backends.cudnn.benchmark = False | |
torch.backends.cudnn.deterministic = True | |
def add_image_tokens_to_prompt(prefix_prompt, bos_token, image_seq_length, image_token): | |
return f"{image_token * image_seq_length}{bos_token}{prefix_prompt}\n" | |
def rescale( | |
image: np.ndarray, scale: float, dtype: np.dtype = np.float32 | |
) -> np.ndarray: | |
rescaled_image = image * scale | |
rescaled_image = rescaled_image.astype(dtype) | |
return rescaled_image | |
def resize( | |
image: Image, | |
size: Tuple[int, int], | |
resample: Image.Resampling = None, | |
reducing_gap: Optional[int] = None, | |
) -> np.ndarray: | |
height, width = size | |
resized_image = image.resize( | |
(width, height), resample=resample, reducing_gap=reducing_gap | |
) | |
return resized_image | |
def normalize( | |
image: np.ndarray, | |
mean: Union[float, Iterable[float]], | |
std: Union[float, Iterable[float]], | |
) -> np.ndarray: | |
mean = np.array(mean, dtype=image.dtype) | |
std = np.array(std, dtype=image.dtype) | |
image = (image - mean) / std | |
return image | |
def process_images( | |
images: List[Image.Image], | |
size: Dict[str, int] = None, | |
resample: Image.Resampling = None, | |
rescale_factor: float = None, | |
image_mean: Optional[Union[float, List[float]]] = None, | |
image_std: Optional[Union[float, List[float]]] = None, | |
) -> List[np.ndarray]: | |
height, width = size[0], size[1] | |
images = [ | |
resize(image=image, size=(height, width), resample=resample) for image in images | |
] | |
# Convert each image to a numpy array | |
images = [np.array(image) for image in images] | |
# Rescale the pixel values to be in the range [0, 1] | |
images = [rescale(image, scale=rescale_factor) for image in images] | |
# Normalize the images to have mean 0 and standard deviation 1 | |
images = [normalize(image, mean=image_mean, std=image_std) for image in images] | |
# Move the channel dimension to the first dimension. The model expects images in the format [Channel, Height, Width] | |
images = [image.transpose(2, 0, 1) for image in images] | |
return images | |
def sample_top_p(probs: torch.Tensor, p: float): | |
# (B, vocab_size) | |
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) | |
# (B, vocab_size) | |
probs_sum = torch.cumsum(probs_sort, dim=-1) | |
# (B, vocab_size) | |
# (Substracting "probs_sort" shifts the cumulative sum by 1 position to the right before masking) | |
mask = probs_sum - probs_sort > p | |
# Zero out all the probabilities of tokens that are not selected by the Top P | |
probs_sort[mask] = 0.0 | |
# Redistribute the probabilities so that they sum up to 1. | |
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) | |
# Sample a token (its index) from the top p distribution | |
next_token = torch.multinomial(probs_sort, num_samples=1) | |
# Get the token position in the vocabulary corresponding to the sampled index | |
next_token = torch.gather(probs_idx, -1, next_token) | |
return next_token | |
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: | |
""" | |
Args: | |
lengths: | |
A 1-D tensor containing sentence lengths. | |
max_len: | |
The length of masks. | |
Returns: | |
Return a 2-D bool tensor, where masked positions | |
are filled with `True` and non-masked positions are | |
filled with `False`. | |
>>> lengths = torch.tensor([1, 3, 2, 5]) | |
>>> make_pad_mask(lengths) | |
tensor([[False, True, True, True, True], | |
[False, False, False, True, True], | |
[False, False, True, True, True], | |
[False, False, False, False, False]]) | |
""" | |
assert lengths.ndim == 1, lengths.ndim | |
max_len = max(max_len, lengths.max()) | |
n = lengths.size(0) | |
seq_range = torch.arange(0, max_len, device=lengths.device) | |
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) | |
return expaned_lengths >= lengths.unsqueeze(-1) | |
def _prepare_4d_causal_attention_mask_with_cache_position( | |
attention_mask: torch.Tensor, | |
sequence_length: int, | |
target_length: int, | |
dtype: torch.dtype, | |
device: torch.device, | |
min_dtype: float, | |
cache_position: torch.Tensor, | |
batch_size: int, | |
is_training: bool = False, | |
token_type_ids: torch.Tensor = None, | |
): | |
""" | |
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape | |
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. | |
Args: | |
attention_mask (`torch.Tensor`): | |
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. | |
sequence_length (`int`): | |
The sequence length being processed. | |
target_length (`int`): | |
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. | |
dtype (`torch.dtype`): | |
The dtype to use for the 4D attention mask. | |
device (`torch.device`): | |
The device to plcae the 4D attention mask on. | |
min_dtype (`float`): | |
The minimum value representable with the dtype `dtype`. | |
cache_position (`torch.Tensor`): | |
Indices depicting the position of the input sequence tokens in the sequence. | |
batch_size (`torch.Tensor`): | |
Batch size. | |
is_training (`bool`): | |
Whether the model is in training mode or in inference. The condition is checked by presence/absence of `token_type_ids/labels` | |
""" | |
if attention_mask is not None and attention_mask.dim() == 4: | |
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. | |
causal_mask = attention_mask | |
else: | |
causal_mask = torch.full( | |
(sequence_length, target_length), | |
fill_value=min_dtype, | |
dtype=dtype, | |
device=device, | |
) | |
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below | |
if sequence_length != 1: | |
if is_training: | |
causal_mask = torch.triu(causal_mask, diagonal=1) | |
else: | |
causal_mask[:, :sequence_length] = 0.0 | |
causal_mask *= torch.arange( | |
target_length, device=cache_position.device | |
) > cache_position.reshape(-1, 1) | |
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) | |
if attention_mask is not None: | |
causal_mask = ( | |
causal_mask.clone() | |
) # copy to contiguous memory for in-place edit | |
mask_length = attention_mask.shape[-1] | |
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[ | |
:, None, None, : | |
].to(causal_mask.device) | |
padding_mask = padding_mask == 0 | |
causal_mask[:, :, :, :mask_length] = causal_mask[ | |
:, :, :, :mask_length | |
].masked_fill(padding_mask, min_dtype) | |
# we are training thus we need to create a full mask on the image + prefix but causal on suffix | |
if is_training: | |
causal_mask[:, :, :, :mask_length] = causal_mask[ | |
:, :, :, :mask_length | |
].masked_fill( | |
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0 | |
) | |
return causal_mask | |
# Copied from transformers.models.idefics2.processing_idefics2.is_url | |
def is_url(val) -> bool: | |
return isinstance(val, str) and val.startswith("http") | |
# Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url | |
def is_image_or_image_url(elem): | |
return is_url(elem) or is_valid_image(elem) | |
def _is_str_or_image(elem): | |
return isinstance(elem, (str)) or is_image_or_image_url(elem) | |
def generate_partial_autoregressive_mask(sz, start, end): | |
mask = torch.zeros(sz, sz).bool() | |
mask[start:end, start:end] = torch.triu( | |
torch.ones(end - start, end - start, dtype=torch.bool), diagonal=1 | |
) | |
mask[:start, start:end] = True | |
mask[end:, start:end] = True | |
return mask | |
def build_string_from_input(prompt, bos_token, image_seq_len, image_token, num_images): | |
return f"{image_token * image_seq_len * num_images}{bos_token}{prompt}\n" | |
def is_torchdynamo_compiling(): | |
try: | |
import torch | |
return torch.compiler.is_compiling() | |
except Exception: | |
try: | |
import torch._dynamo as dynamo # noqa: F401 | |
return dynamo.is_compiling() | |
except Exception: | |
return False | |