shangeth commited on
Commit
90b0faf
1 Parent(s): 1d3796d

generate_meta update

Browse files
Files changed (2) hide show
  1. config.json +6 -1
  2. model.py +10 -9
config.json CHANGED
@@ -1,13 +1,18 @@
1
  {
 
 
 
2
  "audio_enc_dim": 1024,
3
  "audio_encoder_name": "microsoft/wavlm-large",
4
  "audio_processor_name": "microsoft/wavlm-base",
5
  "auto_map": {
6
- "AutoConfig": "config.SpeechLLMModelConfig"
 
7
  },
8
  "llm_dim": 2048,
9
  "llm_model_checkpoint": "hf_repo/llm_model_checkpoint",
10
  "llm_model_name": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
11
  "model_type": "custom_model",
 
12
  "transformers_version": "4.41.2"
13
  }
 
1
  {
2
+ "architectures": [
3
+ "SpeechLLMModel"
4
+ ],
5
  "audio_enc_dim": 1024,
6
  "audio_encoder_name": "microsoft/wavlm-large",
7
  "audio_processor_name": "microsoft/wavlm-base",
8
  "auto_map": {
9
+ "AutoConfig": "config.SpeechLLMModelConfig",
10
+ "AutoModel": "model.SpeechLLMModel"
11
  },
12
  "llm_dim": 2048,
13
  "llm_model_checkpoint": "hf_repo/llm_model_checkpoint",
14
  "llm_model_name": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
15
  "model_type": "custom_model",
16
+ "torch_dtype": "float32",
17
  "transformers_version": "4.41.2"
18
  }
model.py CHANGED
@@ -61,11 +61,11 @@ class SpeechLLMModel(PreTrainedModel):
61
  self.llm_model = get_peft_model(self.llm_model, peft_config)
62
  self.llm_model = self.llm_model.merge_and_unload()
63
 
64
- def encode(self, mel, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids):
65
- batch_size = mel.shape[0]
66
 
67
  with torch.no_grad():
68
- speech_embeds = self.audio_encoder(mel)
69
  speech_embeds = self.connector(speech_embeds)
70
 
71
  embedder = self.llm_model.model.embed_tokens
@@ -83,12 +83,12 @@ class SpeechLLMModel(PreTrainedModel):
83
  ], 1).to(combined_embeds.device).to(torch.int64)
84
  return combined_embeds, atts, label_ids
85
 
86
- def forward(self, wav_tensor, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids, attention_mask=None):
87
- combined_embeds, atts, label_ids = self.encode(wav_tensor, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids)
88
  outputs = self.llm_model(inputs_embeds=combined_embeds, attention_mask=attention_mask)
89
  return outputs
90
 
91
- def generate_meta(self, audio_path, instruction="Give me the following information about the audio [Transcript]", max_new_tokens=2000):
92
  device = self.audio_encoder.return_device()
93
  pre_speech_prompt = f'''Instruction:
94
  {instruction}
@@ -101,14 +101,15 @@ Output:'''
101
  output_prompt = '\n<s>'
102
 
103
  with torch.no_grad():
104
- wav_tensor, sr = torchaudio.load(audio_path)
105
- wav_tensor = self.audio_processor(wav_tensor.squeeze(), return_tensors="pt", sampling_rate=16000).input_values
 
106
 
107
  pre_tokenized_ids = self.llm_tokenizer(pre_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
108
  post_tokenized_ids = self.llm_tokenizer(post_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
109
  output_tokenized_ids = self.llm_tokenizer(output_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
110
 
111
- combined_embeds, atts, label_ids = self.encode(wav_tensor.to(device), pre_tokenized_ids.to(device), post_tokenized_ids.to(device), output_tokenized_ids.to(device))
112
 
113
  out = self.llm_model.generate(
114
  inputs_embeds=combined_embeds,
 
61
  self.llm_model = get_peft_model(self.llm_model, peft_config)
62
  self.llm_model = self.llm_model.merge_and_unload()
63
 
64
+ def encode(self, speech, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids):
65
+ batch_size = speech.shape[0]
66
 
67
  with torch.no_grad():
68
+ speech_embeds = self.audio_encoder(speech)
69
  speech_embeds = self.connector(speech_embeds)
70
 
71
  embedder = self.llm_model.model.embed_tokens
 
83
  ], 1).to(combined_embeds.device).to(torch.int64)
84
  return combined_embeds, atts, label_ids
85
 
86
+ def forward(self, audio_tensor, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids, attention_mask=None):
87
+ combined_embeds, atts, label_ids = self.encode(audio_tensor, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids)
88
  outputs = self.llm_model(inputs_embeds=combined_embeds, attention_mask=attention_mask)
89
  return outputs
90
 
91
+ def generate_meta(self, audio_path=None, audio_tensor=None, instruction="Give me the following information about the audio [Transcript]", max_new_tokens=2000):
92
  device = self.audio_encoder.return_device()
93
  pre_speech_prompt = f'''Instruction:
94
  {instruction}
 
101
  output_prompt = '\n<s>'
102
 
103
  with torch.no_grad():
104
+ if audio_tensor == None and audio_path != None:
105
+ audio_tensor, sr = torchaudio.load(audio_path)
106
+ audio_tensor = self.audio_processor(audio_tensor.squeeze(), return_tensors="pt", sampling_rate=16000).input_values
107
 
108
  pre_tokenized_ids = self.llm_tokenizer(pre_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
109
  post_tokenized_ids = self.llm_tokenizer(post_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
110
  output_tokenized_ids = self.llm_tokenizer(output_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
111
 
112
+ combined_embeds, atts, label_ids = self.encode(audio_tensor.to(device), pre_tokenized_ids.to(device), post_tokenized_ids.to(device), output_tokenized_ids.to(device))
113
 
114
  out = self.llm_model.generate(
115
  inputs_embeds=combined_embeds,