Spaces:
Running
Running
File size: 7,922 Bytes
8f38740 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
import torch
from transformers import (
AutoTokenizer,
CLIPImageProcessor,
WhisperProcessor,
WhisperForConditionalGeneration,
)
from .model import LlavaPhiForCausalLM
from .conversation import conv_templates, SeparatorStyle
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
class AudioLanguageConnector:
def __init__(self, projection_dim):
model_name = "microsoft/phi-2"
self.phi2_tokenizer = AutoTokenizer.from_pretrained(
model_name, trust_remote_code=True
)
self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token
self.phi2_tokenizer.max_length = projection_dim
def __call__(self, text):
text = f"<audio_start> {text} <audio_end>"
tokens = self.phi2_tokenizer(
text, return_tensors="pt", return_attention_mask=False
)
return tokens
class WhisperWithProjection:
def __init__(self, projection_dim, device):
self.device = device
self.processor = WhisperProcessor.from_pretrained(
"openai/whisper-tiny", device_map=device
)
self.model = WhisperForConditionalGeneration.from_pretrained(
"openai/whisper-tiny", device_map=device
)
self.model.config.forced_decoder_ids = None
# self.audio_language_connector = AudioLanguageConnector(projection_dim)
def __call__(self, audio):
input_features = self.processor(
audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt"
).input_features
# generate token ids
predicted_ids = self.model.generate(input_features.to(self.device))
# decode token ids to text
transcription = self.processor.batch_decode(
predicted_ids, skip_special_tokens=True
)
# audio_embeddings = self.audio_language_connector(transcription)
return transcription
class MultiModalPhi2:
def __init__(
self,
modelname_or_path="RaviNaik/Llava-Phi2",
temperature=0.2,
max_new_tokens=1024,
device="cuda:0",
):
self.model_name = modelname_or_path
self.temperature = temperature
self.max_new_tokens = max_new_tokens
self.device = device
self.disable_torch_init()
self.whisper_w_proj = WhisperWithProjection(projection_dim=512, device=device)
self.load_pretrained_model()
def disable_torch_init(self):
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def load_pretrained_model(self):
self.model = LlavaPhiForCausalLM.from_pretrained(
self.model_name, device_map=self.device
)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.image_processor = CLIPImageProcessor.from_pretrained(self.model_name)
mm_use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(
self.model.config, "mm_use_im_patch_token", True
)
if mm_use_im_patch_token:
self.tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
self.tokenizer.add_tokens(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
)
def tokenizer_image_token(
self,
prompt,
tokenizer,
image_token_index=IMAGE_TOKEN_INDEX,
return_tensors=None,
):
prompt_chunks = [
tokenizer(chunk).input_ids for chunk in prompt.split("<image>")
]
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
if (
len(prompt_chunks) > 0
and len(prompt_chunks[0]) > 0
and prompt_chunks[0][0] == tokenizer.bos_token_id
):
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
if return_tensors is not None:
if return_tensors == "pt":
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f"Unsupported tensor type: {return_tensors}")
return input_ids
def __call__(self, text, audio, image):
if text is None:
text = ""
if image is not None:
qs = (
DEFAULT_IM_START_TOKEN
+ DEFAULT_IMAGE_TOKEN
+ DEFAULT_IM_END_TOKEN
+ "\n"
+ text
)
conv = conv_templates["phi-2_v0"].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = self.tokenizer_image_token(
prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
).unsqueeze(0)
image_tensor = self.image_processor.preprocess(image, return_tensors="pt")[
"pixel_values"
].to(self.device)
else:
qs = text
conv = conv_templates["phi-2_v0"].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"]
image_tensor = None
if audio is not None:
audio_transcript = self.whisper_w_proj(audio)
audio_embed = self.tokenizer(audio_transcript, return_tensors="pt")[
"input_ids"
]
input_ids = torch.concat([input_ids, audio_embed], dim=1)
input_ids = input_ids.to(self.device)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
with torch.inference_mode():
if image is not None:
output_ids = self.model.generate(
input_ids,
images=image_tensor,
do_sample=True,
temperature=self.temperature,
max_new_tokens=self.max_new_tokens,
eos_token_id=self.tokenizer.eos_token_id, # End of sequence token
pad_token_id=self.tokenizer.eos_token_id, # Pad token
use_cache=True,
)
else:
output_ids = self.model.generate(
input_ids,
do_sample=True,
temperature=self.temperature,
max_new_tokens=self.max_new_tokens,
eos_token_id=self.tokenizer.eos_token_id, # End of sequence token
pad_token_id=self.tokenizer.eos_token_id, # Pad token
use_cache=True,
)
input_token_len = input_ids.shape[1]
n_diff_input_output = (
(input_ids != output_ids[:, :input_token_len]).sum().item()
)
if n_diff_input_output > 0:
print(
f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids"
)
outputs = self.tokenizer.batch_decode(
output_ids[:, input_token_len:], skip_special_tokens=True
)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[: -len(stop_str)]
outputs = outputs.strip()
return outputs
|