Commit
•
48fa639
1
Parent(s):
bdc1819
upload clipseg
Browse files- clipseg/LICENSE +21 -0
- clipseg/Quickstart.ipynb +107 -0
- clipseg/Readme.md +84 -0
- clipseg/Tables.ipynb +349 -0
- clipseg/Visual_Feature_Engineering.ipynb +366 -0
- clipseg/datasets/coco_wrapper.py +99 -0
- clipseg/datasets/pascal_classes.json +1 -0
- clipseg/datasets/pascal_zeroshot.py +60 -0
- clipseg/datasets/pfe_dataset.py +129 -0
- clipseg/datasets/phrasecut.py +335 -0
- clipseg/datasets/utils.py +68 -0
- clipseg/environment.yml +15 -0
- clipseg/evaluation_utils.py +292 -0
- clipseg/example_image.jpg +0 -0
- clipseg/experiments/ablation.yaml +84 -0
- clipseg/experiments/coco.yaml +101 -0
- clipseg/experiments/pascal_1shot.yaml +101 -0
- clipseg/experiments/phrasecut.yaml +80 -0
- clipseg/general_utils.py +272 -0
- clipseg/metrics.py +271 -0
- clipseg/models/clipseg.py +552 -0
- clipseg/models/vitseg.py +286 -0
- clipseg/overview.png +0 -0
- clipseg/score.py +453 -0
- clipseg/setup.py +30 -0
- clipseg/training.py +266 -0
- clipseg/weights/rd64-uni.pth +3 -0
clipseg/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
This license does not apply to the model weights.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
clipseg/Quickstart.ipynb
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import torch\n",
|
10 |
+
"import requests\n",
|
11 |
+
"\n",
|
12 |
+
"! wget https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -O weights.zip\n",
|
13 |
+
"! unzip -d weights -j weights.zip\n",
|
14 |
+
"from models.clipseg import CLIPDensePredT\n",
|
15 |
+
"from PIL import Image\n",
|
16 |
+
"from torchvision import transforms\n",
|
17 |
+
"from matplotlib import pyplot as plt\n",
|
18 |
+
"\n",
|
19 |
+
"# load model\n",
|
20 |
+
"model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)\n",
|
21 |
+
"model.eval();\n",
|
22 |
+
"\n",
|
23 |
+
"# non-strict, because we only stored decoder weights (not CLIP weights)\n",
|
24 |
+
"model.load_state_dict(torch.load('weights/rd64-uni.pth', map_location=torch.device('cpu')), strict=False);"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"cell_type": "markdown",
|
29 |
+
"metadata": {},
|
30 |
+
"source": [
|
31 |
+
"Load and normalize `example_image.jpg`. You can also load through an URL."
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "code",
|
36 |
+
"execution_count": null,
|
37 |
+
"metadata": {},
|
38 |
+
"outputs": [],
|
39 |
+
"source": [
|
40 |
+
"# load and normalize image\n",
|
41 |
+
"input_image = Image.open('example_image.jpg')\n",
|
42 |
+
"\n",
|
43 |
+
"# or load from URL...\n",
|
44 |
+
"# image_url = 'https://farm5.staticflickr.com/4141/4856248695_03475782dc_z.jpg'\n",
|
45 |
+
"# input_image = Image.open(requests.get(image_url, stream=True).raw)\n",
|
46 |
+
"\n",
|
47 |
+
"transform = transforms.Compose([\n",
|
48 |
+
" transforms.ToTensor(),\n",
|
49 |
+
" transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
|
50 |
+
" transforms.Resize((352, 352)),\n",
|
51 |
+
"])\n",
|
52 |
+
"img = transform(input_image).unsqueeze(0)"
|
53 |
+
]
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"cell_type": "markdown",
|
57 |
+
"metadata": {},
|
58 |
+
"source": [
|
59 |
+
"Predict and visualize (this might take a few seconds if running without GPU support)"
|
60 |
+
]
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"cell_type": "code",
|
64 |
+
"execution_count": null,
|
65 |
+
"metadata": {},
|
66 |
+
"outputs": [],
|
67 |
+
"source": [
|
68 |
+
"prompts = ['a glass', 'something to fill', 'wood', 'a jar']\n",
|
69 |
+
"\n",
|
70 |
+
"# predict\n",
|
71 |
+
"with torch.no_grad():\n",
|
72 |
+
" preds = model(img.repeat(4,1,1,1), prompts)[0]\n",
|
73 |
+
"\n",
|
74 |
+
"# visualize prediction\n",
|
75 |
+
"_, ax = plt.subplots(1, 5, figsize=(15, 4))\n",
|
76 |
+
"[a.axis('off') for a in ax.flatten()]\n",
|
77 |
+
"ax[0].imshow(input_image)\n",
|
78 |
+
"[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(4)];\n",
|
79 |
+
"[ax[i+1].text(0, -15, prompts[i]) for i in range(4)];"
|
80 |
+
]
|
81 |
+
}
|
82 |
+
],
|
83 |
+
"metadata": {
|
84 |
+
"interpreter": {
|
85 |
+
"hash": "800ed241f7db2bd3aa6942aa3be6809cdb30ee6b0a9e773dfecfa9fef1f4c586"
|
86 |
+
},
|
87 |
+
"kernelspec": {
|
88 |
+
"display_name": "Python 3",
|
89 |
+
"language": "python",
|
90 |
+
"name": "python3"
|
91 |
+
},
|
92 |
+
"language_info": {
|
93 |
+
"codemirror_mode": {
|
94 |
+
"name": "ipython",
|
95 |
+
"version": 3
|
96 |
+
},
|
97 |
+
"file_extension": ".py",
|
98 |
+
"mimetype": "text/x-python",
|
99 |
+
"name": "python",
|
100 |
+
"nbconvert_exporter": "python",
|
101 |
+
"pygments_lexer": "ipython3",
|
102 |
+
"version": "3.8.10"
|
103 |
+
}
|
104 |
+
},
|
105 |
+
"nbformat": 4,
|
106 |
+
"nbformat_minor": 4
|
107 |
+
}
|
clipseg/Readme.md
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Image Segmentation Using Text and Image Prompts
|
2 |
+
This repository contains the code used in the paper ["Image Segmentation Using Text and Image Prompts"](https://arxiv.org/abs/2112.10003).
|
3 |
+
|
4 |
+
**The Paper has been accepted to CVPR 2022!**
|
5 |
+
|
6 |
+
<img src="overview.png" alt="drawing" height="200em"/>
|
7 |
+
|
8 |
+
The systems allows to create segmentation models without training based on:
|
9 |
+
- An arbitrary text query
|
10 |
+
- Or an image with a mask highlighting stuff or an object.
|
11 |
+
|
12 |
+
### Quick Start
|
13 |
+
|
14 |
+
In the `Quickstart.ipynb` notebook we provide the code for using a pre-trained CLIPSeg model. If you run the notebook locally, make sure you downloaded the `rd64-uni.pth` weights, either manually or via git lfs extension.
|
15 |
+
It can also be used interactively using [MyBinder](https://mybinder.org/v2/gh/timojl/clipseg/HEAD?labpath=Quickstart.ipynb)
|
16 |
+
(please note that the VM does not use a GPU, thus inference takes a few seconds).
|
17 |
+
|
18 |
+
|
19 |
+
### Dependencies
|
20 |
+
This code base depends on pytorch, torchvision and clip (`pip install git+https://github.com/openai/CLIP.git`).
|
21 |
+
Additional dependencies are hidden for double blind review.
|
22 |
+
|
23 |
+
|
24 |
+
### Datasets
|
25 |
+
|
26 |
+
* `PhraseCut` and `PhraseCutPlus`: Referring expression dataset
|
27 |
+
* `PFEPascalWrapper`: Wrapper class for PFENet's Pascal-5i implementation
|
28 |
+
* `PascalZeroShot`: Wrapper class for PascalZeroShot
|
29 |
+
* `COCOWrapper`: Wrapper class for COCO.
|
30 |
+
|
31 |
+
### Models
|
32 |
+
|
33 |
+
* `CLIPDensePredT`: CLIPSeg model with transformer-based decoder.
|
34 |
+
* `ViTDensePredT`: CLIPSeg model with transformer-based decoder.
|
35 |
+
|
36 |
+
### Third Party Dependencies
|
37 |
+
For some of the datasets third party dependencies are required. Run the following commands in the `third_party` folder.
|
38 |
+
```bash
|
39 |
+
git clone https://github.com/cvlab-yonsei/JoEm
|
40 |
+
git clone https://github.com/Jia-Research-Lab/PFENet.git
|
41 |
+
git clone https://github.com/ChenyunWu/PhraseCutDataset.git
|
42 |
+
git clone https://github.com/juhongm999/hsnet.git
|
43 |
+
```
|
44 |
+
|
45 |
+
### Weights
|
46 |
+
|
47 |
+
The MIT license does not apply to these weights.
|
48 |
+
|
49 |
+
We provide two model weights, for D=64 (4.1MB) and D=16 (1.1MB).
|
50 |
+
```
|
51 |
+
wget https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -O weights.zip
|
52 |
+
unzip -d weights -j weights.zip
|
53 |
+
```
|
54 |
+
|
55 |
+
|
56 |
+
### Training and Evaluation
|
57 |
+
|
58 |
+
To train use the `training.py` script with experiment file and experiment id parameters. E.g. `python training.py phrasecut.yaml 0` will train the first phrasecut experiment which is defined by the `configuration` and first `individual_configurations` parameters. Model weights will be written in `logs/`.
|
59 |
+
|
60 |
+
For evaluation use `score.py`. E.g. `python score.py phrasecut.yaml 0 0` will train the first phrasecut experiment of `test_configuration` and the first configuration in `individual_configurations`.
|
61 |
+
|
62 |
+
|
63 |
+
### Usage of PFENet Wrappers
|
64 |
+
|
65 |
+
In order to use the dataset and model wrappers for PFENet, the PFENet repository needs to be cloned to the root folder.
|
66 |
+
`git clone https://github.com/Jia-Research-Lab/PFENet.git `
|
67 |
+
|
68 |
+
|
69 |
+
### License
|
70 |
+
|
71 |
+
The source code files in this repository (excluding model weights) are released under MIT license.
|
72 |
+
|
73 |
+
### Citation
|
74 |
+
```
|
75 |
+
@InProceedings{lueddecke22_cvpr,
|
76 |
+
author = {L\"uddecke, Timo and Ecker, Alexander},
|
77 |
+
title = {Image Segmentation Using Text and Image Prompts},
|
78 |
+
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
79 |
+
month = {June},
|
80 |
+
year = {2022},
|
81 |
+
pages = {7086-7096}
|
82 |
+
}
|
83 |
+
|
84 |
+
```
|
clipseg/Tables.ipynb
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"%load_ext autoreload\n",
|
10 |
+
"%autoreload 2\n",
|
11 |
+
"\n",
|
12 |
+
"import clip\n",
|
13 |
+
"from evaluation_utils import norm, denorm\n",
|
14 |
+
"from general_utils import *\n",
|
15 |
+
"from datasets.lvis_oneshot3 import LVIS_OneShot3, LVIS_OneShot"
|
16 |
+
]
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"cell_type": "markdown",
|
20 |
+
"metadata": {},
|
21 |
+
"source": [
|
22 |
+
"# PhraseCut"
|
23 |
+
]
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"cell_type": "code",
|
27 |
+
"execution_count": null,
|
28 |
+
"metadata": {},
|
29 |
+
"outputs": [],
|
30 |
+
"source": [
|
31 |
+
"pc = experiment('experiments/phrasecut.yaml', nums=':6').dataframe()"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "code",
|
36 |
+
"execution_count": null,
|
37 |
+
"metadata": {},
|
38 |
+
"outputs": [],
|
39 |
+
"source": [
|
40 |
+
"tab1 = pc[['name', 'pc_miou_best', 'pc_fgiou_best', 'pc_ap']]"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "code",
|
45 |
+
"execution_count": null,
|
46 |
+
"metadata": {},
|
47 |
+
"outputs": [],
|
48 |
+
"source": [
|
49 |
+
"cols = ['pc_miou_0.3', 'pc_fgiou_0.3', 'pc_ap']\n",
|
50 |
+
"tab1 = pc[['name'] + cols]\n",
|
51 |
+
"for k in cols:\n",
|
52 |
+
" tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n",
|
53 |
+
"tab1.loc[:, 'name'] = ['CLIPSeg (PC+)', 'CLIPSeg (PC, $D=128$)', 'CLIPSeg (PC)', 'CLIP-Deconv', 'ViTSeg (PC+)', 'ViTSeg (PC)']\n",
|
54 |
+
"tab1.insert(1, 't', [0.3]*tab1.shape[0])\n",
|
55 |
+
"print(tab1.to_latex(header=False, index=False))"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "markdown",
|
60 |
+
"metadata": {},
|
61 |
+
"source": [
|
62 |
+
"For 0.1 threshold"
|
63 |
+
]
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"cell_type": "code",
|
67 |
+
"execution_count": null,
|
68 |
+
"metadata": {},
|
69 |
+
"outputs": [],
|
70 |
+
"source": [
|
71 |
+
"cols = ['pc_miou_0.1', 'pc_fgiou_0.1', 'pc_ap']\n",
|
72 |
+
"tab1 = pc[['name'] + cols]\n",
|
73 |
+
"for k in cols:\n",
|
74 |
+
" tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n",
|
75 |
+
"tab1.loc[:, 'name'] = ['CLIPSeg (PC+)', 'CLIPSeg (PC, $D=128$)', 'CLIPSeg (PC)', 'CLIP-Deconv', 'ViTSeg (PC+)', 'ViTSeg (PC)']\n",
|
76 |
+
"tab1.insert(1, 't', [0.1]*tab1.shape[0])\n",
|
77 |
+
"print(tab1.to_latex(header=False, index=False))"
|
78 |
+
]
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"cell_type": "markdown",
|
82 |
+
"metadata": {},
|
83 |
+
"source": [
|
84 |
+
"# One-shot"
|
85 |
+
]
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"cell_type": "markdown",
|
89 |
+
"metadata": {},
|
90 |
+
"source": [
|
91 |
+
"### Pascal"
|
92 |
+
]
|
93 |
+
},
|
94 |
+
{
|
95 |
+
"cell_type": "code",
|
96 |
+
"execution_count": null,
|
97 |
+
"metadata": {},
|
98 |
+
"outputs": [],
|
99 |
+
"source": [
|
100 |
+
"pas = experiment('experiments/pascal_1shot.yaml', nums=':19').dataframe()"
|
101 |
+
]
|
102 |
+
},
|
103 |
+
{
|
104 |
+
"cell_type": "code",
|
105 |
+
"execution_count": null,
|
106 |
+
"metadata": {},
|
107 |
+
"outputs": [],
|
108 |
+
"source": [
|
109 |
+
"pas[['name', 'pas_h2_miou_0.3', 'pas_h2_biniou_0.3', 'pas_h2_ap', 'pas_h2_fgiou_ct']]"
|
110 |
+
]
|
111 |
+
},
|
112 |
+
{
|
113 |
+
"cell_type": "code",
|
114 |
+
"execution_count": null,
|
115 |
+
"metadata": {},
|
116 |
+
"outputs": [],
|
117 |
+
"source": [
|
118 |
+
"pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n",
|
119 |
+
"tab1 = pas[['pas_h2_miou_0.3', 'pas_h2_biniou_0.3', 'pas_h2_ap']]\n",
|
120 |
+
"print('CLIPSeg (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
|
121 |
+
"print('CLIPSeg (PC) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
|
122 |
+
"\n",
|
123 |
+
"pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n",
|
124 |
+
"tab1 = pas[['pas_h2_miou_0.2', 'pas_h2_biniou_0.2', 'pas_h2_ap']]\n",
|
125 |
+
"print('CLIP-Deconv (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
|
126 |
+
"\n",
|
127 |
+
"pas = experiment('experiments/pascal_1shot.yaml', nums='16:20').dataframe()\n",
|
128 |
+
"tab1 = pas[['pas_t_miou_0.2', 'pas_t_biniou_0.2', 'pas_t_ap']]\n",
|
129 |
+
"print('ViTSeg (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')"
|
130 |
+
]
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"cell_type": "markdown",
|
134 |
+
"metadata": {},
|
135 |
+
"source": [
|
136 |
+
"#### Pascal Zero-shot (in one-shot setting)\n",
|
137 |
+
"\n",
|
138 |
+
"Using the same setting as one-shot (hence different from the other zero-shot benchmark)"
|
139 |
+
]
|
140 |
+
},
|
141 |
+
{
|
142 |
+
"cell_type": "code",
|
143 |
+
"execution_count": null,
|
144 |
+
"metadata": {},
|
145 |
+
"outputs": [],
|
146 |
+
"source": [
|
147 |
+
"pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n",
|
148 |
+
"tab1 = pas[['pas_t_miou_0.3', 'pas_t_biniou_0.3', 'pas_t_ap']]\n",
|
149 |
+
"print('CLIPSeg (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
|
150 |
+
"print('CLIPSeg (PC) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
|
151 |
+
"\n",
|
152 |
+
"pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n",
|
153 |
+
"tab1 = pas[['pas_t_miou_0.3', 'pas_t_biniou_0.3', 'pas_t_ap']]\n",
|
154 |
+
"print('CLIP-Deconv (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
|
155 |
+
"\n",
|
156 |
+
"pas = experiment('experiments/pascal_1shot.yaml', nums='16:20').dataframe()\n",
|
157 |
+
"tab1 = pas[['pas_t_miou_0.2', 'pas_t_biniou_0.2', 'pas_t_ap']]\n",
|
158 |
+
"print('ViTSeg (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')"
|
159 |
+
]
|
160 |
+
},
|
161 |
+
{
|
162 |
+
"cell_type": "code",
|
163 |
+
"execution_count": null,
|
164 |
+
"metadata": {},
|
165 |
+
"outputs": [],
|
166 |
+
"source": [
|
167 |
+
"# without fixed thresholds...\n",
|
168 |
+
"\n",
|
169 |
+
"pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n",
|
170 |
+
"tab1 = pas[['pas_t_best_miou', 'pas_t_best_biniou', 'pas_t_ap']]\n",
|
171 |
+
"print('CLIPSeg (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
|
172 |
+
"print('CLIPSeg (PC) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
|
173 |
+
"\n",
|
174 |
+
"pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n",
|
175 |
+
"tab1 = pas[['pas_t_best_miou', 'pas_t_best_biniou', 'pas_t_ap']]\n",
|
176 |
+
"print('CLIP-Deconv (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')"
|
177 |
+
]
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"cell_type": "markdown",
|
181 |
+
"metadata": {},
|
182 |
+
"source": [
|
183 |
+
"### COCO"
|
184 |
+
]
|
185 |
+
},
|
186 |
+
{
|
187 |
+
"cell_type": "code",
|
188 |
+
"execution_count": null,
|
189 |
+
"metadata": {},
|
190 |
+
"outputs": [],
|
191 |
+
"source": [
|
192 |
+
"coco = experiment('experiments/coco.yaml', nums=':29').dataframe()"
|
193 |
+
]
|
194 |
+
},
|
195 |
+
{
|
196 |
+
"cell_type": "code",
|
197 |
+
"execution_count": null,
|
198 |
+
"metadata": {},
|
199 |
+
"outputs": [],
|
200 |
+
"source": [
|
201 |
+
"tab1 = coco[['coco_h2_miou_0.1', 'coco_h2_biniou_0.1', 'coco_h2_ap']]\n",
|
202 |
+
"tab2 = coco[['coco_h2_miou_0.2', 'coco_h2_biniou_0.2', 'coco_h2_ap']]\n",
|
203 |
+
"tab3 = coco[['coco_h2_miou_best', 'coco_h2_biniou_best', 'coco_h2_ap']]\n",
|
204 |
+
"print('CLIPSeg (COCO) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[:4].mean(0).values), '\\\\\\\\')\n",
|
205 |
+
"print('CLIPSeg (COCO+N) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
|
206 |
+
"print('CLIP-Deconv (COCO+N) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[12:16].mean(0).values), '\\\\\\\\')\n",
|
207 |
+
"print('ViTSeg (COCO) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[8:12].mean(0).values), '\\\\\\\\')"
|
208 |
+
]
|
209 |
+
},
|
210 |
+
{
|
211 |
+
"cell_type": "markdown",
|
212 |
+
"metadata": {},
|
213 |
+
"source": [
|
214 |
+
"# Zero-shot"
|
215 |
+
]
|
216 |
+
},
|
217 |
+
{
|
218 |
+
"cell_type": "code",
|
219 |
+
"execution_count": null,
|
220 |
+
"metadata": {},
|
221 |
+
"outputs": [],
|
222 |
+
"source": [
|
223 |
+
"zs = experiment('experiments/pascal_0shot.yaml', nums=':11').dataframe()"
|
224 |
+
]
|
225 |
+
},
|
226 |
+
{
|
227 |
+
"cell_type": "code",
|
228 |
+
"execution_count": null,
|
229 |
+
"metadata": {},
|
230 |
+
"outputs": [],
|
231 |
+
"source": [
|
232 |
+
"\n",
|
233 |
+
"tab1 = zs[['pas_zs_seen', 'pas_zs_unseen']]\n",
|
234 |
+
"print('CLIPSeg (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[8:9].values[0].tolist() + tab1[10:11].values[0].tolist()), '\\\\\\\\')\n",
|
235 |
+
"print('CLIP-Deconv & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[2:3].values[0].tolist() + tab1[3:4].values[0].tolist()), '\\\\\\\\')\n",
|
236 |
+
"print('ViTSeg & ImageNet-1K & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:5].values[0].tolist() + tab1[5:6].values[0].tolist()), '\\\\\\\\')"
|
237 |
+
]
|
238 |
+
},
|
239 |
+
{
|
240 |
+
"cell_type": "markdown",
|
241 |
+
"metadata": {},
|
242 |
+
"source": [
|
243 |
+
"# Ablation"
|
244 |
+
]
|
245 |
+
},
|
246 |
+
{
|
247 |
+
"cell_type": "code",
|
248 |
+
"execution_count": null,
|
249 |
+
"metadata": {},
|
250 |
+
"outputs": [],
|
251 |
+
"source": [
|
252 |
+
"ablation = experiment('experiments/ablation.yaml', nums=':8').dataframe()"
|
253 |
+
]
|
254 |
+
},
|
255 |
+
{
|
256 |
+
"cell_type": "code",
|
257 |
+
"execution_count": null,
|
258 |
+
"metadata": {},
|
259 |
+
"outputs": [],
|
260 |
+
"source": [
|
261 |
+
"tab1 = ablation[['name', 'pc_miou_best', 'pc_ap', 'pc-vis_miou_best', 'pc-vis_ap']]\n",
|
262 |
+
"for k in ['pc_miou_best', 'pc_ap', 'pc-vis_miou_best', 'pc-vis_ap']:\n",
|
263 |
+
" tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n",
|
264 |
+
"tab1.loc[:, 'name'] = ['CLIPSeg', 'no CLIP pre-training', 'no-negatives', '50% negatives', 'no visual', '$D=16$', 'only layer 3', 'highlight mask']"
|
265 |
+
]
|
266 |
+
},
|
267 |
+
{
|
268 |
+
"cell_type": "code",
|
269 |
+
"execution_count": null,
|
270 |
+
"metadata": {},
|
271 |
+
"outputs": [],
|
272 |
+
"source": [
|
273 |
+
"print(tab1.loc[[0,1,4,5,6,7],:].to_latex(header=False, index=False))"
|
274 |
+
]
|
275 |
+
},
|
276 |
+
{
|
277 |
+
"cell_type": "code",
|
278 |
+
"execution_count": null,
|
279 |
+
"metadata": {},
|
280 |
+
"outputs": [],
|
281 |
+
"source": [
|
282 |
+
"print(tab1.loc[[0,1,4,5,6,7],:].to_latex(header=False, index=False))"
|
283 |
+
]
|
284 |
+
},
|
285 |
+
{
|
286 |
+
"cell_type": "markdown",
|
287 |
+
"metadata": {},
|
288 |
+
"source": [
|
289 |
+
"# Generalization"
|
290 |
+
]
|
291 |
+
},
|
292 |
+
{
|
293 |
+
"cell_type": "code",
|
294 |
+
"execution_count": null,
|
295 |
+
"metadata": {},
|
296 |
+
"outputs": [],
|
297 |
+
"source": [
|
298 |
+
"generalization = experiment('experiments/generalize.yaml').dataframe()"
|
299 |
+
]
|
300 |
+
},
|
301 |
+
{
|
302 |
+
"cell_type": "code",
|
303 |
+
"execution_count": null,
|
304 |
+
"metadata": {},
|
305 |
+
"outputs": [],
|
306 |
+
"source": [
|
307 |
+
"gen = generalization[['aff_best_fgiou', 'aff_ap', 'ability_best_fgiou', 'ability_ap', 'part_best_fgiou', 'part_ap']].values"
|
308 |
+
]
|
309 |
+
},
|
310 |
+
{
|
311 |
+
"cell_type": "code",
|
312 |
+
"execution_count": null,
|
313 |
+
"metadata": {},
|
314 |
+
"outputs": [],
|
315 |
+
"source": [
|
316 |
+
"print(\n",
|
317 |
+
" 'CLIPSeg (PC+) & ' + ' & '.join(f'{x*100:.1f}' for x in gen[1]) + ' \\\\\\\\ \\n' + \\\n",
|
318 |
+
" 'CLIPSeg (LVIS) & ' + ' & '.join(f'{x*100:.1f}' for x in gen[0]) + ' \\\\\\\\ \\n' + \\\n",
|
319 |
+
" 'CLIP-Deconv & ' + ' & '.join(f'{x*100:.1f}' for x in gen[2]) + ' \\\\\\\\ \\n' + \\\n",
|
320 |
+
" 'VITSeg & ' + ' & '.join(f'{x*100:.1f}' for x in gen[3]) + ' \\\\\\\\'\n",
|
321 |
+
")"
|
322 |
+
]
|
323 |
+
}
|
324 |
+
],
|
325 |
+
"metadata": {
|
326 |
+
"interpreter": {
|
327 |
+
"hash": "800ed241f7db2bd3aa6942aa3be6809cdb30ee6b0a9e773dfecfa9fef1f4c586"
|
328 |
+
},
|
329 |
+
"kernelspec": {
|
330 |
+
"display_name": "env2",
|
331 |
+
"language": "python",
|
332 |
+
"name": "env2"
|
333 |
+
},
|
334 |
+
"language_info": {
|
335 |
+
"codemirror_mode": {
|
336 |
+
"name": "ipython",
|
337 |
+
"version": 3
|
338 |
+
},
|
339 |
+
"file_extension": ".py",
|
340 |
+
"mimetype": "text/x-python",
|
341 |
+
"name": "python",
|
342 |
+
"nbconvert_exporter": "python",
|
343 |
+
"pygments_lexer": "ipython3",
|
344 |
+
"version": "3.8.8"
|
345 |
+
}
|
346 |
+
},
|
347 |
+
"nbformat": 4,
|
348 |
+
"nbformat_minor": 4
|
349 |
+
}
|
clipseg/Visual_Feature_Engineering.ipynb
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# Systematic"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "code",
|
12 |
+
"execution_count": null,
|
13 |
+
"metadata": {},
|
14 |
+
"outputs": [],
|
15 |
+
"source": [
|
16 |
+
"%load_ext autoreload\n",
|
17 |
+
"%autoreload 2\n",
|
18 |
+
"\n",
|
19 |
+
"import clip\n",
|
20 |
+
"from evaluation_utils import norm, denorm\n",
|
21 |
+
"from general_utils import *\n",
|
22 |
+
"from datasets.lvis_oneshot3 import LVIS_OneShot3\n",
|
23 |
+
"\n",
|
24 |
+
"clip_device = 'cuda'\n",
|
25 |
+
"clip_model, preprocess = clip.load(\"ViT-B/16\", device=clip_device)\n",
|
26 |
+
"clip_model.eval();\n",
|
27 |
+
"\n",
|
28 |
+
"from models.clipseg import CLIPDensePredTMasked\n",
|
29 |
+
"\n",
|
30 |
+
"clip_mask_model = CLIPDensePredTMasked(version='ViT-B/16').to(clip_device)\n",
|
31 |
+
"clip_mask_model.eval();"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "code",
|
36 |
+
"execution_count": null,
|
37 |
+
"metadata": {},
|
38 |
+
"outputs": [],
|
39 |
+
"source": [
|
40 |
+
"lvis = LVIS_OneShot3('train_fixed', mask='separate', normalize=True, with_class_label=True, add_bar=False, \n",
|
41 |
+
" text_class_labels=True, image_size=352, min_area=0.1,\n",
|
42 |
+
" min_frac_s=0.05, min_frac_q=0.05, fix_find_crop=True)"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "code",
|
47 |
+
"execution_count": null,
|
48 |
+
"metadata": {},
|
49 |
+
"outputs": [],
|
50 |
+
"source": [
|
51 |
+
"plot_data(lvis)"
|
52 |
+
]
|
53 |
+
},
|
54 |
+
{
|
55 |
+
"cell_type": "code",
|
56 |
+
"execution_count": null,
|
57 |
+
"metadata": {},
|
58 |
+
"outputs": [],
|
59 |
+
"source": [
|
60 |
+
"from collections import defaultdict\n",
|
61 |
+
"import json\n",
|
62 |
+
"\n",
|
63 |
+
"lvis_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_train.json')))\n",
|
64 |
+
"lvis_val_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_val.json')))\n",
|
65 |
+
"\n",
|
66 |
+
"objects_per_image = defaultdict(lambda : set())\n",
|
67 |
+
"for ann in lvis_raw['annotations']:\n",
|
68 |
+
" objects_per_image[ann['image_id']].add(ann['category_id'])\n",
|
69 |
+
" \n",
|
70 |
+
"for ann in lvis_val_raw['annotations']:\n",
|
71 |
+
" objects_per_image[ann['image_id']].add(ann['category_id']) \n",
|
72 |
+
" \n",
|
73 |
+
"objects_per_image = {o: [lvis.category_names[o] for o in v] for o, v in objects_per_image.items()}\n",
|
74 |
+
"\n",
|
75 |
+
"del lvis_raw, lvis_val_raw"
|
76 |
+
]
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"cell_type": "code",
|
80 |
+
"execution_count": null,
|
81 |
+
"metadata": {},
|
82 |
+
"outputs": [],
|
83 |
+
"source": [
|
84 |
+
"#bs = 32\n",
|
85 |
+
"#batches = [get_batch(lvis, i*bs, (i+1)*bs, cuda=True) for i in range(10)]"
|
86 |
+
]
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"cell_type": "code",
|
90 |
+
"execution_count": null,
|
91 |
+
"metadata": {},
|
92 |
+
"outputs": [],
|
93 |
+
"source": [
|
94 |
+
"from general_utils import get_batch\n",
|
95 |
+
"from functools import partial\n",
|
96 |
+
"from evaluation_utils import img_preprocess\n",
|
97 |
+
"import torch\n",
|
98 |
+
"\n",
|
99 |
+
"def get_similarities(batches_or_dataset, process, mask=lambda x: None, clipmask=False):\n",
|
100 |
+
"\n",
|
101 |
+
" # base_words = [f'a photo of {x}' for x in ['a person', 'an animal', 'a knife', 'a cup']]\n",
|
102 |
+
"\n",
|
103 |
+
" all_prompts = []\n",
|
104 |
+
" \n",
|
105 |
+
" with torch.no_grad():\n",
|
106 |
+
" valid_sims = []\n",
|
107 |
+
" torch.manual_seed(571)\n",
|
108 |
+
" \n",
|
109 |
+
" if type(batches_or_dataset) == list:\n",
|
110 |
+
" loader = batches_or_dataset # already loaded\n",
|
111 |
+
" max_iter = float('inf')\n",
|
112 |
+
" else:\n",
|
113 |
+
" loader = DataLoader(batches_or_dataset, shuffle=False, batch_size=32)\n",
|
114 |
+
" max_iter = 50\n",
|
115 |
+
" \n",
|
116 |
+
" global batch\n",
|
117 |
+
" for i_batch, (batch, batch_y) in enumerate(loader):\n",
|
118 |
+
" \n",
|
119 |
+
" if i_batch >= max_iter: break\n",
|
120 |
+
" \n",
|
121 |
+
" processed_batch = process(batch)\n",
|
122 |
+
" if type(processed_batch) == dict:\n",
|
123 |
+
" \n",
|
124 |
+
" # processed_batch = {k: v.to(clip_device) for k, v in processed_batch.items()}\n",
|
125 |
+
" image_features = clip_mask_model.visual_forward(**processed_batch)[0].to(clip_device).half()\n",
|
126 |
+
" else:\n",
|
127 |
+
" processed_batch = process(batch).to(clip_device)\n",
|
128 |
+
" processed_batch = nnf.interpolate(processed_batch, (224, 224), mode='bilinear')\n",
|
129 |
+
" #image_features = clip_model.encode_image(processed_batch.to(clip_device)) \n",
|
130 |
+
" image_features = clip_mask_model.visual_forward(processed_batch)[0].to(clip_device).half()\n",
|
131 |
+
" \n",
|
132 |
+
" image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n",
|
133 |
+
" bs = len(batch[0])\n",
|
134 |
+
" for j in range(bs):\n",
|
135 |
+
" \n",
|
136 |
+
" c, _, sid, qid = lvis.sample_ids[bs * i_batch + j]\n",
|
137 |
+
" support_image = basename(lvis.samples[c][sid])\n",
|
138 |
+
" \n",
|
139 |
+
" img_objs = [o for o in objects_per_image[int(support_image)]]\n",
|
140 |
+
" img_objs = [o.replace('_', ' ') for o in img_objs]\n",
|
141 |
+
" \n",
|
142 |
+
" other_words = [f'a photo of a {o.replace(\"_\", \" \")}' for o in img_objs \n",
|
143 |
+
" if o != batch_y[2][j]]\n",
|
144 |
+
" \n",
|
145 |
+
" prompts = [f'a photo of a {batch_y[2][j]}'] + other_words\n",
|
146 |
+
" all_prompts += [prompts]\n",
|
147 |
+
" \n",
|
148 |
+
" text_cond = clip_model.encode_text(clip.tokenize(prompts).to(clip_device))\n",
|
149 |
+
" text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True) \n",
|
150 |
+
"\n",
|
151 |
+
" global logits\n",
|
152 |
+
" logits = clip_model.logit_scale.exp() * image_features[j] @ text_cond.T\n",
|
153 |
+
"\n",
|
154 |
+
" global sim\n",
|
155 |
+
" sim = torch.softmax(logits, dim=-1)\n",
|
156 |
+
" \n",
|
157 |
+
" valid_sims += [sim]\n",
|
158 |
+
" \n",
|
159 |
+
" #valid_sims = torch.stack(valid_sims)\n",
|
160 |
+
" return valid_sims, all_prompts\n",
|
161 |
+
" \n",
|
162 |
+
"\n",
|
163 |
+
"def new_img_preprocess(x):\n",
|
164 |
+
" return {'x_inp': x[1], 'mask': (11, 'cls_token', x[2])}\n",
|
165 |
+
" \n",
|
166 |
+
"#get_similarities(lvis, partial(img_preprocess, center_context=0.5));\n",
|
167 |
+
"get_similarities(lvis, lambda x: x[1]);"
|
168 |
+
]
|
169 |
+
},
|
170 |
+
{
|
171 |
+
"cell_type": "code",
|
172 |
+
"execution_count": null,
|
173 |
+
"metadata": {},
|
174 |
+
"outputs": [],
|
175 |
+
"source": [
|
176 |
+
"preprocessing_functions = [\n",
|
177 |
+
"# ['clip mask CLS L11', lambda x: {'x_inp': x[1].cuda(), 'mask': (11, 'cls_token', x[2].cuda())}],\n",
|
178 |
+
"# ['clip mask CLS all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'cls_token', x[2].cuda())}],\n",
|
179 |
+
"# ['clip mask all all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'all', x[2].cuda())}],\n",
|
180 |
+
"# ['colorize object red', partial(img_preprocess, colorize=True)],\n",
|
181 |
+
"# ['add red outline', partial(img_preprocess, outline=True)],\n",
|
182 |
+
" \n",
|
183 |
+
"# ['BG brightness 50%', partial(img_preprocess, bg_fac=0.5)],\n",
|
184 |
+
"# ['BG brightness 10%', partial(img_preprocess, bg_fac=0.1)],\n",
|
185 |
+
"# ['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)],\n",
|
186 |
+
"# ['BG blur', partial(img_preprocess, blur=3)],\n",
|
187 |
+
"# ['BG blur & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n",
|
188 |
+
" \n",
|
189 |
+
"# ['crop large context', partial(img_preprocess, center_context=0.5)],\n",
|
190 |
+
"# ['crop small context', partial(img_preprocess, center_context=0.1)],\n",
|
191 |
+
" ['crop & background blur', partial(img_preprocess, blur=3, center_context=0.5)],\n",
|
192 |
+
" ['crop & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n",
|
193 |
+
"# ['crop & background blur & intensity 10%', partial(img_preprocess, blur=3, center_context=0.1, bg_fac=0.1)],\n",
|
194 |
+
"]\n",
|
195 |
+
"\n",
|
196 |
+
"preprocessing_functions = preprocessing_functions\n",
|
197 |
+
"\n",
|
198 |
+
"base, base_p = get_similarities(lvis, lambda x: x[1])\n",
|
199 |
+
"outs = [get_similarities(lvis, fun) for _, fun in preprocessing_functions]"
|
200 |
+
]
|
201 |
+
},
|
202 |
+
{
|
203 |
+
"cell_type": "code",
|
204 |
+
"execution_count": null,
|
205 |
+
"metadata": {},
|
206 |
+
"outputs": [],
|
207 |
+
"source": [
|
208 |
+
"outs2 = [get_similarities(lvis, fun) for _, fun in [['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)]]]"
|
209 |
+
]
|
210 |
+
},
|
211 |
+
{
|
212 |
+
"cell_type": "code",
|
213 |
+
"execution_count": null,
|
214 |
+
"metadata": {},
|
215 |
+
"outputs": [],
|
216 |
+
"source": [
|
217 |
+
"for j in range(1):\n",
|
218 |
+
" print(np.mean([outs2[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3]))"
|
219 |
+
]
|
220 |
+
},
|
221 |
+
{
|
222 |
+
"cell_type": "code",
|
223 |
+
"execution_count": null,
|
224 |
+
"metadata": {},
|
225 |
+
"outputs": [],
|
226 |
+
"source": [
|
227 |
+
"from pandas import DataFrame\n",
|
228 |
+
"tab = dict()\n",
|
229 |
+
"for j, (name, _) in enumerate(preprocessing_functions):\n",
|
230 |
+
" tab[name] = np.mean([outs[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3])\n",
|
231 |
+
" \n",
|
232 |
+
" \n",
|
233 |
+
"print('\\n'.join(f'{k} & {v*100:.2f} \\\\\\\\' for k,v in tab.items())) "
|
234 |
+
]
|
235 |
+
},
|
236 |
+
{
|
237 |
+
"cell_type": "markdown",
|
238 |
+
"metadata": {},
|
239 |
+
"source": [
|
240 |
+
"# Visual"
|
241 |
+
]
|
242 |
+
},
|
243 |
+
{
|
244 |
+
"cell_type": "code",
|
245 |
+
"execution_count": null,
|
246 |
+
"metadata": {},
|
247 |
+
"outputs": [],
|
248 |
+
"source": [
|
249 |
+
"from evaluation_utils import denorm, norm"
|
250 |
+
]
|
251 |
+
},
|
252 |
+
{
|
253 |
+
"cell_type": "code",
|
254 |
+
"execution_count": null,
|
255 |
+
"metadata": {},
|
256 |
+
"outputs": [],
|
257 |
+
"source": [
|
258 |
+
"def load_sample(filename, filename2):\n",
|
259 |
+
" from os.path import join\n",
|
260 |
+
" bp = expanduser('~/cloud/resources/sample_images')\n",
|
261 |
+
" tf = transforms.Compose([\n",
|
262 |
+
" transforms.ToTensor(),\n",
|
263 |
+
" transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
|
264 |
+
" transforms.Resize(224),\n",
|
265 |
+
" transforms.CenterCrop(224)\n",
|
266 |
+
" ])\n",
|
267 |
+
" tf2 = transforms.Compose([\n",
|
268 |
+
" transforms.ToTensor(),\n",
|
269 |
+
" transforms.Resize(224),\n",
|
270 |
+
" transforms.CenterCrop(224)\n",
|
271 |
+
" ])\n",
|
272 |
+
" inp1 = [None, tf(Image.open(join(bp, filename))), tf2(Image.open(join(bp, filename2)))]\n",
|
273 |
+
" inp1[1] = inp1[1].unsqueeze(0)\n",
|
274 |
+
" inp1[2] = inp1[2][:1] \n",
|
275 |
+
" return inp1\n",
|
276 |
+
"\n",
|
277 |
+
"def all_preprocessing(inp1):\n",
|
278 |
+
" return [\n",
|
279 |
+
" img_preprocess(inp1),\n",
|
280 |
+
" img_preprocess(inp1, colorize=True),\n",
|
281 |
+
" img_preprocess(inp1, outline=True), \n",
|
282 |
+
" img_preprocess(inp1, blur=3),\n",
|
283 |
+
" img_preprocess(inp1, bg_fac=0.1),\n",
|
284 |
+
" #img_preprocess(inp1, bg_fac=0.5),\n",
|
285 |
+
" #img_preprocess(inp1, blur=3, bg_fac=0.5), \n",
|
286 |
+
" img_preprocess(inp1, blur=3, bg_fac=0.5, center_context=0.5),\n",
|
287 |
+
" ]\n",
|
288 |
+
"\n"
|
289 |
+
]
|
290 |
+
},
|
291 |
+
{
|
292 |
+
"cell_type": "code",
|
293 |
+
"execution_count": null,
|
294 |
+
"metadata": {},
|
295 |
+
"outputs": [],
|
296 |
+
"source": [
|
297 |
+
"from torchvision import transforms\n",
|
298 |
+
"from PIL import Image\n",
|
299 |
+
"from matplotlib import pyplot as plt\n",
|
300 |
+
"from evaluation_utils import img_preprocess\n",
|
301 |
+
"import clip\n",
|
302 |
+
"\n",
|
303 |
+
"images_queries = [\n",
|
304 |
+
" [load_sample('things1.jpg', 'things1_jar.png'), ['jug', 'knife', 'car', 'animal', 'sieve', 'nothing']],\n",
|
305 |
+
" [load_sample('own_photos/IMG_2017s_square.jpg', 'own_photos/IMG_2017s_square_trash_can.png'), ['trash bin', 'house', 'car', 'bike', 'window', 'nothing']],\n",
|
306 |
+
"]\n",
|
307 |
+
"\n",
|
308 |
+
"\n",
|
309 |
+
"_, ax = plt.subplots(2 * len(images_queries), 6, figsize=(14, 4.5 * len(images_queries)))\n",
|
310 |
+
"\n",
|
311 |
+
"for j, (images, objects) in enumerate(images_queries):\n",
|
312 |
+
" \n",
|
313 |
+
" joint_image = all_preprocessing(images)\n",
|
314 |
+
" \n",
|
315 |
+
" joint_image = torch.stack(joint_image)[:,0]\n",
|
316 |
+
" clip_model, preprocess = clip.load(\"ViT-B/16\", device='cpu')\n",
|
317 |
+
" image_features = clip_model.encode_image(joint_image)\n",
|
318 |
+
" image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n",
|
319 |
+
" \n",
|
320 |
+
" prompts = [f'a photo of a {obj}'for obj in objects]\n",
|
321 |
+
" text_cond = clip_model.encode_text(clip.tokenize(prompts))\n",
|
322 |
+
" text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True)\n",
|
323 |
+
" logits = clip_model.logit_scale.exp() * image_features @ text_cond.T\n",
|
324 |
+
" sim = torch.softmax(logits, dim=-1).detach().cpu()\n",
|
325 |
+
"\n",
|
326 |
+
" for i, img in enumerate(joint_image):\n",
|
327 |
+
" ax[2*j, i].axis('off')\n",
|
328 |
+
" \n",
|
329 |
+
" ax[2*j, i].imshow(torch.clamp(denorm(joint_image[i]).permute(1,2,0), 0, 1))\n",
|
330 |
+
" ax[2*j+ 1, i].grid(True)\n",
|
331 |
+
" \n",
|
332 |
+
" ax[2*j + 1, i].set_ylim(0,1)\n",
|
333 |
+
" ax[2*j + 1, i].set_yticklabels([])\n",
|
334 |
+
" ax[2*j + 1, i].set_xticks([]) # set_xticks(range(len(prompts)))\n",
|
335 |
+
"# ax[1, i].set_xticklabels(objects, rotation=90)\n",
|
336 |
+
" for k in range(len(sim[i])):\n",
|
337 |
+
" ax[2*j + 1, i].bar(k, sim[i][k], color=plt.cm.tab20(1) if k!=0 else plt.cm.tab20(3))\n",
|
338 |
+
" ax[2*j + 1, i].text(k, 0.07, objects[k], rotation=90, ha='center', fontsize=15)\n",
|
339 |
+
"\n",
|
340 |
+
"plt.tight_layout()\n",
|
341 |
+
"plt.savefig('figures/prompt_engineering.pdf', bbox_inches='tight')"
|
342 |
+
]
|
343 |
+
}
|
344 |
+
],
|
345 |
+
"metadata": {
|
346 |
+
"kernelspec": {
|
347 |
+
"display_name": "env2",
|
348 |
+
"language": "python",
|
349 |
+
"name": "env2"
|
350 |
+
},
|
351 |
+
"language_info": {
|
352 |
+
"codemirror_mode": {
|
353 |
+
"name": "ipython",
|
354 |
+
"version": 3
|
355 |
+
},
|
356 |
+
"file_extension": ".py",
|
357 |
+
"mimetype": "text/x-python",
|
358 |
+
"name": "python",
|
359 |
+
"nbconvert_exporter": "python",
|
360 |
+
"pygments_lexer": "ipython3",
|
361 |
+
"version": "3.8.8"
|
362 |
+
}
|
363 |
+
},
|
364 |
+
"nbformat": 4,
|
365 |
+
"nbformat_minor": 4
|
366 |
+
}
|
clipseg/datasets/coco_wrapper.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
from types import new_class
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
|
8 |
+
from os.path import join, dirname, isdir, isfile, expanduser, realpath, basename
|
9 |
+
from random import shuffle, seed as set_seed
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
from itertools import combinations
|
13 |
+
from torchvision import transforms
|
14 |
+
from torchvision.transforms.transforms import Resize
|
15 |
+
|
16 |
+
from datasets.utils import blend_image_segmentation
|
17 |
+
from general_utils import get_from_repository
|
18 |
+
|
19 |
+
COCO_CLASSES = {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'}
|
20 |
+
|
21 |
+
class COCOWrapper(object):
|
22 |
+
|
23 |
+
def __init__(self, split, fold=0, image_size=400, aug=None, mask='separate', negative_prob=0,
|
24 |
+
with_class_label=False):
|
25 |
+
super().__init__()
|
26 |
+
|
27 |
+
self.mask = mask
|
28 |
+
self.with_class_label = with_class_label
|
29 |
+
self.negative_prob = negative_prob
|
30 |
+
|
31 |
+
from third_party.hsnet.data.coco import DatasetCOCO
|
32 |
+
|
33 |
+
get_from_repository('COCO-20i', ['COCO-20i.tar'])
|
34 |
+
|
35 |
+
foldpath = join(dirname(__file__), '../third_party/hsnet/data/splits/coco/%s/fold%d.pkl')
|
36 |
+
|
37 |
+
def build_img_metadata_classwise(self):
|
38 |
+
with open(foldpath % (self.split, self.fold), 'rb') as f:
|
39 |
+
img_metadata_classwise = pickle.load(f)
|
40 |
+
return img_metadata_classwise
|
41 |
+
|
42 |
+
|
43 |
+
DatasetCOCO.build_img_metadata_classwise = build_img_metadata_classwise
|
44 |
+
# DatasetCOCO.read_mask = read_mask
|
45 |
+
|
46 |
+
mean = [0.485, 0.456, 0.406]
|
47 |
+
std = [0.229, 0.224, 0.225]
|
48 |
+
transform = transforms.Compose([
|
49 |
+
transforms.Resize((image_size, image_size)),
|
50 |
+
transforms.ToTensor(),
|
51 |
+
transforms.Normalize(mean, std)
|
52 |
+
])
|
53 |
+
|
54 |
+
self.coco = DatasetCOCO(expanduser('~/datasets/COCO-20i/'), fold, transform, split, 1, False)
|
55 |
+
|
56 |
+
self.all_classes = [self.coco.class_ids]
|
57 |
+
self.coco.base_path = join(expanduser('~/datasets/COCO-20i'))
|
58 |
+
|
59 |
+
def __len__(self):
|
60 |
+
return len(self.coco)
|
61 |
+
|
62 |
+
def __getitem__(self, i):
|
63 |
+
sample = self.coco[i]
|
64 |
+
|
65 |
+
label_name = COCO_CLASSES[int(sample['class_id'])]
|
66 |
+
|
67 |
+
img_s, seg_s = sample['support_imgs'][0], sample['support_masks'][0]
|
68 |
+
|
69 |
+
if self.negative_prob > 0 and torch.rand(1).item() < self.negative_prob:
|
70 |
+
new_class_id = sample['class_id']
|
71 |
+
while new_class_id == sample['class_id']:
|
72 |
+
sample2 = self.coco[torch.randint(0, len(self), (1,)).item()]
|
73 |
+
new_class_id = sample2['class_id']
|
74 |
+
img_s = sample2['support_imgs'][0]
|
75 |
+
seg_s = torch.zeros_like(seg_s)
|
76 |
+
|
77 |
+
mask = self.mask
|
78 |
+
if mask == 'separate':
|
79 |
+
supp = (img_s, seg_s)
|
80 |
+
elif mask == 'text_label':
|
81 |
+
# DEPRECATED
|
82 |
+
supp = [int(sample['class_id'])]
|
83 |
+
elif mask == 'text':
|
84 |
+
supp = [label_name]
|
85 |
+
else:
|
86 |
+
if mask.startswith('text_and_'):
|
87 |
+
mask = mask[9:]
|
88 |
+
label_add = [label_name]
|
89 |
+
else:
|
90 |
+
label_add = []
|
91 |
+
|
92 |
+
supp = label_add + blend_image_segmentation(img_s, seg_s, mode=mask)
|
93 |
+
|
94 |
+
if self.with_class_label:
|
95 |
+
label = (torch.zeros(0), sample['class_id'],)
|
96 |
+
else:
|
97 |
+
label = (torch.zeros(0), )
|
98 |
+
|
99 |
+
return (sample['query_img'],) + tuple(supp), (sample['query_mask'].unsqueeze(0),) + label
|
clipseg/datasets/pascal_classes.json
ADDED
@@ -0,0 +1 @@
|
|
|
1 |
+
[{"id": 1, "synonyms": ["aeroplane"]}, {"id": 2, "synonyms": ["bicycle"]}, {"id": 3, "synonyms": ["bird"]}, {"id": 4, "synonyms": ["boat"]}, {"id": 5, "synonyms": ["bottle"]}, {"id": 6, "synonyms": ["bus"]}, {"id": 7, "synonyms": ["car"]}, {"id": 8, "synonyms": ["cat"]}, {"id": 9, "synonyms": ["chair"]}, {"id": 10, "synonyms": ["cow"]}, {"id": 11, "synonyms": ["diningtable"]}, {"id": 12, "synonyms": ["dog"]}, {"id": 13, "synonyms": ["horse"]}, {"id": 14, "synonyms": ["motorbike"]}, {"id": 15, "synonyms": ["person"]}, {"id": 16, "synonyms": ["pottedplant"]}, {"id": 17, "synonyms": ["sheep"]}, {"id": 18, "synonyms": ["sofa"]}, {"id": 19, "synonyms": ["train"]}, {"id": 20, "synonyms": ["tvmonitor"]}]
|
clipseg/datasets/pascal_zeroshot.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os.path import expanduser
|
2 |
+
import torch
|
3 |
+
import json
|
4 |
+
import torchvision
|
5 |
+
from general_utils import get_from_repository
|
6 |
+
from general_utils import log
|
7 |
+
from torchvision import transforms
|
8 |
+
|
9 |
+
PASCAL_VOC_CLASSES_ZS = [['cattle.n.01', 'motorcycle.n.01'], ['aeroplane.n.01', 'sofa.n.01'],
|
10 |
+
['cat.n.01', 'television.n.03'], ['train.n.01', 'bottle.n.01'],
|
11 |
+
['chair.n.01', 'pot_plant.n.01']]
|
12 |
+
|
13 |
+
|
14 |
+
class PascalZeroShot(object):
|
15 |
+
|
16 |
+
def __init__(self, split, n_unseen, image_size=224) -> None:
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
import sys
|
20 |
+
sys.path.append('third_party/JoEm')
|
21 |
+
from third_party.JoEm.data_loader.dataset import VOCSegmentation
|
22 |
+
from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC
|
23 |
+
|
24 |
+
self.pascal_classes = VOC
|
25 |
+
self.image_size = image_size
|
26 |
+
|
27 |
+
self.transform = transforms.Compose([
|
28 |
+
transforms.Resize((image_size, image_size)),
|
29 |
+
])
|
30 |
+
|
31 |
+
if split == 'train':
|
32 |
+
self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen),
|
33 |
+
split=split, transform=True, transform_args=dict(base_size=312, crop_size=312),
|
34 |
+
ignore_bg=False, ignore_unseen=False, remv_unseen_img=True)
|
35 |
+
elif split == 'val':
|
36 |
+
self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen),
|
37 |
+
split=split, transform=False,
|
38 |
+
ignore_bg=False, ignore_unseen=False)
|
39 |
+
|
40 |
+
self.unseen_idx = get_unseen_idx(n_unseen)
|
41 |
+
|
42 |
+
def __len__(self):
|
43 |
+
return len(self.voc)
|
44 |
+
|
45 |
+
def __getitem__(self, i):
|
46 |
+
|
47 |
+
sample = self.voc[i]
|
48 |
+
label = sample['label'].long()
|
49 |
+
all_labels = [l for l in torch.where(torch.bincount(label.flatten())>0)[0].numpy().tolist() if l != 255]
|
50 |
+
class_indices = [l for l in all_labels]
|
51 |
+
class_names = [self.pascal_classes[l] for l in all_labels]
|
52 |
+
|
53 |
+
image = self.transform(sample['image'])
|
54 |
+
|
55 |
+
label = transforms.Resize((self.image_size, self.image_size),
|
56 |
+
interpolation=torchvision.transforms.InterpolationMode.NEAREST)(label.unsqueeze(0))[0]
|
57 |
+
|
58 |
+
return (image,), (label, )
|
59 |
+
|
60 |
+
|
clipseg/datasets/pfe_dataset.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os.path import expanduser
|
2 |
+
import torch
|
3 |
+
import json
|
4 |
+
from general_utils import get_from_repository
|
5 |
+
from datasets.lvis_oneshot3 import blend_image_segmentation
|
6 |
+
from general_utils import log
|
7 |
+
|
8 |
+
PASCAL_CLASSES = {a['id']: a['synonyms'] for a in json.load(open('datasets/pascal_classes.json'))}
|
9 |
+
|
10 |
+
|
11 |
+
class PFEPascalWrapper(object):
|
12 |
+
|
13 |
+
def __init__(self, mode, split, mask='separate', image_size=473, label_support=None, size=None, p_negative=0, aug=None):
|
14 |
+
import sys
|
15 |
+
# sys.path.append(expanduser('~/projects/new_one_shot'))
|
16 |
+
from third_party.PFENet.util.dataset import SemData
|
17 |
+
|
18 |
+
get_from_repository('PascalVOC2012', ['Pascal5i.tar'])
|
19 |
+
|
20 |
+
self.p_negative = p_negative
|
21 |
+
self.size = size
|
22 |
+
self.mode = mode
|
23 |
+
self.image_size = image_size
|
24 |
+
|
25 |
+
if label_support in {True, False}:
|
26 |
+
log.warning('label_support argument is deprecated. Use mask instead.')
|
27 |
+
#raise ValueError()
|
28 |
+
|
29 |
+
self.mask = mask
|
30 |
+
|
31 |
+
value_scale = 255
|
32 |
+
mean = [0.485, 0.456, 0.406]
|
33 |
+
mean = [item * value_scale for item in mean]
|
34 |
+
std = [0.229, 0.224, 0.225]
|
35 |
+
std = [item * value_scale for item in std]
|
36 |
+
|
37 |
+
import third_party.PFENet.util.transform as transform
|
38 |
+
|
39 |
+
if mode == 'val':
|
40 |
+
data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/val.txt')
|
41 |
+
|
42 |
+
data_transform = [transform.test_Resize(size=image_size)] if image_size != 'original' else []
|
43 |
+
data_transform += [
|
44 |
+
transform.ToTensor(),
|
45 |
+
transform.Normalize(mean=mean, std=std)
|
46 |
+
]
|
47 |
+
|
48 |
+
|
49 |
+
elif mode == 'train':
|
50 |
+
data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/voc_sbd_merge_noduplicate.txt')
|
51 |
+
|
52 |
+
assert image_size != 'original'
|
53 |
+
|
54 |
+
data_transform = [
|
55 |
+
transform.RandScale([0.9, 1.1]),
|
56 |
+
transform.RandRotate([-10, 10], padding=mean, ignore_label=255),
|
57 |
+
transform.RandomGaussianBlur(),
|
58 |
+
transform.RandomHorizontalFlip(),
|
59 |
+
transform.Crop((image_size, image_size), crop_type='rand', padding=mean, ignore_label=255),
|
60 |
+
transform.ToTensor(),
|
61 |
+
transform.Normalize(mean=mean, std=std)
|
62 |
+
]
|
63 |
+
|
64 |
+
data_transform = transform.Compose(data_transform)
|
65 |
+
|
66 |
+
self.dataset = SemData(split=split, mode=mode, data_root=expanduser('~/datasets/PascalVOC2012/VOC2012'),
|
67 |
+
data_list=data_list, shot=1, transform=data_transform, use_coco=False, use_split_coco=False)
|
68 |
+
|
69 |
+
self.class_list = self.dataset.sub_val_list if mode == 'val' else self.dataset.sub_list
|
70 |
+
|
71 |
+
# verify that subcls_list always has length 1
|
72 |
+
# assert len(set([len(d[4]) for d in self.dataset])) == 1
|
73 |
+
|
74 |
+
print('actual length', len(self.dataset.data_list))
|
75 |
+
|
76 |
+
def __len__(self):
|
77 |
+
if self.mode == 'val':
|
78 |
+
return len(self.dataset.data_list)
|
79 |
+
else:
|
80 |
+
return len(self.dataset.data_list)
|
81 |
+
|
82 |
+
def __getitem__(self, index):
|
83 |
+
if self.dataset.mode == 'train':
|
84 |
+
image, label, s_x, s_y, subcls_list = self.dataset[index % len(self.dataset.data_list)]
|
85 |
+
elif self.dataset.mode == 'val':
|
86 |
+
image, label, s_x, s_y, subcls_list, ori_label = self.dataset[index % len(self.dataset.data_list)]
|
87 |
+
ori_label = torch.from_numpy(ori_label).unsqueeze(0)
|
88 |
+
|
89 |
+
if self.image_size != 'original':
|
90 |
+
longerside = max(ori_label.size(1), ori_label.size(2))
|
91 |
+
backmask = torch.ones(ori_label.size(0), longerside, longerside).cuda()*255
|
92 |
+
backmask[0, :ori_label.size(1), :ori_label.size(2)] = ori_label
|
93 |
+
label = backmask.clone().long()
|
94 |
+
else:
|
95 |
+
label = label.unsqueeze(0)
|
96 |
+
|
97 |
+
# assert label.shape == (473, 473)
|
98 |
+
|
99 |
+
if self.p_negative > 0:
|
100 |
+
if torch.rand(1).item() < self.p_negative:
|
101 |
+
while True:
|
102 |
+
idx = torch.randint(0, len(self.dataset.data_list), (1,)).item()
|
103 |
+
_, _, s_x, s_y, subcls_list_tmp, _ = self.dataset[idx]
|
104 |
+
if subcls_list[0] != subcls_list_tmp[0]:
|
105 |
+
break
|
106 |
+
|
107 |
+
s_x = s_x[0]
|
108 |
+
s_y = (s_y == 1)[0]
|
109 |
+
label_fg = (label == 1).float()
|
110 |
+
val_mask = (label != 255).float()
|
111 |
+
|
112 |
+
class_id = self.class_list[subcls_list[0]]
|
113 |
+
|
114 |
+
label_name = PASCAL_CLASSES[class_id][0]
|
115 |
+
label_add = ()
|
116 |
+
mask = self.mask
|
117 |
+
|
118 |
+
if mask == 'text':
|
119 |
+
support = ('a photo of a ' + label_name + '.',)
|
120 |
+
elif mask == 'separate':
|
121 |
+
support = (s_x, s_y)
|
122 |
+
else:
|
123 |
+
if mask.startswith('text_and_'):
|
124 |
+
label_add = (label_name,)
|
125 |
+
mask = mask[9:]
|
126 |
+
|
127 |
+
support = (blend_image_segmentation(s_x, s_y.float(), mask)[0],)
|
128 |
+
|
129 |
+
return (image,) + label_add + support, (label_fg.unsqueeze(0), val_mask.unsqueeze(0), subcls_list[0])
|
clipseg/datasets/phrasecut.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|