{ "cells": [ { "cell_type": "markdown", "id": "55df6d0d-71cf-4110-81ed-7c0d3ce58e43", "metadata": {}, "source": [ "## Import" ] }, { "cell_type": "code", "execution_count": 1, "id": "0abe9574-05f7-4684-b586-033827b89c32", "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "id": "74e70729-b658-4ffd-9d8b-ae42a2d1b212", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import numpy as np\n", "from fairseq import utils, tasks\n", "from fairseq import checkpoint_utils\n", "from utils.eval_utils import eval_step\n", "from tasks.mm_tasks.caption import CaptionTask\n", "from models.unival import UnIVALModel\n", "from PIL import Image\n", "\n", "import random\n", "from torchvision.transforms import functional as F\n", "from torchvision.transforms import InterpolationMode\n", "\n", "from matplotlib import pyplot as plt\n", "\n", "# turn on cuda if GPU is available\n", "use_cuda = torch.cuda.is_available()\n", "# use fp16 only when GPU is available\n", "use_fp16 = False\n", "import os " ] }, { "cell_type": "code", "execution_count": 3, "id": "ce03a870-2852-410e-97c4-59461d08f60a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ ".register_task_cls(cls)>" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Register refcoco task\n", "tasks.register_task('video_caption', CaptionTask)" ] }, { "cell_type": "markdown", "id": "58361680-3e90-4fff-962e-2ff67c1e7289", "metadata": {}, "source": [ "### Load model" ] }, { "cell_type": "code", "execution_count": 80, "id": "adb79611-7563-4fb6-a576-f31050f8438e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "self.sample_patch_num 784\n", "self.sample_audio_patch_num None\n", "self.sample_video_patch_num None\n", "self.with_cls False\n", "Loading: all_resnext101\n", "use bn: \n", "load pretrained_model /data/mshukor/logs/ofa/best_models/resnext-101-kinetics.pth\n", "_IncompatibleKeys(missing_keys=[], unexpected_keys=['fc.weight', 'fc.bias'])\n", "unival\n", "getattr(args, \"stop_on_max_len\", False) False\n" ] } ], "source": [ "# Load pretrained ckpt & config\n", "\n", "checkpoint_path = '/data/mshukor/logs/ofa/best_models/unival_video_caption_stage_1/checkpoint_best.pt'\n", "video_model_path = '/data/mshukor/logs/ofa/best_models/resnext-101-kinetics.pth'\n", "\n", "overrides={\"eval_cider\":False, \"beam\":5, \"max_len_b\":22, \"no_repeat_ngram_size\":3, \"seed\":7, \"unnormalized\": False,\n", " \"bpe_dir\":\"utils/BPE\", \"video_model_path\": video_model_path,}\n", "\n", "models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(\n", " utils.split_paths(checkpoint_path),\n", " arg_overrides=overrides\n", " )\n", "\n", "# Move models to GPU\n", "for model in models:\n", " model.eval()\n", " if use_fp16:\n", " model.half()\n", " if use_cuda and not cfg.distributed_training.pipeline_model_parallel:\n", " model.cuda()\n", " model.prepare_for_inference_(cfg)\n", "\n", "# Initialize generator\n", "generator = task.build_generator(models, cfg.generation)" ] }, { "cell_type": "markdown", "id": "e79aad39-1424-47d5-8cd4-6ab77ea46fb4", "metadata": {}, "source": [ "### Preprocess" ] }, { "cell_type": "code", "execution_count": 81, "id": "576a3e84-a6aa-446d-adab-fef9499318fc", "metadata": {}, "outputs": [], "source": [ "# Image transform\n", "from torchvision import transforms\n", "mean = [0.5, 0.5, 0.5]\n", "std = [0.5, 0.5, 0.5]\n", "\n", "\n", "\n", "type_transform = transforms.Lambda(lambda x: x.float().div(255.0))\n", "patch_video_resize_transform = transforms.Compose([\n", " transforms.CenterCrop(cfg.task.patch_frame_size),\n", " type_transform, \n", " transforms.Normalize(mean=mean, std=std),\n", " ])\n", "\n", "# video process\n", "from data.video_utils import VIDEO_READER_FUNCS\n", "\n", "video_reader = VIDEO_READER_FUNCS['decord'] \n", "\n", "def process_video(video_path, max_num_frames=16, num_frames=16, sample_type='rand',):\n", " \n", " # video \n", " data_path = os.path.join(video_path)\n", "\n", " frames, frame_indices, video_duration = video_reader(\n", " data_path, num_frames, sample_type, max_num_frames=max_num_frames\n", " )\n", "\n", " patch_video = patch_video_resize_transform(frames)\n", " patch_video = patch_video.permute(1, 0, 2, 3) # -> (C, T, h, w)\n", "\n", " return patch_video.unsqueeze(0)\n", " \n", "\n", "# Text preprocess\n", "bos_item = torch.LongTensor([task.src_dict.bos()])\n", "eos_item = torch.LongTensor([task.src_dict.eos()])\n", "pad_idx = task.src_dict.pad()\n", "def encode_text(text, length=None, append_bos=False, append_eos=False):\n", " s = task.tgt_dict.encode_line(\n", " line=task.bpe.encode(text),\n", " add_if_not_exist=False,\n", " append_eos=False\n", " ).long()\n", " if length is not None:\n", " s = s[:length]\n", " if append_bos:\n", " s = torch.cat([bos_item, s])\n", " if append_eos:\n", " s = torch.cat([s, eos_item])\n", " return s\n", "\n", "# Construct input for caption task\n", "def construct_sample(video_path):\n", " \n", " patch_video = process_video(video_path, max_num_frames=16, num_frames=cfg.task.num_frames, sample_type=cfg.task.sample_type,)\n", " patch_image = torch.zeros((3, cfg.task.patch_image_size, cfg.task.patch_image_size)) \n", " \n", " patch_type = torch.tensor([1])\n", " patch_mask = torch.tensor([True])\n", " src_text = encode_text(\" what does the video describe?\", append_bos=True, append_eos=True).unsqueeze(0)\n", " src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text])\n", " sample = {\n", " \"id\":np.array(['42']),\n", " \"net_input\": {\n", " \"src_tokens\": src_text,\n", " \"src_lengths\": src_length,\n", " \"patch_videos\": patch_video,\n", " \"patch_images\": patch_image,\n", " \"patch_masks\": patch_mask,\n", " \"patch_types\": patch_type,\n", " }\n", " }\n", " return sample\n", " \n", "# Function to turn FP32 to FP16\n", "def apply_half(t):\n", " if t.dtype is torch.float32:\n", " return t.to(dtype=torch.half)\n", " return t" ] }, { "cell_type": "markdown", "id": "f96f776e-9aa0-4271-b881-311851cc033c", "metadata": {}, "source": [ "### Inference" ] }, { "cell_type": "code", "execution_count": 157, "id": "6f8ddf8c-82e2-411c-baa3-850da02f1996", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 3, 16, 384, 384])\n" ] }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 157, "metadata": {}, "output_type": "execute_result" } ], "source": [ "save_dir = '/home/mshukor/ofa_adastra'\n", "\n", "\n", "\n", "\n", "video_path = '/data/mshukor/data/video/msrvtt/examples/test/video7019.mp4' # a man is sitting in a chair and talking\n", "# video_path = '/data/mshukor/data/video/msrvtt/examples/test/video7038.mp4' # a person is cooking something in a pan\n", "# video_path = '/data/mshukor/data/video/msrvtt/examples/test/video7021.mp4' # a group of people are playing baseball\n", "# video_path = '/data/mshukor/data/video/msrvtt/examples/test/video7068.mp4' # a man and a woman are talking to each other\n", "# video_path = '/data/mshukor/data/video/msrvtt/examples/test/video7017.mp4' # a person is playing a video game\n", "# video_path = '/data/mshukor/data/video/msrvtt/examples/test/video7014.mp4' # a girl is singing on the voice\n", "\n", "\n", "\n", "# video_path = '/data/mshukor/data/video/msrvtt/examples/video1065.mp4'\n", "\n", "# limitations\n", "video_path = '/data/mshukor/data/video/msrvtt/examples/test/video7055.mp4' # a man is driving a car\n", "\n", "\n", "sample = construct_sample(video_path)\n", "sample = utils.move_to_cuda(sample) if use_cuda else sample\n", "sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample" ] }, { "cell_type": "code", "execution_count": null, "id": "3690f53b-3594-4d8f-81c8-c8ed0931c00b", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 158, "id": "4651039c-b8c0-4687-871e-b42cb13b2984", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([1], device='cuda:0')\n", "torch.Size([1, 2048, 1, 12, 12])\n" ] } ], "source": [ "from utils.eval_utils import eval_caption\n", "\n", "with torch.no_grad():\n", " result, scores = eval_caption(task, generator, models, sample)" ] }, { "cell_type": "code", "execution_count": 159, "id": "712150d4-f28c-4538-870f-b33f775725d5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "a man is driving a car\n" ] } ], "source": [ "caption = result[0]['caption']\n", "print(caption)\n", "\n", "from IPython.display import Video\n", "Video(video_path, embed=True)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "303d531f-dba3-40b9-a1ff-1be92d8c188a", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "d2db0cc0-5cd2-48dd-b900-56331d53b1df", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "ofa", "language": "python", "name": "ofa" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 5 }