{ "cells": [ { "cell_type": "code", "execution_count": 4, "id": "215cfd2f-62b0-4a86-a407-777a1d32597f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2024-01-24 15:18:49,948] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" ] } ], "source": [ "from PIL import Image\n", "import requests\n", "\n", "import torch\n", "from torch import nn\n", "from transformers import AutoProcessor, CLIPVisionModel, CLIPVisionConfig, CLIPPreTrainedModel\n", "from transformers.models.clip.modeling_clip import CLIPVisionModelOutput, CLIPVisionTransformer\n", "from transformers import WhisperProcessor, WhisperForConditionalGeneration\n", "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer" ] }, { "cell_type": "code", "execution_count": 5, "id": "2244e8f3-fcc7-4309-9d4d-fea557f89f79", "metadata": {}, "outputs": [], "source": [ "from llava_phi import LlavaPhiForCausalLM" ] }, { "cell_type": "code", "execution_count": 3, "id": "587883e1-3419-4b14-b16b-38fabbc8bfaa", "metadata": {}, "outputs": [], "source": [ "# model = LlavaPhiForCausalLM.from_pretrained(\"./llava-phi/checkpoints/llavaPhi-v0-3b-finetune/checkpoint-4000\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "0e27a7db-e2ab-4d65-b21d-497222e318ad", "metadata": {}, "outputs": [], "source": [ "# processor = AutoProcessor.from_pretrained(\"./llava-phi/checkpoints/llavaPhi-v0-3b-finetune/checkpoint-4000\")" ] }, { "cell_type": "code", "execution_count": 5, "id": "663efdd8-ea21-4231-a2ae-bcc0fb47b46a", "metadata": {}, "outputs": [], "source": [ "# prompt = \"\\nUSER: What's the content of the image?\\nASSISTANT:\"\n", "# url = \"https://www.ilankelman.org/stopsigns/australia.jpg\"\n", "# image = Image.open(requests.get(url, stream=True).raw)" ] }, { "cell_type": "code", "execution_count": 6, "id": "f622609f-f6a7-4ec1-ac35-c1d33d9436ca", "metadata": {}, "outputs": [], "source": [ "# # Generate\n", "# generate_ids = model.generate(**inputs, max_length=30)\n", "# processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]" ] }, { "cell_type": "code", "execution_count": 6, "id": "45f5ba72-2e41-4ccc-84c1-97d542ebee63", "metadata": {}, "outputs": [], "source": [ "from llava_phi.model.builder import load_pretrained_model\n", "from llava_phi.mm_utils import tokenizer_image_token, get_model_name_from_path\n", "from llava_phi.utils import disable_torch_init\n", "from llava_phi.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\n", "from llava_phi.conversation import conv_templates, SeparatorStyle" ] }, { "cell_type": "code", "execution_count": 11, "id": "b98ac5d3-5503-4430-81d1-19a4f8d6bd75", "metadata": {}, "outputs": [], "source": [ "model_path = \"checkpoints/llavaPhi-v0-3b-finetune/checkpoint-4000\"\n", "model_name = get_model_name_from_path(model_path)" ] }, { "cell_type": "code", "execution_count": 12, "id": "42fd5721-75a7-475b-bd30-5ee23aeaac64", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'llavaPhi-v0-3b-finetune_checkpoint-4000'" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_name" ] }, { "cell_type": "code", "execution_count": 13, "id": "8c2076b5-3bfc-48fd-917b-5dfd06fc532f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "load llaVA-Phi MLLM!!!\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "20b86f2c01744081b537620c8780f12e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/2 [00:00\\nWhat's the content of the image? ASSISTANT:\"" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "prompt" ] }, { "cell_type": "code", "execution_count": 17, "id": "a89cc181-2214-4844-b966-164a41744e54", "metadata": {}, "outputs": [], "source": [ "url = \"https://www.ilankelman.org/stopsigns/australia.jpg\"\n", "image = Image.open(requests.get(url, stream=True).raw)\n", "image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].cuda()\n", "\n", "input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n", "\n", "stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2" ] }, { "cell_type": "code", "execution_count": 25, "id": "0d519851-64d4-4cf5-b2eb-19474f9aa260", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 55])" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "input_ids.shape" ] }, { "cell_type": "code", "execution_count": 24, "id": "1694ff36-f214-4ed3-b2f3-d3dbd0a1a25b", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" ] } ], "source": [ "from datasets import load_dataset\n", "audio_ds = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n", "audio = audio_ds[0][\"audio\"]\n", "\n", "whisper_w_proj = WhisperWithProjection(projection_dim=512)\n", "audio_embed = whisper_w_proj(audio)[\"input_ids\"]" ] }, { "cell_type": "code", "execution_count": 28, "id": "9c4a9fae-d6ed-4fc2-ba02-97df64cddd93", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1, 33]), device(type='cpu'))" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "audio_embed.shape, audio_embed.device" ] }, { "cell_type": "code", "execution_count": 29, "id": "c3fffe29-98fb-4f4b-ac51-4bdda9e46752", "metadata": {}, "outputs": [], "source": [ "input_ids = torch.concat([input_ids, audio_embed.to(\"cuda:0\")], dim=1)" ] }, { "cell_type": "code", "execution_count": 30, "id": "5dee1ec8-2db2-4f65-99e8-d34bd2735c9c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 88])" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "input_ids.shape" ] }, { "cell_type": "code", "execution_count": 31, "id": "96033b43-4f57-4f0c-bcf7-37b57ca02e47", "metadata": {}, "outputs": [], "source": [ "with torch.inference_mode():\n", " output_ids = model.generate(\n", " input_ids,\n", " images=image_tensor,\n", " do_sample=True,\n", " temperature=0.2,\n", " max_new_tokens=1024,\n", " eos_token_id=tokenizer.eos_token_id, # End of sequence token\n", " pad_token_id=tokenizer.eos_token_id, # Pad token\n", " use_cache=True,\n", " )" ] }, { "cell_type": "code", "execution_count": 32, "id": "741e8da5-0d18-4c11-b559-76054ce4ca3a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "is a Japanese character from the story of Jesus, who is a Chinese monk who is also known for his teachings. The story is based on the story of the story of Jesus Christ, and it is a representation of the story of Jesus and the story of Jesus Christ.\n" ] } ], "source": [ "input_token_len = input_ids.shape[1]\n", "n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n", "if n_diff_input_output > 0:\n", " print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n", "outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n", "outputs = outputs.strip()\n", "if outputs.endswith(stop_str):\n", " outputs = outputs[:-len(stop_str)]\n", "outputs = outputs.strip()\n", "print(outputs)" ] }, { "cell_type": "code", "execution_count": 20, "id": "69d494d4-d768-4645-b4d6-5c455791b50d", "metadata": {}, "outputs": [], "source": [ "# image" ] }, { "cell_type": "code", "execution_count": null, "id": "8a340856-a13f-4b18-9911-126a4ba37816", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "3c56fdea-c7a1-4e67-9832-e2ed077d8704", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 52, "id": "89e84d39-8ed8-45db-ae82-27c156ee6dd1", "metadata": {}, "outputs": [], "source": [ "class AudioLanguageConnector:\n", " def __init__(self, projection_dim):\n", " model_name = \"microsoft/phi-2\"\n", " self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n", " self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token\n", " self.phi2_tokenizer.max_length = projection_dim\n", "\n", " def __call__(self, text):\n", " text = f\" {text} \"\n", " tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n", " return tokens\n", " \n", "\n", "class WhisperWithProjection:\n", " def __init__(self, projection_dim, device):\n", " self.device = device\n", " self.processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\", device_map=device)\n", " self.model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\", device_map=device)\n", " self.model.config.forced_decoder_ids = None\n", " self.audio_language_connector = AudioLanguageConnector(projection_dim)\n", " \n", " def __call__(self, audio):\n", " input_features = self.processor(audio[\"array\"],\n", " sampling_rate=audio[\"sampling_rate\"],\n", " return_tensors=\"pt\").input_features\n", " # generate token ids\n", " predicted_ids = self.model.generate(input_features.to(self.device))\n", " # decode token ids to text \n", " transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)\n", "\n", " audio_embeddings = self.audio_language_connector(transcription)\n", " return audio_embeddings.to(self.device)" ] }, { "cell_type": "code", "execution_count": 53, "id": "75e24be0-b236-4047-83ef-5c344e262476", "metadata": {}, "outputs": [], "source": [ "class MultiModalPhi2:\n", " def __init__(self, model_path=\"checkpoints/llavaPhi-v0-3b-finetune/checkpoint-4000\",\n", " temperature=0.2,\n", " max_new_tokens=1024,\n", " device=\"cuda\"):\n", " self.temperature = temperature\n", " self.max_new_tokens = max_new_tokens\n", " self.device = device\n", " model_name = get_model_name_from_path(model_path)\n", " self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(model_path, None, model_name, device_map=device)\n", " self.whisper_w_proj = WhisperWithProjection(projection_dim=512, device=device)\n", " \n", " \n", " def __call__(self, text, audio, image):\n", " qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + text\n", " conv = conv_templates[\"default\"].copy()\n", " conv.append_message(conv.roles[0], qs)\n", " conv.append_message(conv.roles[1], None)\n", " prompt = conv.get_prompt()\n", "\n", " image_tensor = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'].cuda()\n", " \n", " input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n", "\n", " audio_embed = self.whisper_w_proj(audio)[\"input_ids\"]\n", " \n", " stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n", "\n", " input_ids = torch.concat([input_ids, audio_embed], dim=1)\n", "\n", " with torch.inference_mode():\n", " output_ids = self.model.generate(\n", " input_ids,\n", " images=image_tensor,\n", " do_sample=True,\n", " temperature=self.temperature,\n", " max_new_tokens=self.max_new_tokens,\n", " eos_token_id=tokenizer.eos_token_id, # End of sequence token\n", " pad_token_id=tokenizer.eos_token_id, # Pad token\n", " use_cache=True,\n", " )\n", "\n", " input_token_len = input_ids.shape[1]\n", " n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n", " if n_diff_input_output > 0:\n", " print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n", " outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n", " outputs = outputs.strip()\n", " if outputs.endswith(stop_str):\n", " outputs = outputs[:-len(stop_str)]\n", " outputs = outputs.strip()\n", " return outputs" ] }, { "cell_type": "code", "execution_count": 54, "id": "4efdbad4-d88a-4477-a3a0-f5591cd0b172", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n", "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "load llaVA-Phi MLLM!!!\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "492c17cf54f34d4d9e4f288fc9e72e79", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/2 [00:00 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mLlava-Phi-Checkpoint\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/transformers/modeling_utils.py:2376\u001b[0m, in \u001b[0;36mPreTrainedModel.save_pretrained\u001b[0;34m(self, save_directory, is_main_process, state_dict, save_function, push_to_hub, max_shard_size, safe_serialization, variant, token, save_peft_format, **kwargs)\u001b[0m\n\u001b[1;32m 2372\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m shard_file, shard \u001b[38;5;129;01min\u001b[39;00m shards\u001b[38;5;241m.\u001b[39mitems():\n\u001b[1;32m 2373\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m safe_serialization:\n\u001b[1;32m 2374\u001b[0m \u001b[38;5;66;03m# At some point we will need to deal better with save_function (used for TPU and other distributed\u001b[39;00m\n\u001b[1;32m 2375\u001b[0m \u001b[38;5;66;03m# joyfulness), but for now this enough.\u001b[39;00m\n\u001b[0;32m-> 2376\u001b[0m \u001b[43msafe_save_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43mshard\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43msave_directory\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mshard_file\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetadata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mformat\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpt\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2377\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 2378\u001b[0m save_function(shard, os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(save_directory, shard_file))\n", "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/safetensors/torch.py:281\u001b[0m, in \u001b[0;36msave_file\u001b[0;34m(tensors, filename, metadata)\u001b[0m\n\u001b[1;32m 250\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msave_file\u001b[39m(\n\u001b[1;32m 251\u001b[0m tensors: Dict[\u001b[38;5;28mstr\u001b[39m, torch\u001b[38;5;241m.\u001b[39mTensor],\n\u001b[1;32m 252\u001b[0m filename: Union[\u001b[38;5;28mstr\u001b[39m, os\u001b[38;5;241m.\u001b[39mPathLike],\n\u001b[1;32m 253\u001b[0m metadata: Optional[Dict[\u001b[38;5;28mstr\u001b[39m, \u001b[38;5;28mstr\u001b[39m]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 254\u001b[0m ):\n\u001b[1;32m 255\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 256\u001b[0m \u001b[38;5;124;03m Saves a dictionary of tensors into raw bytes in safetensors format.\u001b[39;00m\n\u001b[1;32m 257\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 279\u001b[0m \u001b[38;5;124;03m ```\u001b[39;00m\n\u001b[1;32m 280\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 281\u001b[0m \u001b[43mserialize_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_flatten\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetadata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetadata\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "model.save_pretrained(\"Llava-Phi-Checkpoint\")" ] }, { "cell_type": "code", "execution_count": null, "id": "fa0bec34-a148-4340-a30c-6f09dd5e71ca", "metadata": {}, "outputs": [], "source": [ "model.push_to_hub(\"RaviNaik/Llava-Phi2\")" ] }, { "cell_type": "code", "execution_count": 73, "id": "382f74b0-2967-408a-badc-a90918810d74", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "CommitInfo(commit_url='https://huggingface.co/RaviNaik/Llava-Phi2/commit/fa8f7240058241243f6bdc3d6ab44bb691f76e39', commit_message='Upload tokenizer', commit_description='', oid='fa8f7240058241243f6bdc3d6ab44bb691f76e39', pr_url=None, pr_revision=None, pr_num=None)" ] }, "execution_count": 73, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer.push_to_hub(\"RaviNaik/Llava-Phi2\")" ] }, { "cell_type": "code", "execution_count": null, "id": "b851459b-d3ac-4fb8-99b6-17a648adc41f", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }