In [None]:
print("hello")

In [None]:
import os
import torch
import argparse
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

os.chdir("..")

from PIL import Image
from model import FoundModel
from misc import load_config
from torchvision import transforms as T

NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

In [None]:
PATH_TO_IMG = "./notebooks/0409.jpg"
GT = "./notebooks/0409.png"
SCRIBBLE = "./notebooks/11965.png"

In [None]:
img = Image.open(PATH_TO_IMG)
img = img.convert("RGB")
img

In [None]:
scr = Image.open(GT)
scr = scr.convert("P")
scr

In [None]:
try:
    from torchvision.transforms import InterpolationMode

    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC
    
def _preprocess(img, img_size):
    transform = T.Compose(
        [
            T.Resize(img_size, BICUBIC),
            T.CenterCrop(img_size),
            T.ToTensor(),
            NORMALIZE
        ]
    )
    return transform(img)

In [None]:
img_t = _preprocess(img, 224)#[None,:,:,:]
inputs = img_t.to("cuda")
inputs.shape

In [None]:
scribble = scribble.to("cuda")
scribble.shape

In [None]:
m_i = inputs * scribble
m_i = m_i[None,:,:,:]
inputs = m_i.to("cuda")
inputs.shape

In [None]:
from datasets.utils import unnormalize
img_init = unnormalize(m_i)
img_init.shape

In [None]:
import cv2
import numpy as np 

ten =(img_init.permute(1,2,0).detach().cpu().numpy())
ten=(ten*255).astype(np.uint8)
#ten=cv2.cvtColor(ten,cv2.COLOR_RGB2BGR)
ten.shape

In [None]:
plt.imshow(ten)
plt.axis('off')
plt.savefig('masked_image.png', bbox_inches='tight', pad_inches=0)

In [None]:
gt = Image.open(GT)
gt = gt.convert("P")
gt

In [None]:
try:
    from torchvision.transforms import InterpolationMode

    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC
    
def _preprocess_scribble(img, img_size):
    transform = T.Compose(
        [
            T.Resize(img_size, BICUBIC),
            T.CenterCrop(img_size),
            T.ToTensor(),
        ]
    )
    return transform(img)

In [None]:
scribble = _preprocess_scribble(scr, 224)
#scribble = (scribble > 0).float()  # threshold to [0,1]
#scribble = torch.max(scribble) - scribble  # inverted scribble

In [None]:
scribble.shape

In [None]:
import cv2
import numpy as np 

tens =(scribble.permute(1,2,0).detach().cpu().numpy())
tens=(tens*255).astype(np.uint8)
#ten=cv2.cvtColor(ten,cv2.COLOR_RGB2BGR)
tens.shape

In [None]:
plt.imshow(tens, cmap='gray')
plt.axis('off')
plt.savefig('gt.png', bbox_inches='tight', pad_inches=0)

In [None]:
masked_img_t = img * scribble

In [None]:
model = FoundModel(vit_model="dino",
                    vit_arch="vit_small",
                    vit_patch_size=8,
                    enc_type_feats="k",
                    bkg_type_feats="k",
                    bkg_th=0.3)

# Load weights
model.decoder_load_weights("./outputs/msl_a1.5_b1_g1_reg4-MSL-DUTS-TR-vit_small8/decoder_weights_niter500.pt")
model.eval()

In [None]:
# Forward step
with torch.no_grad():
    preds, _, shape_f, att = model.forward_step(inputs, for_eval=True)

# Apply FOUND
sigmoid = nn.Sigmoid()
h, w = img_t.shape[-2:]
preds_up = F.interpolate(
    preds, scale_factor=model.vit_patch_size, mode="bilinear", align_corners=False
)[..., :h, :w]
preds_up = (
    (sigmoid(preds_up.detach()) > 0.5).squeeze(0).float()
)

In [None]:
plt.imshow(preds_up.cpu().squeeze().numpy(), cmap='gray')
plt.axis('off')
plt.savefig('masked_pred.png', bbox_inches='tight', pad_inches=0)

In [None]:
preds_up.shape

In [None]:
def read_image(path):
    image = cv2.imread(path, -1)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = make_border(image)
    return image


def make_border(im):
    row, col = im.shape[:2]
    bottom = im[row-2:row, 0:col]
    mean = cv2.mean(bottom)[0]
    bordersize = 5
    border = cv2.copyMakeBorder(
        im,
        top=bordersize,
        bottom=bordersize,
        left=bordersize,
        right=bordersize,
        borderType=cv2.BORDER_CONSTANT,
        value=[0, 0, 0]
    )
    return border

In [None]:
img = read_image("./notebooks/scribble.png")

In [None]:
plt.imshow(img)
plt.axis('off')
plt.savefig('scribble.png', bbox_inches='tight', pad_inches=0)