Spaces:
Running
on
Zero
Running
on
Zero
from transformers.generation.logits_process import LogitsProcessor | |
class SpeechOnlyRepetitionPenaltyLogitsProcessor(LogitsProcessor): | |
def __init__(self, speech_token_num, penalty=1.2): | |
self.speech_token_num = speech_token_num | |
self.penalty = penalty | |
self.speech_phase = False # 你需要在外部控制这个变量 | |
def set_phase(self, speech_phase: bool): | |
self.speech_phase = speech_phase | |
def __call__(self, input_ids, scores): | |
if not self.speech_phase: | |
# text阶段,什么都不做 | |
return scores | |
# speech阶段,只对speech token做重复抑制 | |
for batch_idx in range(input_ids.size(0)): | |
generated = input_ids[batch_idx].tolist() | |
for token_id in set(generated): | |
if 0 <= token_id < self.speech_token_num: | |
scores[batch_idx, token_id] /= self.penalty | |
return scores |