File size: 3,191 Bytes
841f290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.stopping_criteria import StoppingCriteria

class ASRLogitsProcessor(LogitsProcessor):
    def __init__(self, text_token_num: int):
        self.text_token_num = text_token_num
        
    def __call__(self, input_ids, scores):
        scores[..., self.text_token_num:] = torch.finfo(scores.dtype).min
        return scores

class TTSLogitsProcessor(LogitsProcessor):
    """
    TTS 任务使用的LogitsProcessor,把所有text位置的logits设置为负无穷
    """
    def __init__(self, text_token_num: int):
        self.text_token_num = text_token_num
        
    def __call__(self, input_ids, scores):
        scores[..., :self.text_token_num] = torch.finfo(scores.dtype).min
        return scores

class S2SLogitsProcessor(LogitsProcessor):
    """Speech 2 Speech 任务使用的 LogitsProcessor,当前只适用于batch_size=1

    Args:
        LogitsProcessor (_type_): _description_
    """
    def __init__(self, text_token_num: int, text_eos_id: int):
        self.text_token_num = text_token_num
        self.text_eos_id = text_eos_id
        self.text_phase = True
    def __call__(self, input_ids, scores):
        print(input_ids.shape)
        assert input_ids.size(0) == 1, "ERROR: S2SSpeechLogitsProcessor only support bs=1 now"
        if self.text_phase:
            scores[..., self.text_token_num:] = torch.finfo(scores.dtype).min
        else:
            scores[..., :self.text_token_num] = torch.finfo(scores.dtype).min
        
        if self.text_phase and torch.isin(input_ids, self.text_eos_id):
            self.text_phase = False
            
        return scores

class S2SStopCriteria(StoppingCriteria):
    """Speech 2 Speech 任务使用的 停止条件,当前只适用于batch_size=1

    Args:
        LogitsProcessor (_type_): _description_
    """
    def __init__(self, text_eos_id: int, speech_eos_id: int):
        self.text_eos_id = text_eos_id
        self.speech_eos_id = speech_eos_id
        
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
        _input_ids = input_ids.flatten().view(-1)
        if torch.isin(_input_ids, self.text_eos_id).any():
            text_eos_idx = (_input_ids == self.text_eos_id).nonzero(as_tuple=True)[0][0].item()
            if torch.sum(_input_ids[text_eos_idx:] == self.speech_eos_id) > 1:
                return True
        return False

class MaxTokenStopper(StoppingCriteria):
    def __init__(self, max_tokens):
        self.max_tokens = max_tokens

    # TODO@wsy:期望能够修改max_tokens,但好像没用,后续注意
    def change_max_tokens(self, max_tokens):
        self.max_tokens = max_tokens
        
    def __call__(self, input_ids, scores, **kwargs):
        return input_ids.shape[1] >= self.max_tokens  # 检查当前序列长度

class InterruptStopper(StoppingCriteria):
    def __init__(self):
        self.stop = False
        
    def __call__(self, input_ids, scores, **kwargs):
        if self.stop == True:
            # self.stop == False # reset
            return True
        else:
            return False