VALL-E-X / models /vallex.py
Plachta's picture
Added safety measure
cf8f15a
# Copyright 2023 (authors: Feiteng Li)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from typing import Dict, Iterator, List, Tuple, Union
import gc
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# from icefall.utils import make_pad_mask
# from torchmetrics.classification import MulticlassAccuracy
from modules.embedding import SinePositionalEmbedding, TokenEmbedding
from modules.transformer import (
AdaptiveLayerNorm,
LayerNorm,
TransformerDecoderLayer,
TransformerEncoder,
TransformerEncoderLayer,
)
from .macros import NUM_AUDIO_TOKENS, NUM_TEXT_TOKENS
import psutil
def get_memory_usage():
process = psutil.Process()
memory_info = process.memory_info()
memory_used = memory_info.rss
memory_used_mb = memory_used / (1024 * 1024)
return memory_used_mb
class Transpose(nn.Identity):
"""(N, T, D) -> (N, D, T)"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return input.transpose(1, 2)
# NOTE: There are two ways to implement the model
# 1) [VALL-F] standard TransformerDecoder, use x as memory
# 2) [VALL-E] modified TransformerDecoder like GPT-x(e.g. causal TransformerEncoder),
# use x as the prefix of decoder inputs
class VALLF(nn.Module):
"""It implements https://arxiv.org/abs/2301.02111
"Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers"
"""
def __init__(
self,
d_model: int,
nhead: int,
num_layers: int,
norm_first: bool = True,
add_prenet: bool = False,
decoder_cls: Union[
nn.TransformerDecoder, nn.TransformerEncoder
] = nn.TransformerDecoder,
decoder_layer_cls: Union[
TransformerDecoderLayer, TransformerEncoderLayer
] = TransformerDecoderLayer,
prefix_mode: int = 0,
share_embedding: bool = True,
nar_scale_factor: float = 1.0,
prepend_bos: bool = True,
num_quantizers: int = 8,
):
"""
Args:
d_model:
The number of expected features in the input (required).
nhead:
The number of heads in the multiheadattention models (required).
num_layers:
The number of sub-decoder-layers in the decoder (required).
"""
super().__init__()
nar_d_model = int(d_model * nar_scale_factor)
self.ar_text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x
self.nar_text_embedding = TokenEmbedding(nar_d_model, NUM_TEXT_TOKENS)
# ID NUM_AUDIO_TOKENS -> PAD
# ID NUM_AUDIO_TOKENS + 1 -> BOS
self.ar_audio_prepend_bos = prepend_bos
self.ar_audio_embedding = TokenEmbedding(
d_model, NUM_AUDIO_TOKENS + 1 + int(prepend_bos)
)
# PreNet
if add_prenet:
self.ar_text_prenet = nn.Sequential(
Transpose(),
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
nn.BatchNorm1d(d_model),
nn.ReLU(),
nn.Dropout(0.5),
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
nn.BatchNorm1d(d_model),
nn.ReLU(),
nn.Dropout(0.5),
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
nn.BatchNorm1d(d_model),
nn.ReLU(),
nn.Dropout(0.5),
Transpose(),
nn.Linear(d_model, d_model),
)
self.ar_audio_prenet = nn.Sequential(
nn.Linear(d_model, 256),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(256, 256),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(256, d_model),
)
else:
self.ar_text_prenet = nn.Identity()
self.ar_audio_prenet = nn.Identity()
self.ar_text_position = SinePositionalEmbedding(
d_model,
dropout=0.1,
scale=False,
alpha=True,
)
self.ar_audio_position = SinePositionalEmbedding(
d_model,
dropout=0.1,
scale=False,
alpha=True,
)
self.ar_decoder = decoder_cls(
decoder_layer_cls(
d_model,
nhead,
dim_feedforward=d_model * 4,
dropout=0.1,
batch_first=True,
norm_first=norm_first,
),
num_layers=num_layers,
norm=LayerNorm(d_model) if norm_first else None,
)
self.ar_predict_layer = nn.Linear(
d_model, NUM_AUDIO_TOKENS + 1, bias=False
)
self.rng = random.Random(0)
self.num_heads = nhead
self.prefix_mode = prefix_mode
self.num_quantizers = num_quantizers
assert num_quantizers >= 1
if num_quantizers > 1:
self.nar_audio_embeddings = nn.ModuleList(
[TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS + 1)]
+ [
TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS)
for i in range(num_quantizers - 1)
]
) # W_a
# PreNet
if add_prenet:
self.nar_text_prenet = nn.Sequential(
Transpose(),
nn.Conv1d(
nar_d_model, nar_d_model, kernel_size=5, padding="same"
),
nn.BatchNorm1d(nar_d_model),
nn.ReLU(),
nn.Dropout(0.5),
nn.Conv1d(
nar_d_model, nar_d_model, kernel_size=5, padding="same"
),
nn.BatchNorm1d(nar_d_model),
nn.ReLU(),
nn.Dropout(0.5),
nn.Conv1d(
nar_d_model, nar_d_model, kernel_size=5, padding="same"
),
nn.BatchNorm1d(nar_d_model),
nn.ReLU(),
nn.Dropout(0.5),
Transpose(),
nn.Linear(nar_d_model, nar_d_model),
)
self.nar_audio_prenet = nn.Sequential(
nn.Linear(nar_d_model, 256),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(256, 256),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(256, nar_d_model),
)
else:
self.nar_text_prenet = nn.Identity()
self.nar_audio_prenet = nn.Identity()
self.nar_text_position = SinePositionalEmbedding(
nar_d_model,
dropout=0.0,
scale=False,
alpha=False,
)
self.nar_audio_position = SinePositionalEmbedding(
nar_d_model,
dropout=0.1,
scale=False,
alpha=False,
)
self.nar_decoder = decoder_cls(
decoder_layer_cls(
nar_d_model,
int(nhead * nar_scale_factor),
dim_feedforward=nar_d_model * 4,
dropout=0.1,
batch_first=True,
norm_first=norm_first,
adaptive_layer_norm=True,
),
num_layers=int(num_layers * nar_scale_factor),
norm=AdaptiveLayerNorm(
nar_d_model, norm=nn.LayerNorm(nar_d_model)
)
if norm_first
else None,
)
self.nar_predict_layers = nn.ModuleList(
[
nn.Linear(nar_d_model, NUM_AUDIO_TOKENS, bias=False)
for i in range(num_quantizers - 1)
]
)
self.nar_stage_embeddings = nn.ModuleList(
[
TokenEmbedding(nar_d_model, 1)
for i in range(num_quantizers - 1)
]
)
if share_embedding:
# We share the parameters of the output projection layer with the parameters of the acoustic embedding Wa
# NOTE(Feiteng): In the experiment, this undermines accuracy
# self.ar_predict_layer.weight = self.ar_audio_embedding.weight
# We also share the parameters of the acoustic embedding layer and the output prediction layer,
# which means the weights of the j-th prediction layer are the same as the (j + 1)-th acoustic embedding layer.
for j in range(0, num_quantizers - 2):
self.nar_predict_layers[
j
].weight = self.nar_audio_embeddings[j + 2].weight
def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]:
assert stage > 0
if stage == 1:
for name, param in self.named_parameters():
if name.startswith("ar_"):
print(f" AR parameter: {name}")
yield param
if stage == 2:
for name, param in self.named_parameters():
if name.startswith("nar_"):
print(f"NAR parameter: {name}")
yield param
def stage_named_parameters(
self, stage: int = 1
) -> Iterator[Tuple[str, nn.Parameter]]:
assert stage > 0
if stage == 1:
for pair in self.named_parameters():
if pair[0].startswith("ar_"):
yield pair
if stage == 2:
for pair in self.named_parameters():
if pair[0].startswith("nar_"):
yield pair
def pad_y_eos(self, y, y_mask_int, eos_id):
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
y_mask_int, (0, 1), value=1
)
# inputs, targets
if self.ar_audio_prepend_bos:
return (
F.pad(targets[:, :-1], (1, 0), value=NUM_AUDIO_TOKENS + 1),
targets,
)
return targets[:, :-1], targets[:, 1:]
def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes, prefix_mode):
# 5.1 For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds
# from the same utterance.
# We implement this differently.
if prefix_mode == 0:
# no prefix
prefix_len = 0
y_emb = self.nar_audio_embeddings[0](y)
for j in range(1, nar_stage):
# Formula (4) (5)
y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j])
elif prefix_mode == 1:
# prefix at begining
int_low = (0.25 * y_lens.min()).type(torch.int64).item()
prefix_len = torch.randint(0, int_low * 2, size=()).item()
prefix_len = min(prefix_len, 225) # 24000/320 * 3s = 225 frames
y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len])
y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:])
for j in range(1, self.num_quantizers):
y_prompts += self.nar_audio_embeddings[j](
codes[:, :prefix_len, j]
)
if j < nar_stage:
y_emb += self.nar_audio_embeddings[j](
codes[:, prefix_len:, j]
)
y_emb = torch.concat([y_prompts, y_emb], axis=1)
elif prefix_mode in [2, 4]:
if prefix_mode == 2:
# random prefix
prefix_len = min(225, int(0.25 * y_lens.min().item()))
y_prompts_codes = []
for b in range(codes.shape[0]):
start = self.rng.randint(0, y_lens[b].item() - prefix_len)
y_prompts_codes.append(
torch.clone(codes[b, start : start + prefix_len])
)
codes[
b, start : start + prefix_len, nar_stage
] = NUM_AUDIO_TOKENS
y_prompts_codes = torch.stack(y_prompts_codes, dim=0)
else:
prefix_len = y_prompts_codes.shape[1]
y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0])
y_emb = self.nar_audio_embeddings[0](y)
for j in range(1, self.num_quantizers):
y_prompts += self.nar_audio_embeddings[j](
y_prompts_codes[..., j]
)
if j < nar_stage:
y_emb += self.nar_audio_embeddings[j](codes[..., j])
y_emb = torch.concat([y_prompts, y_emb], axis=1)
else:
raise ValueError
return y_emb, prefix_len
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: Union[torch.Tensor],
y_lens: Union[torch.Tensor],
reduction: str = "sum",
train_stage: int = 0,
**kwargs,
) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
raise NotImplementedError
def inference(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: torch.Tensor,
enroll_x_lens: Union[torch.Tensor, None] = None,
top_k: int = -100,
temperature: float = 1.0,
) -> torch.Tensor:
raise NotImplementedError
def visualize(
self,
predicts: Tuple[torch.Tensor],
batch: Dict[str, Union[List, torch.Tensor]],
output_dir: str,
limit: int = 4,
) -> None:
raise NotImplementedError
class VALLE(VALLF):
"""It implements https://arxiv.org/abs/2301.02111
"Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers"
"""
def __init__(
self,
d_model: int,
nhead: int,
num_layers: int,
norm_first: bool = True,
add_prenet: bool = False,
prefix_mode: int = 0,
share_embedding: bool = True,
nar_scale_factor: float = 1.0,
**kwargs,
):
"""
Args:
d_model:
The number of expected features in the input (required).
nhead:
The number of heads in the multiheadattention models (required).
num_layers:
The number of sub-decoder-layers in the decoder (required).
"""
super(VALLE, self).__init__(
d_model,
nhead,
num_layers,
norm_first=norm_first,
add_prenet=add_prenet,
decoder_cls=TransformerEncoder,
decoder_layer_cls=TransformerEncoderLayer,
prefix_mode=prefix_mode,
share_embedding=share_embedding,
nar_scale_factor=nar_scale_factor,
**kwargs,
)
self.language_ID = {
'en': 0,
'zh': 1,
'ja': 2,
}
self.ar_language_embedding = TokenEmbedding(d_model, len(self.language_ID))
self.nar_language_embedding = TokenEmbedding(d_model, len(self.language_ID))
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: Union[torch.Tensor],
y_lens: Union[torch.Tensor],
reduction: str = "sum",
train_stage: int = 0,
**kwargs,
):
raise NotImplementedError
def inference(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: torch.Tensor,
enroll_x_lens: torch.Tensor,
top_k: int = -100,
temperature: float = 1.0,
prompt_language: str = None,
text_language: str = None,
) -> torch.Tensor:
"""
Args:
x:
A 2-D tensor of shape (1, S).
x_lens:
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
before padding.
y:
A 3-D tensor of shape (1, T, 8).
top_k: (`optional`) int
The number of highest probability tokens to keep for top-k-filtering. Default to -100.
temperature: (`optional`) float
The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
Returns:
Return the predicted audio code matrix.
"""
assert x.ndim == 2, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.ndim == 3, y.shape
assert y.shape[0] == 1, y.shape
assert torch.all(x_lens > 0)
# NOTE: x has been padded in TextTokenCollater
text = x
x = self.ar_text_embedding(text)
# Add language embedding
prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
if isinstance(text_language, str):
text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device)
elif isinstance(text_language, List):
text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device)
x[:, :enroll_x_lens, :] += self.ar_language_embedding(prompt_language_id)
x[:, enroll_x_lens:, :] += self.ar_language_embedding(text_language_id)
x = self.ar_text_prenet(x)
x = self.ar_text_position(x)
text_len = x_lens.max()
prompts = y
prefix_len = y.shape[1]
# AR Decoder
# TODO: Managing decoder steps avoid repetitive computation
y = prompts[..., 0]
if self.ar_audio_prepend_bos:
y = F.pad(y, (1, 0), value=NUM_AUDIO_TOKENS + 1)
x_len = x_lens.max()
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
kv_cache = None
use_kv_caching = True
while True:
y_emb = self.ar_audio_embedding(y)
y_emb = self.ar_audio_prenet(y_emb)
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
y_len = y.shape[1]
x_attn_mask_pad = F.pad(
x_attn_mask,
(0, y_len),
value=True,
)
y_attn_mask = F.pad(
torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1
),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat(
[x_attn_mask_pad, y_attn_mask], dim=0
).to(y.device)
if use_kv_caching and kv_cache is not None:
xy_pos = xy_pos[:, [-1]]
else:
pass
xy_dec, kv_cache = self.ar_decoder.infer(
xy_pos,
mask=xy_attn_mask,
past_kv=kv_cache,
use_cache=use_kv_caching,
)
# xy_dec, _ = self.ar_decoder(
# (xy_pos, None),
# mask=xy_attn_mask,
# )
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = topk_sampling(
logits, top_k=top_k, top_p=1, temperature=temperature
)
if (
torch.argmax(logits, dim=-1)[0] == NUM_AUDIO_TOKENS
or samples[0, 0] == NUM_AUDIO_TOKENS
or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16
):
if prompts.shape[1] == y.shape[1]:
raise SyntaxError(
"well trained model shouldn't reach here."
)
print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]")
memory_used = get_memory_usage()
print(f"Current memory used: {memory_used:.2f} MB")
break
# safety measure, break if token sequence too long
if y.shape[1] > 2250:
print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]")
break
y = torch.concat([y, samples], dim=1)
codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]]
if self.num_quantizers == 1:
return torch.stack(codes, dim=-1)
# Non-AR Decoders
y_emb = self.nar_audio_embeddings[0](
y[:, int(self.ar_audio_prepend_bos) :]
)
if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes
enrolled_len = enroll_x_lens.max().item()
# SOS + Synthesis Text + EOS
text = torch.concat(
[
text[:, :1],
text[:, enrolled_len - 1 :],
],
dim=1,
)
text_len = text_len - (enrolled_len - 2)
assert text.shape[0] == 1
x = self.nar_text_embedding(text)
# Add language embedding
prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
if isinstance(text_language, str):
text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device)
elif isinstance(text_language, List):
text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device)
x[:, :enroll_x_lens, :] += self.nar_language_embedding(prompt_language_id)
x[:, enroll_x_lens:, :] += self.nar_language_embedding(text_language_id)
x = self.nar_text_prenet(x)
x = self.nar_text_position(x)
if self.prefix_mode == 0:
for i, (predict_layer, embedding_layer) in enumerate(
zip(
self.nar_predict_layers,
self.nar_audio_embeddings[1:],
)
):
y_pos = self.nar_audio_prenet(y_emb)
y_pos = self.nar_audio_position(y_pos)
xy_pos = torch.concat([x, y_pos], dim=1)
xy_dec, _ = self.nar_decoder(
(xy_pos, self.nar_stage_embeddings[i].weight)
)
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
samples = torch.argmax(logits, dim=-1)
codes.append(samples)
if i < self.num_quantizers - 2:
y_emb[:, :prefix_len] += embedding_layer(
prompts[..., i + 1]
)
y_emb[:, prefix_len:] += embedding_layer(samples)
else:
for j in range(1, self.num_quantizers):
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](
prompts[..., j]
)
for i, (predict_layer, embedding_layer) in enumerate(
zip(
self.nar_predict_layers,
self.nar_audio_embeddings[1:],
)
):
y_pos = self.nar_audio_prenet(y_emb)
y_pos = self.nar_audio_position(y_pos)
xy_pos = torch.concat([x, y_pos], dim=1)
xy_dec, _ = self.nar_decoder(
(xy_pos, self.nar_stage_embeddings[i].weight)
)
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
samples = torch.argmax(logits, dim=-1)
codes.append(samples)
if i < self.num_quantizers - 2:
y_emb[:, prefix_len:] += embedding_layer(samples)
assert len(codes) == self.num_quantizers
del text_language_id, prompt_language_id, y_emb, x, y_pos, xy_pos, xy_dec, logits, samples, kv_cache, x_attn_mask, y_attn_mask, xy_attn_mask
gc.collect()
return torch.stack(codes, dim=-1)
def continual(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: torch.Tensor,
) -> torch.Tensor:
"""
Args:
x:
A 2-D tensor of shape (1, S).
x_lens:
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
before padding.
y:
A 3-D tensor of shape (1, T, 8).
Returns:
Return the predicted audio code matrix.
"""
assert x.ndim == 2, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.ndim == 3, y.shape
assert y.shape[0] == 1, y.shape
assert torch.all(x_lens > 0)
assert self.num_quantizers == 8
# NOTE: x has been padded in TextTokenCollater
text = x
x = self.ar_text_embedding(text)
x = self.ar_text_prenet(x)
x = self.ar_text_position(x)
text_len = x_lens.max()
prefix_len = min(int(y.shape[1] * 0.5), 3 * 75)
# AR Decoder
prompts = y[:, :prefix_len]
codes = [y[:, prefix_len:, 0]]
# Non-AR Decoders
x = self.nar_text_embedding(text)
x = self.nar_text_prenet(x)
x = self.nar_text_position(x)
y_emb = self.nar_audio_embeddings[0](y[..., 0])
if self.prefix_mode == 0:
for i, (predict_layer, embedding_layer) in enumerate(
zip(
self.nar_predict_layers,
self.nar_audio_embeddings[1:],
)
):
y_pos = self.nar_audio_position(y_emb)
y_pos = self.nar_audio_prenet(y_pos)
xy_pos = torch.concat([x, y_pos], dim=1)
xy_dec, _ = self.nar_decoder(
(xy_pos, self.nar_stage_embeddings[i].weight)
)
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
samples = torch.argmax(logits, dim=-1)
codes.append(samples)
if i < 6:
y_emb[:, :prefix_len] += embedding_layer(
prompts[..., i + 1]
)
y_emb[:, prefix_len:] += embedding_layer(samples)
else:
for j in range(1, 8):
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](
prompts[..., j]
)
for i, (predict_layer, embedding_layer) in enumerate(
zip(
self.nar_predict_layers,
self.nar_audio_embeddings[1:],
)
):
y_pos = self.nar_audio_prenet(y_emb)
y_pos = self.nar_audio_position(y_pos)
xy_pos = torch.concat([x, y_pos], dim=1)
xy_dec, _ = self.nar_decoder(
(xy_pos, self.nar_stage_embeddings[i].weight)
)
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
samples = torch.argmax(logits, dim=-1)
codes.append(samples)
if i < 6:
y_emb[:, prefix_len:] += embedding_layer(samples)
assert len(codes) == 8
return torch.stack(codes, dim=-1)
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
def top_k_top_p_filtering(
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens_to_keep per batch example in the output
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
if top_k > 0:
top_k = min(
max(top_k, min_tokens_to_keep), logits.size(-1)
) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(
F.softmax(sorted_logits, dim=-1), dim=-1
)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
..., :-1
].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
# temperature: (`optional`) float
# The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
# top_k: (`optional`) int
# The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
# top_p: (`optional`) float
# The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
# Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0:
logits = logits / temperature
# Top-p/top-k filtering
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
# Sample
token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
return token