MARS5-TTS / mars5 /samplers.py
arnavmehta7's picture
Add files (#1)
8520a55 verified
raw
history blame
No virus
5.63 kB
"""
Code for modifying categorical distributions to improve quality of sampling.
Adapted from:
- https://github.com/e-c-k-e-r/vall-e/blob/master/vall_e/samplers.py
- Mirosoft UniLM
- Matthew Baas's typical sampling code.
- https://github.com/LostRuins/koboldcpp
"""
import math
import torch
import torch.nn.functional as F
import numpy as np
import logging
from torch import Tensor, nn
def freq_rep_penalty(logits: Tensor, previous: Tensor, alpha_frequency: float, alpha_presence: float, penalty_window: int = 100) -> Tensor:
""" Apply frequency and presence penalty according to openai's formuation.
Concretely: given `logits` (bs, vocab_size) and `previous` (bs, seq_len,)
Modified to support batched inference.
See: https://platform.openai.com/docs/guides/text-generation/parameter-details
"""
bs = logits.shape[0]
previous = previous[..., -penalty_window:]
c = torch.zeros_like(logits, device=logits.device, dtype=torch.long) # (1, vocab_size)
for i in range(bs):
vals, cnts = previous[i].unique(return_counts=True)
c[i, vals] = cnts.to(c.device)
logits = logits - c * alpha_frequency - (c > 0).to(logits.dtype) * alpha_presence
return logits
def early_eos_penalty(logits: Tensor, n_generated: int, estimated_gen_length: int, decay: float, factor: float = 1, eos_index: int = 0) -> Tensor:
""" Penalize the `eos_index` of `logits` (bs, vocab_size) up to `estimated_gen_length`,
whereby we reduce the logit value by `factor`*(expected_length - current_length)^decay,
`n_generated` is the current number of generated samples. `decay` anneals the penalty relative to the distance.
Good values for decay are between 0 and 1. 0 = hard always apply penalty of 1, 1 = linearly scale penalty relative to distance.
Setting factor = 0 disabled penatly. Increasing factor increases penalty.
"""
if n_generated > estimated_gen_length: return logits
penalty = max(estimated_gen_length - n_generated, 1)
bigger = logits[:, eos_index] > 0
modifier = factor*(penalty ** decay)
# logits[bigger, eos_index] /= modifier
# logits[~bigger, eos_index] *= modifier
logits[:, eos_index] -= modifier
return logits
# Credit to https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py#L1145 /
# https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
def top_k_top_p_filtering( logits: Tensor, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens=1 ) -> Tensor:
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens per batch example in the output
"""
if top_k > 0:
top_k = min(max(top_k, min_tokens), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens > 1:
# Keep at least min_tokens (set to min_tokens-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
def apply_typical_p(logprobs: Tensor, mass: float) -> Tensor:
""" Warp categorical logprobs associated with `x` to be in line with `mass`. Last dimension is the bin dimension.
`mass` corresponds to `tau` in the paper.
"""
if mass > 0.999: return logprobs
# see: https://arxiv.org/abs/2202.00666
# calculate entropy
# normalized = logprobs #torch.nn.functional.log_softmax(scores, dim=-1)
normalized = torch.nn.functional.log_softmax(logprobs, dim=-1)
p = torch.exp(normalized)
ent = -(normalized * p).nansum(-1, keepdim=True)
# shift and sort
shifted_scores = torch.abs((-normalized) - ent)
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
sorted_logits = logprobs.gather(-1, sorted_indices)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative mass above the threshold
last_ind = (cumulative_probs < mass).sum(dim=1)
last_ind[last_ind < 0] = 0
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = logprobs.masked_fill(indices_to_remove, -float('Inf'))
return scores