Feature Extraction
Transformers
Safetensors
diva
custom_code
Helw150 commited on
Commit
d41acc6
1 Parent(s): 11bcf74

Make More Generic; Reduce Config Size

Browse files
Files changed (2) hide show
  1. config.json +2 -128
  2. modeling_diva.py +29 -31
config.json CHANGED
@@ -1,138 +1,12 @@
1
  {
2
  "model_type": "diva",
 
 
3
  "architectures": [ "DiVAModel" ],
4
  "auto_map": {
5
  "AutoConfig": "configuring_diva.DiVAConfig",
6
  "AutoModel": "modeling_diva.DiVAModel"
7
  },
8
  "vocab_size": 128256,
9
- "decoder": {
10
- "architectures": [
11
- "LlamaForCausalLM"
12
- ],
13
- "attention_bias": false,
14
- "attention_dropout": 0,
15
- "bos_token_id": 128000,
16
- "eos_token_id": 128001,
17
- "hidden_act": "silu",
18
- "hidden_size": 4096,
19
- "initializer_range": 0.02,
20
- "intermediate_size": 14336,
21
- "max_position_embeddings": 8192,
22
- "model_type": "llama",
23
- "num_attention_heads": 32,
24
- "num_hidden_layers": 32,
25
- "num_key_value_heads": 8,
26
- "pretraining_tp": 1,
27
- "rms_norm_eps": 1e-05,
28
- "rope_scaling": null,
29
- "rope_theta": 500000,
30
- "tie_word_embeddings": false,
31
- "torch_dtype": "bfloat16",
32
- "transformers_version": "4.40.0.dev0",
33
- "use_cache": true,
34
- "vocab_size": 128256
35
- },
36
- "encoder": {
37
- "_name_or_path": "openai/whisper-large-v3",
38
- "activation_dropout": 0,
39
- "activation_function": "gelu",
40
- "add_cross_attention": false,
41
- "apply_spec_augment": false,
42
- "architectures": [
43
- "WhisperForConditionalGeneration"
44
- ],
45
- "attention_dropout": 0,
46
- "bad_words_ids": null,
47
- "begin_suppress_tokens": [
48
- 220,
49
- 50257
50
- ],
51
- "bos_token_id": 50257,
52
- "chunk_size_feed_forward": 0,
53
- "classifier_proj_size": 256,
54
- "cross_attention_hidden_size": null,
55
- "d_model": 1280,
56
- "decoder_attention_heads": 20,
57
- "decoder_ffn_dim": 5120,
58
- "decoder_layerdrop": 0,
59
- "decoder_layers": 32,
60
- "decoder_start_token_id": 50258,
61
- "diversity_penalty": 0,
62
- "do_sample": false,
63
- "dropout": 0,
64
- "early_stopping": false,
65
- "encoder_attention_heads": 20,
66
- "encoder_ffn_dim": 5120,
67
- "encoder_layerdrop": 0,
68
- "encoder_layers": 32,
69
- "encoder_no_repeat_ngram_size": 0,
70
- "eos_token_id": 50257,
71
- "exponential_decay_length_penalty": null,
72
- "finetuning_task": null,
73
- "forced_bos_token_id": null,
74
- "forced_eos_token_id": null,
75
- "id2label": {
76
- "0": "LABEL_0",
77
- "1": "LABEL_1"
78
- },
79
- "init_std": 0.02,
80
- "is_decoder": false,
81
- "is_encoder_decoder": true,
82
- "label2id": {
83
- "LABEL_0": 0,
84
- "LABEL_1": 1
85
- },
86
- "length_penalty": 1,
87
- "mask_feature_length": 10,
88
- "mask_feature_min_masks": 0,
89
- "mask_feature_prob": 0,
90
- "mask_time_length": 10,
91
- "mask_time_min_masks": 2,
92
- "mask_time_prob": 0.05,
93
- "max_length": 448,
94
- "max_source_positions": 1500,
95
- "max_target_positions": 448,
96
- "median_filter_width": 7,
97
- "min_length": 0,
98
- "model_type": "whisper",
99
- "no_repeat_ngram_size": 0,
100
- "num_beam_groups": 1,
101
- "num_beams": 1,
102
- "num_hidden_layers": 32,
103
- "num_mel_bins": 128,
104
- "num_return_sequences": 1,
105
- "output_attentions": false,
106
- "output_hidden_states": false,
107
- "output_scores": false,
108
- "pad_token_id": 50256,
109
- "prefix": null,
110
- "problem_type": null,
111
- "pruned_heads": {},
112
- "remove_invalid_values": false,
113
- "repetition_penalty": 1,
114
- "return_dict": true,
115
- "return_dict_in_generate": false,
116
- "scale_embedding": false,
117
- "sep_token_id": null,
118
- "suppress_tokens": null,
119
- "task_specific_params": null,
120
- "temperature": 1,
121
- "tf_legacy_loss": false,
122
- "tie_encoder_decoder": false,
123
- "tie_word_embeddings": true,
124
- "tokenizer_class": null,
125
- "top_k": 50,
126
- "top_p": 1,
127
- "torch_dtype": "float16",
128
- "torchscript": false,
129
- "transformers_version": "4.38.2",
130
- "typical_p": 1,
131
- "use_bfloat16": false,
132
- "use_cache": true,
133
- "use_weighted_layer_sum": false,
134
- "vocab_size": 51866
135
- },
136
- "time_dialation": 4,
137
  "transformers_version": "4.38.2"
138
  }
 
