{ "cells": [ { "cell_type": "markdown", "id": "399f2fcf-9241-4910-a30d-6ca19880d0ad", "metadata": {}, "source": [ "## Import" ] }, { "cell_type": "code", "execution_count": 15, "id": "97e68340-0096-475e-8ed8-22f5d627e3ad", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import numpy as np\n", "from fairseq import utils, tasks\n", "from fairseq import checkpoint_utils\n", "from utils.eval_utils import eval_step\n", "from tasks.mm_tasks import ImageGenTask\n", "from models.unival import UnIVALModel\n", "from PIL import Image\n", "from torchvision import transforms\n", "import time\n", "\n", "\n", "# turn on cuda if GPU is available\n", "use_cuda = torch.cuda.is_available()\n", "# use fp16 only when GPU is available\n", "use_fp16 = True if use_cuda else False" ] }, { "cell_type": "code", "execution_count": 16, "id": "719cef65-c00c-4c9c-90b2-e660b386c3d5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ ".register_task_cls(cls)>" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Register caption task\n", "tasks.register_task('image_gen', ImageGenTask)\n" ] }, { "cell_type": "markdown", "id": "cc9c1d7b-898b-4ac4-adf3-832891d9e4be", "metadata": {}, "source": [ "### Load model " ] }, { "cell_type": "code", "execution_count": 12, "id": "568bb6ea-eef9-4024-98e6-35e74b5ffeec", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "self.sample_patch_num 784\n", "self.sample_audio_patch_num None\n", "self.sample_video_patch_num None\n", "self.with_cls False\n", "Frozen image bn \n", "Loading: all_resnext101\n", "use bn: \n", "load pretrained_model /data/mshukor/logs/ofa/best_models/resnext-101-kinetics.pth\n", "_IncompatibleKeys(missing_keys=[], unexpected_keys=['fc.weight', 'fc.bias'])\n", "load resnet /data/mshukor/logs/ofa/best_models/resnet101-5d3b4d8f.pth\n", "\n", "RAM memory % used: 10.5\n", "RAM Used (GB): 19.574349824\n", "encoder\n", "RAM memory % used: 10.5\n", "decoder\n", "RAM memory % used: 10.5\n", "ofa\n", "Working with z of shape (1, 256, 32, 32) = 262144 dimensions.\n" ] } ], "source": [ "# Load pretrained ckpt & config\n", "clip_model_path='/data/mshukor/data/ofa/clip/ViT-B-16.pt'\n", "vqgan_model_path='/data/mshukor/data/ofa/vqgan/last.ckpt'\n", "vqgan_config_path='/data/mshukor/data/ofa/vqgan/model.yaml'\n", "\n", "# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_ofa_stage_1_base_s2_hsep1_long/checkpoint_best.pt'\n", "# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_ofaplus_stage_1_base_s2_long/checkpoint_best.pt'\n", "# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_base_best.pt'\n", "# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_large_best.pt'\n", "\n", "# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_ofaplus_stage_1_base_s2_hsep1_long/checkpoint_best.pt'\n", "checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_ofaplus_stage_2_base_s2_hsep1_long/checkpoint_best.pt'\n", "\n", "\n", "\n", "video_model_path = '/data/mshukor/logs/ofa/best_models/resnext-101-kinetics.pth'\n", "resnet_model_path = '/data/mshukor/logs/ofa/best_models/resnet101-5d3b4d8f.pth'\n", "\n", "gen_images_path='results/image_gen/'\n", "\n", "overrides = {\"bpe_dir\": \"utils/BPE\",\n", " \"eval_cider\": False,\n", " \"beam\": 24,\n", " \"max_len_b\": 1024,\n", " \"max_len_a\": 0,\n", " \"min_len\": 1024,\n", " \"sampling_topk\": 256,\n", " \"constraint_range\": \"50265,58457\",\n", " \"clip_model_path\": clip_model_path,\n", " \"vqgan_model_path\": vqgan_model_path,\n", " \"vqgan_config_path\": vqgan_config_path,\n", " \"seed\": 42,\n", " \"video_model_path\": video_model_path, \n", " \"resnet_model_path\": resnet_model_path,\n", " \"gen_images_path\":gen_images_path,\n", " \"patch_image_size\": 256,\n", " \"temperature\": 1.5,\n", " }\n", "\n", "models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(\n", " utils.split_paths(checkpoint_path),\n", " arg_overrides=overrides\n", ")\n", "\n", "task.cfg.sampling_times = 2\n", "# Move models to GPU\n", "for model in models:\n", " model.eval()\n", " if use_fp16:\n", " model.half()\n", " if use_cuda and not cfg.distributed_training.pipeline_model_parallel:\n", " model.cuda()\n", " model.prepare_for_inference_(cfg)\n", "\n", "# Initialize generator\n", "generator = task.build_generator(models, cfg.generation)\n", "\n", "# Text preprocess\n", "bos_item = torch.LongTensor([task.src_dict.bos()])\n", "eos_item = torch.LongTensor([task.src_dict.eos()])\n", "pad_idx = task.src_dict.pad()" ] }, { "cell_type": "markdown", "id": "5e4a45ec-bce1-495b-8033-3b574367b360", "metadata": {}, "source": [ "### Preprocess" ] }, { "cell_type": "code", "execution_count": 13, "id": "9f2e7e32-c9a0-43b3-bf86-2419d9f7dfe0", "metadata": {}, "outputs": [], "source": [ "def encode_text(text, length=None, append_bos=False, append_eos=False):\n", " s = task.tgt_dict.encode_line(\n", " line=task.bpe.encode(text),\n", " add_if_not_exist=False,\n", " append_eos=False\n", " ).long()\n", " if length is not None:\n", " s = s[:length]\n", " if append_bos:\n", " s = torch.cat([bos_item, s])\n", " if append_eos:\n", " s = torch.cat([s, eos_item])\n", " return s\n", "\n", "\n", "# Construct input for image generation task\n", "def construct_sample(query: str):\n", " code_mask = torch.tensor([True])\n", " src_text = encode_text(\" what is the complete image? caption: {}\".format(query), append_bos=True,\n", " append_eos=True).unsqueeze(0)\n", " src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text])\n", " sample = {\n", " \"id\": np.array(['42']),\n", " \"net_input\": {\n", " \"src_tokens\": src_text,\n", " \"src_lengths\": src_length,\n", " \"code_masks\": code_mask\n", " }\n", " }\n", " return sample\n", "\n", "\n", "# Function to turn FP32 to FP16\n", "def apply_half(t):\n", " if t.dtype is torch.float32:\n", " return t.to(dtype=torch.half)\n", " return t\n", "\n", "\n", "# Function for image generation\n", "def image_generation(caption):\n", " sample = construct_sample(caption)\n", " sample = utils.move_to_cuda(sample) if use_cuda else sample\n", " sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample\n", " print('|Start|', time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()), caption)\n", " with torch.no_grad():\n", " result, scores = eval_step(task, generator, models, sample)\n", "\n", " # return top-4 results (ranked by clip)\n", " images = [result[i]['image'] for i in range(4)]\n", " pic_size = 256\n", " retImage = Image.new('RGB', (pic_size * 2, pic_size * 2))\n", " print('|FINISHED|', time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()), caption)\n", " for i in range(4):\n", " loc = ((i % 2) * pic_size, int(i / 2) * pic_size)\n", " retImage.paste(images[i], loc)\n", " return retImage" ] }, { "cell_type": "markdown", "id": "44dec799-c5c2-4d22-8b08-7a7ca2cdf3c9", "metadata": {}, "source": [ "### Inference" ] }, { "cell_type": "code", "execution_count": 14, "id": "02d5cd7a-8d63-4fa4-9da1-d4b79ec01445", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "|Start| 2023-06-29 12:57:39 A brown horse in the street\n", "|FINISHED| 2023-06-29 12:59:03 A brown horse in the street\n" ] } ], "source": [ "query = \"A brown horse in the street\"\n", "# query = \"Cattle grazing on grass near a lake surrounded by mountain.\"\n", "# query = 'A street scene with a double-decker bus on the road.'\n", "# query = 'A path.'\n", "\n", "\n", "retImage = image_generation(query)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1a8a1654-1f17-41c7-b410-c7491a96dcee", "metadata": {}, "outputs": [], "source": [ "retImage.save(f'{query}.png')" ] } ], "metadata": { "kernelspec": { "display_name": "ofa", "language": "python", "name": "ofa" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 5 }