Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from transformers.generation.logits_process import LogitsProcessor | |
from transformers.generation.stopping_criteria import StoppingCriteria | |
class ASRLogitsProcessor(LogitsProcessor): | |
def __init__(self, text_token_num: int): | |
self.text_token_num = text_token_num | |
def __call__(self, input_ids, scores): | |
scores[..., self.text_token_num:] = torch.finfo(scores.dtype).min | |
return scores | |
class TTSLogitsProcessor(LogitsProcessor): | |
""" | |
TTS 任务使用的LogitsProcessor,把所有text位置的logits设置为负无穷 | |
""" | |
def __init__(self, text_token_num: int): | |
self.text_token_num = text_token_num | |
def __call__(self, input_ids, scores): | |
scores[..., :self.text_token_num] = torch.finfo(scores.dtype).min | |
return scores | |
class S2SLogitsProcessor(LogitsProcessor): | |
"""Speech 2 Speech 任务使用的 LogitsProcessor,当前只适用于batch_size=1 | |
Args: | |
LogitsProcessor (_type_): _description_ | |
""" | |
def __init__(self, text_token_num: int, text_eos_id: int): | |
self.text_token_num = text_token_num | |
self.text_eos_id = text_eos_id | |
self.text_phase = True | |
def __call__(self, input_ids, scores): | |
print(input_ids.shape) | |
assert input_ids.size(0) == 1, "ERROR: S2SSpeechLogitsProcessor only support bs=1 now" | |
if self.text_phase: | |
scores[..., self.text_token_num:] = torch.finfo(scores.dtype).min | |
else: | |
scores[..., :self.text_token_num] = torch.finfo(scores.dtype).min | |
if self.text_phase and torch.isin(input_ids, self.text_eos_id): | |
self.text_phase = False | |
return scores | |
class S2SStopCriteria(StoppingCriteria): | |
"""Speech 2 Speech 任务使用的 停止条件,当前只适用于batch_size=1 | |
Args: | |
LogitsProcessor (_type_): _description_ | |
""" | |
def __init__(self, text_eos_id: int, speech_eos_id: int): | |
self.text_eos_id = text_eos_id | |
self.speech_eos_id = speech_eos_id | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs): | |
_input_ids = input_ids.flatten().view(-1) | |
if torch.isin(_input_ids, self.text_eos_id).any(): | |
text_eos_idx = (_input_ids == self.text_eos_id).nonzero(as_tuple=True)[0][0].item() | |
if torch.sum(_input_ids[text_eos_idx:] == self.speech_eos_id) > 1: | |
return True | |
return False | |
class MaxTokenStopper(StoppingCriteria): | |
def __init__(self, max_tokens): | |
self.max_tokens = max_tokens | |
# TODO@wsy:期望能够修改max_tokens,但好像没用,后续注意 | |
def change_max_tokens(self, max_tokens): | |
self.max_tokens = max_tokens | |
def __call__(self, input_ids, scores, **kwargs): | |
return input_ids.shape[1] >= self.max_tokens # 检查当前序列长度 | |
class InterruptStopper(StoppingCriteria): | |
def __init__(self): | |
self.stop = False | |
def __call__(self, input_ids, scores, **kwargs): | |
if self.stop == True: | |
# self.stop == False # reset | |
return True | |
else: | |
return False |