1
  {
2
  "model_type": "diva",
3
+ "reference_encoder": "openai/whisper-large-v3",
4
+ "reference_decoder": "meta-llama/Meta-Llama-3-8B-Instruct",
5
  "architectures": [ "DiVAModel" ],
6
  "auto_map": {
7
  "AutoConfig": "configuring_diva.DiVAConfig",
8
  "AutoModel": "modeling_diva.DiVAModel"
9
  },
10
  "vocab_size": 128256,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  "transformers_version": "4.38.2"
12
  }
modeling_diva.py CHANGED
@@ -10,13 +10,13 @@ import torch.nn.functional as F
10
  from datasets import Audio
11
  from safetensors.torch import load, load_model
12
  from torch import nn
13
- from .configuring_diva import DiVAConfig
14
  from transformers import (
15
  AutoProcessor,
16
  AutoTokenizer,
17
- LlamaForCausalLM,
18
  PreTrainedModel,
19
- WhisperForConditionalGeneration,
20
  )
21
 
22
 
@@ -51,11 +51,9 @@ class DiVAModel(PreTrainedModel):
51
  super().__init__(DiVAConfig.from_dict(config_dict))
52
  if speech_encoder_device is None:
53
  speech_encoder_device = "cuda:0"
54
- whisper = WhisperForConditionalGeneration.from_pretrained(
55
- "openai/whisper-large-v3"
56
- )
57
  connector = WhisperConnector()
58
- connector.decoder = copy.deepcopy(whisper.model.decoder)
59
  if via_path is not None:
60
  with open(via_path, "rb") as f:
61
  sd = load(f.read())
@@ -83,25 +81,25 @@ class DiVAModel(PreTrainedModel):
83
  )
84
 
85
  self.connector = connector.to(speech_encoder_device)
86
- self.whisper_encoder = whisper.model.encoder.to(speech_encoder_device)
87
- self.llama_decoder = LlamaForCausalLM.from_pretrained(
88
- "meta-llama/Meta-Llama-3-8B-Instruct",
89
  device_map=device_map,
90
  torch_dtype=torch.float16,
91
  )
92
- self.processor = AutoProcessor.from_pretrained("openai/whisper-large-v3")
93
  self.tokenizer = AutoTokenizer.from_pretrained("WillHeld/via-llama")
94
  self.prefix = torch.tensor([128000, 128006, 882, 128007, 271]).to(
95
- self.llama_decoder.model.embed_tokens.weight.device
96
  )
97
 
98
  self.pre_user_suffix = torch.tensor(
99
  self.tokenizer.encode(
100
  "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
101
  )
102
- ).to(self.llama_decoder.model.embed_tokens.weight.device)
103
  self.final_header = torch.tensor([128009, 128006, 78191, 128007, 271]).to(
104
- self.llama_decoder.model.embed_tokens.weight.device
105
  )
106
  self.speech_encoder_device = speech_encoder_device
107
 
@@ -161,18 +159,18 @@ class DiVAModel(PreTrainedModel):
161
  ]
162
  virt_tokens = self.connector(
163
  hidden_states,
164
- output_device=self.llama_decoder.model.embed_tokens.weight.device,
165
  ).squeeze()
166
 
167
- prefix_embed = self.llama_decoder.model.embed_tokens(prefix_text_tokens)
168
- suffix_embed = self.llama_decoder.model.embed_tokens(suffix_text_tokens)
169
  inputs_embeds = torch.cat(
170
  [prefix_embed, virt_tokens, suffix_embed], axis=0
171
  ).unsqueeze(0)
172
 
173
- outputs = self.llama_decoder(
174
  inputs_embeds=inputs_embeds.to(
175
- self.llama_decoder.model.embed_tokens.weight.device
176
  ).half(),
177
  return_dict=True,
178
  output_hidden_states=True,
@@ -197,7 +195,7 @@ class DiVAModel(PreTrainedModel):
197
  ]
