tts / ChatTTS /utils /infer_utils.py
Dzkaka's picture
Upload 9 files
b99882a verified
raw
history blame
1.73 kB
import torch
import torch.nn.functional as F
class CustomRepetitionPenaltyLogitsProcessorRepeat():
def __init__(self, penalty: float, max_input_ids, past_window):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
self.penalty = penalty
self.max_input_ids = max_input_ids
self.past_window = past_window
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
input_ids = input_ids[:, -self.past_window:]
freq = F.one_hot(input_ids, scores.size(1)).sum(1)
freq[self.max_input_ids:] = 0
alpha = self.penalty**freq
scores = torch.where(scores < 0, scores*alpha, scores/alpha)
return scores
class CustomRepetitionPenaltyLogitsProcessor():
def __init__(self, penalty: float, max_input_ids, past_window):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
self.penalty = penalty
self.max_input_ids = max_input_ids
self.past_window = past_window
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
input_ids = input_ids[:, -self.past_window:]
score = torch.gather(scores, 1, input_ids)
_score = score.detach().clone()
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
score[input_ids>=self.max_input_ids] = _score[input_ids>=self.max_input_ids]
scores.scatter_(1, input_ids, score)
return scores