OSUM-EChat / patches /custom_speech_repetition_penalty.py
xlgeng's picture
开始部署
841f290
raw
history blame contribute delete
929 Bytes
from transformers.generation.logits_process import LogitsProcessor
class SpeechOnlyRepetitionPenaltyLogitsProcessor(LogitsProcessor):
def __init__(self, speech_token_num, penalty=1.2):
self.speech_token_num = speech_token_num
self.penalty = penalty
self.speech_phase = False # 你需要在外部控制这个变量
def set_phase(self, speech_phase: bool):
self.speech_phase = speech_phase
def __call__(self, input_ids, scores):
if not self.speech_phase:
# text阶段,什么都不做
return scores
# speech阶段,只对speech token做重复抑制
for batch_idx in range(input_ids.size(0)):
generated = input_ids[batch_idx].tolist()
for token_id in set(generated):
if 0 <= token_id < self.speech_token_num:
scores[batch_idx, token_id] /= self.penalty
return scores