File size: 5,421 Bytes
392f8fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
from torch import nn
import torchaudio
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoTokenizer, HubertModel, AutoFeatureExtractor, AutoModel
from .config import SpeechLLMModelConfig
from peft import LoraConfig, get_peft_model

class TransformerAudioEnoder(nn.Module):
    def __init__(self, model_name='microsoft/wavlm-large', finetune=False):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)

    def forward(self, x):
        return self.encoder(x).last_hidden_state
    
    def return_device(self):
        return next(self.parameters()).device


class CNNConnector(nn.Module):
    def __init__(self, in_channels, out_channels, k=2):
        super().__init__()
        self.layer = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(in_channels, out_channels//2, kernel_size=5,
                      stride=1, padding=0),
            nn.ReLU(),
            nn.Conv1d(out_channels//2, out_channels, kernel_size=5,
                      stride=k, padding=0),
            nn.ReLU(),
            nn.Conv1d(out_channels, out_channels, kernel_size=5,
                      stride=1, padding=0),
        )

    def forward(self, x):
        return self.layer(x.transpose(1,2)).transpose(1,2)
    
    
class SpeechLLMModel(PreTrainedModel):
    config_class = SpeechLLMModelConfig

    def __init__(self, config):
        super().__init__(config)
        self.audio_processor = AutoFeatureExtractor.from_pretrained(config.audio_processor_name)
        self.audio_encoder = TransformerAudioEnoder(config.audio_encoder_name)
        self.connector = CNNConnector(config.audio_enc_dim, config.llm_dim)

        self.llm_model = AutoModelForCausalLM.from_pretrained(config.llm_model_checkpoint)
        self.llm_tokenizer = AutoTokenizer.from_pretrained(config.llm_model_name)

        # self.llm_model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
        # self.llm_tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

        # peft_config = LoraConfig(
        #     r=8,
        #     lora_alpha=16,
        #     target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'up_proj', 'down_proj', 'gate_proj'],
        #     lora_dropout=0.05,
        #     task_type="CAUSAL_LM",
        # )
        # self.llm_model = get_peft_model(self.llm_model, peft_config)

    def encode(self, mel, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids):
        batch_size = mel.shape[0]

        with torch.no_grad():
            speech_embeds = self.audio_encoder(mel)
            speech_embeds = self.connector(speech_embeds)
            
        embedder = self.llm_model.model.embed_tokens
        pre_prompt_embeds = embedder(pre_tokenized_ids)
        post_prompt_embeds = embedder(post_tokenized_ids)
        output_prompt_embeds = embedder(output_tokenized_ids)

        combined_embeds = torch.cat([pre_prompt_embeds, speech_embeds, post_prompt_embeds, output_prompt_embeds], dim=1)
        atts = torch.ones(combined_embeds.size()[:-1], dtype=torch.long).to(combined_embeds.device)

        input_token_length = pre_tokenized_ids.shape[1] + speech_embeds.shape[1] + post_tokenized_ids.shape[1]
        label_ids = torch.cat([
            torch.ones([batch_size, input_token_length], device=combined_embeds.device) * -100,
            output_tokenized_ids
        ], 1).to(combined_embeds.device).to(torch.int64)
        return combined_embeds, atts, label_ids

    def forward(self, wav_tensor, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids, attention_mask=None):
        combined_embeds, atts, label_ids = self.encode(wav_tensor, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids)
        outputs = self.llm_model(inputs_embeds=combined_embeds, attention_mask=attention_mask)
        return outputs
    
    def generate_meta(self, audio_path, instruction="Give me the following information about the audio [Transcript]", max_new_tokens=2000):
        device = self.audio_encoder.return_device()
        pre_speech_prompt = f'''Instruction:
{instruction}

Input: 
<speech>'''
        post_speech_prompt = f'''</speech>

Output:'''
        output_prompt = '\n<s>'

        with torch.no_grad():
            wav_tensor, sr = torchaudio.load(audio_path)
            wav_tensor = self.audio_processor(wav_tensor.squeeze(), return_tensors="pt", sampling_rate=16000).input_values

            pre_tokenized_ids = self.llm_tokenizer(pre_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
            post_tokenized_ids = self.llm_tokenizer(post_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
            output_tokenized_ids = self.llm_tokenizer(output_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]

            combined_embeds, atts, label_ids = self.encode(wav_tensor.to(device), pre_tokenized_ids.to(device), post_tokenized_ids.to(device), output_tokenized_ids.to(device))

            out = self.llm_model.generate(
                inputs_embeds=combined_embeds,
                max_new_tokens=max_new_tokens,
            ).cpu().tolist()[0]

            output_text = self.llm_tokenizer.decode(out, skip_special_tokens=True)
            return output_text