File size: 3,090 Bytes
3d9a1f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98bcccd
3d9a1f8
 
98bcccd
 
 
 
3d9a1f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import math
from typing import Dict, List, Optional, Union
import numpy as np
import transformers
from transformers.tokenization_utils_base import AudioInput
from transformers.models.seamless_m4t.feature_extraction_seamless_m4t import SeamlessM4TFeatureExtractor
from transformers.utils import TensorType
from transformers.feature_extraction_utils import BatchFeature
from transformers import AutoFeatureExtractor


def build_audio_tokens(text: List[str], audio_features: Union[Dict, List[List[np.ndarray]]], audio_token="<|audio|>") -> Dict:
    if not isinstance(audio_features, list):
        audio_features = audio_features['audio_features']
    bs = audio_features.shape[0]
    for i in range(bs):
        for j in range(len(audio_features[i])):
            tgt_token = f"<|audio_{j+1}|>" * get_num_embeddings(audio_features[i][j].shape[0])
            text[i] = text[i].replace(audio_token, tgt_token, 1)
    return text
            
def get_num_embeddings(num_framses, adapter_kernel_size=3, adapter_stride=2) -> int:
    pad = adapter_stride // 2
    seq_lens = ((num_framses + 2 * pad - adapter_kernel_size) / adapter_stride) + 1
    l1 = math.floor(seq_lens)
    seq_lens = ((l1 + 2 * pad - adapter_kernel_size) / adapter_stride) + 1
    l2 = math.floor(seq_lens)
    return l2 + 2

class MllamaAudioFeatureExtractor(SeamlessM4TFeatureExtractor):
    
    def __call__(
        self,
        batch_audio_clips: List[List[AudioInput]],
        return_tensors: Optional[Union[str, TensorType]] = None,
    ) -> BatchFeature:
        audio_features = [[ super(MllamaAudioFeatureExtractor, self).__call__(audio_j, sampling_rate=16000, return_attention_mask=False)['input_features'][0] for audio_j in audio_i ] for audio_i in batch_audio_clips ]
        packed_audio_features = self.pack_audio_clips(audio_features)
        
        encoded_audio_inputs = BatchFeature(
            data={
                "audio_features": packed_audio_features,
            },
            tensor_type=return_tensors,
        )
        
        return encoded_audio_inputs
    
    def pack_audio_clips(self, batch_audio_clips: List[List[np.ndarray]]) -> np.ndarray:
        assert batch_audio_clips[0][0].ndim == 2 # sequence length x feature dimension
        # Determine output shape: (batch_size, max_num_clips, max_frames, feature_dim)
        batch_size = len(batch_audio_clips)
        max_num_clips = max([len(clips) for clips in batch_audio_clips])
        max_frames = max([clip.shape[0] for clips in batch_audio_clips for clip in clips])
        feature_dim = batch_audio_clips[0][0].shape[1]
        
        stacked_audio_clips = np.zeros((batch_size, max_num_clips, max_frames, feature_dim), dtype=np.float32) 
        for i, clips in enumerate(batch_audio_clips):
            for j, clip in enumerate(clips):
                stacked_audio_clips[i, j, :clip.shape[0], :] = clip
            
        return stacked_audio_clips

AutoFeatureExtractor.register("MllamaAudioFeatureExtractor", MllamaAudioFeatureExtractor)
transformers.MllamaAudioFeatureExtractor = MllamaAudioFeatureExtractor