apolinario commited on
Commit
48fa639
1 Parent(s): bdc1819

upload clipseg

Browse files
clipseg/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ This license does not apply to the model weights.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
clipseg/Quickstart.ipynb ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "import requests\n",
11
+ "\n",
12
+ "! wget https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -O weights.zip\n",
13
+ "! unzip -d weights -j weights.zip\n",
14
+ "from models.clipseg import CLIPDensePredT\n",
15
+ "from PIL import Image\n",
16
+ "from torchvision import transforms\n",
17
+ "from matplotlib import pyplot as plt\n",
18
+ "\n",
19
+ "# load model\n",
20
+ "model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)\n",
21
+ "model.eval();\n",
22
+ "\n",
23
+ "# non-strict, because we only stored decoder weights (not CLIP weights)\n",
24
+ "model.load_state_dict(torch.load('weights/rd64-uni.pth', map_location=torch.device('cpu')), strict=False);"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "markdown",
29
+ "metadata": {},
30
+ "source": [
31
+ "Load and normalize `example_image.jpg`. You can also load through an URL."
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "# load and normalize image\n",
41
+ "input_image = Image.open('example_image.jpg')\n",
42
+ "\n",
43
+ "# or load from URL...\n",
44
+ "# image_url = 'https://farm5.staticflickr.com/4141/4856248695_03475782dc_z.jpg'\n",
45
+ "# input_image = Image.open(requests.get(image_url, stream=True).raw)\n",
46
+ "\n",
47
+ "transform = transforms.Compose([\n",
48
+ " transforms.ToTensor(),\n",
49
+ " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
50
+ " transforms.Resize((352, 352)),\n",
51
+ "])\n",
52
+ "img = transform(input_image).unsqueeze(0)"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "markdown",
57
+ "metadata": {},
58
+ "source": [
59
+ "Predict and visualize (this might take a few seconds if running without GPU support)"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": null,
65
+ "metadata": {},
66
+ "outputs": [],
67
+ "source": [
68
+ "prompts = ['a glass', 'something to fill', 'wood', 'a jar']\n",
69
+ "\n",
70
+ "# predict\n",
71
+ "with torch.no_grad():\n",
72
+ " preds = model(img.repeat(4,1,1,1), prompts)[0]\n",
73
+ "\n",
74
+ "# visualize prediction\n",
75
+ "_, ax = plt.subplots(1, 5, figsize=(15, 4))\n",
76
+ "[a.axis('off') for a in ax.flatten()]\n",
77
+ "ax[0].imshow(input_image)\n",
78
+ "[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(4)];\n",
79
+ "[ax[i+1].text(0, -15, prompts[i]) for i in range(4)];"
80
+ ]
81
+ }
82
+ ],
83
+ "metadata": {
84
+ "interpreter": {
85
+ "hash": "800ed241f7db2bd3aa6942aa3be6809cdb30ee6b0a9e773dfecfa9fef1f4c586"
86
+ },
87
+ "kernelspec": {
88
+ "display_name": "Python 3",
89
+ "language": "python",
90
+ "name": "python3"
91
+ },
92
+ "language_info": {
93
+ "codemirror_mode": {
94
+ "name": "ipython",
95
+ "version": 3
96
+ },
97
+ "file_extension": ".py",
98
+ "mimetype": "text/x-python",
99
+ "name": "python",
100
+ "nbconvert_exporter": "python",
101
+ "pygments_lexer": "ipython3",
102
+ "version": "3.8.10"
103
+ }
104
+ },
105
+ "nbformat": 4,
106
+ "nbformat_minor": 4
107
+ }
clipseg/Readme.md ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Image Segmentation Using Text and Image Prompts
2
+ This repository contains the code used in the paper ["Image Segmentation Using Text and Image Prompts"](https://arxiv.org/abs/2112.10003).
3
+
4
+ **The Paper has been accepted to CVPR 2022!**
5
+
6
+ <img src="overview.png" alt="drawing" height="200em"/>
7
+
8
+ The systems allows to create segmentation models without training based on:
9
+ - An arbitrary text query
10
+ - Or an image with a mask highlighting stuff or an object.
11
+
12
+ ### Quick Start
13
+
14
+ In the `Quickstart.ipynb` notebook we provide the code for using a pre-trained CLIPSeg model. If you run the notebook locally, make sure you downloaded the `rd64-uni.pth` weights, either manually or via git lfs extension.
15
+ It can also be used interactively using [MyBinder](https://mybinder.org/v2/gh/timojl/clipseg/HEAD?labpath=Quickstart.ipynb)
16
+ (please note that the VM does not use a GPU, thus inference takes a few seconds).
17
+
18
+
19
+ ### Dependencies
20
+ This code base depends on pytorch, torchvision and clip (`pip install git+https://github.com/openai/CLIP.git`).
21
+ Additional dependencies are hidden for double blind review.
22
+
23
+
24
+ ### Datasets
25
+
26
+ * `PhraseCut` and `PhraseCutPlus`: Referring expression dataset
27
+ * `PFEPascalWrapper`: Wrapper class for PFENet's Pascal-5i implementation
28
+ * `PascalZeroShot`: Wrapper class for PascalZeroShot
29
+ * `COCOWrapper`: Wrapper class for COCO.
30
+
31
+ ### Models
32
+
33
+ * `CLIPDensePredT`: CLIPSeg model with transformer-based decoder.
34
+ * `ViTDensePredT`: CLIPSeg model with transformer-based decoder.
35
+
36
+ ### Third Party Dependencies
37
+ For some of the datasets third party dependencies are required. Run the following commands in the `third_party` folder.
38
+ ```bash
39
+ git clone https://github.com/cvlab-yonsei/JoEm
40
+ git clone https://github.com/Jia-Research-Lab/PFENet.git
41
+ git clone https://github.com/ChenyunWu/PhraseCutDataset.git
42
+ git clone https://github.com/juhongm999/hsnet.git
43
+ ```
44
+
45
+ ### Weights
46
+
47
+ The MIT license does not apply to these weights.
48
+
49
+ We provide two model weights, for D=64 (4.1MB) and D=16 (1.1MB).
50
+ ```
51
+ wget https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -O weights.zip
52
+ unzip -d weights -j weights.zip
53
+ ```
54
+
55
+
56
+ ### Training and Evaluation
57
+
58
+ To train use the `training.py` script with experiment file and experiment id parameters. E.g. `python training.py phrasecut.yaml 0` will train the first phrasecut experiment which is defined by the `configuration` and first `individual_configurations` parameters. Model weights will be written in `logs/`.
59
+
60
+ For evaluation use `score.py`. E.g. `python score.py phrasecut.yaml 0 0` will train the first phrasecut experiment of `test_configuration` and the first configuration in `individual_configurations`.
61
+
62
+
63
+ ### Usage of PFENet Wrappers
64
+
65
+ In order to use the dataset and model wrappers for PFENet, the PFENet repository needs to be cloned to the root folder.
66
+ `git clone https://github.com/Jia-Research-Lab/PFENet.git `
67
+
68
+
69
+ ### License
70
+
71
+ The source code files in this repository (excluding model weights) are released under MIT license.
72
+
73
+ ### Citation
74
+ ```
75
+ @InProceedings{lueddecke22_cvpr,
76
+ author = {L\"uddecke, Timo and Ecker, Alexander},
77
+ title = {Image Segmentation Using Text and Image Prompts},
78
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
79
+ month = {June},
80
+ year = {2022},
81
+ pages = {7086-7096}
82
+ }
83
+
84
+ ```
clipseg/Tables.ipynb ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "%load_ext autoreload\n",
10
+ "%autoreload 2\n",
11
+ "\n",
12
+ "import clip\n",
13
+ "from evaluation_utils import norm, denorm\n",
14
+ "from general_utils import *\n",
15
+ "from datasets.lvis_oneshot3 import LVIS_OneShot3, LVIS_OneShot"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "markdown",
20
+ "metadata": {},
21
+ "source": [
22
+ "# PhraseCut"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "pc = experiment('experiments/phrasecut.yaml', nums=':6').dataframe()"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "tab1 = pc[['name', 'pc_miou_best', 'pc_fgiou_best', 'pc_ap']]"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "cols = ['pc_miou_0.3', 'pc_fgiou_0.3', 'pc_ap']\n",
50
+ "tab1 = pc[['name'] + cols]\n",
51
+ "for k in cols:\n",
52
+ " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n",
53
+ "tab1.loc[:, 'name'] = ['CLIPSeg (PC+)', 'CLIPSeg (PC, $D=128$)', 'CLIPSeg (PC)', 'CLIP-Deconv', 'ViTSeg (PC+)', 'ViTSeg (PC)']\n",
54
+ "tab1.insert(1, 't', [0.3]*tab1.shape[0])\n",
55
+ "print(tab1.to_latex(header=False, index=False))"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "markdown",
60
+ "metadata": {},
61
+ "source": [
62
+ "For 0.1 threshold"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": [
71
+ "cols = ['pc_miou_0.1', 'pc_fgiou_0.1', 'pc_ap']\n",
72
+ "tab1 = pc[['name'] + cols]\n",
73
+ "for k in cols:\n",
74
+ " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n",
75
+ "tab1.loc[:, 'name'] = ['CLIPSeg (PC+)', 'CLIPSeg (PC, $D=128$)', 'CLIPSeg (PC)', 'CLIP-Deconv', 'ViTSeg (PC+)', 'ViTSeg (PC)']\n",
76
+ "tab1.insert(1, 't', [0.1]*tab1.shape[0])\n",
77
+ "print(tab1.to_latex(header=False, index=False))"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "markdown",
82
+ "metadata": {},
83
+ "source": [
84
+ "# One-shot"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "markdown",
89
+ "metadata": {},
90
+ "source": [
91
+ "### Pascal"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": null,
97
+ "metadata": {},
98
+ "outputs": [],
99
+ "source": [
100
+ "pas = experiment('experiments/pascal_1shot.yaml', nums=':19').dataframe()"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": null,
106
+ "metadata": {},
107
+ "outputs": [],
108
+ "source": [
109
+ "pas[['name', 'pas_h2_miou_0.3', 'pas_h2_biniou_0.3', 'pas_h2_ap', 'pas_h2_fgiou_ct']]"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "execution_count": null,
115
+ "metadata": {},
116
+ "outputs": [],
117
+ "source": [
118
+ "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n",
119
+ "tab1 = pas[['pas_h2_miou_0.3', 'pas_h2_biniou_0.3', 'pas_h2_ap']]\n",
120
+ "print('CLIPSeg (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
121
+ "print('CLIPSeg (PC) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
122
+ "\n",
123
+ "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n",
124
+ "tab1 = pas[['pas_h2_miou_0.2', 'pas_h2_biniou_0.2', 'pas_h2_ap']]\n",
125
+ "print('CLIP-Deconv (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
126
+ "\n",
127
+ "pas = experiment('experiments/pascal_1shot.yaml', nums='16:20').dataframe()\n",
128
+ "tab1 = pas[['pas_t_miou_0.2', 'pas_t_biniou_0.2', 'pas_t_ap']]\n",
129
+ "print('ViTSeg (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "markdown",
134
+ "metadata": {},
135
+ "source": [
136
+ "#### Pascal Zero-shot (in one-shot setting)\n",
137
+ "\n",
138
+ "Using the same setting as one-shot (hence different from the other zero-shot benchmark)"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": null,
144
+ "metadata": {},
145
+ "outputs": [],
146
+ "source": [
147
+ "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n",
148
+ "tab1 = pas[['pas_t_miou_0.3', 'pas_t_biniou_0.3', 'pas_t_ap']]\n",
149
+ "print('CLIPSeg (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
150
+ "print('CLIPSeg (PC) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
151
+ "\n",
152
+ "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n",
153
+ "tab1 = pas[['pas_t_miou_0.3', 'pas_t_biniou_0.3', 'pas_t_ap']]\n",
154
+ "print('CLIP-Deconv (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
155
+ "\n",
156
+ "pas = experiment('experiments/pascal_1shot.yaml', nums='16:20').dataframe()\n",
157
+ "tab1 = pas[['pas_t_miou_0.2', 'pas_t_biniou_0.2', 'pas_t_ap']]\n",
158
+ "print('ViTSeg (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": null,
164
+ "metadata": {},
165
+ "outputs": [],
166
+ "source": [
167
+ "# without fixed thresholds...\n",
168
+ "\n",
169
+ "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n",
170
+ "tab1 = pas[['pas_t_best_miou', 'pas_t_best_biniou', 'pas_t_ap']]\n",
171
+ "print('CLIPSeg (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
172
+ "print('CLIPSeg (PC) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
173
+ "\n",
174
+ "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n",
175
+ "tab1 = pas[['pas_t_best_miou', 'pas_t_best_biniou', 'pas_t_ap']]\n",
176
+ "print('CLIP-Deconv (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "markdown",
181
+ "metadata": {},
182
+ "source": [
183
+ "### COCO"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "code",
188
+ "execution_count": null,
189
+ "metadata": {},
190
+ "outputs": [],
191
+ "source": [
192
+ "coco = experiment('experiments/coco.yaml', nums=':29').dataframe()"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "execution_count": null,
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "tab1 = coco[['coco_h2_miou_0.1', 'coco_h2_biniou_0.1', 'coco_h2_ap']]\n",
202
+ "tab2 = coco[['coco_h2_miou_0.2', 'coco_h2_biniou_0.2', 'coco_h2_ap']]\n",
203
+ "tab3 = coco[['coco_h2_miou_best', 'coco_h2_biniou_best', 'coco_h2_ap']]\n",
204
+ "print('CLIPSeg (COCO) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[:4].mean(0).values), '\\\\\\\\')\n",
205
+ "print('CLIPSeg (COCO+N) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
206
+ "print('CLIP-Deconv (COCO+N) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[12:16].mean(0).values), '\\\\\\\\')\n",
207
+ "print('ViTSeg (COCO) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[8:12].mean(0).values), '\\\\\\\\')"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "markdown",
212
+ "metadata": {},
213
+ "source": [
214
+ "# Zero-shot"
215
+ ]
216
+ },
217
+ {
218
+ "cell_type": "code",
219
+ "execution_count": null,
220
+ "metadata": {},
221
+ "outputs": [],
222
+ "source": [
223
+ "zs = experiment('experiments/pascal_0shot.yaml', nums=':11').dataframe()"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "code",
228
+ "execution_count": null,
229
+ "metadata": {},
230
+ "outputs": [],
231
+ "source": [
232
+ "\n",
233
+ "tab1 = zs[['pas_zs_seen', 'pas_zs_unseen']]\n",
234
+ "print('CLIPSeg (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[8:9].values[0].tolist() + tab1[10:11].values[0].tolist()), '\\\\\\\\')\n",
235
+ "print('CLIP-Deconv & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[2:3].values[0].tolist() + tab1[3:4].values[0].tolist()), '\\\\\\\\')\n",
236
+ "print('ViTSeg & ImageNet-1K & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:5].values[0].tolist() + tab1[5:6].values[0].tolist()), '\\\\\\\\')"
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "markdown",
241
+ "metadata": {},
242
+ "source": [
243
+ "# Ablation"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "code",
248
+ "execution_count": null,
249
+ "metadata": {},
250
+ "outputs": [],
251
+ "source": [
252
+ "ablation = experiment('experiments/ablation.yaml', nums=':8').dataframe()"
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "code",
257
+ "execution_count": null,
258
+ "metadata": {},
259
+ "outputs": [],
260
+ "source": [
261
+ "tab1 = ablation[['name', 'pc_miou_best', 'pc_ap', 'pc-vis_miou_best', 'pc-vis_ap']]\n",
262
+ "for k in ['pc_miou_best', 'pc_ap', 'pc-vis_miou_best', 'pc-vis_ap']:\n",
263
+ " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n",
264
+ "tab1.loc[:, 'name'] = ['CLIPSeg', 'no CLIP pre-training', 'no-negatives', '50% negatives', 'no visual', '$D=16$', 'only layer 3', 'highlight mask']"
265
+ ]
266
+ },
267
+ {
268
+ "cell_type": "code",
269
+ "execution_count": null,
270
+ "metadata": {},
271
+ "outputs": [],
272
+ "source": [
273
+ "print(tab1.loc[[0,1,4,5,6,7],:].to_latex(header=False, index=False))"
274
+ ]
275
+ },
276
+ {
277
+ "cell_type": "code",
278
+ "execution_count": null,
279
+ "metadata": {},
280
+ "outputs": [],
281
+ "source": [
282
+ "print(tab1.loc[[0,1,4,5,6,7],:].to_latex(header=False, index=False))"
283
+ ]
284
+ },
285
+ {
286
+ "cell_type": "markdown",
287
+ "metadata": {},
288
+ "source": [
289
+ "# Generalization"
290
+ ]
291
+ },
292
+ {
293
+ "cell_type": "code",
294
+ "execution_count": null,
295
+ "metadata": {},
296
+ "outputs": [],
297
+ "source": [
298
+ "generalization = experiment('experiments/generalize.yaml').dataframe()"
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "execution_count": null,
304
+ "metadata": {},
305
+ "outputs": [],
306
+ "source": [
307
+ "gen = generalization[['aff_best_fgiou', 'aff_ap', 'ability_best_fgiou', 'ability_ap', 'part_best_fgiou', 'part_ap']].values"
308
+ ]
309
+ },
310
+ {
311
+ "cell_type": "code",
312
+ "execution_count": null,
313
+ "metadata": {},
314
+ "outputs": [],
315
+ "source": [
316
+ "print(\n",
317
+ " 'CLIPSeg (PC+) & ' + ' & '.join(f'{x*100:.1f}' for x in gen[1]) + ' \\\\\\\\ \\n' + \\\n",
318
+ " 'CLIPSeg (LVIS) & ' + ' & '.join(f'{x*100:.1f}' for x in gen[0]) + ' \\\\\\\\ \\n' + \\\n",
319
+ " 'CLIP-Deconv & ' + ' & '.join(f'{x*100:.1f}' for x in gen[2]) + ' \\\\\\\\ \\n' + \\\n",
320
+ " 'VITSeg & ' + ' & '.join(f'{x*100:.1f}' for x in gen[3]) + ' \\\\\\\\'\n",
321
+ ")"
322
+ ]
323
+ }
324
+ ],
325
+ "metadata": {
326
+ "interpreter": {
327
+ "hash": "800ed241f7db2bd3aa6942aa3be6809cdb30ee6b0a9e773dfecfa9fef1f4c586"
328
+ },
329
+ "kernelspec": {
330
+ "display_name": "env2",
331
+ "language": "python",
332
+ "name": "env2"
333
+ },
334
+ "language_info": {
335
+ "codemirror_mode": {
336
+ "name": "ipython",
337
+ "version": 3
338
+ },
339
+ "file_extension": ".py",
340
+ "mimetype": "text/x-python",
341
+ "name": "python",
342
+ "nbconvert_exporter": "python",
343
+ "pygments_lexer": "ipython3",
344
+ "version": "3.8.8"
345
+ }
346
+ },
347
+ "nbformat": 4,
348
+ "nbformat_minor": 4
349
+ }
clipseg/Visual_Feature_Engineering.ipynb ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Systematic"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "%load_ext autoreload\n",
17
+ "%autoreload 2\n",
18
+ "\n",
19
+ "import clip\n",
20
+ "from evaluation_utils import norm, denorm\n",
21
+ "from general_utils import *\n",
22
+ "from datasets.lvis_oneshot3 import LVIS_OneShot3\n",
23
+ "\n",
24
+ "clip_device = 'cuda'\n",
25
+ "clip_model, preprocess = clip.load(\"ViT-B/16\", device=clip_device)\n",
26
+ "clip_model.eval();\n",
27
+ "\n",
28
+ "from models.clipseg import CLIPDensePredTMasked\n",
29
+ "\n",
30
+ "clip_mask_model = CLIPDensePredTMasked(version='ViT-B/16').to(clip_device)\n",
31
+ "clip_mask_model.eval();"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "lvis = LVIS_OneShot3('train_fixed', mask='separate', normalize=True, with_class_label=True, add_bar=False, \n",
41
+ " text_class_labels=True, image_size=352, min_area=0.1,\n",
42
+ " min_frac_s=0.05, min_frac_q=0.05, fix_find_crop=True)"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "plot_data(lvis)"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "from collections import defaultdict\n",
61
+ "import json\n",
62
+ "\n",
63
+ "lvis_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_train.json')))\n",
64
+ "lvis_val_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_val.json')))\n",
65
+ "\n",
66
+ "objects_per_image = defaultdict(lambda : set())\n",
67
+ "for ann in lvis_raw['annotations']:\n",
68
+ " objects_per_image[ann['image_id']].add(ann['category_id'])\n",
69
+ " \n",
70
+ "for ann in lvis_val_raw['annotations']:\n",
71
+ " objects_per_image[ann['image_id']].add(ann['category_id']) \n",
72
+ " \n",
73
+ "objects_per_image = {o: [lvis.category_names[o] for o in v] for o, v in objects_per_image.items()}\n",
74
+ "\n",
75
+ "del lvis_raw, lvis_val_raw"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": null,
81
+ "metadata": {},
82
+ "outputs": [],
83
+ "source": [
84
+ "#bs = 32\n",
85
+ "#batches = [get_batch(lvis, i*bs, (i+1)*bs, cuda=True) for i in range(10)]"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": null,
91
+ "metadata": {},
92
+ "outputs": [],
93
+ "source": [
94
+ "from general_utils import get_batch\n",
95
+ "from functools import partial\n",
96
+ "from evaluation_utils import img_preprocess\n",
97
+ "import torch\n",
98
+ "\n",
99
+ "def get_similarities(batches_or_dataset, process, mask=lambda x: None, clipmask=False):\n",
100
+ "\n",
101
+ " # base_words = [f'a photo of {x}' for x in ['a person', 'an animal', 'a knife', 'a cup']]\n",
102
+ "\n",
103
+ " all_prompts = []\n",
104
+ " \n",
105
+ " with torch.no_grad():\n",
106
+ " valid_sims = []\n",
107
+ " torch.manual_seed(571)\n",
108
+ " \n",
109
+ " if type(batches_or_dataset) == list:\n",
110
+ " loader = batches_or_dataset # already loaded\n",
111
+ " max_iter = float('inf')\n",
112
+ " else:\n",
113
+ " loader = DataLoader(batches_or_dataset, shuffle=False, batch_size=32)\n",
114
+ " max_iter = 50\n",
115
+ " \n",
116
+ " global batch\n",
117
+ " for i_batch, (batch, batch_y) in enumerate(loader):\n",
118
+ " \n",
119
+ " if i_batch >= max_iter: break\n",
120
+ " \n",
121
+ " processed_batch = process(batch)\n",
122
+ " if type(processed_batch) == dict:\n",
123
+ " \n",
124
+ " # processed_batch = {k: v.to(clip_device) for k, v in processed_batch.items()}\n",
125
+ " image_features = clip_mask_model.visual_forward(**processed_batch)[0].to(clip_device).half()\n",
126
+ " else:\n",
127
+ " processed_batch = process(batch).to(clip_device)\n",
128
+ " processed_batch = nnf.interpolate(processed_batch, (224, 224), mode='bilinear')\n",
129
+ " #image_features = clip_model.encode_image(processed_batch.to(clip_device)) \n",
130
+ " image_features = clip_mask_model.visual_forward(processed_batch)[0].to(clip_device).half()\n",
131
+ " \n",
132
+ " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n",
133
+ " bs = len(batch[0])\n",
134
+ " for j in range(bs):\n",
135
+ " \n",
136
+ " c, _, sid, qid = lvis.sample_ids[bs * i_batch + j]\n",
137
+ " support_image = basename(lvis.samples[c][sid])\n",
138
+ " \n",
139
+ " img_objs = [o for o in objects_per_image[int(support_image)]]\n",
140
+ " img_objs = [o.replace('_', ' ') for o in img_objs]\n",
141
+ " \n",
142
+ " other_words = [f'a photo of a {o.replace(\"_\", \" \")}' for o in img_objs \n",
143
+ " if o != batch_y[2][j]]\n",
144
+ " \n",
145
+ " prompts = [f'a photo of a {batch_y[2][j]}'] + other_words\n",
146
+ " all_prompts += [prompts]\n",
147
+ " \n",
148
+ " text_cond = clip_model.encode_text(clip.tokenize(prompts).to(clip_device))\n",
149
+ " text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True) \n",
150
+ "\n",
151
+ " global logits\n",
152
+ " logits = clip_model.logit_scale.exp() * image_features[j] @ text_cond.T\n",
153
+ "\n",
154
+ " global sim\n",
155
+ " sim = torch.softmax(logits, dim=-1)\n",
156
+ " \n",
157
+ " valid_sims += [sim]\n",
158
+ " \n",
159
+ " #valid_sims = torch.stack(valid_sims)\n",
160
+ " return valid_sims, all_prompts\n",
161
+ " \n",
162
+ "\n",
163
+ "def new_img_preprocess(x):\n",
164
+ " return {'x_inp': x[1], 'mask': (11, 'cls_token', x[2])}\n",
165
+ " \n",
166
+ "#get_similarities(lvis, partial(img_preprocess, center_context=0.5));\n",
167
+ "get_similarities(lvis, lambda x: x[1]);"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": null,
173
+ "metadata": {},
174
+ "outputs": [],
175
+ "source": [
176
+ "preprocessing_functions = [\n",
177
+ "# ['clip mask CLS L11', lambda x: {'x_inp': x[1].cuda(), 'mask': (11, 'cls_token', x[2].cuda())}],\n",
178
+ "# ['clip mask CLS all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'cls_token', x[2].cuda())}],\n",
179
+ "# ['clip mask all all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'all', x[2].cuda())}],\n",
180
+ "# ['colorize object red', partial(img_preprocess, colorize=True)],\n",
181
+ "# ['add red outline', partial(img_preprocess, outline=True)],\n",
182
+ " \n",
183
+ "# ['BG brightness 50%', partial(img_preprocess, bg_fac=0.5)],\n",
184
+ "# ['BG brightness 10%', partial(img_preprocess, bg_fac=0.1)],\n",
185
+ "# ['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)],\n",
186
+ "# ['BG blur', partial(img_preprocess, blur=3)],\n",
187
+ "# ['BG blur & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n",
188
+ " \n",
189
+ "# ['crop large context', partial(img_preprocess, center_context=0.5)],\n",
190
+ "# ['crop small context', partial(img_preprocess, center_context=0.1)],\n",
191
+ " ['crop & background blur', partial(img_preprocess, blur=3, center_context=0.5)],\n",
192
+ " ['crop & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n",
193
+ "# ['crop & background blur & intensity 10%', partial(img_preprocess, blur=3, center_context=0.1, bg_fac=0.1)],\n",
194
+ "]\n",
195
+ "\n",
196
+ "preprocessing_functions = preprocessing_functions\n",
197
+ "\n",
198
+ "base, base_p = get_similarities(lvis, lambda x: x[1])\n",
199
+ "outs = [get_similarities(lvis, fun) for _, fun in preprocessing_functions]"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": null,
205
+ "metadata": {},
206
+ "outputs": [],
207
+ "source": [
208
+ "outs2 = [get_similarities(lvis, fun) for _, fun in [['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)]]]"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": null,
214
+ "metadata": {},
215
+ "outputs": [],
216
+ "source": [
217
+ "for j in range(1):\n",
218
+ " print(np.mean([outs2[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3]))"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": null,
224
+ "metadata": {},
225
+ "outputs": [],
226
+ "source": [
227
+ "from pandas import DataFrame\n",
228
+ "tab = dict()\n",
229
+ "for j, (name, _) in enumerate(preprocessing_functions):\n",
230
+ " tab[name] = np.mean([outs[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3])\n",
231
+ " \n",
232
+ " \n",
233
+ "print('\\n'.join(f'{k} & {v*100:.2f} \\\\\\\\' for k,v in tab.items())) "
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "markdown",
238
+ "metadata": {},
239
+ "source": [
240
+ "# Visual"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "code",
245
+ "execution_count": null,
246
+ "metadata": {},
247
+ "outputs": [],
248
+ "source": [
249
+ "from evaluation_utils import denorm, norm"
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "code",
254
+ "execution_count": null,
255
+ "metadata": {},
256
+ "outputs": [],
257
+ "source": [
258
+ "def load_sample(filename, filename2):\n",
259
+ " from os.path import join\n",
260
+ " bp = expanduser('~/cloud/resources/sample_images')\n",
261
+ " tf = transforms.Compose([\n",
262
+ " transforms.ToTensor(),\n",
263
+ " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
264
+ " transforms.Resize(224),\n",
265
+ " transforms.CenterCrop(224)\n",
266
+ " ])\n",
267
+ " tf2 = transforms.Compose([\n",
268
+ " transforms.ToTensor(),\n",
269
+ " transforms.Resize(224),\n",
270
+ " transforms.CenterCrop(224)\n",
271
+ " ])\n",
272
+ " inp1 = [None, tf(Image.open(join(bp, filename))), tf2(Image.open(join(bp, filename2)))]\n",
273
+ " inp1[1] = inp1[1].unsqueeze(0)\n",
274
+ " inp1[2] = inp1[2][:1] \n",
275
+ " return inp1\n",
276
+ "\n",
277
+ "def all_preprocessing(inp1):\n",
278
+ " return [\n",
279
+ " img_preprocess(inp1),\n",
280
+ " img_preprocess(inp1, colorize=True),\n",
281
+ " img_preprocess(inp1, outline=True), \n",
282
+ " img_preprocess(inp1, blur=3),\n",
283
+ " img_preprocess(inp1, bg_fac=0.1),\n",
284
+ " #img_preprocess(inp1, bg_fac=0.5),\n",
285
+ " #img_preprocess(inp1, blur=3, bg_fac=0.5), \n",
286
+ " img_preprocess(inp1, blur=3, bg_fac=0.5, center_context=0.5),\n",
287
+ " ]\n",
288
+ "\n"
289
+ ]
290
+ },
291
+ {
292
+ "cell_type": "code",
293
+ "execution_count": null,
294
+ "metadata": {},
295
+ "outputs": [],
296
+ "source": [
297
+ "from torchvision import transforms\n",
298
+ "from PIL import Image\n",
299
+ "from matplotlib import pyplot as plt\n",
300
+ "from evaluation_utils import img_preprocess\n",
301
+ "import clip\n",
302
+ "\n",
303
+ "images_queries = [\n",
304
+ " [load_sample('things1.jpg', 'things1_jar.png'), ['jug', 'knife', 'car', 'animal', 'sieve', 'nothing']],\n",
305
+ " [load_sample('own_photos/IMG_2017s_square.jpg', 'own_photos/IMG_2017s_square_trash_can.png'), ['trash bin', 'house', 'car', 'bike', 'window', 'nothing']],\n",
306
+ "]\n",
307
+ "\n",
308
+ "\n",
309
+ "_, ax = plt.subplots(2 * len(images_queries), 6, figsize=(14, 4.5 * len(images_queries)))\n",
310
+ "\n",
311
+ "for j, (images, objects) in enumerate(images_queries):\n",
312
+ " \n",
313
+ " joint_image = all_preprocessing(images)\n",
314
+ " \n",
315
+ " joint_image = torch.stack(joint_image)[:,0]\n",
316
+ " clip_model, preprocess = clip.load(\"ViT-B/16\", device='cpu')\n",
317
+ " image_features = clip_model.encode_image(joint_image)\n",
318
+ " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n",
319
+ " \n",
320
+ " prompts = [f'a photo of a {obj}'for obj in objects]\n",
321
+ " text_cond = clip_model.encode_text(clip.tokenize(prompts))\n",
322
+ " text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True)\n",
323
+ " logits = clip_model.logit_scale.exp() * image_features @ text_cond.T\n",
324
+ " sim = torch.softmax(logits, dim=-1).detach().cpu()\n",
325
+ "\n",
326
+ " for i, img in enumerate(joint_image):\n",
327
+ " ax[2*j, i].axis('off')\n",
328
+ " \n",
329
+ " ax[2*j, i].imshow(torch.clamp(denorm(joint_image[i]).permute(1,2,0), 0, 1))\n",
330
+ " ax[2*j+ 1, i].grid(True)\n",
331
+ " \n",
332
+ " ax[2*j + 1, i].set_ylim(0,1)\n",
333
+ " ax[2*j + 1, i].set_yticklabels([])\n",
334
+ " ax[2*j + 1, i].set_xticks([]) # set_xticks(range(len(prompts)))\n",
335
+ "# ax[1, i].set_xticklabels(objects, rotation=90)\n",
336
+ " for k in range(len(sim[i])):\n",
337
+ " ax[2*j + 1, i].bar(k, sim[i][k], color=plt.cm.tab20(1) if k!=0 else plt.cm.tab20(3))\n",
338
+ " ax[2*j + 1, i].text(k, 0.07, objects[k], rotation=90, ha='center', fontsize=15)\n",
339
+ "\n",
340
+ "plt.tight_layout()\n",
341
+ "plt.savefig('figures/prompt_engineering.pdf', bbox_inches='tight')"
342
+ ]
343
+ }
344
+ ],
345
+ "metadata": {
346
+ "kernelspec": {
347
+ "display_name": "env2",
348
+ "language": "python",
349
+ "name": "env2"
350
+ },
351
+ "language_info": {
352
+ "codemirror_mode": {
353
+ "name": "ipython",
354
+ "version": 3
355
+ },
356
+ "file_extension": ".py",
357
+ "mimetype": "text/x-python",
358
+ "name": "python",
359
+ "nbconvert_exporter": "python",
360
+ "pygments_lexer": "ipython3",
361
+ "version": "3.8.8"
362
+ }
363
+ },
364
+ "nbformat": 4,
365
+ "nbformat_minor": 4
366
+ }
clipseg/datasets/coco_wrapper.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from types import new_class
3
+ import torch
4
+ import numpy as np
5
+ import os
6
+ import json
7
+
8
+ from os.path import join, dirname, isdir, isfile, expanduser, realpath, basename
9
+ from random import shuffle, seed as set_seed
10
+ from PIL import Image
11
+
12
+ from itertools import combinations
13
+ from torchvision import transforms
14
+ from torchvision.transforms.transforms import Resize
15
+
16
+ from datasets.utils import blend_image_segmentation
17
+ from general_utils import get_from_repository
18
+
19
+ COCO_CLASSES = {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'}
20
+
21
+ class COCOWrapper(object):
22
+
23
+ def __init__(self, split, fold=0, image_size=400, aug=None, mask='separate', negative_prob=0,
24
+ with_class_label=False):
25
+ super().__init__()
26
+
27
+ self.mask = mask
28
+ self.with_class_label = with_class_label
29
+ self.negative_prob = negative_prob
30
+
31
+ from third_party.hsnet.data.coco import DatasetCOCO
32
+
33
+ get_from_repository('COCO-20i', ['COCO-20i.tar'])
34
+
35
+ foldpath = join(dirname(__file__), '../third_party/hsnet/data/splits/coco/%s/fold%d.pkl')
36
+
37
+ def build_img_metadata_classwise(self):
38
+ with open(foldpath % (self.split, self.fold), 'rb') as f:
39
+ img_metadata_classwise = pickle.load(f)
40
+ return img_metadata_classwise
41
+
42
+
43
+ DatasetCOCO.build_img_metadata_classwise = build_img_metadata_classwise
44
+ # DatasetCOCO.read_mask = read_mask
45
+
46
+ mean = [0.485, 0.456, 0.406]
47
+ std = [0.229, 0.224, 0.225]
48
+ transform = transforms.Compose([
49
+ transforms.Resize((image_size, image_size)),
50
+ transforms.ToTensor(),
51
+ transforms.Normalize(mean, std)
52
+ ])
53
+
54
+ self.coco = DatasetCOCO(expanduser('~/datasets/COCO-20i/'), fold, transform, split, 1, False)
55
+
56
+ self.all_classes = [self.coco.class_ids]
57
+ self.coco.base_path = join(expanduser('~/datasets/COCO-20i'))
58
+
59
+ def __len__(self):
60
+ return len(self.coco)
61
+
62
+ def __getitem__(self, i):
63
+ sample = self.coco[i]
64
+
65
+ label_name = COCO_CLASSES[int(sample['class_id'])]
66
+
67
+ img_s, seg_s = sample['support_imgs'][0], sample['support_masks'][0]
68
+
69
+ if self.negative_prob > 0 and torch.rand(1).item() < self.negative_prob:
70
+ new_class_id = sample['class_id']
71
+ while new_class_id == sample['class_id']:
72
+ sample2 = self.coco[torch.randint(0, len(self), (1,)).item()]
73
+ new_class_id = sample2['class_id']
74
+ img_s = sample2['support_imgs'][0]
75
+ seg_s = torch.zeros_like(seg_s)
76
+
77
+ mask = self.mask
78
+ if mask == 'separate':
79
+ supp = (img_s, seg_s)
80
+ elif mask == 'text_label':
81
+ # DEPRECATED
82
+ supp = [int(sample['class_id'])]
83
+ elif mask == 'text':
84
+ supp = [label_name]
85
+ else:
86
+ if mask.startswith('text_and_'):
87
+ mask = mask[9:]
88
+ label_add = [label_name]
89
+ else:
90
+ label_add = []
91
+
92
+ supp = label_add + blend_image_segmentation(img_s, seg_s, mode=mask)
93
+
94
+ if self.with_class_label:
95
+ label = (torch.zeros(0), sample['class_id'],)
96
+ else:
97
+ label = (torch.zeros(0), )
98
+
99
+ return (sample['query_img'],) + tuple(supp), (sample['query_mask'].unsqueeze(0),) + label
clipseg/datasets/pascal_classes.json ADDED
@@ -0,0 +1 @@
 
1
+ [{"id": 1, "synonyms": ["aeroplane"]}, {"id": 2, "synonyms": ["bicycle"]}, {"id": 3, "synonyms": ["bird"]}, {"id": 4, "synonyms": ["boat"]}, {"id": 5, "synonyms": ["bottle"]}, {"id": 6, "synonyms": ["bus"]}, {"id": 7, "synonyms": ["car"]}, {"id": 8, "synonyms": ["cat"]}, {"id": 9, "synonyms": ["chair"]}, {"id": 10, "synonyms": ["cow"]}, {"id": 11, "synonyms": ["diningtable"]}, {"id": 12, "synonyms": ["dog"]}, {"id": 13, "synonyms": ["horse"]}, {"id": 14, "synonyms": ["motorbike"]}, {"id": 15, "synonyms": ["person"]}, {"id": 16, "synonyms": ["pottedplant"]}, {"id": 17, "synonyms": ["sheep"]}, {"id": 18, "synonyms": ["sofa"]}, {"id": 19, "synonyms": ["train"]}, {"id": 20, "synonyms": ["tvmonitor"]}]
clipseg/datasets/pascal_zeroshot.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import expanduser
2
+ import torch
3
+ import json
4
+ import torchvision
5
+ from general_utils import get_from_repository
6
+ from general_utils import log
7
+ from torchvision import transforms
8
+
9
+ PASCAL_VOC_CLASSES_ZS = [['cattle.n.01', 'motorcycle.n.01'], ['aeroplane.n.01', 'sofa.n.01'],
10
+ ['cat.n.01', 'television.n.03'], ['train.n.01', 'bottle.n.01'],
11
+ ['chair.n.01', 'pot_plant.n.01']]
12
+
13
+
14
+ class PascalZeroShot(object):
15
+
16
+ def __init__(self, split, n_unseen, image_size=224) -> None:
17
+ super().__init__()
18
+
19
+ import sys
20
+ sys.path.append('third_party/JoEm')
21
+ from third_party.JoEm.data_loader.dataset import VOCSegmentation
22
+ from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC
23
+
24
+ self.pascal_classes = VOC
25
+ self.image_size = image_size
26
+
27
+ self.transform = transforms.Compose([
28
+ transforms.Resize((image_size, image_size)),
29
+ ])
30
+
31
+ if split == 'train':
32
+ self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen),
33
+ split=split, transform=True, transform_args=dict(base_size=312, crop_size=312),
34
+ ignore_bg=False, ignore_unseen=False, remv_unseen_img=True)
35
+ elif split == 'val':
36
+ self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen),
37
+ split=split, transform=False,
38
+ ignore_bg=False, ignore_unseen=False)
39
+
40
+ self.unseen_idx = get_unseen_idx(n_unseen)
41
+
42
+ def __len__(self):
43
+ return len(self.voc)
44
+
45
+ def __getitem__(self, i):
46
+
47
+ sample = self.voc[i]
48
+ label = sample['label'].long()
49
+ all_labels = [l for l in torch.where(torch.bincount(label.flatten())>0)[0].numpy().tolist() if l != 255]
50
+ class_indices = [l for l in all_labels]
51
+ class_names = [self.pascal_classes[l] for l in all_labels]
52
+
53
+ image = self.transform(sample['image'])
54
+
55
+ label = transforms.Resize((self.image_size, self.image_size),
56
+ interpolation=torchvision.transforms.InterpolationMode.NEAREST)(label.unsqueeze(0))[0]
57
+
58
+ return (image,), (label, )
59
+
60
+
clipseg/datasets/pfe_dataset.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import expanduser
2
+ import torch
3
+ import json
4
+ from general_utils import get_from_repository
5
+ from datasets.lvis_oneshot3 import blend_image_segmentation
6
+ from general_utils import log
7
+
8
+ PASCAL_CLASSES = {a['id']: a['synonyms'] for a in json.load(open('datasets/pascal_classes.json'))}
9
+
10
+
11
+ class PFEPascalWrapper(object):
12
+
13
+ def __init__(self, mode, split, mask='separate', image_size=473, label_support=None, size=None, p_negative=0, aug=None):
14
+ import sys
15
+ # sys.path.append(expanduser('~/projects/new_one_shot'))
16
+ from third_party.PFENet.util.dataset import SemData
17
+
18
+ get_from_repository('PascalVOC2012', ['Pascal5i.tar'])
19
+
20
+ self.p_negative = p_negative
21
+ self.size = size
22
+ self.mode = mode
23
+ self.image_size = image_size
24
+
25
+ if label_support in {True, False}:
26
+ log.warning('label_support argument is deprecated. Use mask instead.')
27
+ #raise ValueError()
28
+
29
+ self.mask = mask
30
+
31
+ value_scale = 255
32
+ mean = [0.485, 0.456, 0.406]
33
+ mean = [item * value_scale for item in mean]
34
+ std = [0.229, 0.224, 0.225]
35
+ std = [item * value_scale for item in std]
36
+
37
+ import third_party.PFENet.util.transform as transform
38
+
39
+ if mode == 'val':
40
+ data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/val.txt')
41
+
42
+ data_transform = [transform.test_Resize(size=image_size)] if image_size != 'original' else []
43
+ data_transform += [
44
+ transform.ToTensor(),
45
+ transform.Normalize(mean=mean, std=std)
46
+ ]
47
+
48
+
49
+ elif mode == 'train':
50
+ data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/voc_sbd_merge_noduplicate.txt')
51
+
52
+ assert image_size != 'original'
53
+
54
+ data_transform = [
55
+ transform.RandScale([0.9, 1.1]),
56
+ transform.RandRotate([-10, 10], padding=mean, ignore_label=255),
57
+ transform.RandomGaussianBlur(),
58
+ transform.RandomHorizontalFlip(),
59
+ transform.Crop((image_size, image_size), crop_type='rand', padding=mean, ignore_label=255),
60
+ transform.ToTensor(),
61
+ transform.Normalize(mean=mean, std=std)
62
+ ]
63
+
64
+ data_transform = transform.Compose(data_transform)
65
+
66
+ self.dataset = SemData(split=split, mode=mode, data_root=expanduser('~/datasets/PascalVOC2012/VOC2012'),
67
+ data_list=data_list, shot=1, transform=data_transform, use_coco=False, use_split_coco=False)
68
+
69
+ self.class_list = self.dataset.sub_val_list if mode == 'val' else self.dataset.sub_list
70
+
71
+ # verify that subcls_list always has length 1
72
+ # assert len(set([len(d[4]) for d in self.dataset])) == 1
73
+
74
+ print('actual length', len(self.dataset.data_list))
75
+
76
+ def __len__(self):
77
+ if self.mode == 'val':
78
+ return len(self.dataset.data_list)
79
+ else:
80
+ return len(self.dataset.data_list)
81
+
82
+ def __getitem__(self, index):
83
+ if self.dataset.mode == 'train':
84
+ image, label, s_x, s_y, subcls_list = self.dataset[index % len(self.dataset.data_list)]
85
+ elif self.dataset.mode == 'val':
86
+ image, label, s_x, s_y, subcls_list, ori_label = self.dataset[index % len(self.dataset.data_list)]
87
+ ori_label = torch.from_numpy(ori_label).unsqueeze(0)
88
+
89
+ if self.image_size != 'original':
90
+ longerside = max(ori_label.size(1), ori_label.size(2))
91
+ backmask = torch.ones(ori_label.size(0), longerside, longerside).cuda()*255
92
+ backmask[0, :ori_label.size(1), :ori_label.size(2)] = ori_label
93
+ label = backmask.clone().long()
94
+ else:
95
+ label = label.unsqueeze(0)
96
+
97
+ # assert label.shape == (473, 473)
98
+
99
+ if self.p_negative > 0:
100
+ if torch.rand(1).item() < self.p_negative:
101
+ while True:
102
+ idx = torch.randint(0, len(self.dataset.data_list), (1,)).item()
103
+ _, _, s_x, s_y, subcls_list_tmp, _ = self.dataset[idx]
104
+ if subcls_list[0] != subcls_list_tmp[0]:
105
+ break
106
+
107
+ s_x = s_x[0]
108
+ s_y = (s_y == 1)[0]
109
+ label_fg = (label == 1).float()
110
+ val_mask = (label != 255).float()
111
+
112
+ class_id = self.class_list[subcls_list[0]]
113
+
114
+ label_name = PASCAL_CLASSES[class_id][0]
115
+ label_add = ()
116
+ mask = self.mask
117
+
118
+ if mask == 'text':
119
+ support = ('a photo of a ' + label_name + '.',)
120
+ elif mask == 'separate':
121
+ support = (s_x, s_y)
122
+ else:
123
+ if mask.startswith('text_and_'):
124
+ label_add = (label_name,)
125
+ mask = mask[9:]
126
+
127
+ support = (blend_image_segmentation(s_x, s_y.float(), mask)[0],)
128
+
129
+ return (image,) + label_add + support, (label_fg.unsqueeze(0), val_mask.unsqueeze(0), subcls_list[0])
clipseg/datasets/phrasecut.py ADDED
@@ -0,0 +1,335 @@