|
from argparse import Namespace |
|
import glob |
|
import logging |
|
from pathlib import Path |
|
import os |
|
import time |
|
from typing import Optional, Tuple |
|
from PIL import Image |
|
from safetensors import safe_open |
|
import torch |
|
from torch import nn |
|
import torchaudio |
|
from src.model.modules import voicecraft |
|
from src.model.modules.gemma import GemmaForCausalLM, KVCache |
|
from src.model.modules.imagecraftconfig import ImageCraftConfig |
|
from src.model.modules.imagecraftprocessor import ( |
|
ImageCraftProcessor, |
|
) |
|
from src.model.modules.siglip import SiglipVisionModel |
|
|
|
from transformers import AutoTokenizer |
|
|
|
from src.model.modules.tokenizer import ( |
|
AudioTokenizer, |
|
TextTokenizer, |
|
tokenize_audio, |
|
tokenize_text, |
|
) |
|
|
|
|
|
from src.utils import tools |
|
from src.utils.image_utils import is_valid_image |
|
from src.utils.model_utils import get_config, get_model_inputs |
|
from src.utils.util import ( |
|
replace_numbers_with_words, |
|
sample_top_p, |
|
save_to_buffer, |
|
save_to_file, |
|
split_line_to_sentences, |
|
) |
|
|
|
from huggingface_hub import HfApi |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class ImageCraftMultiModalProjector(nn.Module): |
|
def __init__(self, config: ImageCraftConfig): |
|
super().__init__() |
|
self.linear = nn.Linear( |
|
config.vision_config.hidden_size, |
|
config.vision_config.projection_dim, |
|
bias=True, |
|
) |
|
|
|
def forward(self, image_features): |
|
hidden_states = self.linear(image_features) |
|
return hidden_states |
|
|
|
|
|
class ImageCraft(nn.Module): |
|
config_class = ImageCraftConfig |
|
|
|
def __init__(self, config: ImageCraftConfig): |
|
super(ImageCraft, self).__init__() |
|
self.config = config |
|
self.vision_tower = SiglipVisionModel(config.vision_config) |
|
self.multi_modal_projector = ImageCraftMultiModalProjector(config) |
|
self.vocab_size = config.text_config.vocab_size |
|
|
|
self.language_model = GemmaForCausalLM(config.text_config) |
|
|
|
self.pad_token_id = ( |
|
self.config.pad_token_id if self.config.pad_token_id is not None else -1 |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
"google/paligemma-3b-pt-224", padding_side="right" |
|
) |
|
assert tokenizer.padding_side == "right" |
|
|
|
num_image_tokens = config.vision_config.num_image_tokens |
|
image_size = config.vision_config.image_size |
|
self.processor = ImageCraftProcessor(tokenizer, num_image_tokens, image_size) |
|
|
|
self.text_tokenizer = None |
|
|
|
self.voicecraft_model = None |
|
self.audio_tokenizer = None |
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
def tie_weights(self): |
|
return self.language_model.tie_weights() |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
pixel_values: torch.FloatTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
kv_cache: Optional[KVCache] = None, |
|
) -> Tuple: |
|
|
|
assert torch.all(attention_mask == 1), "The input cannot be padded" |
|
|
|
|
|
|
|
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) |
|
|
|
|
|
|
|
selected_image_feature = self.vision_tower(pixel_values.to(inputs_embeds.dtype)) |
|
|
|
image_features = self.multi_modal_projector(selected_image_feature) |
|
|
|
|
|
inputs_embeds, attention_mask, position_ids = ( |
|
self._merge_input_ids_with_image_features( |
|
image_features, inputs_embeds, input_ids, attention_mask, kv_cache |
|
) |
|
) |
|
|
|
outputs = self.language_model( |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
inputs_embeds=inputs_embeds, |
|
kv_cache=kv_cache, |
|
) |
|
|
|
return outputs |
|
|
|
def _merge_input_ids_with_image_features( |
|
self, |
|
image_features: torch.Tensor, |
|
inputs_embeds: torch.Tensor, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
kv_cache: Optional[KVCache] = None, |
|
): |
|
_, _, embed_dim = image_features.shape |
|
batch_size, sequence_length = input_ids.shape |
|
dtype, device = inputs_embeds.dtype, inputs_embeds.device |
|
|
|
scaled_image_features = image_features / (self.config.hidden_size**0.5) |
|
|
|
|
|
final_embedding = torch.zeros( |
|
batch_size, |
|
sequence_length, |
|
embed_dim, |
|
dtype=inputs_embeds.dtype, |
|
device=inputs_embeds.device, |
|
) |
|
|
|
text_mask = (input_ids != self.config.image_token_index) & ( |
|
input_ids != self.pad_token_id |
|
) |
|
|
|
image_mask = input_ids == self.config.image_token_index |
|
|
|
pad_mask = input_ids == self.pad_token_id |
|
|
|
|
|
text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim) |
|
pad_mask_expanded = pad_mask.unsqueeze(-1).expand(-1, -1, embed_dim) |
|
image_mask_expanded = image_mask.unsqueeze(-1).expand(-1, -1, embed_dim) |
|
|
|
|
|
final_embedding = torch.where( |
|
text_mask_expanded, inputs_embeds, final_embedding |
|
) |
|
|
|
final_embedding = final_embedding.masked_scatter( |
|
image_mask_expanded, scaled_image_features |
|
) |
|
|
|
final_embedding = torch.where( |
|
pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding |
|
) |
|
|
|
|
|
|
|
dtype, device = inputs_embeds.dtype, inputs_embeds.device |
|
min_dtype = torch.finfo(dtype).min |
|
q_len = inputs_embeds.shape[1] |
|
|
|
if kv_cache is None or kv_cache.num_items() == 0: |
|
|
|
|
|
causal_mask = torch.full( |
|
(batch_size, q_len, q_len), fill_value=0, dtype=dtype, device=device |
|
) |
|
else: |
|
|
|
assert q_len == 1 |
|
kv_len = kv_cache.num_items() + q_len |
|
|
|
|
|
causal_mask = torch.full( |
|
(batch_size, q_len, kv_len), fill_value=0, dtype=dtype, device=device |
|
) |
|
|
|
|
|
|
|
causal_mask = causal_mask.unsqueeze(1) |
|
|
|
if kv_cache is not None and kv_cache.num_items() > 0: |
|
|
|
position_ids = attention_mask.cumsum(-1)[:, -1] |
|
if position_ids.dim() == 1: |
|
position_ids = position_ids.unsqueeze(0) |
|
else: |
|
|
|
|
|
position_ids = ( |
|
(attention_mask.cumsum(-1)) |
|
.masked_fill_((attention_mask == 0), 1) |
|
.to(device) |
|
) |
|
|
|
return final_embedding, causal_mask, position_ids |
|
|
|
def _generate_caption(self, image, max_tokens=100, do_sample=False): |
|
prompt = "caption en" |
|
image = ( |
|
image.convert("RGB") |
|
if is_valid_image(image) |
|
else Image.open(image).convert("RGB") |
|
) |
|
|
|
inputs = get_model_inputs( |
|
processor=self.processor, prompt=prompt, image=image, device=self.device |
|
) |
|
|
|
image.close() |
|
|
|
input_ids = inputs["input_ids"] |
|
attention_mask = inputs["attention_mask"] |
|
pixel_values = inputs["pixel_values"] |
|
|
|
kv_cache = KVCache() |
|
|
|
stop_token = self.processor.tokenizer.eos_token_id |
|
generated_tokens = [] |
|
|
|
for _ in range(max_tokens): |
|
outputs = self( |
|
input_ids=input_ids, |
|
pixel_values=pixel_values, |
|
attention_mask=attention_mask, |
|
kv_cache=kv_cache, |
|
) |
|
kv_cache = outputs["kv_cache"] |
|
next_token_logits = outputs["logits"][:, -1, :] |
|
if do_sample: |
|
next_token_logits = torch.softmax( |
|
next_token_logits / self.config.temperature, dim=-1 |
|
) |
|
next_token = sample_top_p(next_token_logits, self.config.top_p) |
|
else: |
|
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) |
|
assert next_token.size() == (1, 1) |
|
next_token = next_token.squeeze(0) |
|
generated_tokens.append(next_token) |
|
if next_token.item() == stop_token: |
|
break |
|
input_ids = next_token.unsqueeze(-1) |
|
attention_mask = torch.cat( |
|
[attention_mask, torch.ones((1, 1), device=input_ids.device)], dim=-1 |
|
) |
|
|
|
generated_tokens = torch.cat(generated_tokens, dim=-1) |
|
decoded_text = self.processor.tokenizer.decode( |
|
generated_tokens, skip_special_tokens=True |
|
) |
|
decoded_text = ( |
|
parts[1] if len(parts := decoded_text.split("\n", 1)) > 1 else decoded_text |
|
) |
|
|
|
return decoded_text.rstrip(" .").strip().capitalize() + "." |
|
|
|
def _generate_speech(self, text: str, output_type="file"): |
|
|
|
sentences = split_line_to_sentences(text) |
|
|
|
voice_audio = ( |
|
f"media/voicecraft/voices/{self.config.voicecraft_config.voice_audio_path}" |
|
) |
|
voice_transcript = self.config.voicecraft_config.voice_audio_transcript |
|
cut_off_sec = self.config.voicecraft_config.cut_off_sec |
|
|
|
decode_config = { |
|
"top_k": self.config.voicecraft_config.top_k, |
|
"top_p": self.config.voicecraft_config.top_p, |
|
"temperature": self.config.voicecraft_config.temperature, |
|
"stop_repetition": self.config.voicecraft_config.stop_repetition, |
|
"kvcache": self.config.voicecraft_config.kvcache, |
|
"codec_audio_sr": self.config.voicecraft_config.codec_audio_sr, |
|
"codec_sr": self.config.voicecraft_config.codec_sr, |
|
"silence_tokens": self.config.voicecraft_config.silence_tokens, |
|
"sample_batch_size": self.config.voicecraft_config.sample_batch_size, |
|
} |
|
|
|
info = torchaudio.info(voice_audio) |
|
audio_dur = info.num_frames / info.sample_rate |
|
prompt_end_frame = int(min(audio_dur, cut_off_sec) * info.sample_rate) |
|
|
|
audio_tensors = [] |
|
transcript = voice_transcript |
|
|
|
for sentence in sentences: |
|
|
|
transcript += sentence + "\n" |
|
transcript = replace_numbers_with_words(transcript).replace(" ", " ") |
|
|
|
|
|
phn2num = self.voicecraft_model.args.phn2num |
|
text_tokens = [ |
|
phn2num[phn] |
|
for phn in tokenize_text(self.text_tokenizer, text=transcript.strip()) |
|
if phn in phn2num |
|
] |
|
text_tokens = torch.LongTensor(text_tokens).unsqueeze(0) |
|
text_tokens_lens = torch.LongTensor([text_tokens.shape[-1]]) |
|
|
|
|
|
encoded_frames = tokenize_audio( |
|
self.audio_tokenizer, |
|
voice_audio, |
|
offset=0, |
|
num_frames=prompt_end_frame, |
|
) |
|
original_audio = encoded_frames[0][0].transpose(2, 1) |
|
model_args = vars(self.voicecraft_model.args) |
|
model_args = Namespace(**model_args) |
|
|
|
assert ( |
|
original_audio.ndim == 3 |
|
and original_audio.shape[0] == 1 |
|
and original_audio.shape[2] == model_args.n_codebooks |
|
), original_audio.shape |
|
|
|
|
|
stime = time.time() |
|
if decode_config["sample_batch_size"] <= 1: |
|
_, gen_frames = self.voicecraft_model.inference_tts( |
|
text_tokens.to(self.device), |
|
text_tokens_lens.to(self.device), |
|
original_audio[..., : model_args.n_codebooks].to( |
|
self.device |
|
), |
|
top_k=decode_config["top_k"], |
|
top_p=decode_config["top_p"], |
|
temperature=decode_config["temperature"], |
|
stop_repetition=decode_config["stop_repetition"], |
|
kvcache=decode_config["kvcache"], |
|
silence_tokens=( |
|
eval(decode_config["silence_tokens"]) |
|
if type(decode_config["silence_tokens"]) == str |
|
else decode_config["silence_tokens"] |
|
), |
|
) |
|
else: |
|
_, gen_frames = self.voicecraft_model.inference_tts_batch( |
|
text_tokens.to(self.device), |
|
text_tokens_lens.to(self.device), |
|
original_audio[..., : model_args.n_codebooks].to( |
|
self.device |
|
), |
|
top_k=decode_config["top_k"], |
|
top_p=decode_config["top_p"], |
|
temperature=decode_config["temperature"], |
|
stop_repetition=decode_config["stop_repetition"], |
|
kvcache=decode_config["kvcache"], |
|
batch_size=decode_config["sample_batch_size"], |
|
silence_tokens=( |
|
eval(decode_config["silence_tokens"]) |
|
if type(decode_config["silence_tokens"]) == str |
|
else decode_config["silence_tokens"] |
|
), |
|
) |
|
gen_sample = self.audio_tokenizer.decode([(gen_frames, None)]) |
|
gen_audio = gen_sample[0].cpu() |
|
audio_tensors.append(gen_audio) |
|
|
|
output = None |
|
|
|
if output_type == "file": |
|
output = save_to_file(audio_tensors, decode_config["codec_audio_sr"]) |
|
else: |
|
output = save_to_buffer(audio_tensors, decode_config["codec_audio_sr"]) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
return output |
|
|
|
@torch.inference_mode() |
|
def generate( |
|
self, |
|
image, |
|
max_tokens=30, |
|
do_sample=False, |
|
output_type="file", |
|
): |
|
transcript = self._generate_caption(image, max_tokens, do_sample) |
|
speech = self._generate_speech(transcript, output_type) |
|
return transcript, speech |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
model_path=None, |
|
): |
|
api = HfApi() |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
env_config = tools.load_config() |
|
pretrained_dir = env_config["pretrained_dir"] |
|
imagecraft_cache_dir = f"{pretrained_dir}/imagecraft" |
|
voicecraft_cache_dir = f"{pretrained_dir}/voicecraft" |
|
|
|
state_dict = {} |
|
|
|
if Path(model_path).is_file(): |
|
checkpoint = torch.load(model_path, weights_only=False) |
|
state_dict = checkpoint["state_dict"] |
|
|
|
else: |
|
|
|
model_path = api.snapshot_download( |
|
repo_id=model_path, |
|
repo_type="model", |
|
cache_dir=imagecraft_cache_dir, |
|
local_files_only=False, |
|
) |
|
|
|
safetensors_files = glob.glob(os.path.join(model_path, "*.safetensors")) |
|
|
|
for safetensors_file in safetensors_files: |
|
with safe_open(safetensors_file, framework="pt", device="cpu") as f: |
|
for key in f.keys(): |
|
state_dict[key] = f.get_tensor(key) |
|
|
|
imagecraft_config = get_config() |
|
|
|
model = cls(imagecraft_config).to(device) |
|
|
|
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
model.tie_weights() |
|
|
|
model = model.eval() |
|
|
|
|
|
|
|
model.voicecraft_model = voicecraft.VoiceCraft.from_pretrained( |
|
f"pyp1/VoiceCraft_{model.config.voicecraft_config.model_name.replace('.pth', '')}", |
|
cache_dir=voicecraft_cache_dir, |
|
) |
|
|
|
encodec_fn = f"{voicecraft_cache_dir}/{model.config.voicecraft_config.encodec}" |
|
|
|
if not os.path.exists(encodec_fn): |
|
os.system( |
|
f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{model.config.voicecraft_config.encodec}" |
|
) |
|
os.system(f"mv {model.config.voicecraft_config.encodec} {encodec_fn}") |
|
|
|
model.audio_tokenizer = AudioTokenizer( |
|
signature=encodec_fn, |
|
device=device, |
|
) |
|
|
|
model.text_tokenizer = TextTokenizer(backend="espeak") |
|
|
|
model.voicecraft_model.to(device) |
|
|
|
return model |
|
|