OSUM-EChat / patches /cumstom_stop_criteria.py
xlgeng's picture
开始部署
841f290
import torch
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.stopping_criteria import StoppingCriteria
class ASRLogitsProcessor(LogitsProcessor):
def __init__(self, text_token_num: int):
self.text_token_num = text_token_num
def __call__(self, input_ids, scores):
scores[..., self.text_token_num:] = torch.finfo(scores.dtype).min
return scores
class TTSLogitsProcessor(LogitsProcessor):
"""
TTS 任务使用的LogitsProcessor,把所有text位置的logits设置为负无穷
"""
def __init__(self, text_token_num: int):
self.text_token_num = text_token_num
def __call__(self, input_ids, scores):
scores[..., :self.text_token_num] = torch.finfo(scores.dtype).min
return scores
class S2SLogitsProcessor(LogitsProcessor):
"""Speech 2 Speech 任务使用的 LogitsProcessor,当前只适用于batch_size=1
Args:
LogitsProcessor (_type_): _description_
"""
def __init__(self, text_token_num: int, text_eos_id: int):
self.text_token_num = text_token_num
self.text_eos_id = text_eos_id
self.text_phase = True
def __call__(self, input_ids, scores):
print(input_ids.shape)
assert input_ids.size(0) == 1, "ERROR: S2SSpeechLogitsProcessor only support bs=1 now"
if self.text_phase:
scores[..., self.text_token_num:] = torch.finfo(scores.dtype).min
else:
scores[..., :self.text_token_num] = torch.finfo(scores.dtype).min
if self.text_phase and torch.isin(input_ids, self.text_eos_id):
self.text_phase = False
return scores
class S2SStopCriteria(StoppingCriteria):
"""Speech 2 Speech 任务使用的 停止条件,当前只适用于batch_size=1
Args:
LogitsProcessor (_type_): _description_
"""
def __init__(self, text_eos_id: int, speech_eos_id: int):
self.text_eos_id = text_eos_id
self.speech_eos_id = speech_eos_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
_input_ids = input_ids.flatten().view(-1)
if torch.isin(_input_ids, self.text_eos_id).any():
text_eos_idx = (_input_ids == self.text_eos_id).nonzero(as_tuple=True)[0][0].item()
if torch.sum(_input_ids[text_eos_idx:] == self.speech_eos_id) > 1:
return True
return False
class MaxTokenStopper(StoppingCriteria):
def __init__(self, max_tokens):
self.max_tokens = max_tokens
# TODO@wsy:期望能够修改max_tokens,但好像没用,后续注意
def change_max_tokens(self, max_tokens):
self.max_tokens = max_tokens
def __call__(self, input_ids, scores, **kwargs):
return input_ids.shape[1] >= self.max_tokens # 检查当前序列长度
class InterruptStopper(StoppingCriteria):
def __init__(self):
self.stop = False
def __call__(self, input_ids, scores, **kwargs):
if self.stop == True:
# self.stop == False # reset
return True
else:
return False