File size: 5,075 Bytes
5fbdd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90d1f68
5fbdd3c
 
 
 
 
 
 
 
 
 
90d1f68
5fbdd3c
 
90d1f68
5fbdd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2ccf6d
 
5fbdd3c
 
 
 
c2ccf6d
 
 
5fbdd3c
c2ccf6d
 
5fbdd3c
 
 
 
c2ccf6d
5fbdd3c
 
c2ccf6d
 
 
 
 
5fbdd3c
 
 
 
915339e
 
 
 
5fbdd3c
915339e
 
 
 
 
 
 
 
 
 
 
 
5fbdd3c
915339e
5fbdd3c
915339e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import preprocess
from shared import CustomTokens
from dataclasses import dataclass, field


@dataclass
class SegmentationArguments:
    pause_threshold: int = field(default=2, metadata={
        'help': 'When the time between words is greater than pause threshold, force into a new segment'})


# WORDS TO ALWAYS HAVE ON THEIR OWN
# always_split_re = re.compile(r'\[\w+\]')
# e.g., [Laughter], [Applause], [Music]
always_split = [
    CustomTokens.MUSIC.value,
    CustomTokens.APPLAUSE.value,
    CustomTokens.LAUGHTER.value
]


def get_overlapping_chunks_of_tokens(tokens, size, overlap):
    for i in range(0, len(tokens), size-overlap+1):
        yield tokens[i:i+size]


# Generate up to max_tokens - SAFETY_TOKENS
SAFETY_TOKENS = 12


# TODO play around with this?
OVERLAP_TOKEN_PERCENTAGE = 0.5  # 0.25


def add_labels_to_words(words, sponsor_segments):

    # TODO binary search
    for word in words:
        word['category'] = None
        for sponsor_segment in sponsor_segments:
            if sponsor_segment['start'] <= word['start'] <= sponsor_segment['end']:
                word['category'] = sponsor_segment['category']

    # TODO use extract_segment with mapping function?
    # TODO remove sponsor segments that contain mostly empty space?

    return words


def generate_labelled_segments(words, tokenizer, segmentation_args, sponsor_segments):
    segments = generate_segments(words, tokenizer, segmentation_args)

    labelled_segments = list(
        map(lambda x: add_labels_to_words(x, sponsor_segments), segments))

    return labelled_segments


def word_start(word):
    return word['start']


def word_end(word):
    return word.get('end', word['start'])


def generate_segments(words, tokenizer, segmentation_args):
    first_pass_segments = []

    for index, word in enumerate(words):
        # Get length of tokenized word
        cleaned = preprocess.clean_text(word['text'])
        word['num_tokens'] = len(
            tokenizer(cleaned, add_special_tokens=False, truncation=True).input_ids)

        add_new_segment = index == 0
        if not add_new_segment:

            if word['text'] in always_split or words[index-1]['text'] in always_split:
                add_new_segment = True

            # Pause too small, do not split
            elif word_start(words[index]) - word_end(words[index-1]) >= segmentation_args.pause_threshold:
                add_new_segment = True

        if add_new_segment:  # New segment
            first_pass_segments.append([word])

        else:  # Add to current segment
            first_pass_segments[-1].append(word)

    max_q_size = tokenizer.model_max_length - SAFETY_TOKENS

    buffer_size = OVERLAP_TOKEN_PERCENTAGE*max_q_size  # tokenizer.model_max_length

    # In second pass, we split those segments if too big
    second_pass_segments = []
    for segment in first_pass_segments:
        current_segment_num_tokens = 0
        current_segment = []
        for word in segment:
            new_seg = current_segment_num_tokens + word['num_tokens'] >= max_q_size
            if new_seg:
                # Adding this token would make it have too many tokens
                # We save this batch and create new
                second_pass_segments.append(current_segment.copy())

            # Add tokens to current segment
            current_segment.append(word)
            current_segment_num_tokens += word['num_tokens']

            if new_seg:
                # Just created a new segment, so we remove until we only have buffer_size tokens
                while current_segment_num_tokens > buffer_size and current_segment:
                    first_word = current_segment.pop(0)
                    current_segment_num_tokens -= first_word['num_tokens']

        if current_segment: # Add remaining segment
            second_pass_segments.append(current_segment.copy())

    # Cleaning up, delete 'num_tokens' from each word
    for segment in second_pass_segments:
        for word in segment:  
            word.pop('num_tokens', None)
    
    return second_pass_segments


def extract_segment(words, start, end, map_function=None):
    """Extracts all words with time in [start, end]"""
    
    a = binary_search(words, 0, len(words), start, True)
    b = min(binary_search(words, 0, len(words), end , False) + 1, len(words))

    to_transform = map_function is not None and callable(map_function)
    
    return [
        map_function(words[i]) if to_transform else words[i] for i in range(a, b)
    ]

# Binary search to get first index of word whose start/end time is greater/less than some value
def binary_search(words, start_index, end_index, time, below):
    if start_index >= end_index:
        return end_index
    
    middle_index = (start_index + end_index ) // 2

    middle_time = word_start(words[middle_index]) if below else word_end(words[middle_index])

    if time <= middle_time:
        return binary_search(words, start_index, middle_index, time, below)
    else:
        return binary_search(words, middle_index + 1, end_index, time, below)