File size: 25,211 Bytes
982b37b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import logging
import math
import typing as tp
import torch
import numpy as np

from ..utils import utils
from ..modules.conditioners import (
    ClassifierFreeGuidanceDropout,
    ConditioningAttributes,
    ConditionType,
)
from .lm import LMModel

logger = logging.getLogger(__name__)
ConditionTensors = tp.Dict[str, ConditionType]
CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]


class MagnetLMModel(LMModel):
    """Transformer-based, non-autoregressive model, operates on multiple streams of audio tokens (MAGNeT).
    Args:
        subcodes_context (int): The number of timesteps attended in the self-attention blocks of codebooks > 0.
                                When set to -1, attention is unrestricted and all timesteps are attended. Defaults to 5.
        compression_model_framerate (int): frame rate of the audio tokenizer.
        segment_duration (int): Sample length in seconds.
        span_len (int): Determines the length of masking spans. This is the minimal length of consecutive masked tokens,
                        for both training and inference. Defaults to 3.
        **kwargs: Additional parameters for the LMModel.
    """
    def __init__(self, subcodes_context: int = 5, compression_model_framerate: int = 50,
                 segment_duration: int = 10, span_len: int = 3, **kwargs):
        super().__init__(**kwargs)
        self.causal = kwargs['causal']
        self.subcodes_context = subcodes_context
        self.span_len = span_len
        self._build_attn_masks(compression_model_framerate=compression_model_framerate,
                               segment_duration=segment_duration,
                               num_heads=kwargs['num_heads'],
                               device=kwargs['device'], dtype=kwargs['dtype'])

    def restricted_context_attn_mask(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
        """Creates a restricted attention mask (local attention map) where the context
           is determined by self.subcodes_context.
        Args:
            seq_len (int): token sequence length.
            device (torch.device): device of the output tensor.
            dtype (torch.dtype): data type of the output tensor.
        Returns:
            torch.Tensor: The restricted attention mask.
        """
        # Return a context restricted non-causal att mask
        queries_pos = torch.arange(seq_len, device=device).view(-1, 1)
        keys_pos = torch.arange(seq_len, device=device).view(1, -1)

        delta = queries_pos - keys_pos
        valid = torch.abs(delta) <= self.subcodes_context
        return torch.where(
            valid,
            torch.zeros([], device=device, dtype=dtype),
            torch.full([], float('-inf'), device=device, dtype=dtype))

    def _stage_attn_mask(self, stage: int, seq_len: int, num_heads: int,
                         device: torch.device, dtype: torch.dtype) -> tp.Optional[torch.Tensor]:
        """Creates a restricted attention mask given the stage (codebook index).
        Args:
            stage (int): The codebook index. Takes values in [0, n_q].
            seq_len (int): Token sequence length.
            num_heads (int): Num transformer attention heads.
            device (torch.device): device of the output tensor.
            dtype (torch.dtype): data type of the output tensor.
        Returns:
            torch.Tensor: Either a restricted attention mask or None if stage attention is unrestricted.
        """
        sa_mask = None

        if stage > 0 and self.subcodes_context > -1:
            # parallel - non-causal - with restricted subcodes context
            sa_mask = self.restricted_context_attn_mask(seq_len, device=device, dtype=dtype)

        if sa_mask is not None:
            # Repeat for each attention head
            sa_mask = sa_mask.repeat((1, num_heads, 1, 1))

            # align8 to enable memory efficient attention
            MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR = 8
            seq_len_aligned = \
                int(np.ceil(seq_len / MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR)) * MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR

            sa_mask_aligned = torch.zeros((1, num_heads, seq_len_aligned, seq_len_aligned), device=device, dtype=dtype)
            sa_mask_aligned[..., :seq_len, :seq_len] = sa_mask
            sa_mask = sa_mask_aligned

        return sa_mask

    def _build_attn_masks(self, compression_model_framerate: int, segment_duration: int, num_heads: int,
                          device: torch.device, dtype: torch.dtype):
        """Construct attention mask per stage. For each of the RVQ codebook levels in the [0, n_q] range,
           either a local attention map or None would be stored as an entry in the self.attn_mask_per_stage list.
        Args:
            compression_model_framerate (int): The frame rate of the tokenizer.
            segment_duration (int): Sample length in seconds.
            num_heads (int): Num transformer attention heads.
            device (torch.device): device of the output tensor.
            dtype (torch.dtype): data type of the output tensor.
        """
        seq_len = compression_model_framerate * segment_duration
        self.attn_mask_per_stage = [self._stage_attn_mask(stage, seq_len, num_heads,
                                                          device, dtype) for stage in range(self.n_q)]

    @torch.no_grad()
    def generate(self,
                 prompt: tp.Optional[torch.Tensor] = None,
                 conditions: tp.List[ConditioningAttributes] = [],
                 num_samples: tp.Optional[int] = None,
                 max_gen_len: int = 256,
                 use_sampling: bool = True,
                 temp: float = 1.0,
                 top_k: int = 250,
                 top_p: float = 0.0,
                 cfg_coef: tp.Optional[float] = None,
                 two_step_cfg: tp.Optional[bool] = None,
                 remove_prompts: bool = False,
                 check: bool = False,
                 callback: tp.Optional[tp.Callable[[int, int], None]] = None,
                 **kwargs) -> torch.Tensor:

        assert cfg_coef is None, "Unsupported in MAGNeT. Use max_cfg_coef,min_cfg_coef instead."
        assert two_step_cfg is None, "MAGNeT currently doesn't support two step classifier-free-guidance."
        assert remove_prompts is False, "MAGNeT currently doesn't support the remove_prompts arg."
        assert check is False, "MAGNeT currently doesn't support the check arg."
        # Call the MAGNeT-specific generation method
        return self._generate_magnet(prompt=prompt,
                                     conditions=conditions,
                                     num_samples=num_samples,
                                     max_gen_len=max_gen_len,
                                     use_sampling=use_sampling,
                                     temp=temp,
                                     top_k=top_k,
                                     top_p=top_p,
                                     callback=callback, **kwargs)

    @torch.no_grad()
    def _generate_magnet(self,
                         prompt: tp.Optional[torch.Tensor] = None,
                         conditions: tp.List[ConditioningAttributes] = [],
                         num_samples: tp.Optional[int] = None,
                         max_gen_len: int = 256,
                         use_sampling: bool = True,
                         temp: float = 3.0,
                         top_k: int = 0,
                         top_p: float = 0.9,
                         callback: tp.Optional[tp.Callable[[int, int], None]] = None,
                         max_cfg_coef: float = 10.0,
                         min_cfg_coef: float = 1.0,
                         decoding_steps: tp.List[int] = [20, 10, 10, 10],
                         anneal_temp: bool = True,
                         span_scoring='max',
                         span_arrangement='nonoverlap') -> torch.Tensor:
        """Generate audio tokens given textual conditions, and optionally given audio prompts,
        by running MAGNeT's iterative decoding algorithm for each of the n_q RVQ levels.
        Args:
            prompt (torch.Tensor): Prompt tokens of shape [B, K, T].
            conditions (list of ConditioningAttributes): List of conditions.
            num_samples (int): Number of samples to generate when no prompt and no conditions are given.
            max_gen_len (int): Maximum generation length.
            use_sampling (bool): Whether to use a sampling strategy or not.
            temp (float): Initial sampling temperature.
            top_k (int): k for "top-k" sampling.
            top_p (float): p for "top-p" sampling.
            callback (Callback): Callback function to report generation progress.
            max_clsfg_coef (float): Initial coefficient used for classifier free guidance.
            min_clsfg_coef (float): Final coefficient used for classifier free guidance.
            decoding_steps (list of n_q ints): The number of iterative decoding steps,
                                            for each of the n_q RVQ codebooks.
            anneal_temp (bool): When set to True, softmax temperature will be linearly decayed to zero, at each stage.
            span_scoring (str): Use the maximum probability of each span ('max')
                                or the product of probabilities ('prod').
            span_arrangement (str): Use either non-overlapping spans ('nonoverlap') or overlapping spans ('stride1').
                                                in the masking scheme.
        Returns:
            torch.Tensor: Generated tokens.
        """
        assert not self.training, "generation shouldn't be used in training mode."
        first_param = next(iter(self.parameters()))
        device = first_param.device

        # Checking all input shapes are consistent.
        possible_num_samples = []
        if num_samples is not None:
            possible_num_samples.append(num_samples)
        elif prompt is not None:
            possible_num_samples.append(prompt.shape[0])
        elif conditions:
            possible_num_samples.append(len(conditions))
        else:
            possible_num_samples.append(1)
        assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
        num_samples = possible_num_samples[0]

        # below we create set of conditions: one conditional and one unconditional
        # to do that we merge the regular condition together with the null condition
        # we then do 1 forward pass instead of 2.
        cfg_conditions: tp.Optional[ConditionTensors]
        if conditions:
            null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
            conditions = conditions + null_conditions
            tokenized = self.condition_provider.tokenize(conditions)
            cfg_conditions = self.condition_provider(tokenized)
        else:
            cfg_conditions = {}

        if prompt is None:
            assert num_samples > 0
            prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)

        B, K, prompt_length = prompt.shape
        start_offset = prompt_length
        assert start_offset < max_gen_len

        mask_id = self.special_token_id

        # we generate codes with a fixed sequence length
        shape = (B, K, max_gen_len)

        gen_codes = torch.full(shape, mask_id, dtype=torch.long, device=device)
        # filling the gen_codes with the prompt if needed
        gen_codes[..., :start_offset] = prompt
        # create the gen_sequence with proper interleaving from the pattern: [B, K, S]
        gen_sequence = gen_codes

        curr_step = 0
        for stage, n_steps in zip(range(self.n_q), decoding_steps):
            gen_sequence, curr_step = self._generate_stage(gen_sequence,
                                                           cfg_conditions,
                                                           stage=stage,
                                                           device=device,
                                                           prompt_length=prompt_length,
                                                           prompt=prompt,
                                                           temp=temp,
                                                           max_cfg_coef=max_cfg_coef,
                                                           min_cfg_coef=min_cfg_coef,
                                                           top_k=top_k,
                                                           top_p=top_p,
                                                           timesteps=n_steps,
                                                           anneal_temp=anneal_temp,
                                                           span_scoring=span_scoring,
                                                           use_sampling=use_sampling,
                                                           span_arrangement=span_arrangement,
                                                           curr_step=curr_step,
                                                           total_steps=sum(decoding_steps),
                                                           callback=callback)

        return gen_sequence

    @torch.no_grad()
    def _generate_stage(self,
                        gen_sequence: torch.Tensor,
                        condition_tensors: tp.Optional[ConditionTensors],
                        stage: int,
                        device: torch.device,
                        prompt_length: int = 0,
                        prompt: tp.Optional[torch.Tensor] = None,
                        use_sampling: bool = True,
                        temp: float = 3.0,
                        max_cfg_coef: float = 10.0,
                        min_cfg_coef: float = 1.0,
                        top_k: int = 0,
                        top_p: float = 0.0,
                        timesteps: int = 10,
                        anneal_temp: bool = True,
                        span_scoring: str = 'max',
                        span_arrangement: str = 'nonoverlap',
                        curr_step: int = 0,
                        total_steps: int = 0,
                        callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> tp.Tuple[torch.Tensor, int]:
        """Generate audio tokens of a single RVQ level (stage), given the previously generated stages,
           and the textual conditions.
        Args:
            gen_sequence (torch.Tensor): Previously generated tokens.
            condition_tensors (tp.Optional[ConditionTensors]): pre-computed conditioning tensors.
            stage (int): RVQ level to generate.
            device (torch.device): device of the output tensor.
            prompt_length (int): Temporal length of the audio prompt.
            prompt (torch.Tensor): Prompt tokens of shape [B, K, T].
            use_sampling (bool): Whether to use a sampling strategy or not.
            temp (float): Initial sampling temperature.
            max_clsfg_coef (float): Initial coefficient used for classifier free guidance.
            min_clsfg_coef (float): Final coefficient used for classifier free guidance.
            top_k (int): k for "top-k" sampling.
            top_p (float): p for "top-p" sampling.
            timesteps (int): Number of iterative decoding steps.
            anneal_temp (bool): When set to True, softmax temperature will be linearly decayed to zero, at each stage.
            span_scoring (str): Use the maximum probability of each span ('max')
                                or the product of probabilities ('prod').
            span_arrangement (str): Use either non-overlapping spans ('nonoverlap') or overlapping spans ('stride1').
                                                in the masking scheme.
            curr_step (int): Global iterative decoding step counter.
            total_steps (int): Total decoding steps.
            callback (Callback): Callback function to report generation progress.
        Returns:
            tuple(torch.Tensor, int): Generated tokens and the current decoding step counter.
        """
        B, K, T = gen_sequence.shape
        shape = (B, 1, T)  # generating a single codebook per stage

        mask_id = self.special_token_id
        stage_gen_seq = torch.full(shape, mask_id, dtype=torch.long, device=device)

        assert span_arrangement == 'nonoverlap' or span_arrangement == 'stride1'
        chunk_masking = self.span_len > 1 and span_arrangement == 'nonoverlap'

        DONT_REMASK_ME_SCORE = -1e4

        model = self if self._fsdp is None else self._fsdp

        if chunk_masking:
            # span-wise scores
            n_chunks = T // self.span_len
            if T % self.span_len != 0:
                # trim sequence ending to achieve a multiple of span_len
                T = self.span_len * n_chunks
                gen_sequence = gen_sequence[..., :T]
                stage_gen_seq = stage_gen_seq[..., :T]

            chunked_shape = (B, 1, n_chunks)
            n_prompt_chunks = prompt_length // self.span_len
            scores = torch.zeros(chunked_shape, dtype=torch.float32, device=device)
            scores[..., :n_prompt_chunks] = DONT_REMASK_ME_SCORE
            num_chunks_to_gen = n_chunks - n_prompt_chunks
        else:
            # token-wise scores
            scores = torch.zeros(shape, dtype=torch.float32, device=device)
            scores[..., :prompt_length] = DONT_REMASK_ME_SCORE
            gen_T = T - prompt_length

        # run MAGNeT iterative decoding for "timesteps" iterations
        for timestep, steps_left in zip(torch.linspace(0, 1, timesteps, device=device), reversed(range(timesteps))):

            mask_p = torch.cos(timestep * math.pi * 0.5)

            if chunk_masking:
                num_masked = max(int((mask_p * num_chunks_to_gen).item()), 1)
            else:
                num_masked = max(int((mask_p * gen_T).item()), 1)

            # masking
            run_lps_masking = (span_arrangement == 'stride1') and self.span_len > 1
            if run_lps_masking:
                # masking of the k least probable overlapping (stride 1) spans
                mask = torch.concat((
                    [self._least_probable_span_masking(scores[[i], :, :], num_masked).to(device)
                     for i in range(B)]), dim=0)
                stage_gen_seq[mask] = mask_id
            else:
                # masking of the k least probable non-overlapping spans
                masked = scores.topk(num_masked, dim=-1).indices
                if chunk_masking:
                    chunks_mask = torch.full(chunked_shape, False, dtype=torch.bool, device=device)
                    chunks_mask = chunks_mask.scatter(2, masked, True)
                    mask = torch.repeat_interleave(chunks_mask, self.span_len, dim=-1)
                    stage_gen_seq[mask] = mask_id
                else:
                    stage_gen_seq = stage_gen_seq.scatter(2, masked, mask_id)

            if prompt is not None:
                stage_gen_seq[..., :prompt_length] = prompt[:, stage, :].unsqueeze(1)

            gen_sequence[:, [stage], :] = stage_gen_seq
            if condition_tensors:
                # duplicate input for classifier free guidance
                sequence = torch.cat([gen_sequence, gen_sequence], dim=0)

            all_logits = model(sequence, [], condition_tensors, stage=stage)

            if condition_tensors:
                # classifier free guidance with annealing
                cond_logits, uncond_logits = all_logits.split(B, dim=0)  # [B, K, T, card]
                clsfg_coef = float(mask_p) * max_cfg_coef + (1 - float(mask_p)) * min_cfg_coef
                logits = uncond_logits + (cond_logits - uncond_logits) * clsfg_coef
            else:
                logits = all_logits

            # temperature annealing - linear
            t = temp * (steps_left / timesteps) if anneal_temp else temp

            # sampling
            logits = logits[:, stage, :, :].unsqueeze(1)
            probs = torch.softmax(logits / max(t, 1e-2), dim=-1)
            if use_sampling:
                if top_p > 0.0:
                    sampled_tokens = utils.sample_top_p(probs, p=top_p)
                elif top_k > 0:
                    sampled_tokens = utils.sample_top_k(probs, k=top_k)
                else:
                    sampled_tokens = utils.multinomial(probs, num_samples=1)
            else:
                sampled_tokens = torch.argmax(logits, dim=-1, keepdim=True)

            # place mask_id token in each of the masked positions
            mask = stage_gen_seq == mask_id
            stage_gen_seq = torch.where(mask, sampled_tokens[..., 0], stage_gen_seq)
            gen_sequence[:, [stage], :] = stage_gen_seq

            # get probs of sampled tokens
            sampled_probs = torch.gather(probs, 3, sampled_tokens)[..., 0]

            # span scoring
            if chunk_masking:
                if span_scoring == 'max':
                    # max in linear space
                    scores = 1 - torch.max(sampled_probs.reshape((B, 1, n_chunks, -1)), dim=-1)[0]
                elif span_scoring == 'prod':
                    # prod in log space
                    scores = torch.sum(-torch.log(sampled_probs).reshape((B, 1, n_chunks, -1)), dim=-1)
                else:
                    raise NotImplementedError
            else:
                # prod in log space for lps masking (stride1)
                scores = -torch.log(sampled_probs)

            # Fix unmasked tokens by placing inf probs (-inf scores)
            if chunk_masking:
                scores = scores.masked_fill(~chunks_mask, DONT_REMASK_ME_SCORE)
            else:
                scores = scores.masked_fill(~mask, DONT_REMASK_ME_SCORE)

            if callback is not None:
                curr_step += 1
                callback(curr_step, total_steps)

        return gen_sequence, curr_step

    def _construct_spans_mask(self, span_starts: torch.Tensor, T: int, device: torch.device) -> torch.Tensor:
        """Build a [1x1xT] boolean mask consists of overlapping spans of True values, where
           span_starts defines the initial index of each span, and the span length is
           defined by self.span_len.
        Args:
            span_starts (torch.Tensor): Boolean mask determines the temporal location of each span start.
            T (int): Sequence length.
            device (torch.device): device of the output tensor.
        Returns:
            torch.Tensor: Spans mask of shape [1x1xT]
        """
        mask = torch.full((1, 1, T), False, device=device)
        mask[:, :, span_starts] = True
        shifted_mask = mask.clone()
        for _ in range(self.span_len - 1):
            shifted_mask = torch.concat((torch.full((1, 1, 1), False, device=device), shifted_mask[:, :, :-1]), dim=-1)
            mask = torch.logical_or(mask, shifted_mask)
        return mask

    def _least_probable_span_masking(self, scores: torch.Tensor, num_masked_trg: int) -> torch.Tensor:
        """Construct a [1x1xT] boolean mask, consists of the u least probable spans,
           where the token probability is determined by -scores, and the total
           number of masked tokens is as closest as possible to num_masked_trg.
           Find u using binary search.
        Args:
            scores (torch.Tensor): Per token score [-log(prob)]
            num_masked_trg: int: The desired amount of tokens to be masked.
        Returns:
            torch.Tensor: Spans mask of shape [1x1xT]
        """
        T = scores.shape[-1]
        device = scores.device
        scores_unfolded = scores.unfold(2, self.span_len, 1)
        # Span score is the product of probs (sum in log space)
        span_scores = scores_unfolded.sum(dim=-1)
        spans_by_scores = torch.argsort(span_scores[0, 0], descending=True)

        num_masked_trg = max(num_masked_trg, self.span_len)

        # Binary search for u - the number least probable overlapping masked spans s.t.
        # the total masking rate is the closest to num_masked_trg / T.
        min_u = num_masked_trg // self.span_len
        max_u = num_masked_trg - self.span_len + 1
        mid = round(0.5 * (min_u + max_u))

        if mid == min_u or mid == max_u:
            return self._construct_spans_mask(spans_by_scores[:mid], T, device)

        while mid > min_u and mid < max_u:
            mask = self._construct_spans_mask(spans_by_scores[:mid], T, device)
            n_masked = mask.sum()
            if n_masked > num_masked_trg:
                max_u = mid
                mid = round(0.5 * (min_u + max_u))
            else:
                min_u = mid
                mid = round(0.5 * (min_u + max_u))

        return mask