File size: 929 Bytes
841f290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from transformers.generation.logits_process import LogitsProcessor

class SpeechOnlyRepetitionPenaltyLogitsProcessor(LogitsProcessor):
    def __init__(self, speech_token_num, penalty=1.2):
        self.speech_token_num = speech_token_num
        self.penalty = penalty
        self.speech_phase = False  # 你需要在外部控制这个变量

    def set_phase(self, speech_phase: bool):
        self.speech_phase = speech_phase

    def __call__(self, input_ids, scores):
        if not self.speech_phase:
            # text阶段,什么都不做
            return scores
        # speech阶段,只对speech token做重复抑制
        for batch_idx in range(input_ids.size(0)):
            generated = input_ids[batch_idx].tolist()
            for token_id in set(generated):
                if 0 <= token_id < self.speech_token_num:
                    scores[batch_idx, token_id] /= self.penalty
        return scores