{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "cdad6b21-030a-40d3-9b31-a229e5b6196d", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer, AutoConfig, CLIPImageProcessor" ] }, { "cell_type": "code", "execution_count": 2, "id": "1f832710-0e8c-42ec-b581-1b15fd2a6acc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2024-01-25 14:31:58,511] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" ] } ], "source": [ "from model import LlavaPhiForCausalLM" ] }, { "cell_type": "code", "execution_count": 3, "id": "9e68f1d4-1ae3-4d45-b818-4600218d2215", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e5e13e666e3a43d4ad26cc70904abee8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/3 [00:00 {text} \"\n", " tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n", " return tokens\n", " \n", "\n", "class WhisperWithProjection:\n", " def __init__(self, projection_dim, device):\n", " self.device = device\n", " self.processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\", device_map=device)\n", " self.model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\", device_map=device)\n", " self.model.config.forced_decoder_ids = None\n", " # self.audio_language_connector = AudioLanguageConnector(projection_dim)\n", " \n", " def __call__(self, audio):\n", " input_features = self.processor(audio[\"array\"],\n", " sampling_rate=audio[\"sampling_rate\"],\n", " return_tensors=\"pt\").input_features\n", " # generate token ids\n", " predicted_ids = self.model.generate(input_features.to(self.device))\n", " # decode token ids to text \n", " transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)\n", "\n", " # audio_embeddings = self.audio_language_connector(transcription)\n", " return transcription" ] }, { "cell_type": "code", "execution_count": 10, "id": "a2757c91-2ec1-4fe7-9216-03740bf80061", "metadata": {}, "outputs": [], "source": [ "IGNORE_INDEX = -100\n", "IMAGE_TOKEN_INDEX = -200\n", "DEFAULT_IMAGE_TOKEN = \"\"\n", "DEFAULT_IMAGE_PATCH_TOKEN = \"\"\n", "DEFAULT_IM_START_TOKEN = \"\"\n", "DEFAULT_IM_END_TOKEN = \"\"\n", "\n", "from conversation import conv_templates, SeparatorStyle\n", "\n", "class MultiModalPhi2:\n", " def __init__(self, modelname_or_path=\"RaviNaik/Llava-Phi2\",\n", " temperature=0.2,\n", " max_new_tokens=1024,\n", " device=\"cuda:0\"):\n", " self.model_name = modelname_or_path\n", " self.temperature = temperature\n", " self.max_new_tokens = max_new_tokens\n", " self.device = device\n", " self.disable_torch_init()\n", " self.whisper_w_proj = WhisperWithProjection(projection_dim=512, device=device)\n", " self.load_pretrained_model()\n", " \n", " def disable_torch_init(self):\n", " \"\"\"\n", " Disable the redundant torch default initialization to accelerate model creation.\n", " \"\"\"\n", " setattr(torch.nn.Linear, \"reset_parameters\", lambda self: None)\n", " setattr(torch.nn.LayerNorm, \"reset_parameters\", lambda self: None)\n", " \n", " def load_pretrained_model(self):\n", " self.model = LlavaPhiForCausalLM.from_pretrained(self.model_name, device_map=self.device)\n", " self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)\n", " self.image_processor = CLIPImageProcessor.from_pretrained(self.model_name)\n", " mm_use_im_start_end = getattr(self.model.config, \"mm_use_im_start_end\", False)\n", " mm_use_im_patch_token = getattr(self.model.config, \"mm_use_im_patch_token\", True)\n", " if mm_use_im_patch_token:\n", " self.tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n", " if mm_use_im_start_end:\n", " self.tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)\n", " \n", " def tokenizer_image_token(self, prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):\n", " prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')]\n", " \n", " def insert_separator(X, sep):\n", " return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]\n", " \n", " input_ids = []\n", " offset = 0\n", " if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:\n", " offset = 1\n", " input_ids.append(prompt_chunks[0][0])\n", " for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):\n", " input_ids.extend(x[offset:])\n", " \n", " if return_tensors is not None:\n", " if return_tensors == 'pt':\n", " return torch.tensor(input_ids, dtype=torch.long)\n", " raise ValueError(f'Unsupported tensor type: {return_tensors}')\n", " return input_ids\n", " \n", " def __call__(self, text, audio, image):\n", " qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + text\n", " conv = conv_templates[\"phi-2_v0\"].copy()\n", " conv.append_message(conv.roles[0], qs)\n", " conv.append_message(conv.roles[1], None)\n", " prompt = conv.get_prompt()\n", "\n", " image_tensor = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'].to(self.device)\n", " \n", " input_ids = self.tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)\n", " if audio is not None:\n", " audio_transcript = self.whisper_w_proj(audio)\n", " audio_embed = self.tokenizer(audio_transcript, return_tensors='pt')[\"input_ids\"]\n", " input_ids = torch.concat([input_ids, audio_embed], dim=1)\n", " input_ids = input_ids.to(self.device)\n", " \n", " stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n", "\n", " with torch.inference_mode():\n", " output_ids = self.model.generate(\n", " input_ids,\n", " images=image_tensor,\n", " do_sample=True,\n", " temperature=self.temperature,\n", " max_new_tokens=self.max_new_tokens,\n", " eos_token_id=self.tokenizer.eos_token_id, # End of sequence token\n", " pad_token_id=self.tokenizer.eos_token_id, # Pad token\n", " use_cache=True,\n", " )\n", "\n", " input_token_len = input_ids.shape[1]\n", " n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n", " if n_diff_input_output > 0:\n", " print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n", " outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n", " outputs = outputs.strip()\n", " if outputs.endswith(stop_str):\n", " outputs = outputs[:-len(stop_str)]\n", " outputs = outputs.strip()\n", " return outputs" ] }, { "cell_type": "code", "execution_count": 11, "id": "cc47e6a0-3544-4a60-930f-ccae87ef945a", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9ef56077307d4cef907e25b092061611", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/3 [00:00