198
  virt_tokens = self.connector(
199
  hidden_states,
200
- output_device=self.llama_decoder.model.embed_tokens.weight.device,
201
  )
202
  bsz = virt_tokens.shape[0]
203
 
@@ -227,9 +225,9 @@ class DiVAModel(PreTrainedModel):
227
  )
228
  else:
229
  prefix = self.prefix
230
- prefix_embed = self.llama_decoder.model.embed_tokens(prefix).expand(bsz, -1, -1)
231
  suffix = self.final_header
232
- suffix_embed = self.llama_decoder.model.embed_tokens(suffix).expand(bsz, -1, -1)
233
  inputs_embeds = torch.cat([prefix_embed, virt_tokens, suffix_embed], axis=1)
234
  outs = [[] for i in range(bsz)]
235
  complete = [False] * bsz
@@ -238,9 +236,9 @@ class DiVAModel(PreTrainedModel):
238
  i = 0
239
  while not all(complete) and len(outs[0]) < max_new_tokens:
240
  past_key_values = outputs.past_key_values if outputs else None
241
- outputs = self.llama_decoder(
242
  inputs_embeds=inputs_embeds.to(
243
- self.llama_decoder.model.embed_tokens.weight.device
244
  ).half(),
245
  return_dict=True,
246
  output_hidden_states=True,
@@ -268,7 +266,7 @@ class DiVAModel(PreTrainedModel):
268
  if out == 128009:
269
  complete[token_index] = True
270
 
271
- next_embed = self.llama_decoder.model.embed_tokens(greedy.reshape(-1, 1))
272
  inputs_embeds = next_embed
273
  return self.tokenizer.batch_decode(outs, skip_special_tokens=True)
274
 
@@ -287,7 +285,7 @@ class DiVAModel(PreTrainedModel):
287
  ]
288
  virt_tokens = self.connector(
289
  hidden_states,
290
- output_device=self.llama_decoder.model.embed_tokens.weight.device,
291
  ).squeeze()
292
 
293
  if text_prompt != None and text_prompt != "":
@@ -300,9 +298,9 @@ class DiVAModel(PreTrainedModel):
300
  )
301
  else:
302
  prefix = self.prefix
303
- prefix_embed = self.llama_decoder.model.embed_tokens(prefix)
304
  suffix = self.final_header
305
- suffix_embed = self.llama_decoder.model.embed_tokens(suffix)
306
  inputs_embeds = torch.cat(
307
  [prefix_embed, virt_tokens, suffix_embed], axis=0
308
  ).unsqueeze(0)
@@ -312,9 +310,9 @@ class DiVAModel(PreTrainedModel):
312
  i = 0
313
  while greedy != 128009 and len(outs) < max_new_tokens:
314
  past_key_values = outputs.past_key_values if outputs else None
315
- outputs = self.llama_decoder(
316
  inputs_embeds=inputs_embeds.to(
317
- self.llama_decoder.model.embed_tokens.weight.device
318
  ).half(),
319
  return_dict=True,
320
  output_hidden_states=True,
@@ -337,7 +335,7 @@ class DiVAModel(PreTrainedModel):
337
  else:
338
  greedy = next_token_logits.argmax()
339
  outs.append(greedy)
340
- next_embed = self.llama_decoder.model.embed_tokens(greedy.reshape(1, 1))
341
  inputs_embeds = next_embed
342
  yield self.tokenizer.decode(outs, skip_special_tokens=True).replace(
343
  "<|eot_id|>", ""
 
10
  from datasets import Audio
11
  from safetensors.torch import load, load_model
12
  from torch import nn
13
+ from configuring_diva import DiVAConfig
14
  from transformers import (
15
  AutoProcessor,
16
  AutoTokenizer,
17
+ AutoModelForCausalLM,
18
  PreTrainedModel,
19
+ WhisperModel,
20
  )
21
 
22
 
 
51
  super().__init__(DiVAConfig.from_dict(config_dict))
52
  if speech_encoder_device is None:
53
  speech_encoder_device = "cuda:0"
54
+ whisper = WhisperModel.from_pretrained(config_dict["reference_encoder"])
 
 
55
  connector = WhisperConnector()
56
+ connector.decoder = copy.deepcopy(whisper.decoder)
57
  if via_path is not None:
58
  with open(via_path, "rb") as f:
59
  sd = load(f.read())
 
81
  )
82
 
83
  self.connector = connector.to(speech_encoder_device)
84
+ self.whisper_encoder = whisper.encoder.to(speech_encoder_device)
85
+ self.llm_decoder = AutoModelForCausalLM.from_pretrained(
86
+ config_dict["reference_decoder"],
87
  device_map=device_map,
88
  torch_dtype=torch.float16,
89
  )
90
+ self.processor = AutoProcessor.from_pretrained(config_dict["reference_encoder"])
91
  self.tokenizer = AutoTokenizer.from_pretrained("WillHeld/via-llama")
92
  self.prefix = torch.tensor([128000, 128006, 882, 128007, 271]).to(
93
+ self.llm_decoder.model.embed_tokens.weight.device
94
  )
95
 
96
  self.pre_user_suffix = torch.tensor(
97
  self.tokenizer.encode(
98
  "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
99
  )
100
+ ).to(self.llm_decoder.model.embed_tokens.weight.device)
101
  self.final_header = torch.tensor([128009, 128006, 78191, 128007, 271]).to(
102
+ self.llm_decoder.model.embed_tokens.weight.device
103
  )
104
  self.speech_encoder_device = speech_encoder_device
105
 
 
159
  ]
