File size: 4,541 Bytes
25cae60 3ab688e 3317dd8 25cae60 3e36141 25cae60 3317dd8 853c722 3317dd8 25cae60 5c00c7e 25cae60 bbefb98 25cae60 3ab688e 977fa50 aed8463 977fa50 aed8463 977fa50 5bf6d36 977fa50 25cae60 907a760 25cae60 aed8463 25cae60 3ab688e aed8463 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import os
import torch
import argparse
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image
from model import FoundModel
from misc import load_config
from torchvision import transforms as T
import gradio as gr
MAX_IM_SIZE = 512
NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
CACHE = True
def blend_images(bg, fg, alpha=0.5):
bg = bg.convert('RGBA')
fg = fg.convert('RGBA')
blended = Image.blend(bg, fg, alpha=alpha)
return blended
def predict(img_input):
config = "configs/found_DUTS-TR.yaml"
model_weights = "data/weights/decoder_weights.pt"
# Configuration
config = load_config(config)
# ------------------------------------
# Load the model
model = FoundModel(vit_model=config.model["pre_training"],
vit_arch=config.model["arch"],
vit_patch_size=config.model["patch_size"],
enc_type_feats=config.found["feats"],
bkg_type_feats=config.found["feats"],
bkg_th=config.found["bkg_th"])
# Load weights
model.decoder_load_weights(model_weights)
model.eval()
print(f"Model {model_weights} loaded correctly.")
# Load the image
img_pil = Image.open(img_input)
img = img_pil.convert("RGB")
# Image transformations
transforms = [T.ToTensor()]
# Resize image if needed
if img.size[0] > MAX_IM_SIZE or img.size[1] > MAX_IM_SIZE:
transforms.append(T.Resize(MAX_IM_SIZE))
transforms.append(NORMALIZE)
t = T.Compose(transforms)
img_t = t(img)[None,:,:,:]
inputs = img_t
# Forward step
with torch.no_grad():
preds, _, _, _ = model.forward_step(inputs, for_eval=True)
# Apply FOUND
sigmoid = nn.Sigmoid()
h, w = img_t.shape[-2:]
preds_up = F.interpolate(
preds, scale_factor=model.vit_patch_size, mode="bilinear", align_corners=False
)[..., :h, :w]
preds_up = (
(sigmoid(preds_up.detach()) > 0.5).squeeze(0).float()
)
return blend_images(img_pil.resize([img_t.shape[-1], img_t.shape[-2]]), T.ToPILImage()(preds_up))
title = 'FOUND - unsupervised object localization'
description = 'Gradio Demo for our CVPR23 paper "Unsupervised Object Localization: Observing the Background to Discover Objects"\n \
The app is <i>running on CPUs</i>, inference times are therefore longer than those expected on GPU (80 FPS on a V100 GPU).\n \
Please see below for more details.'
article = """
<h1 align="center">Unsupervised Object Localization: Observing the Background to Discover Objects</h1>
## Highlights
- Single **conv 1 x 1** layer trained to extract information from DINO [1] features.
- **No supervision**.
- Trained only for **2 epochs** on the dataset DUTS-TR.
- Inference runs at **80 FPS** on a V100 GPU.
- No post-processing applied in results here.
<i> Images provided are taken from VOC07 [2], ECSSD [3] and DUT-OMRON [4].</i>
## Citation
```
@inproceedings{simeoni2023found,
author = {Siméoni, Oriane and Sekkat, Chloé and Puy, Gilles and Vobecky, Antonin and Zablocki, Éloi and Pérez, Patrick},
title = {Unsupervised Object Localization: Observing the Background to Discover Objects},
booktitle = {{IEEE} Conference on Computer Vision and Pattern Recognition, {CVPR}},
year = {2023},
}
```
### References
[1] M. Caron et al. Emerging properties in self-supervised vision transformers, ICCV 2021
[2] M. Everingham et al. The PASCAL Visual Object Classes Challenge 2007 (VOC2007) Results
[3] J. Shi et al. Hierarchical image saliency detection on extended CSSD, IEEE TPAMI 2016
[4] C. Yang et al. Saliency detection via graph-based manifold ranking, CVPR 2013
"""
examples = ["data/examples/VOC_000030.jpg",
"data/examples/ECSSD_0010.png",
"data/examples/VOC07_000038.jpg",
"data/examples/VOC07_000075.jpg",
"data/examples/DUT-OMRON_im103.png",
]
iface = gr.Interface(fn=predict,
title=title,
description=description,
article=article,
inputs=gr.Image(type='filepath'),
outputs=gr.Image(label="Unsupervised object localization", type="pil"),
examples=examples,
cache_examples=CACHE
)
iface.launch(show_error=True,
enable_queue=True,
inline=True,
) |