{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"hello\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import torch\n", "import argparse\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import matplotlib.pyplot as plt\n", "\n", "os.chdir(\"..\")\n", "\n", "from PIL import Image\n", "from model import FoundModel\n", "from misc import load_config\n", "from torchvision import transforms as T\n", "\n", "NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "PATH_TO_IMG = \"./notebooks/0409.jpg\"\n", "GT = \"./notebooks/0409.png\"\n", "SCRIBBLE = \"./notebooks/11965.png\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "img = Image.open(PATH_TO_IMG)\n", "img = img.convert(\"RGB\")\n", "img" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "scr = Image.open(GT)\n", "scr = scr.convert(\"P\")\n", "scr" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "try:\n", " from torchvision.transforms import InterpolationMode\n", "\n", " BICUBIC = InterpolationMode.BICUBIC\n", "except ImportError:\n", " BICUBIC = Image.BICUBIC\n", " \n", "def _preprocess(img, img_size):\n", " transform = T.Compose(\n", " [\n", " T.Resize(img_size, BICUBIC),\n", " T.CenterCrop(img_size),\n", " T.ToTensor(),\n", " NORMALIZE\n", " ]\n", " )\n", " return transform(img)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "img_t = _preprocess(img, 224)#[None,:,:,:]\n", "inputs = img_t.to(\"cuda\")\n", "inputs.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "scribble = scribble.to(\"cuda\")\n", "scribble.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "m_i = inputs * scribble\n", "m_i = m_i[None,:,:,:]\n", "inputs = m_i.to(\"cuda\")\n", "inputs.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datasets.utils import unnormalize\n", "img_init = unnormalize(m_i)\n", "img_init.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import cv2\n", "import numpy as np \n", "\n", "ten =(img_init.permute(1,2,0).detach().cpu().numpy())\n", "ten=(ten*255).astype(np.uint8)\n", "#ten=cv2.cvtColor(ten,cv2.COLOR_RGB2BGR)\n", "ten.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.imshow(ten)\n", "plt.axis('off')\n", "plt.savefig('masked_image.png', bbox_inches='tight', pad_inches=0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "gt = Image.open(GT)\n", "gt = gt.convert(\"P\")\n", "gt" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "try:\n", " from torchvision.transforms import InterpolationMode\n", "\n", " BICUBIC = InterpolationMode.BICUBIC\n", "except ImportError:\n", " BICUBIC = Image.BICUBIC\n", " \n", "def _preprocess_scribble(img, img_size):\n", " transform = T.Compose(\n", " [\n", " T.Resize(img_size, BICUBIC),\n", " T.CenterCrop(img_size),\n", " T.ToTensor(),\n", " ]\n", " )\n", " return transform(img)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "scribble = _preprocess_scribble(scr, 224)\n", "#scribble = (scribble > 0).float() # threshold to [0,1]\n", "#scribble = torch.max(scribble) - scribble # inverted scribble" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "scribble.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import cv2\n", "import numpy as np \n", "\n", "tens =(scribble.permute(1,2,0).detach().cpu().numpy())\n", "tens=(tens*255).astype(np.uint8)\n", "#ten=cv2.cvtColor(ten,cv2.COLOR_RGB2BGR)\n", "tens.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.imshow(tens, cmap='gray')\n", "plt.axis('off')\n", "plt.savefig('gt.png', bbox_inches='tight', pad_inches=0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "masked_img_t = img * scribble" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = FoundModel(vit_model=\"dino\",\n", " vit_arch=\"vit_small\",\n", " vit_patch_size=8,\n", " enc_type_feats=\"k\",\n", " bkg_type_feats=\"k\",\n", " bkg_th=0.3)\n", "\n", "# Load weights\n", "model.decoder_load_weights(\"./outputs/msl_a1.5_b1_g1_reg4-MSL-DUTS-TR-vit_small8/decoder_weights_niter500.pt\")\n", "model.eval()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Forward step\n", "with torch.no_grad():\n", " preds, _, shape_f, att = model.forward_step(inputs, for_eval=True)\n", "\n", "# Apply FOUND\n", "sigmoid = nn.Sigmoid()\n", "h, w = img_t.shape[-2:]\n", "preds_up = F.interpolate(\n", " preds, scale_factor=model.vit_patch_size, mode=\"bilinear\", align_corners=False\n", ")[..., :h, :w]\n", "preds_up = (\n", " (sigmoid(preds_up.detach()) > 0.5).squeeze(0).float()\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.imshow(preds_up.cpu().squeeze().numpy(), cmap='gray')\n", "plt.axis('off')\n", "plt.savefig('masked_pred.png', bbox_inches='tight', pad_inches=0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "preds_up.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def read_image(path):\n", " image = cv2.imread(path, -1)\n", " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n", " image = make_border(image)\n", " return image\n", "\n", "\n", "def make_border(im):\n", " row, col = im.shape[:2]\n", " bottom = im[row-2:row, 0:col]\n", " mean = cv2.mean(bottom)[0]\n", " bordersize = 5\n", " border = cv2.copyMakeBorder(\n", " im,\n", " top=bordersize,\n", " bottom=bordersize,\n", " left=bordersize,\n", " right=bordersize,\n", " borderType=cv2.BORDER_CONSTANT,\n", " value=[0, 0, 0]\n", " )\n", " return border" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "img = read_image(\"./notebooks/scribble.png\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.imshow(img)\n", "plt.axis('off')\n", "plt.savefig('scribble.png', bbox_inches='tight', pad_inches=0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "tarmak", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.18" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }