|
import copy |
|
import json |
|
import os |
|
from typing import Optional, Union |
|
|
|
import librosa |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from datasets import Audio |
|
from safetensors.torch import load, load_model |
|
from torch import nn |
|
from .configuring_diva import DiVAConfig |
|
from transformers import ( |
|
AutoProcessor, |
|
AutoTokenizer, |
|
LlamaForCausalLM, |
|
PreTrainedModel, |
|
WhisperForConditionalGeneration, |
|
) |
|
|
|
|
|
class WhisperConnector(nn.Module): |
|
def __init__( |
|
self, |
|
): |
|
super().__init__() |
|
self.decoder = None |
|
self.projection = nn.Linear(1280, 4096) |
|
self.query_tokens = nn.Parameter(torch.randn(448, 1280)) |
|
|
|
def forward(self, x, output_device="cuda:1"): |
|
bsz = x.shape[0] |
|
query_tokens = self.query_tokens[None, :, :].expand(bsz, -1, -1) |
|
virt_whisper_tokens = self.decoder( |
|
inputs_embeds=query_tokens, encoder_hidden_states=x |
|
) |
|
if self.projection.weight.shape[-1] == 5120: |
|
virtual_tokens = self.projection(virt_whisper_tokens[0].reshape(112, 5120)) |
|
else: |
|
virtual_tokens = self.projection(virt_whisper_tokens[0]) |
|
return virtual_tokens.to(output_device) |
|
|
|
|
|
class DiVAModel(PreTrainedModel): |
|
config_class = DiVAConfig |
|
|
|
def __init__( |
|
self, via_path=None, config_dict={}, device_map=None, speech_encoder_device=None |
|
): |
|
super().__init__(DiVAConfig.from_dict(config_dict)) |
|
if speech_encoder_device is None: |
|
speech_encoder_device = "cuda:0" |
|
whisper = WhisperForConditionalGeneration.from_pretrained( |
|
"openai/whisper-large-v3" |
|
) |
|
connector = WhisperConnector() |
|
connector.decoder = copy.deepcopy(whisper.model.decoder) |
|
if via_path is not None: |
|
with open(via_path, "rb") as f: |
|
sd = load(f.read()) |
|
|
|
with torch.no_grad(): |
|
connector.query_tokens = nn.Parameter(sd["query_tokens"]) |
|
connector.projection.weight = nn.Parameter(sd["projection.weight"].T) |
|
connector.projection.bias = nn.Parameter(sd["projection.bias"]) |
|
wsd = { |
|
key.replace("connector.", ""): sd[key] |
|
for key in sd |
|
if key.startswith("connector.") |
|
} |
|
connector.decoder.load_state_dict(wsd) |
|
|
|
if device_map == None: |
|
num_layers = 32 |
|
num_gpus = 2 |
|
device_map = dict( |
|
**{"model.embed_tokens": 1, "model.norm": 1, "lm_head": 2}, |
|
**{ |
|
"model.layers." + str(i): 1 + (i // (num_layers // num_gpus)) |
|
for i in range(num_layers) |
|
}, |
|
) |
|
|
|
self.connector = connector.to(speech_encoder_device) |
|
self.whisper_encoder = whisper.model.encoder.to(speech_encoder_device) |
|
self.llama_decoder = LlamaForCausalLM.from_pretrained( |
|
"meta-llama/Meta-Llama-3-8B-Instruct", |
|
device_map=device_map, |
|
torch_dtype=torch.float16, |
|
) |
|
self.processor = AutoProcessor.from_pretrained("openai/whisper-large-v3") |
|
self.tokenizer = AutoTokenizer.from_pretrained("WillHeld/via-llama") |
|
self.prefix = torch.tensor([128000, 128006, 882, 128007, 271]).to( |
|
self.llama_decoder.model.embed_tokens.weight.device |
|
) |
|
|
|
self.pre_user_suffix = torch.tensor( |
|
self.tokenizer.encode( |
|
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n" |
|
) |
|
).to(self.llama_decoder.model.embed_tokens.weight.device) |
|
self.final_header = torch.tensor([128009, 128006, 78191, 128007, 271]).to( |
|
self.llama_decoder.model.embed_tokens.weight.device |
|
) |
|
self.speech_encoder_device = speech_encoder_device |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], |
|
*model_args, |
|
config=None, |
|
cache_dir=None, |
|
**kwargs, |
|
): |
|
if os.path.isdir(pretrained_model_name_or_path): |
|
via_path = ( |
|
pretrained_model_name_or_path + "/model-00001-of-00004.safetensors" |
|
) |
|
config_path = pretrained_model_name_or_path + "/config.json" |
|
else: |
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
hf_hub_download( |
|
repo_id=pretrained_model_name_or_path, |
|
filename="model-00001-of-00004.safetensors", |
|
token=kwargs.get("token", None), |
|
local_dir=os.path.dirname(__file__), |
|
) |
|
hf_hub_download( |
|
repo_id=pretrained_model_name_or_path, |
|
filename="config.json", |
|
token=kwargs.get("token", None), |
|
local_dir=os.path.dirname(__file__), |
|
) |
|
via_path = os.path.dirname(__file__) + "/model-00001-of-00004.safetensors" |
|
config_path = os.path.dirname(__file__) + "/config.json" |
|
with open(config_path, "r") as f: |
|
config_dict = json.loads(f.read()) |
|
return cls( |
|
via_path, |
|
config_dict, |
|
kwargs["device_map"] if "device_map" in kwargs else "auto", |
|
( |
|
kwargs["speech_encoder_device"] |
|
if "speech_encoder_device" in kwargs |
|
else None |
|
), |
|
) |
|
|
|
def forward(self, audio, prefix_text_tokens, suffix_text_tokens): |
|
inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000) |
|
input_features = inputs.input_features.to(self.speech_encoder_device) |
|
hidden_states = self.whisper_encoder(input_features=input_features)[ |
|
"last_hidden_state" |
|
] |
|
virt_tokens = self.connector( |
|
hidden_states, |
|
output_device=self.llama_decoder.model.embed_tokens.weight.device, |
|
).squeeze() |
|
|
|
prefix_embed = self.llama_decoder.model.embed_tokens(prefix_text_tokens) |
|
suffix_embed = self.llama_decoder.model.embed_tokens(suffix_text_tokens) |
|
inputs_embeds = torch.cat( |
|
[prefix_embed, virt_tokens, suffix_embed], axis=0 |
|
).unsqueeze(0) |
|
|
|
outputs = self.llama_decoder( |
|
inputs_embeds=inputs_embeds.to( |
|
self.llama_decoder.model.embed_tokens.weight.device |
|
).half(), |
|
return_dict=True, |
|
output_hidden_states=True, |
|
past_key_values=past_key_values, |
|
) |
|
|
|
return outputs |
|
|
|
def generate( |
|
self, audio, text_prompt, do_sample=False, logits_processor=None, max_new_tokens=128 |
|
): |
|
inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000) |
|
input_features = inputs.input_features.to(self.speech_encoder_device) |
|
hidden_states = self.whisper_encoder(input_features=input_features)[ |
|
"last_hidden_state" |
|
] |
|
virt_tokens = self.connector( |
|
hidden_states, |
|
output_device=self.llama_decoder.model.embed_tokens.weight.device, |
|
).squeeze() |
|
|
|
if text_prompt != None and text_prompt != "": |
|
user_prompt_text = torch.tensor( |
|
self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"], |
|
device=self.pre_user_suffix.device, |
|
) |
|
prefix = torch.cat( |
|
[self.pre_user_suffix, user_prompt_text, self.prefix], axis=0 |
|
) |
|
else: |
|
prefix = self.prefix |
|
prefix_embed = self.llama_decoder.model.embed_tokens(prefix) |
|
suffix = self.final_header |
|
suffix_embed = self.llama_decoder.model.embed_tokens(suffix) |
|
inputs_embeds = torch.cat( |
|
[prefix_embed, virt_tokens, suffix_embed], axis=0 |
|
).unsqueeze(0) |
|
outs = [] |
|
outputs = None |
|
greedy = 1 |
|
i = 0 |
|
while greedy != 128009 and len(outs) < max_new_tokens: |
|
past_key_values = outputs.past_key_values if outputs else None |
|
outputs = self.llama_decoder( |
|
inputs_embeds=inputs_embeds.to( |
|
self.llama_decoder.model.embed_tokens.weight.device |
|
).half(), |
|
return_dict=True, |
|
output_hidden_states=True, |
|
past_key_values=past_key_values, |
|
) |
|
next_token_logits = outputs.logits[-1, -1, :] |
|
|
|
if logits_processor: |
|
local_outs = torch.tensor(outs) if outs != [] else suffix |
|
local_outs = local_outs.reshape(1, -1) |
|
next_token_logits = logits_processor( |
|
local_outs, |
|
next_token_logits.reshape(1, -1), |
|
) |
|
next_token_logits = next_token_logits.flatten() |
|
if do_sample: |
|
logits = next_token_logits / temperature |
|
probs = F.softmax(logits, dim=-1) |
|
greedy = torch.multinomial(probs, num_samples=1)[0] |
|
else: |
|
greedy = next_token_logits.argmax() |
|
outs.append(greedy) |
|
next_embed = self.llama_decoder.model.embed_tokens(greedy.reshape(1, 1)) |
|
inputs_embeds = next_embed |
|
return self.tokenizer.decode(outs, skip_special_tokens=True).replace( |
|
"<|eot_id|>", "" |
|
) |
|
|
|
def generate_stream( |
|
self, audio, text_prompt, do_sample=False, logits_processor=None, max_new_tokens=128 |
|
): |
|
inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000) |
|
input_features = inputs.input_features |
|
hidden_states = self.whisper_encoder(input_features=input_features)[ |
|
"last_hidden_state" |
|
] |
|
virt_tokens = self.connector( |
|
hidden_states, |
|
output_device=self.llama_decoder.model.embed_tokens.weight.device, |
|
).squeeze() |
|
|
|
if text_prompt != None and text_prompt != "": |
|
user_prompt_text = torch.tensor( |
|
self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"], |
|
device=self.pre_user_suffix.device, |
|
) |
|
prefix = torch.cat( |
|
[self.pre_user_suffix, user_prompt_text, self.prefix], axis=0 |
|
) |
|
else: |
|
prefix = self.prefix |
|
prefix_embed = self.llama_decoder.model.embed_tokens(prefix) |
|
suffix = self.final_header |
|
suffix_embed = self.llama_decoder.model.embed_tokens(suffix) |
|
inputs_embeds = torch.cat( |
|
[prefix_embed, virt_tokens, suffix_embed], axis=0 |
|
).unsqueeze(0) |
|
outs = [] |
|
outputs = None |
|
greedy = 1 |
|
i = 0 |
|
while greedy != 128009 and len(outs) < max_new_tokens: |
|
past_key_values = outputs.past_key_values if outputs else None |
|
outputs = self.llama_decoder( |
|
inputs_embeds=inputs_embeds, |
|
return_dict=True, |
|
output_hidden_states=True, |
|
past_key_values=past_key_values, |
|
) |
|
next_token_logits = outputs.logits[-1, -1, :] |
|
|
|
if logits_processor: |
|
local_outs = torch.tensor(outs) if outs != [] else suffix |
|
local_outs = local_outs.reshape(1, -1) |
|
next_token_logits = logits_processor( |
|
local_outs, |
|
next_token_logits.reshape(1, -1), |
|
) |
|
next_token_logits = next_token_logits.flatten() |
|
if do_sample: |
|
logits = next_token_logits / temperature |
|
probs = F.softmax(logits, dim=-1) |
|
greedy = torch.multinomial(probs, num_samples=1)[0] |
|
else: |
|
greedy = next_token_logits.argmax() |
|
outs.append(greedy) |
|
next_embed = self.llama_decoder.model.embed_tokens(greedy.reshape(1, 1)) |
|
inputs_embeds = next_embed |
|
yield self.tokenizer.decode(outs, skip_special_tokens=True).replace("<|eot_id|>", "") |
|
return self.tokenizer.decode(outs, skip_special_tokens=True).replace("<|eot_id|>", "") |
|
|