shangeth commited on
Commit
8e999ad
1 Parent(s): 8252cc5

generate_meta update

Browse files
Files changed (2) hide show
  1. config.json +6 -1
  2. model.py +10 -9
config.json CHANGED
@@ -1,13 +1,18 @@
1
  {
 
 
 
2
  "audio_enc_dim": 1280,
3
  "audio_encoder_name": "facebook/hubert-xlarge-ll60k",
4
  "audio_processor_name": "facebook/hubert-large-ls960-ft",
5
  "auto_map": {
6
- "AutoConfig": "config.SpeechLLMModelConfig"
 
7
  },
8
  "llm_dim": 2048,
9
  "llm_model_checkpoint": "hf_repo/llm_model_checkpoint",
10
  "llm_model_name": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
11
  "model_type": "custom_model",
 
12
  "transformers_version": "4.38.2"
13
  }
 
1
  {
2
+ "architectures": [
3
+ "SpeechLLMModel"
4
+ ],
5
  "audio_enc_dim": 1280,
6
  "audio_encoder_name": "facebook/hubert-xlarge-ll60k",
7
  "audio_processor_name": "facebook/hubert-large-ls960-ft",
8
  "auto_map": {
9
+ "AutoConfig": "config.SpeechLLMModelConfig",
10
+ "AutoModel": "model.SpeechLLMModel"
11
  },
12
  "llm_dim": 2048,
13
  "llm_model_checkpoint": "hf_repo/llm_model_checkpoint",
14
  "llm_model_name": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
15
  "model_type": "custom_model",
16
+ "torch_dtype": "float32",
17
  "transformers_version": "4.38.2"
18
  }
model.py CHANGED
@@ -52,11 +52,11 @@ class SpeechLLMModel(PreTrainedModel):
52
  self.llm_model = self.llm_model.merge_and_unload()
53
 
54
 
55
- def encode(self, mel, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids):
56
- batch_size = mel.shape[0]
57
 
58
  with torch.no_grad():
59
- speech_embeds = self.audio_encoder(mel)
60
  embedder = self.llm_model.model.embed_tokens
61
  pre_prompt_embeds = embedder(pre_tokenized_ids)
62
  post_prompt_embeds = embedder(post_tokenized_ids)
@@ -72,12 +72,12 @@ class SpeechLLMModel(PreTrainedModel):
72
  ], 1).to(combined_embeds.device).to(torch.int64)
73
  return combined_embeds, atts, label_ids
74
 
75
- def forward(self, wav_tensor, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids, attention_mask=None):
76
- combined_embeds, atts, label_ids = self.encode(wav_tensor, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids)
77
  outputs = self.llm_model(inputs_embeds=combined_embeds, attention_mask=attention_mask)
78
  return outputs
79
 
80
- def generate_meta(self, audio_path, instruction="Give me the following information about the audio [Transcript]", max_new_tokens=2000):
81
  device = self.audio_encoder.return_device()
82
  pre_speech_prompt = f'''Instruction:
83
  {instruction}
@@ -90,14 +90,15 @@ Output:'''
90
  output_prompt = '\n<s>'
91
 
92
  with torch.no_grad():
93
- wav_tensor, sr = torchaudio.load(audio_path)
94
- wav_tensor = self.audio_processor(wav_tensor.squeeze(), return_tensors="pt", sampling_rate=16000).input_values
95
 
 
 
 
96
  pre_tokenized_ids = self.llm_tokenizer(pre_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
97
  post_tokenized_ids = self.llm_tokenizer(post_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
98
  output_tokenized_ids = self.llm_tokenizer(output_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
99
 
100
- combined_embeds, atts, label_ids = self.encode(wav_tensor.to(device), pre_tokenized_ids.to(device), post_tokenized_ids.to(device), output_tokenized_ids.to(device))
101
 
102
  out = self.llm_model.generate(
103
  inputs_embeds=combined_embeds,
 
52
  self.llm_model = self.llm_model.merge_and_unload()
53
 
54
 
55
+ def encode(self, speech, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids):
56
+ batch_size = speech.shape[0]
57
 
58
  with torch.no_grad():
59
+ speech_embeds = self.audio_encoder(speech)
60
  embedder = self.llm_model.model.embed_tokens
61
  pre_prompt_embeds = embedder(pre_tokenized_ids)
62
  post_prompt_embeds = embedder(post_tokenized_ids)
 
72
  ], 1).to(combined_embeds.device).to(torch.int64)
73
  return combined_embeds, atts, label_ids
74
 
75
+ def forward(self, audio_tensor, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids, attention_mask=None):
76
+ combined_embeds, atts, label_ids = self.encode(audio_tensor, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids)
77
  outputs = self.llm_model(inputs_embeds=combined_embeds, attention_mask=attention_mask)
78
  return outputs
79
 
80
+ def generate_meta(self, audio_path=None, audio_tensor=None, instruction="Give me the following information about the audio [Transcript]", max_new_tokens=2000):
81
  device = self.audio_encoder.return_device()
82
  pre_speech_prompt = f'''Instruction:
83
  {instruction}
 
90
  output_prompt = '\n<s>'
91
 
92
  with torch.no_grad():
 
 
93
 
94
+ if audio_tensor == None and audio_path != None:
95
+ audio_tensor, sr = torchaudio.load(audio_path)
96
+ audio_tensor = self.audio_processor(audio_tensor.squeeze(), return_tensors="pt", sampling_rate=16000).input_values
97
  pre_tokenized_ids = self.llm_tokenizer(pre_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
98
  post_tokenized_ids = self.llm_tokenizer(post_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
99
  output_tokenized_ids = self.llm_tokenizer(output_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
100
 
101
+ combined_embeds, atts, label_ids = self.encode(audio_tensor.to(device), pre_tokenized_ids.to(device), post_tokenized_ids.to(device), output_tokenized_ids.to(device))
102
 
103
  out = self.llm_model.generate(
104
  inputs_embeds=combined_embeds,