160
  virt_tokens = self.connector(
161
  hidden_states,
162
+ output_device=self.llm_decoder.model.embed_tokens.weight.device,
163
  ).squeeze()
164
 
165
+ prefix_embed = self.llm_decoder.model.embed_tokens(prefix_text_tokens)
166
+ suffix_embed = self.llm_decoder.model.embed_tokens(suffix_text_tokens)
167
  inputs_embeds = torch.cat(
168
  [prefix_embed, virt_tokens, suffix_embed], axis=0
169
  ).unsqueeze(0)
170
 
171
+ outputs = self.llm_decoder(
172
  inputs_embeds=inputs_embeds.to(
173
+ self.llm_decoder.model.embed_tokens.weight.device
174
  ).half(),
175
  return_dict=True,
176
  output_hidden_states=True,
 
195
  ]
196
  virt_tokens = self.connector(
197
  hidden_states,
198
+ output_device=self.llm_decoder.model.embed_tokens.weight.device,
199
  )
200
  bsz = virt_tokens.shape[0]
201
 
 
225
  )
226
  else:
227
  prefix = self.prefix
228
+ prefix_embed = self.llm_decoder.model.embed_tokens(prefix).expand(bsz, -1, -1)
229
  suffix = self.final_header
230
+ suffix_embed = self.llm_decoder.model.embed_tokens(suffix).expand(bsz, -1, -1)
231
  inputs_embeds = torch.cat([prefix_embed, virt_tokens, suffix_embed], axis=1)
232
  outs = [[] for i in range(bsz)]
233
  complete = [False] * bsz
 
236
  i = 0
237
  while not all(complete) and len(outs[0]) < max_new_tokens:
238
  past_key_values = outputs.past_key_values if outputs else None
239
+ outputs = self.llm_decoder(
240
  inputs_embeds=inputs_embeds.to(
241
+ self.llm_decoder.model.embed_tokens.weight.device
242
  ).half(),
243
  return_dict=True,
244
  output_hidden_states=True,
 
266
  if out == 128009:
267
  complete[token_index] = True
268
 
269
+ next_embed = self.llm_decoder.model.embed_tokens(greedy.reshape(-1, 1))
270
  inputs_embeds = next_embed
271
  return self.tokenizer.batch_decode(outs, skip_special_tokens=True)
272
 
 
285
  ]
286
  virt_tokens = self.connector(
287
  hidden_states,
288
+ output_device=self.llm_decoder.model.embed_tokens.weight.device,
289
  ).squeeze()
290
 
291
  if text_prompt != None and text_prompt != "":
 
298
  )
299
  else:
300
  prefix = self.prefix
301
+ prefix_embed = self.llm_decoder.model.embed_tokens(prefix)
302
  suffix = self.final_header
303
+ suffix_embed = self.llm_decoder.model.embed_tokens(suffix)
304
  inputs_embeds = torch.cat(
305
  [prefix_embed, virt_tokens, suffix_embed], axis=0
306
  ).unsqueeze(0)
 
310
  i = 0
311
  while greedy != 128009 and len(outs) < max_new_tokens:
312
  past_key_values = outputs.past_key_values if outputs else None
313
+ outputs = self.llm_decoder(
314
  inputs_embeds=inputs_embeds.to(
315
+ self.llm_decoder.model.embed_tokens.weight.device
316
  ).half(),
317
  return_dict=True,
318
  output_hidden_states=True,
 
335
  else:
336
  greedy = next_token_logits.argmax()
337
  outs.append(greedy)
338
+ next_embed = self.llm_decoder.model.embed_tokens(greedy.reshape(1, 1))
339
  inputs_embeds = next_embed
340
  yield self.tokenizer.decode(outs, skip_special_tokens=True).replace(
341
  "<|eot_id|>", ""