apolinario commited on
Commit
48fa639
·
1 Parent(s): bdc1819

upload clipseg

Browse files
clipseg/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ This license does not apply to the model weights.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
clipseg/Quickstart.ipynb ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "import requests\n",
11
+ "\n",
12
+ "! wget https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -O weights.zip\n",
13
+ "! unzip -d weights -j weights.zip\n",
14
+ "from models.clipseg import CLIPDensePredT\n",
15
+ "from PIL import Image\n",
16
+ "from torchvision import transforms\n",
17
+ "from matplotlib import pyplot as plt\n",
18
+ "\n",
19
+ "# load model\n",
20
+ "model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)\n",
21
+ "model.eval();\n",
22
+ "\n",
23
+ "# non-strict, because we only stored decoder weights (not CLIP weights)\n",
24
+ "model.load_state_dict(torch.load('weights/rd64-uni.pth', map_location=torch.device('cpu')), strict=False);"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "markdown",
29
+ "metadata": {},
30
+ "source": [
31
+ "Load and normalize `example_image.jpg`. You can also load through an URL."
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "# load and normalize image\n",
41
+ "input_image = Image.open('example_image.jpg')\n",
42
+ "\n",
43
+ "# or load from URL...\n",
44
+ "# image_url = 'https://farm5.staticflickr.com/4141/4856248695_03475782dc_z.jpg'\n",
45
+ "# input_image = Image.open(requests.get(image_url, stream=True).raw)\n",
46
+ "\n",
47
+ "transform = transforms.Compose([\n",
48
+ " transforms.ToTensor(),\n",
49
+ " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
50
+ " transforms.Resize((352, 352)),\n",
51
+ "])\n",
52
+ "img = transform(input_image).unsqueeze(0)"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "markdown",
57
+ "metadata": {},
58
+ "source": [
59
+ "Predict and visualize (this might take a few seconds if running without GPU support)"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": null,
65
+ "metadata": {},
66
+ "outputs": [],
67
+ "source": [
68
+ "prompts = ['a glass', 'something to fill', 'wood', 'a jar']\n",
69
+ "\n",
70
+ "# predict\n",
71
+ "with torch.no_grad():\n",
72
+ " preds = model(img.repeat(4,1,1,1), prompts)[0]\n",
73
+ "\n",
74
+ "# visualize prediction\n",
75
+ "_, ax = plt.subplots(1, 5, figsize=(15, 4))\n",
76
+ "[a.axis('off') for a in ax.flatten()]\n",
77
+ "ax[0].imshow(input_image)\n",
78
+ "[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(4)];\n",
79
+ "[ax[i+1].text(0, -15, prompts[i]) for i in range(4)];"
80
+ ]
81
+ }
82
+ ],
83
+ "metadata": {
84
+ "interpreter": {
85
+ "hash": "800ed241f7db2bd3aa6942aa3be6809cdb30ee6b0a9e773dfecfa9fef1f4c586"
86
+ },
87
+ "kernelspec": {
88
+ "display_name": "Python 3",
89
+ "language": "python",
90
+ "name": "python3"
91
+ },
92
+ "language_info": {
93
+ "codemirror_mode": {
94
+ "name": "ipython",
95
+ "version": 3
96
+ },
97
+ "file_extension": ".py",
98
+ "mimetype": "text/x-python",
99
+ "name": "python",
100
+ "nbconvert_exporter": "python",
101
+ "pygments_lexer": "ipython3",
102
+ "version": "3.8.10"
103
+ }
104
+ },
105
+ "nbformat": 4,
106
+ "nbformat_minor": 4
107
+ }
clipseg/Readme.md ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Image Segmentation Using Text and Image Prompts
2
+ This repository contains the code used in the paper ["Image Segmentation Using Text and Image Prompts"](https://arxiv.org/abs/2112.10003).
3
+
4
+ **The Paper has been accepted to CVPR 2022!**
5
+
6
+ <img src="overview.png" alt="drawing" height="200em"/>
7
+
8
+ The systems allows to create segmentation models without training based on:
9
+ - An arbitrary text query
10
+ - Or an image with a mask highlighting stuff or an object.
11
+
12
+ ### Quick Start
13
+
14
+ In the `Quickstart.ipynb` notebook we provide the code for using a pre-trained CLIPSeg model. If you run the notebook locally, make sure you downloaded the `rd64-uni.pth` weights, either manually or via git lfs extension.
15
+ It can also be used interactively using [MyBinder](https://mybinder.org/v2/gh/timojl/clipseg/HEAD?labpath=Quickstart.ipynb)
16
+ (please note that the VM does not use a GPU, thus inference takes a few seconds).
17
+
18
+
19
+ ### Dependencies
20
+ This code base depends on pytorch, torchvision and clip (`pip install git+https://github.com/openai/CLIP.git`).
21
+ Additional dependencies are hidden for double blind review.
22
+
23
+
24
+ ### Datasets
25
+
26
+ * `PhraseCut` and `PhraseCutPlus`: Referring expression dataset
27
+ * `PFEPascalWrapper`: Wrapper class for PFENet's Pascal-5i implementation
28
+ * `PascalZeroShot`: Wrapper class for PascalZeroShot
29
+ * `COCOWrapper`: Wrapper class for COCO.
30
+
31
+ ### Models
32
+
33
+ * `CLIPDensePredT`: CLIPSeg model with transformer-based decoder.
34
+ * `ViTDensePredT`: CLIPSeg model with transformer-based decoder.
35
+
36
+ ### Third Party Dependencies
37
+ For some of the datasets third party dependencies are required. Run the following commands in the `third_party` folder.
38
+ ```bash
39
+ git clone https://github.com/cvlab-yonsei/JoEm
40
+ git clone https://github.com/Jia-Research-Lab/PFENet.git
41
+ git clone https://github.com/ChenyunWu/PhraseCutDataset.git
42
+ git clone https://github.com/juhongm999/hsnet.git
43
+ ```
44
+
45
+ ### Weights
46
+
47
+ The MIT license does not apply to these weights.
48
+
49
+ We provide two model weights, for D=64 (4.1MB) and D=16 (1.1MB).
50
+ ```
51
+ wget https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -O weights.zip
52
+ unzip -d weights -j weights.zip
53
+ ```
54
+
55
+
56
+ ### Training and Evaluation
57
+
58
+ To train use the `training.py` script with experiment file and experiment id parameters. E.g. `python training.py phrasecut.yaml 0` will train the first phrasecut experiment which is defined by the `configuration` and first `individual_configurations` parameters. Model weights will be written in `logs/`.
59
+
60
+ For evaluation use `score.py`. E.g. `python score.py phrasecut.yaml 0 0` will train the first phrasecut experiment of `test_configuration` and the first configuration in `individual_configurations`.
61
+
62
+
63
+ ### Usage of PFENet Wrappers
64
+
65
+ In order to use the dataset and model wrappers for PFENet, the PFENet repository needs to be cloned to the root folder.
66
+ `git clone https://github.com/Jia-Research-Lab/PFENet.git `
67
+
68
+
69
+ ### License
70
+
71
+ The source code files in this repository (excluding model weights) are released under MIT license.
72
+
73
+ ### Citation
74
+ ```
75
+ @InProceedings{lueddecke22_cvpr,
76
+ author = {L\"uddecke, Timo and Ecker, Alexander},
77
+ title = {Image Segmentation Using Text and Image Prompts},
78
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
79
+ month = {June},
80
+ year = {2022},
81
+ pages = {7086-7096}
82
+ }
83
+
84
+ ```
clipseg/Tables.ipynb ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "%load_ext autoreload\n",
10
+ "%autoreload 2\n",
11
+ "\n",
12
+ "import clip\n",
13
+ "from evaluation_utils import norm, denorm\n",
14
+ "from general_utils import *\n",
15
+ "from datasets.lvis_oneshot3 import LVIS_OneShot3, LVIS_OneShot"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "markdown",
20
+ "metadata": {},
21
+ "source": [
22
+ "# PhraseCut"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "pc = experiment('experiments/phrasecut.yaml', nums=':6').dataframe()"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "tab1 = pc[['name', 'pc_miou_best', 'pc_fgiou_best', 'pc_ap']]"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "cols = ['pc_miou_0.3', 'pc_fgiou_0.3', 'pc_ap']\n",
50
+ "tab1 = pc[['name'] + cols]\n",
51
+ "for k in cols:\n",
52
+ " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n",
53
+ "tab1.loc[:, 'name'] = ['CLIPSeg (PC+)', 'CLIPSeg (PC, $D=128$)', 'CLIPSeg (PC)', 'CLIP-Deconv', 'ViTSeg (PC+)', 'ViTSeg (PC)']\n",
54
+ "tab1.insert(1, 't', [0.3]*tab1.shape[0])\n",
55
+ "print(tab1.to_latex(header=False, index=False))"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "markdown",
60
+ "metadata": {},
61
+ "source": [
62
+ "For 0.1 threshold"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": [
71
+ "cols = ['pc_miou_0.1', 'pc_fgiou_0.1', 'pc_ap']\n",
72
+ "tab1 = pc[['name'] + cols]\n",
73
+ "for k in cols:\n",
74
+ " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n",
75
+ "tab1.loc[:, 'name'] = ['CLIPSeg (PC+)', 'CLIPSeg (PC, $D=128$)', 'CLIPSeg (PC)', 'CLIP-Deconv', 'ViTSeg (PC+)', 'ViTSeg (PC)']\n",
76
+ "tab1.insert(1, 't', [0.1]*tab1.shape[0])\n",
77
+ "print(tab1.to_latex(header=False, index=False))"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "markdown",
82
+ "metadata": {},
83
+ "source": [
84
+ "# One-shot"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "markdown",
89
+ "metadata": {},
90
+ "source": [
91
+ "### Pascal"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": null,
97
+ "metadata": {},
98
+ "outputs": [],
99
+ "source": [
100
+ "pas = experiment('experiments/pascal_1shot.yaml', nums=':19').dataframe()"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": null,
106
+ "metadata": {},
107
+ "outputs": [],
108
+ "source": [
109
+ "pas[['name', 'pas_h2_miou_0.3', 'pas_h2_biniou_0.3', 'pas_h2_ap', 'pas_h2_fgiou_ct']]"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "execution_count": null,
115
+ "metadata": {},
116
+ "outputs": [],
117
+ "source": [
118
+ "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n",
119
+ "tab1 = pas[['pas_h2_miou_0.3', 'pas_h2_biniou_0.3', 'pas_h2_ap']]\n",
120
+ "print('CLIPSeg (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
121
+ "print('CLIPSeg (PC) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
122
+ "\n",
123
+ "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n",
124
+ "tab1 = pas[['pas_h2_miou_0.2', 'pas_h2_biniou_0.2', 'pas_h2_ap']]\n",
125
+ "print('CLIP-Deconv (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
126
+ "\n",
127
+ "pas = experiment('experiments/pascal_1shot.yaml', nums='16:20').dataframe()\n",
128
+ "tab1 = pas[['pas_t_miou_0.2', 'pas_t_biniou_0.2', 'pas_t_ap']]\n",
129
+ "print('ViTSeg (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "markdown",
134
+ "metadata": {},
135
+ "source": [
136
+ "#### Pascal Zero-shot (in one-shot setting)\n",
137
+ "\n",
138
+ "Using the same setting as one-shot (hence different from the other zero-shot benchmark)"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": null,
144
+ "metadata": {},
145
+ "outputs": [],
146
+ "source": [
147
+ "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n",
148
+ "tab1 = pas[['pas_t_miou_0.3', 'pas_t_biniou_0.3', 'pas_t_ap']]\n",
149
+ "print('CLIPSeg (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
150
+ "print('CLIPSeg (PC) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
151
+ "\n",
152
+ "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n",
153
+ "tab1 = pas[['pas_t_miou_0.3', 'pas_t_biniou_0.3', 'pas_t_ap']]\n",
154
+ "print('CLIP-Deconv (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
155
+ "\n",
156
+ "pas = experiment('experiments/pascal_1shot.yaml', nums='16:20').dataframe()\n",
157
+ "tab1 = pas[['pas_t_miou_0.2', 'pas_t_biniou_0.2', 'pas_t_ap']]\n",
158
+ "print('ViTSeg (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": null,
164
+ "metadata": {},
165
+ "outputs": [],
166
+ "source": [
167
+ "# without fixed thresholds...\n",
168
+ "\n",
169
+ "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n",
170
+ "tab1 = pas[['pas_t_best_miou', 'pas_t_best_biniou', 'pas_t_ap']]\n",
171
+ "print('CLIPSeg (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
172
+ "print('CLIPSeg (PC) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
173
+ "\n",
174
+ "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n",
175
+ "tab1 = pas[['pas_t_best_miou', 'pas_t_best_biniou', 'pas_t_ap']]\n",
176
+ "print('CLIP-Deconv (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "markdown",
181
+ "metadata": {},
182
+ "source": [
183
+ "### COCO"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "code",
188
+ "execution_count": null,
189
+ "metadata": {},
190
+ "outputs": [],
191
+ "source": [
192
+ "coco = experiment('experiments/coco.yaml', nums=':29').dataframe()"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "execution_count": null,
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "tab1 = coco[['coco_h2_miou_0.1', 'coco_h2_biniou_0.1', 'coco_h2_ap']]\n",
202
+ "tab2 = coco[['coco_h2_miou_0.2', 'coco_h2_biniou_0.2', 'coco_h2_ap']]\n",
203
+ "tab3 = coco[['coco_h2_miou_best', 'coco_h2_biniou_best', 'coco_h2_ap']]\n",
204
+ "print('CLIPSeg (COCO) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[:4].mean(0).values), '\\\\\\\\')\n",
205
+ "print('CLIPSeg (COCO+N) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
206
+ "print('CLIP-Deconv (COCO+N) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[12:16].mean(0).values), '\\\\\\\\')\n",
207
+ "print('ViTSeg (COCO) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[8:12].mean(0).values), '\\\\\\\\')"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "markdown",
212
+ "metadata": {},
213
+ "source": [
214
+ "# Zero-shot"
215
+ ]
216
+ },
217
+ {
218
+ "cell_type": "code",
219
+ "execution_count": null,
220
+ "metadata": {},
221
+ "outputs": [],
222
+ "source": [
223
+ "zs = experiment('experiments/pascal_0shot.yaml', nums=':11').dataframe()"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "code",
228
+ "execution_count": null,
229
+ "metadata": {},
230
+ "outputs": [],
231
+ "source": [
232
+ "\n",
233
+ "tab1 = zs[['pas_zs_seen', 'pas_zs_unseen']]\n",
234
+ "print('CLIPSeg (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[8:9].values[0].tolist() + tab1[10:11].values[0].tolist()), '\\\\\\\\')\n",
235
+ "print('CLIP-Deconv & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[2:3].values[0].tolist() + tab1[3:4].values[0].tolist()), '\\\\\\\\')\n",
236
+ "print('ViTSeg & ImageNet-1K & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:5].values[0].tolist() + tab1[5:6].values[0].tolist()), '\\\\\\\\')"
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "markdown",
241
+ "metadata": {},
242
+ "source": [
243
+ "# Ablation"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "code",
248
+ "execution_count": null,
249
+ "metadata": {},
250
+ "outputs": [],
251
+ "source": [
252
+ "ablation = experiment('experiments/ablation.yaml', nums=':8').dataframe()"
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "code",
257
+ "execution_count": null,
258
+ "metadata": {},
259
+ "outputs": [],
260
+ "source": [
261
+ "tab1 = ablation[['name', 'pc_miou_best', 'pc_ap', 'pc-vis_miou_best', 'pc-vis_ap']]\n",
262
+ "for k in ['pc_miou_best', 'pc_ap', 'pc-vis_miou_best', 'pc-vis_ap']:\n",
263
+ " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n",
264
+ "tab1.loc[:, 'name'] = ['CLIPSeg', 'no CLIP pre-training', 'no-negatives', '50% negatives', 'no visual', '$D=16$', 'only layer 3', 'highlight mask']"
265
+ ]
266
+ },
267
+ {
268
+ "cell_type": "code",
269
+ "execution_count": null,
270
+ "metadata": {},
271
+ "outputs": [],
272
+ "source": [
273
+ "print(tab1.loc[[0,1,4,5,6,7],:].to_latex(header=False, index=False))"
274
+ ]
275
+ },
276
+ {
277
+ "cell_type": "code",
278
+ "execution_count": null,
279
+ "metadata": {},
280
+ "outputs": [],
281
+ "source": [
282
+ "print(tab1.loc[[0,1,4,5,6,7],:].to_latex(header=False, index=False))"
283
+ ]
284
+ },
285
+ {
286
+ "cell_type": "markdown",
287
+ "metadata": {},
288
+ "source": [
289
+ "# Generalization"
290
+ ]
291
+ },
292
+ {
293
+ "cell_type": "code",
294
+ "execution_count": null,
295
+ "metadata": {},
296
+ "outputs": [],
297
+ "source": [
298
+ "generalization = experiment('experiments/generalize.yaml').dataframe()"
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "execution_count": null,
304
+ "metadata": {},
305
+ "outputs": [],
306
+ "source": [
307
+ "gen = generalization[['aff_best_fgiou', 'aff_ap', 'ability_best_fgiou', 'ability_ap', 'part_best_fgiou', 'part_ap']].values"
308
+ ]
309
+ },
310
+ {
311
+ "cell_type": "code",
312
+ "execution_count": null,
313
+ "metadata": {},
314
+ "outputs": [],
315
+ "source": [
316
+ "print(\n",
317
+ " 'CLIPSeg (PC+) & ' + ' & '.join(f'{x*100:.1f}' for x in gen[1]) + ' \\\\\\\\ \\n' + \\\n",
318
+ " 'CLIPSeg (LVIS) & ' + ' & '.join(f'{x*100:.1f}' for x in gen[0]) + ' \\\\\\\\ \\n' + \\\n",
319
+ " 'CLIP-Deconv & ' + ' & '.join(f'{x*100:.1f}' for x in gen[2]) + ' \\\\\\\\ \\n' + \\\n",
320
+ " 'VITSeg & ' + ' & '.join(f'{x*100:.1f}' for x in gen[3]) + ' \\\\\\\\'\n",
321
+ ")"
322
+ ]
323
+ }
324
+ ],
325
+ "metadata": {
326
+ "interpreter": {
327
+ "hash": "800ed241f7db2bd3aa6942aa3be6809cdb30ee6b0a9e773dfecfa9fef1f4c586"
328
+ },
329
+ "kernelspec": {
330
+ "display_name": "env2",
331
+ "language": "python",
332
+ "name": "env2"
333
+ },
334
+ "language_info": {
335
+ "codemirror_mode": {
336
+ "name": "ipython",
337
+ "version": 3
338
+ },
339
+ "file_extension": ".py",
340
+ "mimetype": "text/x-python",
341
+ "name": "python",
342
+ "nbconvert_exporter": "python",
343
+ "pygments_lexer": "ipython3",
344
+ "version": "3.8.8"
345
+ }
346
+ },
347
+ "nbformat": 4,
348
+ "nbformat_minor": 4
349
+ }
clipseg/Visual_Feature_Engineering.ipynb ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Systematic"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "%load_ext autoreload\n",
17
+ "%autoreload 2\n",
18
+ "\n",
19
+ "import clip\n",
20
+ "from evaluation_utils import norm, denorm\n",
21
+ "from general_utils import *\n",
22
+ "from datasets.lvis_oneshot3 import LVIS_OneShot3\n",
23
+ "\n",
24
+ "clip_device = 'cuda'\n",
25
+ "clip_model, preprocess = clip.load(\"ViT-B/16\", device=clip_device)\n",
26
+ "clip_model.eval();\n",
27
+ "\n",
28
+ "from models.clipseg import CLIPDensePredTMasked\n",
29
+ "\n",
30
+ "clip_mask_model = CLIPDensePredTMasked(version='ViT-B/16').to(clip_device)\n",
31
+ "clip_mask_model.eval();"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "lvis = LVIS_OneShot3('train_fixed', mask='separate', normalize=True, with_class_label=True, add_bar=False, \n",
41
+ " text_class_labels=True, image_size=352, min_area=0.1,\n",
42
+ " min_frac_s=0.05, min_frac_q=0.05, fix_find_crop=True)"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "plot_data(lvis)"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "from collections import defaultdict\n",
61
+ "import json\n",
62
+ "\n",
63
+ "lvis_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_train.json')))\n",
64
+ "lvis_val_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_val.json')))\n",
65
+ "\n",
66
+ "objects_per_image = defaultdict(lambda : set())\n",
67
+ "for ann in lvis_raw['annotations']:\n",
68
+ " objects_per_image[ann['image_id']].add(ann['category_id'])\n",
69
+ " \n",
70
+ "for ann in lvis_val_raw['annotations']:\n",
71
+ " objects_per_image[ann['image_id']].add(ann['category_id']) \n",
72
+ " \n",
73
+ "objects_per_image = {o: [lvis.category_names[o] for o in v] for o, v in objects_per_image.items()}\n",
74
+ "\n",
75
+ "del lvis_raw, lvis_val_raw"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": null,
81
+ "metadata": {},
82
+ "outputs": [],
83
+ "source": [
84
+ "#bs = 32\n",
85
+ "#batches = [get_batch(lvis, i*bs, (i+1)*bs, cuda=True) for i in range(10)]"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": null,
91
+ "metadata": {},
92
+ "outputs": [],
93
+ "source": [
94
+ "from general_utils import get_batch\n",
95
+ "from functools import partial\n",
96
+ "from evaluation_utils import img_preprocess\n",
97
+ "import torch\n",
98
+ "\n",
99
+ "def get_similarities(batches_or_dataset, process, mask=lambda x: None, clipmask=False):\n",
100
+ "\n",
101
+ " # base_words = [f'a photo of {x}' for x in ['a person', 'an animal', 'a knife', 'a cup']]\n",
102
+ "\n",
103
+ " all_prompts = []\n",
104
+ " \n",
105
+ " with torch.no_grad():\n",
106
+ " valid_sims = []\n",
107
+ " torch.manual_seed(571)\n",
108
+ " \n",
109
+ " if type(batches_or_dataset) == list:\n",
110
+ " loader = batches_or_dataset # already loaded\n",
111
+ " max_iter = float('inf')\n",
112
+ " else:\n",
113
+ " loader = DataLoader(batches_or_dataset, shuffle=False, batch_size=32)\n",
114
+ " max_iter = 50\n",
115
+ " \n",
116
+ " global batch\n",
117
+ " for i_batch, (batch, batch_y) in enumerate(loader):\n",
118
+ " \n",
119
+ " if i_batch >= max_iter: break\n",
120
+ " \n",
121
+ " processed_batch = process(batch)\n",
122
+ " if type(processed_batch) == dict:\n",
123
+ " \n",
124
+ " # processed_batch = {k: v.to(clip_device) for k, v in processed_batch.items()}\n",
125
+ " image_features = clip_mask_model.visual_forward(**processed_batch)[0].to(clip_device).half()\n",
126
+ " else:\n",
127
+ " processed_batch = process(batch).to(clip_device)\n",
128
+ " processed_batch = nnf.interpolate(processed_batch, (224, 224), mode='bilinear')\n",
129
+ " #image_features = clip_model.encode_image(processed_batch.to(clip_device)) \n",
130
+ " image_features = clip_mask_model.visual_forward(processed_batch)[0].to(clip_device).half()\n",
131
+ " \n",
132
+ " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n",
133
+ " bs = len(batch[0])\n",
134
+ " for j in range(bs):\n",
135
+ " \n",
136
+ " c, _, sid, qid = lvis.sample_ids[bs * i_batch + j]\n",
137
+ " support_image = basename(lvis.samples[c][sid])\n",
138
+ " \n",
139
+ " img_objs = [o for o in objects_per_image[int(support_image)]]\n",
140
+ " img_objs = [o.replace('_', ' ') for o in img_objs]\n",
141
+ " \n",
142
+ " other_words = [f'a photo of a {o.replace(\"_\", \" \")}' for o in img_objs \n",
143
+ " if o != batch_y[2][j]]\n",
144
+ " \n",
145
+ " prompts = [f'a photo of a {batch_y[2][j]}'] + other_words\n",
146
+ " all_prompts += [prompts]\n",
147
+ " \n",
148
+ " text_cond = clip_model.encode_text(clip.tokenize(prompts).to(clip_device))\n",
149
+ " text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True) \n",
150
+ "\n",
151
+ " global logits\n",
152
+ " logits = clip_model.logit_scale.exp() * image_features[j] @ text_cond.T\n",
153
+ "\n",
154
+ " global sim\n",
155
+ " sim = torch.softmax(logits, dim=-1)\n",
156
+ " \n",
157
+ " valid_sims += [sim]\n",
158
+ " \n",
159
+ " #valid_sims = torch.stack(valid_sims)\n",
160
+ " return valid_sims, all_prompts\n",
161
+ " \n",
162
+ "\n",
163
+ "def new_img_preprocess(x):\n",
164
+ " return {'x_inp': x[1], 'mask': (11, 'cls_token', x[2])}\n",
165
+ " \n",
166
+ "#get_similarities(lvis, partial(img_preprocess, center_context=0.5));\n",
167
+ "get_similarities(lvis, lambda x: x[1]);"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": null,
173
+ "metadata": {},
174
+ "outputs": [],
175
+ "source": [
176
+ "preprocessing_functions = [\n",
177
+ "# ['clip mask CLS L11', lambda x: {'x_inp': x[1].cuda(), 'mask': (11, 'cls_token', x[2].cuda())}],\n",
178
+ "# ['clip mask CLS all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'cls_token', x[2].cuda())}],\n",
179
+ "# ['clip mask all all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'all', x[2].cuda())}],\n",
180
+ "# ['colorize object red', partial(img_preprocess, colorize=True)],\n",
181
+ "# ['add red outline', partial(img_preprocess, outline=True)],\n",
182
+ " \n",
183
+ "# ['BG brightness 50%', partial(img_preprocess, bg_fac=0.5)],\n",
184
+ "# ['BG brightness 10%', partial(img_preprocess, bg_fac=0.1)],\n",
185
+ "# ['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)],\n",
186
+ "# ['BG blur', partial(img_preprocess, blur=3)],\n",
187
+ "# ['BG blur & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n",
188
+ " \n",
189
+ "# ['crop large context', partial(img_preprocess, center_context=0.5)],\n",
190
+ "# ['crop small context', partial(img_preprocess, center_context=0.1)],\n",
191
+ " ['crop & background blur', partial(img_preprocess, blur=3, center_context=0.5)],\n",
192
+ " ['crop & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n",
193
+ "# ['crop & background blur & intensity 10%', partial(img_preprocess, blur=3, center_context=0.1, bg_fac=0.1)],\n",
194
+ "]\n",
195
+ "\n",
196
+ "preprocessing_functions = preprocessing_functions\n",
197
+ "\n",
198
+ "base, base_p = get_similarities(lvis, lambda x: x[1])\n",
199
+ "outs = [get_similarities(lvis, fun) for _, fun in preprocessing_functions]"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": null,
205
+ "metadata": {},
206
+ "outputs": [],
207
+ "source": [
208
+ "outs2 = [get_similarities(lvis, fun) for _, fun in [['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)]]]"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": null,
214
+ "metadata": {},
215
+ "outputs": [],
216
+ "source": [
217
+ "for j in range(1):\n",
218
+ " print(np.mean([outs2[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3]))"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": null,
224
+ "metadata": {},
225
+ "outputs": [],
226
+ "source": [
227
+ "from pandas import DataFrame\n",
228
+ "tab = dict()\n",
229
+ "for j, (name, _) in enumerate(preprocessing_functions):\n",
230
+ " tab[name] = np.mean([outs[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3])\n",
231
+ " \n",
232
+ " \n",
233
+ "print('\\n'.join(f'{k} & {v*100:.2f} \\\\\\\\' for k,v in tab.items())) "
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "markdown",
238
+ "metadata": {},
239
+ "source": [
240
+ "# Visual"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "code",
245
+ "execution_count": null,
246
+ "metadata": {},
247
+ "outputs": [],
248
+ "source": [
249
+ "from evaluation_utils import denorm, norm"
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "code",
254
+ "execution_count": null,
255
+ "metadata": {},
256
+ "outputs": [],
257
+ "source": [
258
+ "def load_sample(filename, filename2):\n",
259
+ " from os.path import join\n",
260
+ " bp = expanduser('~/cloud/resources/sample_images')\n",
261
+ " tf = transforms.Compose([\n",
262
+ " transforms.ToTensor(),\n",
263
+ " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
264
+ " transforms.Resize(224),\n",
265
+ " transforms.CenterCrop(224)\n",
266
+ " ])\n",
267
+ " tf2 = transforms.Compose([\n",
268
+ " transforms.ToTensor(),\n",
269
+ " transforms.Resize(224),\n",
270
+ " transforms.CenterCrop(224)\n",
271
+ " ])\n",
272
+ " inp1 = [None, tf(Image.open(join(bp, filename))), tf2(Image.open(join(bp, filename2)))]\n",
273
+ " inp1[1] = inp1[1].unsqueeze(0)\n",
274
+ " inp1[2] = inp1[2][:1] \n",
275
+ " return inp1\n",
276
+ "\n",
277
+ "def all_preprocessing(inp1):\n",
278
+ " return [\n",
279
+ " img_preprocess(inp1),\n",
280
+ " img_preprocess(inp1, colorize=True),\n",
281
+ " img_preprocess(inp1, outline=True), \n",
282
+ " img_preprocess(inp1, blur=3),\n",
283
+ " img_preprocess(inp1, bg_fac=0.1),\n",
284
+ " #img_preprocess(inp1, bg_fac=0.5),\n",
285
+ " #img_preprocess(inp1, blur=3, bg_fac=0.5), \n",
286
+ " img_preprocess(inp1, blur=3, bg_fac=0.5, center_context=0.5),\n",
287
+ " ]\n",
288
+ "\n"
289
+ ]
290
+ },
291
+ {
292
+ "cell_type": "code",
293
+ "execution_count": null,
294
+ "metadata": {},
295
+ "outputs": [],
296
+ "source": [
297
+ "from torchvision import transforms\n",
298
+ "from PIL import Image\n",
299
+ "from matplotlib import pyplot as plt\n",
300
+ "from evaluation_utils import img_preprocess\n",
301
+ "import clip\n",
302
+ "\n",
303
+ "images_queries = [\n",
304
+ " [load_sample('things1.jpg', 'things1_jar.png'), ['jug', 'knife', 'car', 'animal', 'sieve', 'nothing']],\n",
305
+ " [load_sample('own_photos/IMG_2017s_square.jpg', 'own_photos/IMG_2017s_square_trash_can.png'), ['trash bin', 'house', 'car', 'bike', 'window', 'nothing']],\n",
306
+ "]\n",
307
+ "\n",
308
+ "\n",
309
+ "_, ax = plt.subplots(2 * len(images_queries), 6, figsize=(14, 4.5 * len(images_queries)))\n",
310
+ "\n",
311
+ "for j, (images, objects) in enumerate(images_queries):\n",
312
+ " \n",
313
+ " joint_image = all_preprocessing(images)\n",
314
+ " \n",
315
+ " joint_image = torch.stack(joint_image)[:,0]\n",
316
+ " clip_model, preprocess = clip.load(\"ViT-B/16\", device='cpu')\n",
317
+ " image_features = clip_model.encode_image(joint_image)\n",
318
+ " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n",
319
+ " \n",
320
+ " prompts = [f'a photo of a {obj}'for obj in objects]\n",
321
+ " text_cond = clip_model.encode_text(clip.tokenize(prompts))\n",
322
+ " text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True)\n",
323
+ " logits = clip_model.logit_scale.exp() * image_features @ text_cond.T\n",
324
+ " sim = torch.softmax(logits, dim=-1).detach().cpu()\n",
325
+ "\n",
326
+ " for i, img in enumerate(joint_image):\n",
327
+ " ax[2*j, i].axis('off')\n",
328
+ " \n",
329
+ " ax[2*j, i].imshow(torch.clamp(denorm(joint_image[i]).permute(1,2,0), 0, 1))\n",
330
+ " ax[2*j+ 1, i].grid(True)\n",
331
+ " \n",
332
+ " ax[2*j + 1, i].set_ylim(0,1)\n",
333
+ " ax[2*j + 1, i].set_yticklabels([])\n",
334
+ " ax[2*j + 1, i].set_xticks([]) # set_xticks(range(len(prompts)))\n",
335
+ "# ax[1, i].set_xticklabels(objects, rotation=90)\n",
336
+ " for k in range(len(sim[i])):\n",
337
+ " ax[2*j + 1, i].bar(k, sim[i][k], color=plt.cm.tab20(1) if k!=0 else plt.cm.tab20(3))\n",
338
+ " ax[2*j + 1, i].text(k, 0.07, objects[k], rotation=90, ha='center', fontsize=15)\n",
339
+ "\n",
340
+ "plt.tight_layout()\n",
341
+ "plt.savefig('figures/prompt_engineering.pdf', bbox_inches='tight')"
342
+ ]
343
+ }
344
+ ],
345
+ "metadata": {
346
+ "kernelspec": {
347
+ "display_name": "env2",
348
+ "language": "python",
349
+ "name": "env2"
350
+ },
351
+ "language_info": {
352
+ "codemirror_mode": {
353
+ "name": "ipython",
354
+ "version": 3
355
+ },
356
+ "file_extension": ".py",
357
+ "mimetype": "text/x-python",
358
+ "name": "python",
359
+ "nbconvert_exporter": "python",
360
+ "pygments_lexer": "ipython3",
361
+ "version": "3.8.8"
362
+ }
363
+ },
364
+ "nbformat": 4,
365
+ "nbformat_minor": 4
366
+ }
clipseg/datasets/coco_wrapper.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from types import new_class
3
+ import torch
4
+ import numpy as np
5
+ import os
6
+ import json
7
+
8
+ from os.path import join, dirname, isdir, isfile, expanduser, realpath, basename
9
+ from random import shuffle, seed as set_seed
10
+ from PIL import Image
11
+
12
+ from itertools import combinations
13
+ from torchvision import transforms
14
+ from torchvision.transforms.transforms import Resize
15
+
16
+ from datasets.utils import blend_image_segmentation
17
+ from general_utils import get_from_repository
18
+
19
+ COCO_CLASSES = {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'}
20
+
21
+ class COCOWrapper(object):
22
+
23
+ def __init__(self, split, fold=0, image_size=400, aug=None, mask='separate', negative_prob=0,
24
+ with_class_label=False):
25
+ super().__init__()
26
+
27
+ self.mask = mask
28
+ self.with_class_label = with_class_label
29
+ self.negative_prob = negative_prob
30
+
31
+ from third_party.hsnet.data.coco import DatasetCOCO
32
+
33
+ get_from_repository('COCO-20i', ['COCO-20i.tar'])
34
+
35
+ foldpath = join(dirname(__file__), '../third_party/hsnet/data/splits/coco/%s/fold%d.pkl')
36
+
37
+ def build_img_metadata_classwise(self):
38
+ with open(foldpath % (self.split, self.fold), 'rb') as f:
39
+ img_metadata_classwise = pickle.load(f)
40
+ return img_metadata_classwise
41
+
42
+
43
+ DatasetCOCO.build_img_metadata_classwise = build_img_metadata_classwise
44
+ # DatasetCOCO.read_mask = read_mask
45
+
46
+ mean = [0.485, 0.456, 0.406]
47
+ std = [0.229, 0.224, 0.225]
48
+ transform = transforms.Compose([
49
+ transforms.Resize((image_size, image_size)),
50
+ transforms.ToTensor(),
51
+ transforms.Normalize(mean, std)
52
+ ])
53
+
54
+ self.coco = DatasetCOCO(expanduser('~/datasets/COCO-20i/'), fold, transform, split, 1, False)
55
+
56
+ self.all_classes = [self.coco.class_ids]
57
+ self.coco.base_path = join(expanduser('~/datasets/COCO-20i'))
58
+
59
+ def __len__(self):
60
+ return len(self.coco)
61
+
62
+ def __getitem__(self, i):
63
+ sample = self.coco[i]
64
+
65
+ label_name = COCO_CLASSES[int(sample['class_id'])]
66
+
67
+ img_s, seg_s = sample['support_imgs'][0], sample['support_masks'][0]
68
+
69
+ if self.negative_prob > 0 and torch.rand(1).item() < self.negative_prob:
70
+ new_class_id = sample['class_id']
71
+ while new_class_id == sample['class_id']:
72
+ sample2 = self.coco[torch.randint(0, len(self), (1,)).item()]
73
+ new_class_id = sample2['class_id']
74
+ img_s = sample2['support_imgs'][0]
75
+ seg_s = torch.zeros_like(seg_s)
76
+
77
+ mask = self.mask
78
+ if mask == 'separate':
79
+ supp = (img_s, seg_s)
80
+ elif mask == 'text_label':
81
+ # DEPRECATED
82
+ supp = [int(sample['class_id'])]
83
+ elif mask == 'text':
84
+ supp = [label_name]
85
+ else:
86
+ if mask.startswith('text_and_'):
87
+ mask = mask[9:]
88
+ label_add = [label_name]
89
+ else:
90
+ label_add = []
91
+
92
+ supp = label_add + blend_image_segmentation(img_s, seg_s, mode=mask)
93
+
94
+ if self.with_class_label:
95
+ label = (torch.zeros(0), sample['class_id'],)
96
+ else:
97
+ label = (torch.zeros(0), )
98
+
99
+ return (sample['query_img'],) + tuple(supp), (sample['query_mask'].unsqueeze(0),) + label
clipseg/datasets/pascal_classes.json ADDED
@@ -0,0 +1 @@
 
 
1
+ [{"id": 1, "synonyms": ["aeroplane"]}, {"id": 2, "synonyms": ["bicycle"]}, {"id": 3, "synonyms": ["bird"]}, {"id": 4, "synonyms": ["boat"]}, {"id": 5, "synonyms": ["bottle"]}, {"id": 6, "synonyms": ["bus"]}, {"id": 7, "synonyms": ["car"]}, {"id": 8, "synonyms": ["cat"]}, {"id": 9, "synonyms": ["chair"]}, {"id": 10, "synonyms": ["cow"]}, {"id": 11, "synonyms": ["diningtable"]}, {"id": 12, "synonyms": ["dog"]}, {"id": 13, "synonyms": ["horse"]}, {"id": 14, "synonyms": ["motorbike"]}, {"id": 15, "synonyms": ["person"]}, {"id": 16, "synonyms": ["pottedplant"]}, {"id": 17, "synonyms": ["sheep"]}, {"id": 18, "synonyms": ["sofa"]}, {"id": 19, "synonyms": ["train"]}, {"id": 20, "synonyms": ["tvmonitor"]}]
clipseg/datasets/pascal_zeroshot.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import expanduser
2
+ import torch
3
+ import json
4
+ import torchvision
5
+ from general_utils import get_from_repository
6
+ from general_utils import log
7
+ from torchvision import transforms
8
+
9
+ PASCAL_VOC_CLASSES_ZS = [['cattle.n.01', 'motorcycle.n.01'], ['aeroplane.n.01', 'sofa.n.01'],
10
+ ['cat.n.01', 'television.n.03'], ['train.n.01', 'bottle.n.01'],
11
+ ['chair.n.01', 'pot_plant.n.01']]
12
+
13
+
14
+ class PascalZeroShot(object):
15
+
16
+ def __init__(self, split, n_unseen, image_size=224) -> None:
17
+ super().__init__()
18
+
19
+ import sys
20
+ sys.path.append('third_party/JoEm')
21
+ from third_party.JoEm.data_loader.dataset import VOCSegmentation
22
+ from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC
23
+
24
+ self.pascal_classes = VOC
25
+ self.image_size = image_size
26
+
27
+ self.transform = transforms.Compose([
28
+ transforms.Resize((image_size, image_size)),
29
+ ])
30
+
31
+ if split == 'train':
32
+ self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen),
33
+ split=split, transform=True, transform_args=dict(base_size=312, crop_size=312),
34
+ ignore_bg=False, ignore_unseen=False, remv_unseen_img=True)
35
+ elif split == 'val':
36
+ self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen),
37
+ split=split, transform=False,
38
+ ignore_bg=False, ignore_unseen=False)
39
+
40
+ self.unseen_idx = get_unseen_idx(n_unseen)
41
+
42
+ def __len__(self):
43
+ return len(self.voc)
44
+
45
+ def __getitem__(self, i):
46
+
47
+ sample = self.voc[i]
48
+ label = sample['label'].long()
49
+ all_labels = [l for l in torch.where(torch.bincount(label.flatten())>0)[0].numpy().tolist() if l != 255]
50
+ class_indices = [l for l in all_labels]
51
+ class_names = [self.pascal_classes[l] for l in all_labels]
52
+
53
+ image = self.transform(sample['image'])
54
+
55
+ label = transforms.Resize((self.image_size, self.image_size),
56
+ interpolation=torchvision.transforms.InterpolationMode.NEAREST)(label.unsqueeze(0))[0]
57
+
58
+ return (image,), (label, )
59
+
60
+
clipseg/datasets/pfe_dataset.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import expanduser
2
+ import torch
3
+ import json
4
+ from general_utils import get_from_repository
5
+ from datasets.lvis_oneshot3 import blend_image_segmentation
6
+ from general_utils import log
7
+
8
+ PASCAL_CLASSES = {a['id']: a['synonyms'] for a in json.load(open('datasets/pascal_classes.json'))}
9
+
10
+
11
+ class PFEPascalWrapper(object):
12
+
13
+ def __init__(self, mode, split, mask='separate', image_size=473, label_support=None, size=None, p_negative=0, aug=None):
14
+ import sys
15
+ # sys.path.append(expanduser('~/projects/new_one_shot'))
16
+ from third_party.PFENet.util.dataset import SemData
17
+
18
+ get_from_repository('PascalVOC2012', ['Pascal5i.tar'])
19
+
20
+ self.p_negative = p_negative
21
+ self.size = size
22
+ self.mode = mode
23
+ self.image_size = image_size
24
+
25
+ if label_support in {True, False}:
26
+ log.warning('label_support argument is deprecated. Use mask instead.')
27
+ #raise ValueError()
28
+
29
+ self.mask = mask
30
+
31
+ value_scale = 255
32
+ mean = [0.485, 0.456, 0.406]
33
+ mean = [item * value_scale for item in mean]
34
+ std = [0.229, 0.224, 0.225]
35
+ std = [item * value_scale for item in std]
36
+
37
+ import third_party.PFENet.util.transform as transform
38
+
39
+ if mode == 'val':
40
+ data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/val.txt')
41
+
42
+ data_transform = [transform.test_Resize(size=image_size)] if image_size != 'original' else []
43
+ data_transform += [
44
+ transform.ToTensor(),
45
+ transform.Normalize(mean=mean, std=std)
46
+ ]
47
+
48
+
49
+ elif mode == 'train':
50
+ data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/voc_sbd_merge_noduplicate.txt')
51
+
52
+ assert image_size != 'original'
53
+
54
+ data_transform = [
55
+ transform.RandScale([0.9, 1.1]),
56
+ transform.RandRotate([-10, 10], padding=mean, ignore_label=255),
57
+ transform.RandomGaussianBlur(),
58
+ transform.RandomHorizontalFlip(),
59
+ transform.Crop((image_size, image_size), crop_type='rand', padding=mean, ignore_label=255),
60
+ transform.ToTensor(),
61
+ transform.Normalize(mean=mean, std=std)
62
+ ]
63
+
64
+ data_transform = transform.Compose(data_transform)
65
+
66
+ self.dataset = SemData(split=split, mode=mode, data_root=expanduser('~/datasets/PascalVOC2012/VOC2012'),
67
+ data_list=data_list, shot=1, transform=data_transform, use_coco=False, use_split_coco=False)
68
+
69
+ self.class_list = self.dataset.sub_val_list if mode == 'val' else self.dataset.sub_list
70
+
71
+ # verify that subcls_list always has length 1
72
+ # assert len(set([len(d[4]) for d in self.dataset])) == 1
73
+
74
+ print('actual length', len(self.dataset.data_list))
75
+
76
+ def __len__(self):
77
+ if self.mode == 'val':
78
+ return len(self.dataset.data_list)
79
+ else:
80
+ return len(self.dataset.data_list)
81
+
82
+ def __getitem__(self, index):
83
+ if self.dataset.mode == 'train':
84
+ image, label, s_x, s_y, subcls_list = self.dataset[index % len(self.dataset.data_list)]
85
+ elif self.dataset.mode == 'val':
86
+ image, label, s_x, s_y, subcls_list, ori_label = self.dataset[index % len(self.dataset.data_list)]
87
+ ori_label = torch.from_numpy(ori_label).unsqueeze(0)
88
+
89
+ if self.image_size != 'original':
90
+ longerside = max(ori_label.size(1), ori_label.size(2))
91
+ backmask = torch.ones(ori_label.size(0), longerside, longerside).cuda()*255
92
+ backmask[0, :ori_label.size(1), :ori_label.size(2)] = ori_label
93
+ label = backmask.clone().long()
94
+ else:
95
+ label = label.unsqueeze(0)
96
+
97
+ # assert label.shape == (473, 473)
98
+
99
+ if self.p_negative > 0:
100
+ if torch.rand(1).item() < self.p_negative:
101
+ while True:
102
+ idx = torch.randint(0, len(self.dataset.data_list), (1,)).item()
103
+ _, _, s_x, s_y, subcls_list_tmp, _ = self.dataset[idx]
104
+ if subcls_list[0] != subcls_list_tmp[0]:
105
+ break
106
+
107
+ s_x = s_x[0]
108
+ s_y = (s_y == 1)[0]
109
+ label_fg = (label == 1).float()
110
+ val_mask = (label != 255).float()
111
+
112
+ class_id = self.class_list[subcls_list[0]]
113
+
114
+ label_name = PASCAL_CLASSES[class_id][0]
115
+ label_add = ()
116
+ mask = self.mask
117
+
118
+ if mask == 'text':
119
+ support = ('a photo of a ' + label_name + '.',)
120
+ elif mask == 'separate':
121
+ support = (s_x, s_y)
122
+ else:
123
+ if mask.startswith('text_and_'):
124
+ label_add = (label_name,)
125
+ mask = mask[9:]
126
+
127
+ support = (blend_image_segmentation(s_x, s_y.float(), mask)[0],)
128
+
129
+ return (image,) + label_add + support, (label_fg.unsqueeze(0), val_mask.unsqueeze(0), subcls_list[0])
clipseg/datasets/phrasecut.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import numpy as np
4
+ import os
5
+
6
+ from os.path import join, isdir, isfile, expanduser
7
+ from PIL import Image
8
+
9
+ from torchvision import transforms
10
+ from torchvision.transforms.transforms import Resize
11
+
12
+ from torch.nn import functional as nnf
13
+ from general_utils import get_from_repository
14
+
15
+ from skimage.draw import polygon2mask
16
+
17
+
18
+
19
+ def random_crop_slices(origin_size, target_size):
20
+ """Gets slices of a random crop. """
21
+ assert origin_size[0] >= target_size[0] and origin_size[1] >= target_size[1], f'actual size: {origin_size}, target size: {target_size}'
22
+
23
+ offset_y = torch.randint(0, origin_size[0] - target_size[0] + 1, (1,)).item() # range: 0 <= value < high
24
+ offset_x = torch.randint(0, origin_size[1] - target_size[1] + 1, (1,)).item()
25
+
26
+ return slice(offset_y, offset_y + target_size[0]), slice(offset_x, offset_x + target_size[1])
27
+
28
+
29
+ def find_crop(seg, image_size, iterations=1000, min_frac=None, best_of=None):
30
+
31
+
32
+ best_crops = []
33
+ best_crop_not_ok = float('-inf'), None, None
34
+ min_sum = 0
35
+
36
+ seg = seg.astype('bool')
37
+
38
+ if min_frac is not None:
39
+ #min_sum = seg.sum() * min_frac
40
+ min_sum = seg.shape[0] * seg.shape[1] * min_frac
41
+
42
+ for iteration in range(iterations):
43
+ sl_y, sl_x = random_crop_slices(seg.shape, image_size)
44
+ seg_ = seg[sl_y, sl_x]
45
+ sum_seg_ = seg_.sum()
46
+
47
+ if sum_seg_ > min_sum:
48
+
49
+ if best_of is None:
50
+ return sl_y, sl_x, False
51
+ else:
52
+ best_crops += [(sum_seg_, sl_y, sl_x)]
53
+ if len(best_crops) >= best_of:
54
+ best_crops.sort(key=lambda x:x[0], reverse=True)
55
+ sl_y, sl_x = best_crops[0][1:]
56
+
57
+ return sl_y, sl_x, False
58
+
59
+ else:
60
+ if sum_seg_ > best_crop_not_ok[0]:
61
+ best_crop_not_ok = sum_seg_, sl_y, sl_x
62
+
63
+ else:
64
+ # return best segmentation found
65
+ return best_crop_not_ok[1:] + (best_crop_not_ok[0] <= min_sum,)
66
+
67
+
68
+ class PhraseCut(object):
69
+
70
+ def __init__(self, split, image_size=400, negative_prob=0, aug=None, aug_color=False, aug_crop=True,
71
+ min_size=0, remove_classes=None, with_visual=False, only_visual=False, mask=None):
72
+ super().__init__()
73
+
74
+ self.negative_prob = negative_prob
75
+ self.image_size = image_size
76
+ self.with_visual = with_visual
77
+ self.only_visual = only_visual
78
+ self.phrase_form = '{}'
79
+ self.mask = mask
80
+ self.aug_crop = aug_crop
81
+
82
+ if aug_color:
83
+ self.aug_color = transforms.Compose([
84
+ transforms.ColorJitter(0.5, 0.5, 0.2, 0.05),
85
+ ])
86
+ else:
87
+ self.aug_color = None
88
+
89
+ get_from_repository('PhraseCut', ['PhraseCut.tar'], integrity_check=lambda local_dir: all([
90
+ isdir(join(local_dir, 'VGPhraseCut_v0')),
91
+ isdir(join(local_dir, 'VGPhraseCut_v0', 'images')),
92
+ isfile(join(local_dir, 'VGPhraseCut_v0', 'refer_train.json')),
93
+ len(os.listdir(join(local_dir, 'VGPhraseCut_v0', 'images'))) in {108250, 108249}
94
+ ]))
95
+
96
+ from third_party.PhraseCutDataset.utils.refvg_loader import RefVGLoader
97
+ self.refvg_loader = RefVGLoader(split=split)
98
+
99
+ # img_ids where the size in the annotations does not match actual size
100
+ invalid_img_ids = set([150417, 285665, 498246, 61564, 285743, 498269, 498010, 150516, 150344, 286093, 61530,
101
+ 150333, 286065, 285814, 498187, 285761, 498042])
102
+
103
+ mean = [0.485, 0.456, 0.406]
104
+ std = [0.229, 0.224, 0.225]
105
+ self.normalize = transforms.Normalize(mean, std)
106
+
107
+ self.sample_ids = [(i, j)
108
+ for i in self.refvg_loader.img_ids
109
+ for j in range(len(self.refvg_loader.get_img_ref_data(i)['phrases']))
110
+ if i not in invalid_img_ids]
111
+
112
+
113
+ # self.all_phrases = list(set([p for i in self.refvg_loader.img_ids for p in self.refvg_loader.get_img_ref_data(i)['phrases']]))
114
+
115
+ from nltk.stem import WordNetLemmatizer
116
+ wnl = WordNetLemmatizer()
117
+
118
+ # Filter by class (if remove_classes is set)
119
+ if remove_classes is None:
120
+ pass
121
+ else:
122
+ from datasets.generate_lvis_oneshot import PASCAL_SYNSETS, traverse_lemmas, traverse_lemmas_hypo
123
+ from nltk.corpus import wordnet
124
+
125
+ print('remove pascal classes...')
126
+
127
+ get_data = self.refvg_loader.get_img_ref_data # shortcut
128
+ keep_sids = None
129
+
130
+ if remove_classes[0] == 'pas5i':
131
+ subset_id = remove_classes[1]
132
+ from datasets.generate_lvis_oneshot import PASCAL_5I_SYNSETS_ORDERED, PASCAL_5I_CLASS_IDS
133
+ avoid = [PASCAL_5I_SYNSETS_ORDERED[i] for i in range(20) if i+1 not in PASCAL_5I_CLASS_IDS[subset_id]]
134
+
135
+
136
+ elif remove_classes[0] == 'zs':
137
+ stop = remove_classes[1]
138
+
139
+ from datasets.pascal_zeroshot import PASCAL_VOC_CLASSES_ZS
140
+
141
+ avoid = [c for class_set in PASCAL_VOC_CLASSES_ZS[:stop] for c in class_set]
142
+ print(avoid)
143
+
144
+ elif remove_classes[0] == 'aff':
145
+ # avoid = ['drink.v.01', 'sit.v.01', 'ride.v.02']
146
+ # all_lemmas = set(['drink', 'sit', 'ride'])
147
+ avoid = ['drink', 'drinks', 'drinking', 'sit', 'sits', 'sitting',
148
+ 'ride', 'rides', 'riding',
149
+ 'fly', 'flies', 'flying', 'drive', 'drives', 'driving', 'driven',
150
+ 'swim', 'swims', 'swimming',
151
+ 'wheels', 'wheel', 'legs', 'leg', 'ear', 'ears']
152
+ keep_sids = [(i, j) for i, j in self.sample_ids if
153
+ all(x not in avoid for x in get_data(i)['phrases'][j].split(' '))]
154
+
155
+ print('avoid classes:', avoid)
156
+
157
+
158
+ if keep_sids is None:
159
+ all_lemmas = [s for ps in avoid for s in traverse_lemmas_hypo(wordnet.synset(ps), max_depth=None)]
160
+ all_lemmas = list(set(all_lemmas))
161
+ all_lemmas = [h.replace('_', ' ').lower() for h in all_lemmas]
162
+ all_lemmas = set(all_lemmas)
163
+
164
+ # divide into multi word and single word
165
+ all_lemmas_s = set(l for l in all_lemmas if ' ' not in l)
166
+ all_lemmas_m = set(l for l in all_lemmas if l not in all_lemmas_s)
167
+
168
+ # new3
169
+ phrases = [get_data(i)['phrases'][j] for i, j in self.sample_ids]
170
+ remove_sids = set((i,j) for (i,j), phrase in zip(self.sample_ids, phrases)
171
+ if any(l in phrase for l in all_lemmas_m) or
172
+ len(set(wnl.lemmatize(w) for w in phrase.split(' ')).intersection(all_lemmas_s)) > 0
173
+ )
174
+ keep_sids = [(i, j) for i, j in self.sample_ids if (i,j) not in remove_sids]
175
+
176
+ print(f'Reduced to {len(keep_sids) / len(self.sample_ids):.3f}')
177
+ removed_ids = set(self.sample_ids) - set(keep_sids)
178
+
179
+ print('Examples of removed', len(removed_ids))
180
+ for i, j in list(removed_ids)[:20]:
181
+ print(i, get_data(i)['phrases'][j])
182
+
183
+ self.sample_ids = keep_sids
184
+
185
+ from itertools import groupby
186
+ samples_by_phrase = [(self.refvg_loader.get_img_ref_data(i)['phrases'][j], (i, j))
187
+ for i, j in self.sample_ids]
188
+ samples_by_phrase = sorted(samples_by_phrase)
189
+ samples_by_phrase = groupby(samples_by_phrase, key=lambda x: x[0])
190
+
191
+ self.samples_by_phrase = {prompt: [s[1] for s in prompt_sample_ids] for prompt, prompt_sample_ids in samples_by_phrase}
192
+
193
+ self.all_phrases = list(set(self.samples_by_phrase.keys()))
194
+
195
+
196
+ if self.only_visual:
197
+ assert self.with_visual
198
+ self.sample_ids = [(i, j) for i, j in self.sample_ids
199
+ if len(self.samples_by_phrase[self.refvg_loader.get_img_ref_data(i)['phrases'][j]]) > 1]
200
+
201
+ # Filter by size (if min_size is set)
202
+ sizes = [self.refvg_loader.get_img_ref_data(i)['gt_boxes'][j] for i, j in self.sample_ids]
203
+ image_sizes = [self.refvg_loader.get_img_ref_data(i)['width'] * self.refvg_loader.get_img_ref_data(i)['height'] for i, j in self.sample_ids]
204
+ #self.sizes = [sum([(s[2] - s[0]) * (s[3] - s[1]) for s in size]) for size in sizes]
205
+ self.sizes = [sum([s[2] * s[3] for s in size]) / img_size for size, img_size in zip(sizes, image_sizes)]
206
+
207
+ if min_size:
208
+ print('filter by size')
209
+
210
+ self.sample_ids = [self.sample_ids[i] for i in range(len(self.sample_ids)) if self.sizes[i] > min_size]
211
+
212
+ self.base_path = join(expanduser('~/datasets/PhraseCut/VGPhraseCut_v0/images/'))
213
+
214
+ def __len__(self):
215
+ return len(self.sample_ids)
216
+
217
+
218
+ def load_sample(self, sample_i, j):
219
+
220
+ img_ref_data = self.refvg_loader.get_img_ref_data(sample_i)
221
+
222
+ polys_phrase0 = img_ref_data['gt_Polygons'][j]
223
+ phrase = img_ref_data['phrases'][j]
224
+ phrase = self.phrase_form.format(phrase)
225
+
226
+ masks = []
227
+ for polys in polys_phrase0:
228
+ for poly in polys:
229
+ poly = [p[::-1] for p in poly] # swap x,y
230
+ masks += [polygon2mask((img_ref_data['height'], img_ref_data['width']), poly)]
231
+
232
+ seg = np.stack(masks).max(0)
233
+ img = np.array(Image.open(join(self.base_path, str(img_ref_data['image_id']) + '.jpg')))
234
+
235
+ min_shape = min(img.shape[:2])
236
+
237
+ if self.aug_crop:
238
+ sly, slx, exceed = find_crop(seg, (min_shape, min_shape), iterations=50, min_frac=0.05)
239
+ else:
240
+ sly, slx = slice(0, None), slice(0, None)
241
+
242
+ seg = seg[sly, slx]
243
+ img = img[sly, slx]
244
+
245
+ seg = seg.astype('uint8')
246
+ seg = torch.from_numpy(seg).view(1, 1, *seg.shape)
247
+
248
+ if img.ndim == 2:
249
+ img = np.dstack([img] * 3)
250
+
251
+ img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).float()
252
+
253
+ seg = nnf.interpolate(seg, (self.image_size, self.image_size), mode='nearest')[0,0]
254
+ img = nnf.interpolate(img, (self.image_size, self.image_size), mode='bilinear', align_corners=True)[0]
255
+
256
+ # img = img.permute([2,0, 1])
257
+ img = img / 255.0
258
+
259
+ if self.aug_color is not None:
260
+ img = self.aug_color(img)
261
+
262
+ img = self.normalize(img)
263
+
264
+
265
+
266
+ return img, seg, phrase
267
+
268
+ def __getitem__(self, i):
269
+
270
+ sample_i, j = self.sample_ids[i]
271
+
272
+ img, seg, phrase = self.load_sample(sample_i, j)
273
+
274
+ if self.negative_prob > 0:
275
+ if torch.rand((1,)).item() < self.negative_prob:
276
+
277
+ new_phrase = None
278
+ while new_phrase is None or new_phrase == phrase:
279
+ idx = torch.randint(0, len(self.all_phrases), (1,)).item()
280
+ new_phrase = self.all_phrases[idx]
281
+ phrase = new_phrase
282
+ seg = torch.zeros_like(seg)
283
+
284
+ if self.with_visual:
285
+ # find a corresponding visual image
286
+ if phrase in self.samples_by_phrase and len(self.samples_by_phrase[phrase]) > 1:
287
+ idx = torch.randint(0, len(self.samples_by_phrase[phrase]), (1,)).item()
288
+ other_sample = self.samples_by_phrase[phrase][idx]
289
+ #print(other_sample)
290
+ img_s, seg_s, _ = self.load_sample(*other_sample)
291
+
292
+ from datasets.utils import blend_image_segmentation
293
+
294
+ if self.mask in {'separate', 'text_and_separate'}:
295
+ # assert img.shape[1:] == img_s.shape[1:] == seg_s.shape == seg.shape[1:]
296
+ add_phrase = [phrase] if self.mask == 'text_and_separate' else []
297
+ vis_s = add_phrase + [img_s, seg_s, True]
298
+ else:
299
+ if self.mask.startswith('text_and_'):
300
+ mask_mode = self.mask[9:]
301
+ label_add = [phrase]
302
+ else:
303
+ mask_mode = self.mask
304
+ label_add = []
305
+
306
+ masked_img_s = torch.from_numpy(blend_image_segmentation(img_s, seg_s, mode=mask_mode, image_size=self.image_size)[0])
307
+ vis_s = label_add + [masked_img_s, True]
308
+
309
+ else:
310
+ # phrase is unique
311
+ vis_s = torch.zeros_like(img)
312
+
313
+ if self.mask in {'separate', 'text_and_separate'}:
314
+ add_phrase = [phrase] if self.mask == 'text_and_separate' else []
315
+ vis_s = add_phrase + [vis_s, torch.zeros(*vis_s.shape[1:], dtype=torch.uint8), False]
316
+ elif self.mask.startswith('text_and_'):
317
+ vis_s = [phrase, vis_s, False]
318
+ else:
319
+ vis_s = [vis_s, False]
320
+ else:
321
+ assert self.mask == 'text'
322
+ vis_s = [phrase]
323
+
324
+ seg = seg.unsqueeze(0).float()
325
+
326
+ data_x = (img,) + tuple(vis_s)
327
+
328
+ return data_x, (seg, torch.zeros(0), i)
329
+
330
+
331
+ class PhraseCutPlus(PhraseCut):
332
+
333
+ def __init__(self, split, image_size=400, aug=None, aug_color=False, aug_crop=True, min_size=0, remove_classes=None, only_visual=False, mask=None):
334
+ super().__init__(split, image_size=image_size, negative_prob=0.2, aug=aug, aug_color=aug_color, aug_crop=aug_crop, min_size=min_size,
335
+ remove_classes=remove_classes, with_visual=True, only_visual=only_visual, mask=mask)
clipseg/datasets/utils.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import torch
4
+
5
+
6
+ def blend_image_segmentation(img, seg, mode, image_size=224):
7
+
8
+
9
+ if mode in {'blur_highlight', 'blur3_highlight', 'blur3_highlight01', 'blur_highlight_random', 'crop'}:
10
+ if isinstance(img, np.ndarray):
11
+ img = torch.from_numpy(img)
12
+
13
+ if isinstance(seg, np.ndarray):
14
+ seg = torch.from_numpy(seg)
15
+
16
+ if mode == 'overlay':
17
+ out = img * seg
18
+ out = [out.astype('float32')]
19
+ elif mode == 'highlight':
20
+ out = img * seg[None, :, :] * 0.85 + 0.15 * img
21
+ out = [out.astype('float32')]
22
+ elif mode == 'highlight2':
23
+ img = img / 2
24
+ out = (img+0.1) * seg[None, :, :] + 0.3 * img
25
+ out = [out.astype('float32')]
26
+ elif mode == 'blur_highlight':
27
+ from evaluation_utils import img_preprocess
28
+ out = [img_preprocess((None, [img], [seg]), blur=1, bg_fac=0.5).numpy()[0] - 0.01]
29
+ elif mode == 'blur3_highlight':
30
+ from evaluation_utils import img_preprocess
31
+ out = [img_preprocess((None, [img], [seg]), blur=3, bg_fac=0.5).numpy()[0] - 0.01]
32
+ elif mode == 'blur3_highlight01':
33
+ from evaluation_utils import img_preprocess
34
+ out = [img_preprocess((None, [img], [seg]), blur=3, bg_fac=0.1).numpy()[0] - 0.01]
35
+ elif mode == 'blur_highlight_random':
36
+ from evaluation_utils import img_preprocess
37
+ out = [img_preprocess((None, [img], [seg]), blur=0 + torch.randint(0, 3, (1,)).item(), bg_fac=0.1 + 0.8*torch.rand(1).item()).numpy()[0] - 0.01]
38
+ elif mode == 'crop':
39
+ from evaluation_utils import img_preprocess
40
+ out = [img_preprocess((None, [img], [seg]), blur=1, center_context=0.1, image_size=image_size)[0].numpy()]
41
+ elif mode == 'crop_blur_highlight':
42
+ from evaluation_utils import img_preprocess
43
+ out = [img_preprocess((None, [img], [seg]), blur=3, center_context=0.1, bg_fac=0.1, image_size=image_size)[0].numpy()]
44
+ elif mode == 'crop_blur_highlight352':
45
+ from evaluation_utils import img_preprocess
46
+ out = [img_preprocess((None, [img], [seg]), blur=3, center_context=0.1, bg_fac=0.1, image_size=352)[0].numpy()]
47
+ elif mode == 'shape':
48
+ out = [np.stack([seg[:, :]]*3).astype('float32')]
49
+ elif mode == 'concat':
50
+ out = [np.concatenate([img, seg[None, :, :]]).astype('float32')]
51
+ elif mode == 'image_only':
52
+ out = [img.astype('float32')]
53
+ elif mode == 'image_black':
54
+ out = [img.astype('float32')*0]
55
+ elif mode is None:
56
+ out = [img.astype('float32')]
57
+ elif mode == 'separate':
58
+ out = [img.astype('float32'), seg.astype('int64')]
59
+ elif mode == 'separate_img_black':
60
+ out = [img.astype('float32')*0, seg.astype('int64')]
61
+ elif mode == 'separate_seg_ones':
62
+ out = [img.astype('float32'), np.ones_like(seg).astype('int64')]
63
+ elif mode == 'separate_both_black':
64
+ out = [img.astype('float32')*0, seg.astype('int64')*0]
65
+ else:
66
+ raise ValueError(f'invalid mode: {mode}')
67
+
68
+ return out
clipseg/environment.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: clipseg-environment
2
+ channels:
3
+ - conda-forge
4
+ - pytorch
5
+ dependencies:
6
+ - numpy
7
+ - scipy
8
+ - matplotlib-base
9
+ - pip
10
+ - pip:
11
+ - --find-links https://download.pytorch.org/whl/torch_stable.html
12
+ - torch==1.10.0+cpu
13
+ - torchvision==0.11.1+cpu
14
+ - opencv-python
15
+ - git+https://github.com/openai/CLIP.git
clipseg/evaluation_utils.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.functional import Tensor
2
+ from general_utils import load_model
3
+ from torch.utils.data import DataLoader
4
+ import torch
5
+ import numpy as np
6
+
7
+ def denorm(img):
8
+
9
+ np_input = False
10
+ if isinstance(img, np.ndarray):
11
+ img = torch.from_numpy(img)
12
+ np_input = True
13
+
14
+ mean = torch.Tensor([0.485, 0.456, 0.406])
15
+ std = torch.Tensor([0.229, 0.224, 0.225])
16
+
17
+ img_denorm = (img*std[:,None,None]) + mean[:,None,None]
18
+
19
+ if np_input:
20
+ img_denorm = np.clip(img_denorm.numpy(), 0, 1)
21
+ else:
22
+ img_denorm = torch.clamp(img_denorm, 0, 1)
23
+
24
+ return img_denorm
25
+
26
+
27
+ def norm(img):
28
+ mean = torch.Tensor([0.485, 0.456, 0.406])
29
+ std = torch.Tensor([0.229, 0.224, 0.225])
30
+ return (img - mean[:,None,None]) / std[:,None,None]
31
+
32
+
33
+ def fast_iou_curve(p, g):
34
+
35
+ g = g[p.sort().indices]
36
+ p = torch.sigmoid(p.sort().values)
37
+
38
+ scores = []
39
+ vals = np.linspace(0, 1, 50)
40
+
41
+ for q in vals:
42
+
43
+ n = int(len(g) * q)
44
+
45
+ valid = torch.where(p > q)[0]
46
+ if len(valid) > 0:
47
+ n = int(valid[0])
48
+ else:
49
+ n = len(g)
50
+
51
+ fn = g[:n].sum()
52
+ tn = n - fn
53
+ tp = g[n:].sum()
54
+ fp = len(g) - n - tp
55
+
56
+ iou = tp / (tp + fn + fp)
57
+
58
+ precision = tp / (tp + fp)
59
+ recall = tp / (tp + fn)
60
+
61
+ scores += [iou]
62
+
63
+ return vals, scores
64
+
65
+
66
+ def fast_rp_curve(p, g):
67
+
68
+ g = g[p.sort().indices]
69
+ p = torch.sigmoid(p.sort().values)
70
+
71
+ precisions, recalls = [], []
72
+ vals = np.linspace(p.min(), p.max(), 250)
73
+
74
+ for q in p[::100000]:
75
+
76
+ n = int(len(g) * q)
77
+
78
+ valid = torch.where(p > q)[0]
79
+ if len(valid) > 0:
80
+ n = int(valid[0])
81
+ else:
82
+ n = len(g)
83
+
84
+ fn = g[:n].sum()
85
+ tn = n - fn
86
+ tp = g[n:].sum()
87
+ fp = len(g) - n - tp
88
+
89
+ iou = tp / (tp + fn + fp)
90
+
91
+ precision = tp / (tp + fp)
92
+ recall = tp / (tp + fn)
93
+
94
+ precisions += [precision]
95
+ recalls += [recall]
96
+
97
+ return recalls, precisions
98
+
99
+
100
+ # Image processing
101
+
102
+ def img_preprocess(batch, blur=0, grayscale=False, center_context=None, rect=False, rect_color=(255,0,0), rect_width=2,
103
+ brightness=1.0, bg_fac=1, colorize=False, outline=False, image_size=224):
104
+ import cv2
105
+
106
+ rw = rect_width
107
+
108
+ out = []
109
+ for img, mask in zip(batch[1], batch[2]):
110
+
111
+ img = img.cpu() if isinstance(img, torch.Tensor) else torch.from_numpy(img)
112
+ mask = mask.cpu() if isinstance(mask, torch.Tensor) else torch.from_numpy(mask)
113
+
114
+ img *= brightness
115
+ img_bl = img
116
+ if blur > 0: # best 5
117
+ img_bl = torch.from_numpy(cv2.GaussianBlur(img.permute(1,2,0).numpy(), (15, 15), blur)).permute(2,0,1)
118
+
119
+ if grayscale:
120
+ img_bl = img_bl[1][None]
121
+
122
+ #img_inp = img_ratio*img*mask + (1-img_ratio)*img_bl
123
+ # img_inp = img_ratio*img*mask + (1-img_ratio)*img_bl * (1-mask)
124
+ img_inp = img*mask + (bg_fac) * img_bl * (1-mask)
125
+
126
+ if rect:
127
+ _, bbox = crop_mask(img, mask, context=0.1)
128
+ img_inp[:, bbox[2]: bbox[3], max(0, bbox[0]-rw):bbox[0]+rw] = torch.tensor(rect_color)[:,None,None]
129
+ img_inp[:, bbox[2]: bbox[3], max(0, bbox[1]-rw):bbox[1]+rw] = torch.tensor(rect_color)[:,None,None]
130
+ img_inp[:, max(0, bbox[2]-1): bbox[2]+rw, bbox[0]:bbox[1]] = torch.tensor(rect_color)[:,None,None]
131
+ img_inp[:, max(0, bbox[3]-1): bbox[3]+rw, bbox[0]:bbox[1]] = torch.tensor(rect_color)[:,None,None]
132
+
133
+
134
+ if center_context is not None:
135
+ img_inp = object_crop(img_inp, mask, context=center_context, image_size=image_size)
136
+
137
+ if colorize:
138
+ img_gray = denorm(img)
139
+ img_gray = cv2.cvtColor(img_gray.permute(1,2,0).numpy(), cv2.COLOR_RGB2GRAY)
140
+ img_gray = torch.stack([torch.from_numpy(img_gray)]*3)
141
+ img_inp = torch.tensor([1,0.2,0.2])[:,None,None] * img_gray * mask + bg_fac * img_gray * (1-mask)
142
+ img_inp = norm(img_inp)
143
+
144
+ if outline:
145
+ cont = cv2.findContours(mask.byte().numpy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
146
+ outline_img = np.zeros(mask.shape, dtype=np.uint8)
147
+ cv2.drawContours(outline_img, cont[0], -1, thickness=5, color=(255, 255, 255))
148
+ outline_img = torch.stack([torch.from_numpy(outline_img)]*3).float() / 255.
149
+ img_inp = torch.tensor([1,0,0])[:,None,None] * outline_img + denorm(img_inp) * (1- outline_img)
150
+ img_inp = norm(img_inp)
151
+
152
+ out += [img_inp]
153
+
154
+ return torch.stack(out)
155
+
156
+
157
+ def object_crop(img, mask, context=0.0, square=False, image_size=224):
158
+ img_crop, bbox = crop_mask(img, mask, context=context, square=square)
159
+ img_crop = pad_to_square(img_crop, channel_dim=0)
160
+ img_crop = torch.nn.functional.interpolate(img_crop.unsqueeze(0), (image_size, image_size)).squeeze(0)
161
+ return img_crop
162
+
163
+
164
+ def crop_mask(img, mask, context=0.0, square=False):
165
+
166
+ assert img.shape[1:] == mask.shape
167
+
168
+ bbox = [mask.max(0).values.argmax(), mask.size(0) - mask.max(0).values.flip(0).argmax()]
169
+ bbox += [mask.max(1).values.argmax(), mask.size(1) - mask.max(1).values.flip(0).argmax()]
170
+ bbox = [int(x) for x in bbox]
171
+
172
+ width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0])
173
+
174
+ # square mask
175
+ if square:
176
+ bbox[0] = int(max(0, bbox[0] - context * height))
177
+ bbox[1] = int(min(mask.size(0), bbox[1] + context * height))
178
+ bbox[2] = int(max(0, bbox[2] - context * width))
179
+ bbox[3] = int(min(mask.size(1), bbox[3] + context * width))
180
+
181
+ width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0])
182
+ if height > width:
183
+ bbox[2] = int(max(0, (bbox[2] - 0.5*height)))
184
+ bbox[3] = bbox[2] + height
185
+ else:
186
+ bbox[0] = int(max(0, (bbox[0] - 0.5*width)))
187
+ bbox[1] = bbox[0] + width
188
+ else:
189
+ bbox[0] = int(max(0, bbox[0] - context * height))
190
+ bbox[1] = int(min(mask.size(0), bbox[1] + context * height))
191
+ bbox[2] = int(max(0, bbox[2] - context * width))
192
+ bbox[3] = int(min(mask.size(1), bbox[3] + context * width))
193
+
194
+ width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0])
195
+ img_crop = img[:, bbox[2]: bbox[3], bbox[0]: bbox[1]]
196
+ return img_crop, bbox
197
+
198
+
199
+ def pad_to_square(img, channel_dim=2, fill=0):
200
+ """
201
+
202
+
203
+ add padding such that a squared image is returned """
204
+
205
+ from torchvision.transforms.functional import pad
206
+
207
+ if channel_dim == 2:
208
+ img = img.permute(2, 0, 1)
209
+ elif channel_dim == 0:
210
+ pass
211
+ else:
212
+ raise ValueError('invalid channel_dim')
213
+
214
+ h, w = img.shape[1:]
215
+ pady1 = pady2 = padx1 = padx2 = 0
216
+
217
+ if h > w:
218
+ padx1 = (h - w) // 2
219
+ padx2 = h - w - padx1
220
+ elif w > h:
221
+ pady1 = (w - h) // 2
222
+ pady2 = w - h - pady1
223
+
224
+ img_padded = pad(img, padding=(padx1, pady1, padx2, pady2), padding_mode='constant')
225
+
226
+ if channel_dim == 2:
227
+ img_padded = img_padded.permute(1, 2, 0)
228
+
229
+ return img_padded
230
+
231
+
232
+ # qualitative
233
+
234
+ def split_sentence(inp, limit=9):
235
+ t_new, current_len = [], 0
236
+ for k, t in enumerate(inp.split(' ')):
237
+ current_len += len(t) + 1
238
+ t_new += [t+' ']
239
+ # not last
240
+ if current_len > limit and k != len(inp.split(' ')) - 1:
241
+ current_len = 0
242
+ t_new += ['\n']
243
+
244
+ t_new = ''.join(t_new)
245
+ return t_new
246
+
247
+
248
+ from matplotlib import pyplot as plt
249
+
250
+
251
+ def plot(imgs, *preds, labels=None, scale=1, cmap=plt.cm.magma, aps=None, gt_labels=None, vmax=None):
252
+
253
+ row_off = 0 if labels is None else 1
254
+ _, ax = plt.subplots(len(imgs) + row_off, 1 + len(preds), figsize=(scale * float(1 + 2*len(preds)), scale * float(len(imgs)*2)))
255
+ [a.axis('off') for a in ax.flatten()]
256
+
257
+ if labels is not None:
258
+ for j in range(len(labels)):
259
+ t_new = split_sentence(labels[j], limit=6)
260
+ ax[0, 1+ j].text(0.5, 0.1, t_new, ha='center', fontsize=3+ 10*scale)
261
+
262
+
263
+ for i in range(len(imgs)):
264
+ ax[i + row_off,0].imshow(imgs[i])
265
+ for j in range(len(preds)):
266
+ img = preds[j][i][0].detach().cpu().numpy()
267
+
268
+ if gt_labels is not None and labels[j] == gt_labels[i]:
269
+ print(j, labels[j], gt_labels[i])
270
+ edgecolor = 'red'
271
+ if aps is not None:
272
+ ax[i + row_off, 1 + j].text(30, 70, f'AP: {aps[i]:.3f}', color='red', fontsize=8)
273
+ else:
274
+ edgecolor = 'k'
275
+
276
+ rect = plt.Rectangle([0,0], img.shape[0], img.shape[1], facecolor="none",
277
+ edgecolor=edgecolor, linewidth=3)
278
+ ax[i + row_off,1 + j].add_patch(rect)
279
+
280
+ if vmax is None:
281
+ this_vmax = 1
282
+ elif vmax == 'per_prompt':
283
+ this_vmax = max([preds[j][_i][0].max() for _i in range(len(imgs))])
284
+ elif vmax == 'per_image':
285
+ this_vmax = max([preds[_j][i][0].max() for _j in range(len(preds))])
286
+
287
+ ax[i + row_off,1 + j].imshow(img, vmin=0, vmax=this_vmax, cmap=cmap)
288
+
289
+
290
+ # ax[i,1 + j].imshow(preds[j][i][0].detach().cpu().numpy(), vmin=preds[j].min(), vmax=preds[j].max())
291
+ plt.tight_layout()
292
+ plt.subplots_adjust(wspace=0.05, hspace=0.05)
clipseg/example_image.jpg ADDED
clipseg/experiments/ablation.yaml ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ configuration:
2
+ batch_size: 64
3
+ optimizer: torch.optim.AdamW
4
+
5
+ lr: 0.001
6
+
7
+ trainer: experiment_setup.train_loop
8
+ scorer: experiment_setup.score
9
+ model: models.clipseg.CLIPDensePredT
10
+
11
+ lr_scheduler: cosine
12
+ T_max: 20000
13
+ eta_min: 0.0001
14
+
15
+ max_iterations: 20000 # <-##########################################
16
+ val_interval: null
17
+
18
+ # dataset
19
+ dataset: datasets.phrasecut.PhraseCut # <-----------------
20
+ split_mode: pascal_test
21
+ split: train
22
+ mask: text_and_crop_blur_highlight352
23
+ image_size: 352
24
+ negative_prob: 0.2
25
+ mix_text_max: 0.5
26
+
27
+ # general
28
+ mix: True # <-----------------
29
+ prompt: shuffle+
30
+ norm_cond: True
31
+ mix_text_min: 0.0
32
+ with_visual: True
33
+
34
+ # model
35
+ version: 'ViT-B/16'
36
+ extract_layers: [3, 7, 9]
37
+ reduce_dim: 64
38
+ depth: 3
39
+ fix_shift: False # <-##########################################
40
+
41
+ loss: torch.nn.functional.binary_cross_entropy_with_logits
42
+ amp: True
43
+
44
+ test_configuration_common:
45
+ normalize: True
46
+ image_size: 352
47
+ batch_size: 32
48
+ sigmoid: True
49
+ split: test
50
+ label_support: True
51
+
52
+ test_configuration:
53
+
54
+ -
55
+ name: pc
56
+ metric: metrics.FixedIntervalMetrics
57
+ test_dataset: phrasecut
58
+ mask: text
59
+
60
+ -
61
+ name: pc-vis
62
+ metric: metrics.FixedIntervalMetrics
63
+ test_dataset: phrasecut
64
+ mask: crop_blur_highlight352
65
+ with_visual: True
66
+ visual_only: True
67
+
68
+
69
+ columns: [name,
70
+ pc_fgiou_best, pc_miou_best, pc_fgiou_0.5,
71
+ pc-vis_fgiou_best, pc-vis_miou_best, pc-vis_fgiou_0.5,
72
+ duration]
73
+
74
+
75
+ individual_configurations:
76
+
77
+ - {name: rd64-uni}
78
+ - {name: rd64-no-pretrain, not_pretrained: True, lr: 0.0003}
79
+ - {name: rd64-no-negatives, negative_prob: 0.0}
80
+ - {name: rd64-neg0.5, negative_prob: 0.5}
81
+ - {name: rd64-no-visual, with_visual: False, mix: False}
82
+ - {name: rd16-uni, reduce_dim: 16}
83
+ - {name: rd64-layer3, extract_layers: [3], depth: 1}
84
+ - {name: rd64-blur-highlight, mask: text_and_blur_highlight, test_configuration: {mask: blur_highlight}}
clipseg/experiments/coco.yaml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ configuration:
2
+ batch_size: 64
3
+ optimizer: torch.optim.AdamW
4
+
5
+ lr: 0.001
6
+
7
+ trainer: experiment_setup.train_loop
8
+ scorer: experiment_setup.score
9
+ model: models.clipseg.CLIPDensePredT
10
+
11
+ lr_scheduler: cosine
12
+ T_max: 20000
13
+ eta_min: 0.0001
14
+
15
+ max_iterations: 20000
16
+ val_interval: null
17
+
18
+ # dataset
19
+ dataset: datasets.coco_wrapper.COCOWrapper
20
+ # split_mode: pascal_test
21
+ split: train
22
+ mask: text_and_blur3_highlight01
23
+ image_size: 352
24
+ normalize: True
25
+ pre_crop_image_size: [sample, 1, 1.5]
26
+ aug: 1new
27
+
28
+ # general
29
+ mix: True
30
+ prompt: shuffle+
31
+ norm_cond: True
32
+ mix_text_min: 0.0
33
+
34
+ # model
35
+ out: 1
36
+ extract_layers: [3, 7, 9]
37
+ reduce_dim: 64
38
+ depth: 3
39
+ fix_shift: False
40
+
41
+ loss: torch.nn.functional.binary_cross_entropy_with_logits
42
+ amp: True
43
+
44
+ test_configuration_common:
45
+ normalize: True
46
+ image_size: 352
47
+ # max_iterations: 10
48
+ batch_size: 8
49
+ sigmoid: True
50
+ test_dataset: coco
51
+ metric: metrics.FixedIntervalMetrics
52
+
53
+ test_configuration:
54
+
55
+ -
56
+ name: coco_t
57
+ mask: text
58
+
59
+ -
60
+ name: coco_h
61
+ mask: blur3_highlight01
62
+
63
+ -
64
+ name: coco_h2
65
+ mask: crop_blur_highlight352
66
+
67
+
68
+ columns: [i, name,
69
+ coco_t_fgiou_best, coco_t_miou_best, coco_t_fgiou_0.5,
70
+ coco_h_fgiou_best, coco_h_miou_best, coco_h_fgiou_0.5,
71
+ coco_h2_fgiou_best, coco_h2_miou_best, coco_h2_fgiou_0.5, coco_h2_fgiou_best_t,
72
+ train_loss, duration, date
73
+ ]
74
+
75
+ individual_configurations:
76
+
77
+
78
+ - {name: rd64-7K-vit16-cbh-coco-0, version: 'ViT-B/16', fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
79
+ - {name: rd64-7K-vit16-cbh-coco-1, version: 'ViT-B/16', fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
80
+ - {name: rd64-7K-vit16-cbh-coco-2, version: 'ViT-B/16', fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
81
+ - {name: rd64-7K-vit16-cbh-coco-3, version: 'ViT-B/16', fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
82
+
83
+
84
+ - {name: rd64-7K-vit16-cbh-neg0.2-coco-0, version: 'ViT-B/16', negative_prob: 0.2, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
85
+ - {name: rd64-7K-vit16-cbh-neg0.2-coco-1, version: 'ViT-B/16', negative_prob: 0.2, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
86
+ - {name: rd64-7K-vit16-cbh-neg0.2-coco-2, version: 'ViT-B/16', negative_prob: 0.2, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
87
+ - {name: rd64-7K-vit16-cbh-neg0.2-coco-3, version: 'ViT-B/16', negative_prob: 0.2, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
88
+
89
+
90
+ # ViT
91
+ - {name: vit64-7K-vit16-cbh-coco-0, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001}
92
+ - {name: vit64-7K-vit16-cbh-coco-1, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001}
93
+ - {name: vit64-7K-vit16-cbh-coco-2, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001}
94
+ - {name: vit64-7K-vit16-cbh-coco-3, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001}
95
+
96
+
97
+ # BASELINE
98
+ - {name: bl64-7K-vit16-cbh-neg0.2-coco-0, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
99
+ - {name: bl64-7K-vit16-cbh-neg0.2-coco-1, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
100
+ - {name: bl64-7K-vit16-cbh-neg0.2-coco-2, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
101
+ - {name: bl64-7K-vit16-cbh-neg0.2-coco-3, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
clipseg/experiments/pascal_1shot.yaml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ configuration:
2
+ batch_size: 64
3
+ optimizer: torch.optim.AdamW
4
+
5
+ lr: 0.001
6
+
7
+ trainer: experiment_setup.train_loop
8
+ scorer: experiment_setup.score
9
+ model: models.clipseg.CLIPDensePredT
10
+
11
+ lr_scheduler: cosine
12
+ T_max: 20000
13
+ eta_min: 0.0001
14
+
15
+ max_iterations: 20000 # <-##########################################
16
+ val_interval: null
17
+
18
+ # dataset
19
+ dataset: datasets.phrasecut.PhraseCut
20
+ split_mode: pascal_test
21
+ mode: train
22
+ mask: text_and_crop_blur_highlight352
23
+ image_size: 352
24
+ normalize: True
25
+ pre_crop_image_size: [sample, 1, 1.5]
26
+ aug: 1new
27
+ with_visual: True
28
+ split: train
29
+
30
+ # general
31
+ mix: True
32
+ prompt: shuffle+
33
+ norm_cond: True
34
+ mix_text_min: 0.0
35
+
36
+ # model
37
+ out: 1
38
+ version: 'ViT-B/16'
39
+ extract_layers: [3, 7, 9]
40
+ reduce_dim: 64
41
+ depth: 3
42
+
43
+ loss: torch.nn.functional.binary_cross_entropy_with_logits
44
+ amp: True
45
+
46
+ test_configuration_common:
47
+ normalize: True
48
+ image_size: 352
49
+ metric: metrics.FixedIntervalMetrics
50
+ batch_size: 1
51
+ test_dataset: pascal
52
+ sigmoid: True
53
+ # max_iterations: 250
54
+
55
+ test_configuration:
56
+
57
+ -
58
+ name: pas_t
59
+ mask: text
60
+
61
+ -
62
+ name: pas_h
63
+ mask: blur3_highlight01
64
+
65
+ -
66
+ name: pas_h2
67
+ mask: crop_blur_highlight352
68
+
69
+
70
+ columns: [name,
71
+ pas_t_fgiou_best, pas_t_miou_best, pas_t_fgiou_ct,
72
+ pas_h_fgiou_best, pas_h_miou_best, pas_h_fgiou_ct,
73
+ pas_h2_fgiou_best, pas_h2_miou_best, pas_h2_fgiou_ct, pas_h2_fgiou_best_t,
74
+ train_loss, duration, date
75
+ ]
76
+
77
+ individual_configurations:
78
+
79
+ - {name: rd64-uni-phrasepas5i-0, remove_classes: [pas5i, 0], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [0], custom_threshold: 0.24}}
80
+ - {name: rd64-uni-phrasepas5i-1, remove_classes: [pas5i, 1], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [1], custom_threshold: 0.24}}
81
+ - {name: rd64-uni-phrasepas5i-2, remove_classes: [pas5i, 2], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [2], custom_threshold: 0.24}}
82
+ - {name: rd64-uni-phrasepas5i-3, remove_classes: [pas5i, 3], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [3], custom_threshold: 0.24}}
83
+
84
+
85
+ - {name: rd64-phrasepas5i-0, remove_classes: [pas5i, 0], negative_prob: 0.0, test_configuration: {splits: [0], custom_threshold: 0.28}}
86
+ - {name: rd64-phrasepas5i-1, remove_classes: [pas5i, 1], negative_prob: 0.0, test_configuration: {splits: [1], custom_threshold: 0.28}}
87
+ - {name: rd64-phrasepas5i-2, remove_classes: [pas5i, 2], negative_prob: 0.0, test_configuration: {splits: [2], custom_threshold: 0.28}}
88
+ - {name: rd64-phrasepas5i-3, remove_classes: [pas5i, 3], negative_prob: 0.0, test_configuration: {splits: [3], custom_threshold: 0.28}}
89
+
90
+
91
+ # baseline
92
+ - {name: bl64-phrasepas5i-0, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 0], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [0], custom_threshold: 0.24}}
93
+ - {name: bl64-phrasepas5i-1, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 1], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [1], custom_threshold: 0.24}}
94
+ - {name: bl64-phrasepas5i-2, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 2], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [2], custom_threshold: 0.24}}
95
+ - {name: bl64-phrasepas5i-3, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 3], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [3], custom_threshold: 0.24}}
96
+
97
+ # ViT
98
+ - {name: vit64-uni-phrasepas5i-0, remove_classes: [pas5i, 0], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [0], custom_threshold: 0.02}}
99
+ - {name: vit64-uni-phrasepas5i-1, remove_classes: [pas5i, 1], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [1], custom_threshold: 0.02}}
100
+ - {name: vit64-uni-phrasepas5i-2, remove_classes: [pas5i, 2], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [2], custom_threshold: 0.02}}
101
+ - {name: vit64-uni-phrasepas5i-3, remove_classes: [pas5i, 3], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [3], custom_threshold: 0.02}}
clipseg/experiments/phrasecut.yaml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ configuration:
2
+ batch_size: 64
3
+ optimizer: torch.optim.AdamW
4
+
5
+ lr: 0.001
6
+
7
+ trainer: experiment_setup.train_loop
8
+ scorer: experiment_setup.score
9
+ model: models.clipseg.CLIPDensePredT
10
+
11
+ lr_scheduler: cosine
12
+ T_max: 20000
13
+ eta_min: 0.0001
14
+
15
+ max_iterations: 20000
16
+ val_interval: null
17
+
18
+ # dataset
19
+ dataset: datasets.phrasecut.PhraseCut # <-----------------
20
+ split_mode: pascal_test
21
+ split: train
22
+ mask: text_and_crop_blur_highlight352
23
+ image_size: 352
24
+ normalize: True
25
+ pre_crop_image_size: [sample, 1, 1.5]
26
+ aug: 1new
27
+
28
+ # general
29
+ mix: False # <-----------------
30
+ prompt: shuffle+
31
+ norm_cond: True
32
+ mix_text_min: 0.0
33
+
34
+ # model
35
+ out: 1
36
+ extract_layers: [3, 7, 9]
37
+ reduce_dim: 64
38
+ depth: 3
39
+ fix_shift: False
40
+
41
+ loss: torch.nn.functional.binary_cross_entropy_with_logits
42
+ amp: True
43
+
44
+ test_configuration_common:
45
+ normalize: True
46
+ image_size: 352
47
+ batch_size: 32
48
+ # max_iterations: 5
49
+ # max_iterations: 150
50
+
51
+ test_configuration:
52
+
53
+ -
54
+ name: pc # old: phrasecut
55
+ metric: metrics.FixedIntervalMetrics
56
+ test_dataset: phrasecut
57
+ split: test
58
+ mask: text
59
+ label_support: True
60
+ sigmoid: True
61
+
62
+
63
+ columns: [i, name, pc_miou_0.3, pc_fgiou_0.3, pc_fgiou_0.5, pc_ap, duration, date]
64
+
65
+
66
+ individual_configurations:
67
+
68
+ # important ones
69
+
70
+
71
+ - {name: rd64-uni, version: 'ViT-B/16', reduce_dim: 64, with_visual: True, negative_prob: 0.2, mix: True, mix_text_max: 0.5}
72
+
73
+ # this was accedentally trained using old mask
74
+ - {name: rd128-vit16-phrasecut, version: 'ViT-B/16', reduce_dim: 128, mask: text_and_blur3_highlight01}
75
+ - {name: rd64-uni-novis, version: 'ViT-B/16', reduce_dim: 64, with_visual: False, negative_prob: 0.2, mix: False}
76
+ # this was accedentally trained using old mask
77
+ - {name: baseline3-vit16-phrasecut, model: models.clipseg.CLIPDenseBaseline, version: 'ViT-B/16', reduce_dim: 64, reduce2_dim: 64, mask: text_and_blur3_highlight01}
78
+
79
+ - {name: vit64-uni, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, reduce_dim: 64, with_visual: True, only_visual: True, negative_prob: 0.2, mask: crop_blur_highlight352, lr: 0.0003}
80
+ - {name: vit64-uni-novis, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, with_visual: False, reduce_dim: 64, lr: 0.0001}
clipseg/general_utils.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import inspect
3
+ import torch
4
+ import os
5
+ import sys
6
+ import yaml
7
+ from shutil import copy, copytree
8
+ from os.path import join, dirname, realpath, expanduser, isfile, isdir, basename
9
+
10
+
11
+ class Logger(object):
12
+
13
+ def __getattr__(self, k):
14
+ return print
15
+
16
+ log = Logger()
17
+
18
+ def training_config_from_cli_args():
19
+ experiment_name = sys.argv[1]
20
+ experiment_id = int(sys.argv[2])
21
+
22
+ yaml_config = yaml.load(open(f'experiments/{experiment_name}'), Loader=yaml.SafeLoader)
23
+
24
+ config = yaml_config['configuration']
25
+ config = {**config, **yaml_config['individual_configurations'][experiment_id]}
26
+ config = AttributeDict(config)
27
+ return config
28
+
29
+
30
+ def score_config_from_cli_args():
31
+ experiment_name = sys.argv[1]
32
+ experiment_id = int(sys.argv[2])
33
+
34
+
35
+ yaml_config = yaml.load(open(f'experiments/{experiment_name}'), Loader=yaml.SafeLoader)
36
+
37
+ config = yaml_config['test_configuration_common']
38
+
39
+ if type(yaml_config['test_configuration']) == list:
40
+ test_id = int(sys.argv[3])
41
+ config = {**config, **yaml_config['test_configuration'][test_id]}
42
+ else:
43
+ config = {**config, **yaml_config['test_configuration']}
44
+
45
+ if 'test_configuration' in yaml_config['individual_configurations'][experiment_id]:
46
+ config = {**config, **yaml_config['individual_configurations'][experiment_id]['test_configuration']}
47
+
48
+ train_checkpoint_id = yaml_config['individual_configurations'][experiment_id]['name']
49
+
50
+ config = AttributeDict(config)
51
+ return config, train_checkpoint_id
52
+
53
+
54
+ def get_from_repository(local_name, repo_files, integrity_check=None, repo_dir='~/dataset_repository',
55
+ local_dir='~/datasets'):
56
+ """ copies files from repository to local folder.
57
+
58
+ repo_files: list of filenames or list of tuples [filename, target path]
59
+
60
+ e.g. get_from_repository('MyDataset', [['data/dataset1.tar', 'other/path/ds03.tar'])
61
+ will create a folder 'MyDataset' in local_dir, and extract the content of
62
+ '<repo_dir>/data/dataset1.tar' to <local_dir>/MyDataset/other/path.
63
+ """
64
+
65
+ local_dir = realpath(join(expanduser(local_dir), local_name))
66
+
67
+ dataset_exists = True
68
+
69
+ # check if folder is available
70
+ if not isdir(local_dir):
71
+ dataset_exists = False
72
+
73
+ if integrity_check is not None:
74
+ try:
75
+ integrity_ok = integrity_check(local_dir)
76
+ except BaseException:
77
+ integrity_ok = False
78
+
79
+ if integrity_ok:
80
+ log.hint('Passed custom integrity check')
81
+ else:
82
+ log.hint('Custom integrity check failed')
83
+
84
+ dataset_exists = dataset_exists and integrity_ok
85
+
86
+ if not dataset_exists:
87
+
88
+ repo_dir = realpath(expanduser(repo_dir))
89
+
90
+ for i, filename in enumerate(repo_files):
91
+
92
+ if type(filename) == str:
93
+ origin, target = filename, filename
94
+ archive_target = join(local_dir, basename(origin))
95
+ extract_target = join(local_dir)
96
+ else:
97
+ origin, target = filename
98
+ archive_target = join(local_dir, dirname(target), basename(origin))
99
+ extract_target = join(local_dir, dirname(target))
100
+
101
+ archive_origin = join(repo_dir, origin)
102
+
103
+ log.hint(f'copy: {archive_origin} to {archive_target}')
104
+
105
+ # make sure the path exists
106
+ os.makedirs(dirname(archive_target), exist_ok=True)
107
+
108
+ if os.path.isfile(archive_target):
109
+ # only copy if size differs
110
+ if os.path.getsize(archive_target) != os.path.getsize(archive_origin):
111
+ log.hint(f'file exists but filesize differs: target {os.path.getsize(archive_target)} vs. origin {os.path.getsize(archive_origin)}')
112
+ copy(archive_origin, archive_target)
113
+ else:
114
+ copy(archive_origin, archive_target)
115
+
116
+ extract_archive(archive_target, extract_target, noarchive_ok=True)
117
+
118
+ # concurrent processes might have deleted the file
119
+ if os.path.isfile(archive_target):
120
+ os.remove(archive_target)
121
+
122
+
123
+ def extract_archive(filename, target_folder=None, noarchive_ok=False):
124
+ from subprocess import run, PIPE
125
+
126
+ if filename.endswith('.tgz') or filename.endswith('.tar'):
127
+ command = f'tar -xf {filename}'
128
+ command += f' -C {target_folder}' if target_folder is not None else ''
129
+ elif filename.endswith('.tar.gz'):
130
+ command = f'tar -xzf {filename}'
131
+ command += f' -C {target_folder}' if target_folder is not None else ''
132
+ elif filename.endswith('zip'):
133
+ command = f'unzip {filename}'
134
+ command += f' -d {target_folder}' if target_folder is not None else ''
135
+ else:
136
+ if noarchive_ok:
137
+ return
138
+ else:
139
+ raise ValueError(f'unsuppored file ending of {filename}')
140
+
141
+ log.hint(command)
142
+ result = run(command.split(), stdout=PIPE, stderr=PIPE)
143
+ if result.returncode != 0:
144
+ print(result.stdout, result.stderr)
145
+
146
+
147
+ class AttributeDict(dict):
148
+ """
149
+ An extended dictionary that allows access to elements as atttributes and counts
150
+ these accesses. This way, we know if some attributes were never used.
151
+ """
152
+
153
+ def __init__(self, *args, **kwargs):
154
+ from collections import Counter
155
+ super().__init__(*args, **kwargs)
156
+ self.__dict__['counter'] = Counter()
157
+
158
+ def __getitem__(self, k):
159
+ self.__dict__['counter'][k] += 1
160
+ return super().__getitem__(k)
161
+
162
+ def __getattr__(self, k):
163
+ self.__dict__['counter'][k] += 1
164
+ return super().get(k)
165
+
166
+ def __setattr__(self, k, v):
167
+ return super().__setitem__(k, v)
168
+
169
+ def __delattr__(self, k, v):
170
+ return super().__delitem__(k, v)
171
+
172
+ def unused_keys(self, exceptions=()):
173
+ return [k for k in super().keys() if self.__dict__['counter'][k] == 0 and k not in exceptions]
174
+
175
+ def assume_no_unused_keys(self, exceptions=()):
176
+ if len(self.unused_keys(exceptions=exceptions)) > 0:
177
+ log.warning('Unused keys:', self.unused_keys(exceptions=exceptions))
178
+
179
+
180
+ def get_attribute(name):
181
+ import importlib
182
+
183
+ if name is None:
184
+ raise ValueError('The provided attribute is None')
185
+
186
+ name_split = name.split('.')
187
+ mod = importlib.import_module('.'.join(name_split[:-1]))
188
+ return getattr(mod, name_split[-1])
189
+
190
+
191
+
192
+ def filter_args(input_args, default_args):
193
+
194
+ updated_args = {k: input_args[k] if k in input_args else v for k, v in default_args.items()}
195
+ used_args = {k: v for k, v in input_args.items() if k in default_args}
196
+ unused_args = {k: v for k, v in input_args.items() if k not in default_args}
197
+
198
+ return AttributeDict(updated_args), AttributeDict(used_args), AttributeDict(unused_args)
199
+
200
+
201
+ def load_model(checkpoint_id, weights_file=None, strict=True, model_args='from_config', with_config=False):
202
+
203
+ config = json.load(open(join('logs', checkpoint_id, 'config.json')))
204
+
205
+ if model_args != 'from_config' and type(model_args) != dict:
206
+ raise ValueError('model_args must either be "from_config" or a dictionary of values')
207
+
208
+ model_cls = get_attribute(config['model'])
209
+
210
+ # load model
211
+ if model_args == 'from_config':
212
+ _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters)
213
+
214
+ model = model_cls(**model_args)
215
+
216
+ if weights_file is None:
217
+ weights_file = realpath(join('logs', checkpoint_id, 'weights.pth'))
218
+ else:
219
+ weights_file = realpath(join('logs', checkpoint_id, weights_file))
220
+
221
+ if isfile(weights_file):
222
+ weights = torch.load(weights_file)
223
+ for _, w in weights.items():
224
+ assert not torch.any(torch.isnan(w)), 'weights contain NaNs'
225
+ model.load_state_dict(weights, strict=strict)
226
+ else:
227
+ raise FileNotFoundError(f'model checkpoint {weights_file} was not found')
228
+
229
+ if with_config:
230
+ return model, config
231
+
232
+ return model
233
+
234
+
235
+ class TrainingLogger(object):
236
+
237
+ def __init__(self, model, log_dir, config=None, *args):
238
+ super().__init__()
239
+ self.model = model
240
+ self.base_path = join(f'logs/{log_dir}') if log_dir is not None else None
241
+
242
+ os.makedirs('logs/', exist_ok=True)
243
+ os.makedirs(self.base_path, exist_ok=True)
244
+
245
+ if config is not None:
246
+ json.dump(config, open(join(self.base_path, 'config.json'), 'w'))
247
+
248
+ def iter(self, i, **kwargs):
249
+ if i % 100 == 0 and 'loss' in kwargs:
250
+ loss = kwargs['loss']
251
+ print(f'iteration {i}: loss {loss:.4f}')
252
+
253
+ def save_weights(self, only_trainable=False, weight_file='weights.pth'):
254
+ if self.model is None:
255
+ raise AttributeError('You need to provide a model reference when initializing TrainingTracker to save weights.')
256
+
257
+ weights_path = join(self.base_path, weight_file)
258
+
259
+ weight_dict = self.model.state_dict()
260
+
261
+ if only_trainable:
262
+ weight_dict = {n: weight_dict[n] for n, p in self.model.named_parameters() if p.requires_grad}
263
+
264
+ torch.save(weight_dict, weights_path)
265
+ log.info(f'Saved weights to {weights_path}')
266
+
267
+ def __enter__(self):
268
+ return self
269
+
270
+ def __exit__(self, type, value, traceback):
271
+ """ automatically stop processes if used in a context manager """
272
+ pass
clipseg/metrics.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.functional import Tensor
2
+ from general_utils import log
3
+ from collections import defaultdict
4
+ import numpy as np
5
+
6
+ import torch
7
+ from torch.nn import functional as nnf
8
+
9
+
10
+ class BaseMetric(object):
11
+
12
+ def __init__(self, metric_names, pred_range=None, gt_index=0, pred_index=0, eval_intermediate=True,
13
+ eval_validation=True):
14
+ self._names = tuple(metric_names)
15
+ self._eval_intermediate = eval_intermediate
16
+ self._eval_validation = eval_validation
17
+
18
+ self._pred_range = pred_range
19
+ self._pred_index = pred_index
20
+ self._gt_index = gt_index
21
+
22
+ self.predictions = []
23
+ self.ground_truths = []
24
+
25
+ def eval_intermediate(self):
26
+ return self._eval_intermediate
27
+
28
+ def eval_validation(self):
29
+ return self._eval_validation
30
+
31
+ def names(self):
32
+ return self._names
33
+
34
+ def add(self, predictions, ground_truth):
35
+ raise NotImplementedError
36
+
37
+ def value(self):
38
+ raise NotImplementedError
39
+
40
+ def scores(self):
41
+ # similar to value but returns dict
42
+ value = self.value()
43
+ if type(value) == dict:
44
+ return value
45
+ else:
46
+ assert type(value) in {list, tuple}
47
+ return list(zip(self.names(), self.value()))
48
+
49
+ def _get_pred_gt(self, predictions, ground_truth):
50
+ pred = predictions[self._pred_index]
51
+ gt = ground_truth[self._gt_index]
52
+
53
+ if self._pred_range is not None:
54
+ pred = pred[:, self._pred_range[0]: self._pred_range[1]]
55
+
56
+ return pred, gt
57
+
58
+
59
+ class FixedIntervalMetrics(BaseMetric):
60
+
61
+ def __init__(self, sigmoid=False, ignore_mask=False, resize_to=None,
62
+ resize_pred=None, n_values=51, custom_threshold=None):
63
+
64
+
65
+ super().__init__(('ap', 'best_fgiou', 'best_miou', 'fgiou0.5', 'fgiou0.1', 'mean_iou_0p5', 'mean_iou_0p1', 'best_biniou', 'biniou_0.5', 'fgiou_thresh'))
66
+ self.intersections = []
67
+ self.unions = []
68
+ # self.threshold = threshold
69
+ self.sigmoid = sigmoid
70
+ self.resize_to = resize_to
71
+ self.resize_pred = resize_pred # resize prediction to match ground truth
72
+ self.class_count = defaultdict(lambda: 0)
73
+ self.per_class = defaultdict(lambda : [0,0])
74
+ self.ignore_mask = ignore_mask
75
+ self.custom_threshold = custom_threshold
76
+
77
+ self.scores_ap = []
78
+ self.scores_iou = []
79
+ self.gts, self.preds = [], []
80
+ self.classes = []
81
+
82
+ # [1:-1] ignores 0 and 1
83
+ self.threshold_values = np.linspace(0, 1, n_values)[1:-1]
84
+
85
+ self.metrics = dict(tp=[], fp=[], fn=[], tn=[])
86
+
87
+ def add(self, pred, gt):
88
+
89
+ pred_batch = pred[0].cpu()
90
+
91
+ if self.sigmoid:
92
+ pred_batch = torch.sigmoid(pred_batch)
93
+
94
+ gt_batch = gt[0].cpu()
95
+ mask_batch = gt[1] if len(gt) > 1 and not self.ignore_mask and gt[1].numel() > 0 else ([None] * len(pred_batch))
96
+ cls_batch = gt[2] if len(gt) > 2 else [None] * len(pred_batch)
97
+
98
+ if self.resize_to is not None:
99
+ gt_batch = nnf.interpolate(gt_batch, self.resize_to, mode='nearest')
100
+ pred_batch = nnf.interpolate(pred_batch, self.resize_to, mode='bilinear', align_corners=False)
101
+
102
+ if isinstance(cls_batch, torch.Tensor):
103
+ cls_batch = cls_batch.cpu().numpy().tolist()
104
+
105
+ assert len(gt_batch) == len(pred_batch) == len(cls_batch), f'{len(gt_batch)} {len(pred_batch)} {len(cls_batch)}'
106
+
107
+ for predictions, ground_truth, mask, cls in zip(pred_batch, gt_batch, mask_batch, cls_batch):
108
+
109
+ if self.resize_pred:
110
+ predictions = nnf.interpolate(predictions.unsqueeze(0).float(), size=ground_truth.size()[-2:], mode='bilinear', align_corners=True)
111
+
112
+ p = predictions.flatten()
113
+ g = ground_truth.flatten()
114
+
115
+ assert len(p) == len(g)
116
+
117
+ if mask is not None:
118
+ m = mask.flatten().bool()
119
+ p = p[m]
120
+ g = g[m]
121
+
122
+ p_sorted = p.sort()
123
+ p = p_sorted.values
124
+ g = g[p_sorted.indices]
125
+
126
+ tps, fps, fns, tns = [], [], [], []
127
+ for thresh in self.threshold_values:
128
+
129
+ valid = torch.where(p > thresh)[0]
130
+ if len(valid) > 0:
131
+ n = int(valid[0])
132
+ else:
133
+ n = len(g)
134
+
135
+ fn = int(g[:n].sum())
136
+ tp = int(g[n:].sum())
137
+ fns += [fn]
138
+ tns += [n - fn]
139
+ tps += [tp]
140
+ fps += [len(g) - n - tp]
141
+
142
+ self.metrics['tp'] += [tps]
143
+ self.metrics['fp'] += [fps]
144
+ self.metrics['fn'] += [fns]
145
+ self.metrics['tn'] += [tns]
146
+
147
+ self.classes += [cls.item() if isinstance(cls, torch.Tensor) else cls]
148
+
149
+ def value(self):
150
+
151
+ import time
152
+ t_start = time.time()
153
+
154
+ if set(self.classes) == set([None]):
155
+ all_classes = None
156
+ log.warning('classes were not provided, cannot compute mIoU')
157
+ else:
158
+ all_classes = set(int(c) for c in self.classes)
159
+ # log.info(f'compute metrics for {len(all_classes)} classes')
160
+
161
+ summed = {k: [sum([self.metrics[k][i][j]
162
+ for i in range(len(self.metrics[k]))])
163
+ for j in range(len(self.threshold_values))]
164
+ for k in self.metrics.keys()}
165
+
166
+ if all_classes is not None:
167
+
168
+ assert len(self.classes) == len(self.metrics['tp']) == len(self.metrics['fn'])
169
+ # group by class
170
+ metrics_by_class = {c: {k: [] for k in self.metrics.keys()} for c in all_classes}
171
+ for i in range(len(self.metrics['tp'])):
172
+ for k in self.metrics.keys():
173
+ metrics_by_class[self.classes[i]][k] += [self.metrics[k][i]]
174
+
175
+ # sum over all instances within the classes
176
+ summed_by_cls = {k: {c: np.array(metrics_by_class[c][k]).sum(0).tolist() for c in all_classes} for k in self.metrics.keys()}
177
+
178
+
179
+ # Compute average precision
180
+
181
+ assert (np.array(summed['fp']) + np.array(summed['tp']) ).sum(), 'no predictions is made'
182
+
183
+ # only consider values where a prediction is made
184
+ precisions = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j]) for j in range(len(self.threshold_values))
185
+ if summed['tp'][j] + summed['fp'][j] > 0]
186
+ recalls = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fn'][j]) for j in range(len(self.threshold_values))
187
+ if summed['tp'][j] + summed['fp'][j] > 0]
188
+
189
+ # remove duplicate recall-precision-pairs (and sort by recall value)
190
+ recalls, precisions = zip(*sorted(list(set(zip(recalls, precisions))), key=lambda x: x[0]))
191
+
192
+ from scipy.integrate import simps
193
+ ap = simps(precisions, recalls)
194
+
195
+ # Compute best IoU
196
+ fgiou_scores = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j] + summed['fn'][j]) for j in range(len(self.threshold_values))]
197
+
198
+ biniou_scores = [
199
+ 0.5*(summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j] + summed['fn'][j])) +
200
+ 0.5*(summed['tn'][j] / (1 + summed['tn'][j] + summed['fn'][j] + summed['fp'][j]))
201
+ for j in range(len(self.threshold_values))
202
+ ]
203
+
204
+ index_0p5 = self.threshold_values.tolist().index(0.5)
205
+ index_0p1 = self.threshold_values.tolist().index(0.1)
206
+ index_0p2 = self.threshold_values.tolist().index(0.2)
207
+ index_0p3 = self.threshold_values.tolist().index(0.3)
208
+
209
+ if self.custom_threshold is not None:
210
+ index_ct = self.threshold_values.tolist().index(self.custom_threshold)
211
+
212
+ if all_classes is not None:
213
+ # mean IoU
214
+ mean_ious = [np.mean([summed_by_cls['tp'][c][j] / (1 + summed_by_cls['tp'][c][j] + summed_by_cls['fp'][c][j] + summed_by_cls['fn'][c][j])
215
+ for c in all_classes])
216
+ for j in range(len(self.threshold_values))]
217
+
218
+ mean_iou_dict = {
219
+ 'miou_best': max(mean_ious) if all_classes is not None else None,
220
+ 'miou_0.5': mean_ious[index_0p5] if all_classes is not None else None,
221
+ 'miou_0.1': mean_ious[index_0p1] if all_classes is not None else None,
222
+ 'miou_0.2': mean_ious[index_0p2] if all_classes is not None else None,
223
+ 'miou_0.3': mean_ious[index_0p3] if all_classes is not None else None,
224
+ 'miou_best_t': self.threshold_values[np.argmax(mean_ious)],
225
+ 'mean_iou_ct': mean_ious[index_ct] if all_classes is not None and self.custom_threshold is not None else None,
226
+ 'mean_iou_scores': mean_ious,
227
+ }
228
+
229
+ print(f'metric computation on {(len(all_classes) if all_classes is not None else "no")} classes took {time.time() - t_start:.1f}s')
230
+
231
+ return {
232
+ 'ap': ap,
233
+
234
+ # fgiou
235
+ 'fgiou_best': max(fgiou_scores),
236
+ 'fgiou_0.5': fgiou_scores[index_0p5],
237
+ 'fgiou_0.1': fgiou_scores[index_0p1],
238
+ 'fgiou_0.2': fgiou_scores[index_0p2],
239
+ 'fgiou_0.3': fgiou_scores[index_0p3],
240
+ 'fgiou_best_t': self.threshold_values[np.argmax(fgiou_scores)],
241
+
242
+ # mean iou
243
+
244
+
245
+ # biniou
246
+ 'biniou_best': max(biniou_scores),
247
+ 'biniou_0.5': biniou_scores[index_0p5],
248
+ 'biniou_0.1': biniou_scores[index_0p1],
249
+ 'biniou_0.2': biniou_scores[index_0p2],
250
+ 'biniou_0.3': biniou_scores[index_0p3],
251
+ 'biniou_best_t': self.threshold_values[np.argmax(biniou_scores)],
252
+
253
+ # custom threshold
254
+ 'fgiou_ct': fgiou_scores[index_ct] if self.custom_threshold is not None else None,
255
+ 'biniou_ct': biniou_scores[index_ct] if self.custom_threshold is not None else None,
256
+ 'ct': self.custom_threshold,
257
+
258
+ # statistics
259
+ 'fgiou_scores': fgiou_scores,
260
+ 'biniou_scores': biniou_scores,
261
+ 'precision_recall_curve': sorted(list(set(zip(recalls, precisions)))),
262
+ 'summed_statistics': summed,
263
+ 'summed_by_cls_statistics': summed_by_cls,
264
+
265
+ **mean_iou_dict
266
+ }
267
+
268
+ # ('ap', 'best_fgiou', 'best_miou', 'fgiou0.5', 'fgiou0.1', 'mean_iou_0p5', 'mean_iou_0p1', 'best_biniou', 'biniou_0.5', 'fgiou_thresh'
269
+
270
+ # return ap, best_fgiou, best_mean_iou, iou_0p5, iou_0p1, mean_iou_0p5, mean_iou_0p1, best_biniou, biniou0p5, best_fgiou_thresh, {'summed': summed, 'summed_by_cls': summed_by_cls}
271
+
clipseg/models/clipseg.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from os.path import basename, dirname, join, isfile
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as nnf
6
+ from torch.nn.modules.activation import ReLU
7
+
8
+
9
+ def precompute_clip_vectors():
10
+
11
+ from trails.initialization import init_dataset
12
+ lvis = init_dataset('LVIS_OneShot3', split='train', mask='text_label', image_size=224, aug=1, normalize=True,
13
+ reduce_factor=None, add_bar=False, negative_prob=0.5)
14
+
15
+ all_names = list(lvis.category_names.values())
16
+
17
+ import clip
18
+ from models.clip_prompts import imagenet_templates
19
+ clip_model = clip.load("ViT-B/32", device='cuda', jit=False)[0]
20
+ prompt_vectors = {}
21
+ for name in all_names[:100]:
22
+ with torch.no_grad():
23
+ conditionals = [t.format(name).replace('_', ' ') for t in imagenet_templates]
24
+ text_tokens = clip.tokenize(conditionals).cuda()
25
+ cond = clip_model.encode_text(text_tokens).cpu()
26
+
27
+ for cond, vec in zip(conditionals, cond):
28
+ prompt_vectors[cond] = vec.cpu()
29
+
30
+ import pickle
31
+
32
+ pickle.dump(prompt_vectors, open('precomputed_prompt_vectors.pickle', 'wb'))
33
+
34
+
35
+ def get_prompt_list(prompt):
36
+ if prompt == 'plain':
37
+ return ['{}']
38
+ elif prompt == 'fixed':
39
+ return ['a photo of a {}.']
40
+ elif prompt == 'shuffle':
41
+ return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.']
42
+ elif prompt == 'shuffle+':
43
+ return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.',
44
+ 'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.',
45
+ 'a bad photo of a {}.', 'a photo of the {}.']
46
+ elif prompt == 'shuffle_clip':
47
+ from models.clip_prompts import imagenet_templates
48
+ return imagenet_templates
49
+ else:
50
+ raise ValueError('Invalid value for prompt')
51
+
52
+
53
+ def forward_multihead_attention(x, b, with_aff=False, attn_mask=None):
54
+ """
55
+ Simplified version of multihead attention (taken from torch source code but without tons of if clauses).
56
+ The mlp and layer norm come from CLIP.
57
+ x: input.
58
+ b: multihead attention module.
59
+ """
60
+
61
+ x_ = b.ln_1(x)
62
+ q, k, v = nnf.linear(x_, b.attn.in_proj_weight, b.attn.in_proj_bias).chunk(3, dim=-1)
63
+ tgt_len, bsz, embed_dim = q.size()
64
+
65
+ head_dim = embed_dim // b.attn.num_heads
66
+ scaling = float(head_dim) ** -0.5
67
+
68
+ q = q.contiguous().view(tgt_len, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1)
69
+ k = k.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1)
70
+ v = v.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1)
71
+
72
+ q = q * scaling
73
+
74
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2)) # n_heads * batch_size, tokens^2, tokens^2
75
+ if attn_mask is not None:
76
+
77
+
78
+ attn_mask_type, attn_mask = attn_mask
79
+ n_heads = attn_output_weights.size(0) // attn_mask.size(0)
80
+ attn_mask = attn_mask.repeat(n_heads, 1)
81
+
82
+ if attn_mask_type == 'cls_token':
83
+ # the mask only affects similarities compared to the readout-token.
84
+ attn_output_weights[:, 0, 1:] = attn_output_weights[:, 0, 1:] * attn_mask[None,...]
85
+ # attn_output_weights[:, 0, 0] = 0*attn_output_weights[:, 0, 0]
86
+
87
+ if attn_mask_type == 'all':
88
+ # print(attn_output_weights.shape, attn_mask[:, None].shape)
89
+ attn_output_weights[:, 1:, 1:] = attn_output_weights[:, 1:, 1:] * attn_mask[:, None]
90
+
91
+
92
+ attn_output_weights = torch.softmax(attn_output_weights, dim=-1)
93
+
94
+ attn_output = torch.bmm(attn_output_weights, v)
95
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
96
+ attn_output = b.attn.out_proj(attn_output)
97
+
98
+ x = x + attn_output
99
+ x = x + b.mlp(b.ln_2(x))
100
+
101
+ if with_aff:
102
+ return x, attn_output_weights
103
+ else:
104
+ return x
105
+
106
+
107
+ class CLIPDenseBase(nn.Module):
108
+
109
+ def __init__(self, version, reduce_cond, reduce_dim, prompt, n_tokens):
110
+ super().__init__()
111
+
112
+ import clip
113
+
114
+ # prec = torch.FloatTensor
115
+ self.clip_model, _ = clip.load(version, device='cpu', jit=False)
116
+ self.model = self.clip_model.visual
117
+
118
+ # if not None, scale conv weights such that we obtain n_tokens.
119
+ self.n_tokens = n_tokens
120
+
121
+ for p in self.clip_model.parameters():
122
+ p.requires_grad_(False)
123
+
124
+ # conditional
125
+ if reduce_cond is not None:
126
+ self.reduce_cond = nn.Linear(512, reduce_cond)
127
+ for p in self.reduce_cond.parameters():
128
+ p.requires_grad_(False)
129
+ else:
130
+ self.reduce_cond = None
131
+
132
+ self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
133
+ self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
134
+
135
+ self.reduce = nn.Linear(768, reduce_dim)
136
+
137
+ self.prompt_list = get_prompt_list(prompt)
138
+
139
+ # precomputed prompts
140
+ import pickle
141
+ if isfile('precomputed_prompt_vectors.pickle'):
142
+ precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb'))
143
+ self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()}
144
+ else:
145
+ self.precomputed_prompts = dict()
146
+
147
+ def rescaled_pos_emb(self, new_size):
148
+ assert len(new_size) == 2
149
+
150
+ a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape)
151
+ b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T
152
+ return torch.cat([self.model.positional_embedding[:1], b])
153
+
154
+ def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None):
155
+
156
+
157
+ with torch.no_grad():
158
+
159
+ inp_size = x_inp.shape[2:]
160
+
161
+ if self.n_tokens is not None:
162
+ stride2 = x_inp.shape[2] // self.n_tokens
163
+ conv_weight2 = nnf.interpolate(self.model.conv1.weight, (stride2, stride2), mode='bilinear', align_corners=True)
164
+ x = nnf.conv2d(x_inp, conv_weight2, bias=self.model.conv1.bias, stride=stride2, dilation=self.model.conv1.dilation)
165
+ else:
166
+ x = self.model.conv1(x_inp) # shape = [*, width, grid, grid]
167
+
168
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
169
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
170
+
171
+ x = torch.cat([self.model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
172
+
173
+ standard_n_tokens = 50 if self.model.conv1.kernel_size[0] == 32 else 197
174
+
175
+ if x.shape[1] != standard_n_tokens:
176
+ new_shape = int(math.sqrt(x.shape[1]-1))
177
+ x = x + self.rescaled_pos_emb((new_shape, new_shape)).to(x.dtype)[None,:,:]
178
+ else:
179
+ x = x + self.model.positional_embedding.to(x.dtype)
180
+
181
+ x = self.model.ln_pre(x)
182
+
183
+ x = x.permute(1, 0, 2) # NLD -> LND
184
+
185
+ activations, affinities = [], []
186
+ for i, res_block in enumerate(self.model.transformer.resblocks):
187
+
188
+ if mask is not None:
189
+ mask_layer, mask_type, mask_tensor = mask
190
+ if mask_layer == i or mask_layer == 'all':
191
+ # import ipdb; ipdb.set_trace()
192
+ size = int(math.sqrt(x.shape[0] - 1))
193
+
194
+ attn_mask = (mask_type, nnf.interpolate(mask_tensor.unsqueeze(1).float(), (size, size)).view(mask_tensor.shape[0], size * size))
195
+
196
+ else:
197
+ attn_mask = None
198
+ else:
199
+ attn_mask = None
200
+
201
+ x, aff_per_head = forward_multihead_attention(x, res_block, with_aff=True, attn_mask=attn_mask)
202
+
203
+ if i in extract_layers:
204
+ affinities += [aff_per_head]
205
+
206
+ #if self.n_tokens is not None:
207
+ # activations += [nnf.interpolate(x, inp_size, mode='bilinear', align_corners=True)]
208
+ #else:
209
+ activations += [x]
210
+
211
+ if len(extract_layers) > 0 and i == max(extract_layers) and skip:
212
+ print('early skip')
213
+ break
214
+
215
+ x = x.permute(1, 0, 2) # LND -> NLD
216
+ x = self.model.ln_post(x[:, 0, :])
217
+
218
+ if self.model.proj is not None:
219
+ x = x @ self.model.proj
220
+
221
+ return x, activations, affinities
222
+
223
+ def sample_prompts(self, words, prompt_list=None):
224
+
225
+ prompt_list = prompt_list if prompt_list is not None else self.prompt_list
226
+
227
+ prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
228
+ prompts = [prompt_list[i] for i in prompt_indices]
229
+ return [promt.format(w) for promt, w in zip(prompts, words)]
230
+
231
+ def get_cond_vec(self, conditional, batch_size):
232
+ # compute conditional from a single string
233
+ if conditional is not None and type(conditional) == str:
234
+ cond = self.compute_conditional(conditional)
235
+ cond = cond.repeat(batch_size, 1)
236
+
237
+ # compute conditional from string list/tuple
238
+ elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str:
239
+ assert len(conditional) == batch_size
240
+ cond = self.compute_conditional(conditional)
241
+
242
+ # use conditional directly
243
+ elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2:
244
+ cond = conditional
245
+
246
+ # compute conditional from image
247
+ elif conditional is not None and type(conditional) == torch.Tensor:
248
+ with torch.no_grad():
249
+ cond, _, _ = self.visual_forward(conditional)
250
+ else:
251
+ raise ValueError('invalid conditional')
252
+ return cond
253
+
254
+ def compute_conditional(self, conditional):
255
+ import clip
256
+
257
+ dev = next(self.parameters()).device
258
+
259
+ if type(conditional) in {list, tuple}:
260
+ text_tokens = clip.tokenize(conditional).to(dev)
261
+ cond = self.clip_model.encode_text(text_tokens)
262
+ else:
263
+ if conditional in self.precomputed_prompts:
264
+ cond = self.precomputed_prompts[conditional].float().to(dev)
265
+ else:
266
+ text_tokens = clip.tokenize([conditional]).to(dev)
267
+ cond = self.clip_model.encode_text(text_tokens)[0]
268
+
269
+ if self.shift_vector is not None:
270
+ return cond + self.shift_vector
271
+ else:
272
+ return cond
273
+
274
+
275
+ def clip_load_untrained(version):
276
+ assert version == 'ViT-B/16'
277
+ from clip.model import CLIP
278
+ from clip.clip import _MODELS, _download
279
+ model = torch.jit.load(_download(_MODELS['ViT-B/16'])).eval()
280
+ state_dict = model.state_dict()
281
+
282
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
283
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
284
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
285
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
286
+ image_resolution = vision_patch_size * grid_size
287
+ embed_dim = state_dict["text_projection"].shape[1]
288
+ context_length = state_dict["positional_embedding"].shape[0]
289
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
290
+ transformer_width = state_dict["ln_final.weight"].shape[0]
291
+ transformer_heads = transformer_width // 64
292
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
293
+
294
+ return CLIP(embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size,
295
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers)
296
+
297
+
298
+ class CLIPDensePredT(CLIPDenseBase):
299
+
300
+ def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed',
301
+ extra_blocks=0, reduce_cond=None, fix_shift=False,
302
+ learn_trans_conv_only=False, limit_to_clip_only=False, upsample=False,
303
+ add_calibration=False, rev_activations=False, trans_conv=None, n_tokens=None):
304
+
305
+ super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens)
306
+ # device = 'cpu'
307
+
308
+ self.extract_layers = extract_layers
309
+ self.cond_layer = cond_layer
310
+ self.limit_to_clip_only = limit_to_clip_only
311
+ self.process_cond = None
312
+ self.rev_activations = rev_activations
313
+
314
+ depth = len(extract_layers)
315
+
316
+ if add_calibration:
317
+ self.calibration_conds = 1
318
+
319
+ self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None
320
+
321
+ self.add_activation1 = True
322
+
323
+ self.version = version
324
+
325
+ self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version]
326
+
327
+ if fix_shift:
328
+ # self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'clip_text_shift_vector.pth')), requires_grad=False)
329
+ self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'shift_text_to_vis.pth')), requires_grad=False)
330
+ # self.shift_vector = nn.Parameter(-1*torch.load(join(dirname(basename(__file__)), 'shift2.pth')), requires_grad=False)
331
+ else:
332
+ self.shift_vector = None
333
+
334
+ if trans_conv is None:
335
+ trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version]
336
+ else:
337
+ # explicitly define transposed conv kernel size
338
+ trans_conv_ks = (trans_conv, trans_conv)
339
+
340
+ self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
341
+
342
+ assert len(self.extract_layers) == depth
343
+
344
+ self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)])
345
+ self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))])
346
+ self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)])
347
+
348
+ # refinement and trans conv
349
+
350
+ if learn_trans_conv_only:
351
+ for p in self.parameters():
352
+ p.requires_grad_(False)
353
+
354
+ for p in self.trans_conv.parameters():
355
+ p.requires_grad_(True)
356
+
357
+ self.prompt_list = get_prompt_list(prompt)
358
+
359
+
360
+ def forward(self, inp_image, conditional=None, return_features=False, mask=None):
361
+
362
+ assert type(return_features) == bool
363
+
364
+ inp_image = inp_image.to(self.model.positional_embedding.device)
365
+
366
+ if mask is not None:
367
+ raise ValueError('mask not supported')
368
+
369
+ # x_inp = normalize(inp_image)
370
+ x_inp = inp_image
371
+
372
+ bs, dev = inp_image.shape[0], x_inp.device
373
+
374
+ cond = self.get_cond_vec(conditional, bs)
375
+
376
+ visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers))
377
+
378
+ activation1 = activations[0]
379
+ activations = activations[1:]
380
+
381
+ _activations = activations[::-1] if not self.rev_activations else activations
382
+
383
+ a = None
384
+ for i, (activation, block, reduce) in enumerate(zip(_activations, self.blocks, self.reduces)):
385
+
386
+ if a is not None:
387
+ a = reduce(activation) + a
388
+ else:
389
+ a = reduce(activation)
390
+
391
+ if i == self.cond_layer:
392
+ if self.reduce_cond is not None:
393
+ cond = self.reduce_cond(cond)
394
+
395
+ a = self.film_mul(cond) * a + self.film_add(cond)
396
+
397
+ a = block(a)
398
+
399
+ for block in self.extra_blocks:
400
+ a = a + block(a)
401
+
402
+ a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
403
+
404
+ size = int(math.sqrt(a.shape[2]))
405
+
406
+ a = a.view(bs, a.shape[1], size, size)
407
+
408
+ a = self.trans_conv(a)
409
+
410
+ if self.n_tokens is not None:
411
+ a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear', align_corners=True)
412
+
413
+ if self.upsample_proj is not None:
414
+ a = self.upsample_proj(a)
415
+ a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear')
416
+
417
+ if return_features:
418
+ return a, visual_q, cond, [activation1] + activations
419
+ else:
420
+ return a,
421
+
422
+
423
+
424
+ class CLIPDensePredTMasked(CLIPDensePredT):
425
+
426
+ def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4,
427
+ prompt='fixed', extra_blocks=0, reduce_cond=None, fix_shift=False, learn_trans_conv_only=False,
428
+ refine=None, limit_to_clip_only=False, upsample=False, add_calibration=False, n_tokens=None):
429
+
430
+ super().__init__(version=version, extract_layers=extract_layers, cond_layer=cond_layer, reduce_dim=reduce_dim,
431
+ n_heads=n_heads, prompt=prompt, extra_blocks=extra_blocks, reduce_cond=reduce_cond,
432
+ fix_shift=fix_shift, learn_trans_conv_only=learn_trans_conv_only,
433
+ limit_to_clip_only=limit_to_clip_only, upsample=upsample, add_calibration=add_calibration,
434
+ n_tokens=n_tokens)
435
+
436
+ def visual_forward_masked(self, img_s, seg_s):
437
+ return super().visual_forward(img_s, mask=('all', 'cls_token', seg_s))
438
+
439
+ def forward(self, img_q, cond_or_img_s, seg_s=None, return_features=False):
440
+
441
+ if seg_s is None:
442
+ cond = cond_or_img_s
443
+ else:
444
+ img_s = cond_or_img_s
445
+
446
+ with torch.no_grad():
447
+ cond, _, _ = self.visual_forward_masked(img_s, seg_s)
448
+
449
+ return super().forward(img_q, cond, return_features=return_features)
450
+
451
+
452
+
453
+ class CLIPDenseBaseline(CLIPDenseBase):
454
+
455
+ def __init__(self, version='ViT-B/32', cond_layer=0,
456
+ extract_layer=9, reduce_dim=128, reduce2_dim=None, prompt='fixed',
457
+ reduce_cond=None, limit_to_clip_only=False, n_tokens=None):
458
+
459
+ super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens)
460
+ device = 'cpu'
461
+
462
+ # self.cond_layer = cond_layer
463
+ self.extract_layer = extract_layer
464
+ self.limit_to_clip_only = limit_to_clip_only
465
+ self.shift_vector = None
466
+
467
+ self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version]
468
+
469
+ assert reduce2_dim is not None
470
+
471
+ self.reduce2 = nn.Sequential(
472
+ nn.Linear(reduce_dim, reduce2_dim),
473
+ nn.ReLU(),
474
+ nn.Linear(reduce2_dim, reduce_dim)
475
+ )
476
+
477
+ trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version]
478
+ self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
479
+
480
+
481
+ def forward(self, inp_image, conditional=None, return_features=False):
482
+
483
+ inp_image = inp_image.to(self.model.positional_embedding.device)
484
+
485
+ # x_inp = normalize(inp_image)
486
+ x_inp = inp_image
487
+
488
+ bs, dev = inp_image.shape[0], x_inp.device
489
+
490
+ cond = self.get_cond_vec(conditional, bs)
491
+
492
+ visual_q, activations, affinities = self.visual_forward(x_inp, extract_layers=[self.extract_layer])
493
+
494
+ a = activations[0]
495
+ a = self.reduce(a)
496
+ a = self.film_mul(cond) * a + self.film_add(cond)
497
+
498
+ if self.reduce2 is not None:
499
+ a = self.reduce2(a)
500
+
501
+ # the original model would execute a transformer block here
502
+
503
+ a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
504
+
505
+ size = int(math.sqrt(a.shape[2]))
506
+
507
+ a = a.view(bs, a.shape[1], size, size)
508
+ a = self.trans_conv(a)
509
+
510
+ if return_features:
511
+ return a, visual_q, cond, activations
512
+ else:
513
+ return a,
514
+
515
+
516
+ class CLIPSegMultiLabel(nn.Module):
517
+
518
+ def __init__(self, model) -> None:
519
+ super().__init__()
520
+
521
+ from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC
522
+
523
+ self.pascal_classes = VOC
524
+
525
+ from models.clipseg import CLIPDensePredT
526
+ from general_utils import load_model
527
+ # self.clipseg = load_model('rd64-vit16-neg0.2-phrasecut', strict=False)
528
+ self.clipseg = load_model(model, strict=False)
529
+
530
+ self.clipseg.eval()
531
+
532
+ def forward(self, x):
533
+
534
+ bs = x.shape[0]
535
+ out = torch.ones(21, bs, 352, 352).to(x.device) * -10
536
+
537
+ for class_id, class_name in enumerate(self.pascal_classes):
538
+
539
+ fac = 3 if class_name == 'background' else 1
540
+
541
+ with torch.no_grad():
542
+ pred = torch.sigmoid(self.clipseg(x, class_name)[0][:,0]) * fac
543
+
544
+ out[class_id] += pred
545
+
546
+
547
+ out = out.permute(1, 0, 2, 3)
548
+
549
+ return out
550
+
551
+ # construct output tensor
552
+
clipseg/models/vitseg.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from posixpath import basename, dirname, join
3
+ # import clip
4
+ from clip.model import convert_weights
5
+ import torch
6
+ import json
7
+ from torch import nn
8
+ from torch.nn import functional as nnf
9
+ from torch.nn.modules import activation
10
+ from torch.nn.modules.activation import ReLU
11
+ from torchvision import transforms
12
+
13
+ normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
14
+
15
+ from torchvision.models import ResNet
16
+
17
+
18
+ def process_prompts(conditional, prompt_list, conditional_map):
19
+ # DEPRECATED
20
+
21
+ # randomly sample a synonym
22
+ words = [conditional_map[int(i)] for i in conditional]
23
+ words = [syns[torch.multinomial(torch.ones(len(syns)), 1, replacement=True).item()] for syns in words]
24
+ words = [w.replace('_', ' ') for w in words]
25
+
26
+ if prompt_list is not None:
27
+ prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
28
+ prompts = [prompt_list[i] for i in prompt_indices]
29
+ else:
30
+ prompts = ['a photo of {}'] * (len(words))
31
+
32
+ return [promt.format(w) for promt, w in zip(prompts, words)]
33
+
34
+
35
+ class VITDenseBase(nn.Module):
36
+
37
+ def rescaled_pos_emb(self, new_size):
38
+ assert len(new_size) == 2
39
+
40
+ a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape)
41
+ b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T
42
+ return torch.cat([self.model.positional_embedding[:1], b])
43
+
44
+ def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None):
45
+
46
+ with torch.no_grad():
47
+
48
+ x_inp = nnf.interpolate(x_inp, (384, 384))
49
+
50
+ x = self.model.patch_embed(x_inp)
51
+ cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
52
+ if self.model.dist_token is None:
53
+ x = torch.cat((cls_token, x), dim=1)
54
+ else:
55
+ x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
56
+ x = self.model.pos_drop(x + self.model.pos_embed)
57
+
58
+ activations = []
59
+ for i, block in enumerate(self.model.blocks):
60
+ x = block(x)
61
+
62
+ if i in extract_layers:
63
+ # permute to be compatible with CLIP
64
+ activations += [x.permute(1,0,2)]
65
+
66
+ x = self.model.norm(x)
67
+ x = self.model.head(self.model.pre_logits(x[:, 0]))
68
+
69
+ # again for CLIP compatibility
70
+ # x = x.permute(1, 0, 2)
71
+
72
+ return x, activations, None
73
+
74
+ def sample_prompts(self, words, prompt_list=None):
75
+
76
+ prompt_list = prompt_list if prompt_list is not None else self.prompt_list
77
+
78
+ prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
79
+ prompts = [prompt_list[i] for i in prompt_indices]
80
+ return [promt.format(w) for promt, w in zip(prompts, words)]
81
+
82
+ def get_cond_vec(self, conditional, batch_size):
83
+ # compute conditional from a single string
84
+ if conditional is not None and type(conditional) == str:
85
+ cond = self.compute_conditional(conditional)
86
+ cond = cond.repeat(batch_size, 1)
87
+
88
+ # compute conditional from string list/tuple
89
+ elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str:
90
+ assert len(conditional) == batch_size
91
+ cond = self.compute_conditional(conditional)
92
+
93
+ # use conditional directly
94
+ elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2:
95
+ cond = conditional
96
+
97
+ # compute conditional from image
98
+ elif conditional is not None and type(conditional) == torch.Tensor:
99
+ with torch.no_grad():
100
+ cond, _, _ = self.visual_forward(conditional)
101
+ else:
102
+ raise ValueError('invalid conditional')
103
+ return cond
104
+
105
+ def compute_conditional(self, conditional):
106
+ import clip
107
+
108
+ dev = next(self.parameters()).device
109
+
110
+ if type(conditional) in {list, tuple}:
111
+ text_tokens = clip.tokenize(conditional).to(dev)
112
+ cond = self.clip_model.encode_text(text_tokens)
113
+ else:
114
+ if conditional in self.precomputed_prompts:
115
+ cond = self.precomputed_prompts[conditional].float().to(dev)
116
+ else:
117
+ text_tokens = clip.tokenize([conditional]).to(dev)
118
+ cond = self.clip_model.encode_text(text_tokens)[0]
119
+
120
+ return cond
121
+
122
+
123
+ class VITDensePredT(VITDenseBase):
124
+
125
+ def __init__(self, extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed',
126
+ depth=3, extra_blocks=0, reduce_cond=None, fix_shift=False,
127
+ learn_trans_conv_only=False, refine=None, limit_to_clip_only=False, upsample=False,
128
+ add_calibration=False, process_cond=None, not_pretrained=False):
129
+ super().__init__()
130
+ # device = 'cpu'
131
+
132
+ self.extract_layers = extract_layers
133
+ self.cond_layer = cond_layer
134
+ self.limit_to_clip_only = limit_to_clip_only
135
+ self.process_cond = None
136
+
137
+ if add_calibration:
138
+ self.calibration_conds = 1
139
+
140
+ self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None
141
+
142
+ self.add_activation1 = True
143
+
144
+ import timm
145
+ self.model = timm.create_model('vit_base_patch16_384', pretrained=True)
146
+ self.model.head = nn.Linear(768, 512 if reduce_cond is None else reduce_cond)
147
+
148
+ for p in self.model.parameters():
149
+ p.requires_grad_(False)
150
+
151
+ import clip
152
+ self.clip_model, _ = clip.load('ViT-B/16', device='cpu', jit=False)
153
+ # del self.clip_model.visual
154
+
155
+
156
+ self.token_shape = (14, 14)
157
+
158
+ # conditional
159
+ if reduce_cond is not None:
160
+ self.reduce_cond = nn.Linear(512, reduce_cond)
161
+ for p in self.reduce_cond.parameters():
162
+ p.requires_grad_(False)
163
+ else:
164
+ self.reduce_cond = None
165
+
166
+ # self.film = AVAILABLE_BLOCKS['film'](512, 128)
167
+ self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
168
+ self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
169
+
170
+ # DEPRECATED
171
+ # self.conditional_map = {c['id']: c['synonyms'] for c in json.load(open(cond_map))}
172
+
173
+ assert len(self.extract_layers) == depth
174
+
175
+ self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)])
176
+ self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))])
177
+ self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)])
178
+
179
+ trans_conv_ks = (16, 16)
180
+ self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
181
+
182
+ # refinement and trans conv
183
+
184
+ if learn_trans_conv_only:
185
+ for p in self.parameters():
186
+ p.requires_grad_(False)
187
+
188
+ for p in self.trans_conv.parameters():
189
+ p.requires_grad_(True)
190
+
191
+ if prompt == 'fixed':
192
+ self.prompt_list = ['a photo of a {}.']
193
+ elif prompt == 'shuffle':
194
+ self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.']
195
+ elif prompt == 'shuffle+':
196
+ self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.',
197
+ 'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.',
198
+ 'a bad photo of a {}.', 'a photo of the {}.']
199
+ elif prompt == 'shuffle_clip':
200
+ from models.clip_prompts import imagenet_templates
201
+ self.prompt_list = imagenet_templates
202
+
203
+ if process_cond is not None:
204
+ if process_cond == 'clamp' or process_cond[0] == 'clamp':
205
+
206
+ val = process_cond[1] if type(process_cond) in {list, tuple} else 0.2
207
+
208
+ def clamp_vec(x):
209
+ return torch.clamp(x, -val, val)
210
+
211
+ self.process_cond = clamp_vec
212
+
213
+ elif process_cond.endswith('.pth'):
214
+
215
+ shift = torch.load(process_cond)
216
+ def add_shift(x):
217
+ return x + shift.to(x.device)
218
+
219
+ self.process_cond = add_shift
220
+
221
+ import pickle
222
+ precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb'))
223
+ self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()}
224
+
225
+
226
+ def forward(self, inp_image, conditional=None, return_features=False, mask=None):
227
+
228
+ assert type(return_features) == bool
229
+
230
+ # inp_image = inp_image.to(self.model.positional_embedding.device)
231
+
232
+ if mask is not None:
233
+ raise ValueError('mask not supported')
234
+
235
+ # x_inp = normalize(inp_image)
236
+ x_inp = inp_image
237
+
238
+ bs, dev = inp_image.shape[0], x_inp.device
239
+
240
+ inp_image_size = inp_image.shape[2:]
241
+
242
+ cond = self.get_cond_vec(conditional, bs)
243
+
244
+ visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers))
245
+
246
+ activation1 = activations[0]
247
+ activations = activations[1:]
248
+
249
+ a = None
250
+ for i, (activation, block, reduce) in enumerate(zip(activations[::-1], self.blocks, self.reduces)):
251
+
252
+ if a is not None:
253
+ a = reduce(activation) + a
254
+ else:
255
+ a = reduce(activation)
256
+
257
+ if i == self.cond_layer:
258
+ if self.reduce_cond is not None:
259
+ cond = self.reduce_cond(cond)
260
+
261
+ a = self.film_mul(cond) * a + self.film_add(cond)
262
+
263
+ a = block(a)
264
+
265
+ for block in self.extra_blocks:
266
+ a = a + block(a)
267
+
268
+ a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
269
+
270
+ size = int(math.sqrt(a.shape[2]))
271
+
272
+ a = a.view(bs, a.shape[1], size, size)
273
+
274
+ if self.trans_conv is not None:
275
+ a = self.trans_conv(a)
276
+
277
+ if self.upsample_proj is not None:
278
+ a = self.upsample_proj(a)
279
+ a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear')
280
+
281
+ a = nnf.interpolate(a, inp_image_size)
282
+
283
+ if return_features:
284
+ return a, visual_q, cond, [activation1] + activations
285
+ else:
286
+ return a,
clipseg/overview.png ADDED
clipseg/score.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.functional import Tensor
2
+
3
+ import torch
4
+ import inspect
5
+ import json
6
+ import yaml
7
+ import time
8
+ import sys
9
+
10
+ from general_utils import log
11
+
12
+ import numpy as np
13
+ from os.path import expanduser, join, isfile, realpath
14
+
15
+ from torch.utils.data import DataLoader
16
+
17
+ from metrics import FixedIntervalMetrics
18
+
19
+ from general_utils import load_model, log, score_config_from_cli_args, AttributeDict, get_attribute, filter_args
20
+
21
+
22
+ DATASET_CACHE = dict()
23
+
24
+ def load_model(checkpoint_id, weights_file=None, strict=True, model_args='from_config', with_config=False, ignore_weights=False):
25
+
26
+ config = json.load(open(join('logs', checkpoint_id, 'config.json')))
27
+
28
+ if model_args != 'from_config' and type(model_args) != dict:
29
+ raise ValueError('model_args must either be "from_config" or a dictionary of values')
30
+
31
+ model_cls = get_attribute(config['model'])
32
+
33
+ # load model
34
+ if model_args == 'from_config':
35
+ _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters)
36
+
37
+ model = model_cls(**model_args)
38
+
39
+ if weights_file is None:
40
+ weights_file = realpath(join('logs', checkpoint_id, 'weights.pth'))
41
+ else:
42
+ weights_file = realpath(join('logs', checkpoint_id, weights_file))
43
+
44
+ if isfile(weights_file) and not ignore_weights:
45
+ weights = torch.load(weights_file)
46
+ for _, w in weights.items():
47
+ assert not torch.any(torch.isnan(w)), 'weights contain NaNs'
48
+ model.load_state_dict(weights, strict=strict)
49
+ else:
50
+ if not ignore_weights:
51
+ raise FileNotFoundError(f'model checkpoint {weights_file} was not found')
52
+
53
+ if with_config:
54
+ return model, config
55
+
56
+ return model
57
+
58
+
59
+ def compute_shift2(model, datasets, seed=123, repetitions=1):
60
+ """ computes shift """
61
+
62
+ model.eval()
63
+ model.cuda()
64
+
65
+ import random
66
+ random.seed(seed)
67
+
68
+ preds, gts = [], []
69
+ for i_dataset, dataset in enumerate(datasets):
70
+
71
+ loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False, drop_last=False)
72
+
73
+ max_iterations = int(repetitions * len(dataset.dataset.data_list))
74
+
75
+ with torch.no_grad():
76
+
77
+ i, losses = 0, []
78
+ for i_all, (data_x, data_y) in enumerate(loader):
79
+
80
+ data_x = [v.cuda(non_blocking=True) if v is not None else v for v in data_x]
81
+ data_y = [v.cuda(non_blocking=True) if v is not None else v for v in data_y]
82
+
83
+ pred, = model(data_x[0], data_x[1], data_x[2])
84
+ preds += [pred.detach()]
85
+ gts += [data_y]
86
+
87
+ i += 1
88
+ if max_iterations and i >= max_iterations:
89
+ break
90
+
91
+ from metrics import FixedIntervalMetrics
92
+ n_values = 51
93
+ thresholds = np.linspace(0, 1, n_values)[1:-1]
94
+ metric = FixedIntervalMetrics(resize_pred=True, sigmoid=True, n_values=n_values)
95
+
96
+ for p, y in zip(preds, gts):
97
+ metric.add(p.unsqueeze(1), y)
98
+
99
+ best_idx = np.argmax(metric.value()['fgiou_scores'])
100
+ best_thresh = thresholds[best_idx]
101
+
102
+ return best_thresh
103
+
104
+
105
+ def get_cached_pascal_pfe(split, config):
106
+ from datasets.pfe_dataset import PFEPascalWrapper
107
+ try:
108
+ dataset = DATASET_CACHE[(split, config.image_size, config.label_support, config.mask)]
109
+ except KeyError:
110
+ dataset = PFEPascalWrapper(mode='val', split=split, mask=config.mask, image_size=config.image_size, label_support=config.label_support)
111
+ DATASET_CACHE[(split, config.image_size, config.label_support, config.mask)] = dataset
112
+ return dataset
113
+
114
+
115
+
116
+
117
+ def main():
118
+ config, train_checkpoint_id = score_config_from_cli_args()
119
+
120
+ metrics = score(config, train_checkpoint_id, None)
121
+
122
+ for dataset in metrics.keys():
123
+ for k in metrics[dataset]:
124
+ if type(metrics[dataset][k]) in {float, int}:
125
+ print(dataset, f'{k:<16} {metrics[dataset][k]:.3f}')
126
+
127
+
128
+ def score(config, train_checkpoint_id, train_config):
129
+
130
+ config = AttributeDict(config)
131
+
132
+ print(config)
133
+
134
+ # use training dataset and loss
135
+ train_config = AttributeDict(json.load(open(f'logs/{train_checkpoint_id}/config.json')))
136
+
137
+ cp_str = f'_{config.iteration_cp}' if config.iteration_cp is not None else ''
138
+
139
+
140
+ model_cls = get_attribute(train_config['model'])
141
+
142
+ _, model_args, _ = filter_args(train_config, inspect.signature(model_cls).parameters)
143
+
144
+ model_args = {**model_args, **{k: config[k] for k in ['process_cond', 'fix_shift'] if k in config}}
145
+
146
+ strict_models = {'ConditionBase4', 'PFENetWrapper'}
147
+ model = load_model(train_checkpoint_id, strict=model_cls.__name__ in strict_models, model_args=model_args,
148
+ weights_file=f'weights{cp_str}.pth', )
149
+
150
+
151
+ model.eval()
152
+ model.cuda()
153
+
154
+ metric_args = dict()
155
+
156
+ if 'threshold' in config:
157
+ if config.metric.split('.')[-1] == 'SkLearnMetrics':
158
+ metric_args['threshold'] = config.threshold
159
+
160
+ if 'resize_to' in config:
161
+ metric_args['resize_to'] = config.resize_to
162
+
163
+ if 'sigmoid' in config:
164
+ metric_args['sigmoid'] = config.sigmoid
165
+
166
+ if 'custom_threshold' in config:
167
+ metric_args['custom_threshold'] = config.custom_threshold
168
+
169
+ if config.test_dataset == 'pascal':
170
+
171
+ loss_fn = get_attribute(train_config.loss)
172
+ # assume that if no split is specified in train_config, test on all splits,
173
+
174
+ if 'splits' in config:
175
+ splits = config.splits
176
+ else:
177
+ if 'split' in train_config and type(train_config.split) == int:
178
+ # unless train_config has a split set, in that case assume train mode in training
179
+ splits = [train_config.split]
180
+ assert train_config.mode == 'train'
181
+ else:
182
+ splits = [0,1,2,3]
183
+
184
+ log.info('Test on these splits', splits)
185
+
186
+ scores = dict()
187
+ for split in splits:
188
+
189
+ shift = config.shift if 'shift' in config else 0
190
+
191
+ # automatic shift
192
+ if shift == 'auto':
193
+ shift_compute_t = time.time()
194
+ shift = compute_shift2(model, [get_cached_pascal_pfe(s, config) for s in range(4) if s != split], repetitions=config.compute_shift_fac)
195
+ log.info(f'Best threshold is {shift}, computed on splits: {[s for s in range(4) if s != split]}, took {time.time() - shift_compute_t:.1f}s')
196
+
197
+ dataset = get_cached_pascal_pfe(split, config)
198
+
199
+ eval_start_t = time.time()
200
+
201
+ loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False, drop_last=False)
202
+
203
+ assert config.batch_size is None or config.batch_size == 1, 'When PFE Dataset is used, batch size must be 1'
204
+
205
+ metric = FixedIntervalMetrics(resize_pred=True, sigmoid=True, custom_threshold=shift, **metric_args)
206
+
207
+ with torch.no_grad():
208
+
209
+ i, losses = 0, []
210
+ for i_all, (data_x, data_y) in enumerate(loader):
211
+
212
+ data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x]
213
+ data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y]
214
+
215
+ if config.mask == 'separate': # for old CondBase model
216
+ pred, = model(data_x[0], data_x[1], data_x[2])
217
+ else:
218
+ # assert config.mask in {'text', 'highlight'}
219
+ pred, _, _, _ = model(data_x[0], data_x[1], return_features=True)
220
+
221
+ # loss = loss_fn(pred, data_y[0])
222
+ metric.add(pred.unsqueeze(1) + shift, data_y)
223
+
224
+ # losses += [float(loss)]
225
+
226
+ i += 1
227
+ if config.max_iterations and i >= config.max_iterations:
228
+ break
229
+
230
+ #scores[split] = {m: s for m, s in zip(metric.names(), metric.value())}
231
+
232
+ log.info(f'Dataset length: {len(dataset)}, took {time.time() - eval_start_t:.1f}s to evaluate.')
233
+
234
+ print(metric.value()['mean_iou_scores'])
235
+
236
+ scores[split] = metric.scores()
237
+
238
+ log.info(f'Completed split {split}')
239
+
240
+ key_prefix = config['name'] if 'name' in config else 'pas'
241
+
242
+ all_keys = set.intersection(*[set(v.keys()) for v in scores.values()])
243
+
244
+ valid_keys = [k for k in all_keys if all(v[k] is not None and isinstance(v[k], (int, float, np.float)) for v in scores.values())]
245
+
246
+ return {key_prefix: {k: np.mean([s[k] for s in scores.values()]) for k in valid_keys}}
247
+
248
+
249
+ if config.test_dataset == 'coco':
250
+ from datasets.coco_wrapper import COCOWrapper
251
+
252
+ coco_dataset = COCOWrapper('test', fold=train_config.fold, image_size=train_config.image_size, mask=config.mask,
253
+ with_class_label=True)
254
+
255
+ log.info('Dataset length', len(coco_dataset))
256
+ loader = DataLoader(coco_dataset, batch_size=config.batch_size, num_workers=2, shuffle=False, drop_last=False)
257
+
258
+ metric = get_attribute(config.metric)(resize_pred=True, **metric_args)
259
+
260
+ shift = config.shift if 'shift' in config else 0
261
+
262
+ with torch.no_grad():
263
+
264
+ i, losses = 0, []
265
+ for i_all, (data_x, data_y) in enumerate(loader):
266
+ data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x]
267
+ data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y]
268
+
269
+ if config.mask == 'separate': # for old CondBase model
270
+ pred, = model(data_x[0], data_x[1], data_x[2])
271
+ else:
272
+ # assert config.mask in {'text', 'highlight'}
273
+ pred, _, _, _ = model(data_x[0], data_x[1], return_features=True)
274
+
275
+ metric.add([pred + shift], data_y)
276
+
277
+ i += 1
278
+ if config.max_iterations and i >= config.max_iterations:
279
+ break
280
+
281
+ key_prefix = config['name'] if 'name' in config else 'coco'
282
+ return {key_prefix: metric.scores()}
283
+ #return {key_prefix: {k: v for k, v in zip(metric.names(), metric.value())}}
284
+
285
+
286
+ if config.test_dataset == 'phrasecut':
287
+ from datasets.phrasecut import PhraseCut
288
+
289
+ only_visual = config.only_visual is not None and config.only_visual
290
+ with_visual = config.with_visual is not None and config.with_visual
291
+
292
+ dataset = PhraseCut('test',
293
+ image_size=train_config.image_size,
294
+ mask=config.mask,
295
+ with_visual=with_visual, only_visual=only_visual, aug_crop=False,
296
+ aug_color=False)
297
+
298
+ loader = DataLoader(dataset, batch_size=config.batch_size, num_workers=2, shuffle=False, drop_last=False)
299
+ metric = get_attribute(config.metric)(resize_pred=True, **metric_args)
300
+
301
+ shift = config.shift if 'shift' in config else 0
302
+
303
+
304
+ with torch.no_grad():
305
+
306
+ i, losses = 0, []
307
+ for i_all, (data_x, data_y) in enumerate(loader):
308
+ data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x]
309
+ data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y]
310
+
311
+ pred, _, _, _ = model(data_x[0], data_x[1], return_features=True)
312
+ metric.add([pred + shift], data_y)
313
+
314
+ i += 1
315
+ if config.max_iterations and i >= config.max_iterations:
316
+ break
317
+
318
+ key_prefix = config['name'] if 'name' in config else 'phrasecut'
319
+ return {key_prefix: metric.scores()}
320
+ #return {key_prefix: {k: v for k, v in zip(metric.names(), metric.value())}}
321
+
322
+ if config.test_dataset == 'pascal_zs':
323
+ from third_party.JoEm.model.metric import Evaluator
324
+ from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC
325
+ from datasets.pascal_zeroshot import PascalZeroShot, PASCAL_VOC_CLASSES_ZS
326
+
327
+ from models.clipseg import CLIPSegMultiLabel
328
+
329
+ n_unseen = train_config.remove_classes[1]
330
+
331
+ pz = PascalZeroShot('val', n_unseen, image_size=352)
332
+ m = CLIPSegMultiLabel(model=train_config.name).cuda()
333
+ m.eval();
334
+
335
+ print(len(pz), n_unseen)
336
+ print('training removed', [c for class_set in PASCAL_VOC_CLASSES_ZS[:n_unseen // 2] for c in class_set])
337
+
338
+ print('unseen', [VOC[i] for i in get_unseen_idx(n_unseen)])
339
+ print('seen', [VOC[i] for i in get_seen_idx(n_unseen)])
340
+
341
+ loader = DataLoader(pz, batch_size=8)
342
+ evaluator = Evaluator(21, get_unseen_idx(n_unseen), get_seen_idx(n_unseen))
343
+
344
+ for i, (data_x, data_y) in enumerate(loader):
345
+ pred = m(data_x[0].cuda())
346
+ evaluator.add_batch(data_y[0].numpy(), pred.argmax(1).cpu().detach().numpy())
347
+
348
+ if config.max_iter is not None and i > config.max_iter:
349
+ break
350
+
351
+ scores = evaluator.Mean_Intersection_over_Union()
352
+ key_prefix = config['name'] if 'name' in config else 'pas_zs'
353
+
354
+ return {key_prefix: {k: scores[k] for k in ['seen', 'unseen', 'harmonic', 'overall']}}
355
+
356
+ elif config.test_dataset in {'same_as_training', 'affordance'}:
357
+ loss_fn = get_attribute(train_config.loss)
358
+
359
+ metric_cls = get_attribute(config.metric)
360
+ metric = metric_cls(**metric_args)
361
+
362
+ if config.test_dataset == 'same_as_training':
363
+ dataset_cls = get_attribute(train_config.dataset)
364
+ elif config.test_dataset == 'affordance':
365
+ dataset_cls = get_attribute('datasets.lvis_oneshot3.LVIS_Affordance')
366
+ dataset_name = 'aff'
367
+ else:
368
+ dataset_cls = get_attribute('datasets.lvis_oneshot3.LVIS_OneShot')
369
+ dataset_name = 'lvis'
370
+
371
+ _, dataset_args, _ = filter_args(config, inspect.signature(dataset_cls).parameters)
372
+
373
+ dataset_args['image_size'] = train_config.image_size # explicitly use training image size for evaluation
374
+
375
+ if model.__class__.__name__ == 'PFENetWrapper':
376
+ dataset_args['image_size'] = config.image_size
377
+
378
+ log.info('init dataset', str(dataset_cls))
379
+ dataset = dataset_cls(**dataset_args)
380
+
381
+ log.info(f'Score on {model.__class__.__name__} on {dataset_cls.__name__}')
382
+
383
+ data_loader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size, shuffle=config.shuffle)
384
+
385
+ # explicitly set prompts
386
+ if config.prompt == 'plain':
387
+ model.prompt_list = ['{}']
388
+ elif config.prompt == 'fixed':
389
+ model.prompt_list = ['a photo of a {}.']
390
+ elif config.prompt == 'shuffle':
391
+ model.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.']
392
+ elif config.prompt == 'shuffle_clip':
393
+ from models.clip_prompts import imagenet_templates
394
+ model.prompt_list = imagenet_templates
395
+
396
+ config.assume_no_unused_keys(exceptions=['max_iterations'])
397
+
398
+ t_start = time.time()
399
+
400
+ with torch.no_grad(): # TODO: switch to inference_mode (torch 1.9)
401
+ i, losses = 0, []
402
+ for data_x, data_y in data_loader:
403
+
404
+ data_x = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_x]
405
+ data_y = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_y]
406
+
407
+ if model.__class__.__name__ in {'ConditionBase4', 'PFENetWrapper'}:
408
+ pred, = model(data_x[0], data_x[1], data_x[2])
409
+ visual_q = None
410
+ else:
411
+ pred, visual_q, _, _ = model(data_x[0], data_x[1], return_features=True)
412
+
413
+ loss = loss_fn(pred, data_y[0])
414
+
415
+ metric.add([pred], data_y)
416
+
417
+ losses += [float(loss)]
418
+
419
+ i += 1
420
+ if config.max_iterations and i >= config.max_iterations:
421
+ break
422
+
423
+ # scores = {m: s for m, s in zip(metric.names(), metric.value())}
424
+ scores = metric.scores()
425
+
426
+ keys = set(scores.keys())
427
+ if dataset.negative_prob > 0 and 'mIoU' in keys:
428
+ keys.remove('mIoU')
429
+
430
+ name_mask = dataset.mask.replace('text_label', 'txt')[:3]
431
+ name_neg = '' if dataset.negative_prob == 0 else '_' + str(dataset.negative_prob)
432
+
433
+ score_name = config.name if 'name' in config else f'{dataset_name}_{name_mask}{name_neg}'
434
+
435
+ scores = {score_name: {k: v for k,v in scores.items() if k in keys}}
436
+ scores[score_name].update({'test_loss': np.mean(losses)})
437
+
438
+ log.info(f'Evaluation took {time.time() - t_start:.1f}s')
439
+
440
+ return scores
441
+ else:
442
+ raise ValueError('invalid test dataset')
443
+
444
+
445
+
446
+
447
+
448
+
449
+
450
+
451
+
452
+ if __name__ == '__main__':
453
+ main()
clipseg/setup.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup
2
+
3
+ with open("README.md", "r", encoding="utf-8") as readme_file:
4
+ readme = readme_file.read()
5
+
6
+ requirements = [
7
+ "numpy",
8
+ "scipy",
9
+ "matplotlib",
10
+ "torch",
11
+ "torchvision",
12
+ "opencv-python",
13
+ "CLIP @ git+https://github.com/openai/CLIP.git"
14
+ ]
15
+
16
+ setup(
17
+ name='clipseg',
18
+ packages=['clipseg'],
19
+ package_dir={'clipseg': 'models'},
20
+ package_data={'clipseg': [
21
+ "../weights/*.pth",
22
+ ]},
23
+ version='0.0.1',
24
+ url='https://github.com/timojl/clipseg',
25
+ python_requires='>=3.9',
26
+ install_requires=requirements,
27
+ description='This repository contains the code used in the paper "Image Segmentation Using Text and Image Prompts".',
28
+ long_description=readme,
29
+ long_description_content_type="text/markdown",
30
+ )
clipseg/training.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import inspect
3
+ import json
4
+ import yaml
5
+ import math
6
+ import os
7
+ import sys
8
+
9
+ from general_utils import log
10
+
11
+ import numpy as np
12
+ from functools import partial
13
+ from os.path import expanduser, join, isfile, basename
14
+
15
+ from torch.cuda.amp import autocast, GradScaler
16
+ from torch.optim.lr_scheduler import LambdaLR
17
+ from contextlib import nullcontext
18
+ from torch.utils.data import DataLoader
19
+
20
+ from general_utils import TrainingLogger, get_attribute, filter_args, log, training_config_from_cli_args
21
+
22
+
23
+ def cosine_warmup_lr(i, warmup=10, max_iter=90):
24
+ """ Cosine LR with Warmup """
25
+ if i < warmup:
26
+ return (i+1)/(warmup+1)
27
+ else:
28
+ return 0.5 + 0.5*math.cos(math.pi*(((i-warmup)/(max_iter- warmup))))
29
+
30
+
31
+ def validate(model, dataset, config):
32
+ data_loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False)
33
+
34
+ metric_class, use_metric = config.val_metric_class, config.use_val_metric
35
+ loss_fn = get_attribute(config.loss)
36
+
37
+ model.eval()
38
+ model.cuda()
39
+
40
+ if metric_class is not None:
41
+ metric = get_attribute(metric_class)()
42
+
43
+ with torch.no_grad():
44
+
45
+ i, losses = 0, []
46
+ for data_x, data_y in data_loader:
47
+
48
+ data_x = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_x]
49
+ data_y = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_y]
50
+
51
+ prompts = model.sample_prompts(data_x[1], prompt_list=('a photo of a {}',))
52
+ pred, visual_q, _, _ = model(data_x[0], prompts, return_features=True)
53
+
54
+ if metric_class is not None:
55
+ metric.add([pred], data_y)
56
+
57
+ # pred = model(data_x[0], prompts)
58
+ # loss = loss_fn(pred[0], data_y[0])
59
+ loss = loss_fn(pred, data_y[0])
60
+ losses += [float(loss)]
61
+
62
+ i += 1
63
+
64
+ if config.val_max_iterations is not None and i > config.val_max_iterations:
65
+ break
66
+
67
+ if use_metric is None:
68
+ return np.mean(losses), {}, False
69
+ else:
70
+ metric_scores = {m: s for m, s in zip(metric.names(), metric.value())} if metric is not None else {}
71
+ return np.mean(losses), metric_scores, True
72
+
73
+
74
+ def main():
75
+
76
+ config = training_config_from_cli_args()
77
+
78
+ val_interval, best_val_loss, best_val_score = config.val_interval, float('inf'), float('-inf')
79
+
80
+ model_cls = get_attribute(config.model)
81
+ _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters)
82
+ model = model_cls(**model_args).cuda()
83
+
84
+ dataset_cls = get_attribute(config.dataset)
85
+ _, dataset_args, _ = filter_args(config, inspect.signature(dataset_cls).parameters)
86
+
87
+ dataset = dataset_cls(**dataset_args)
88
+
89
+ log.info(f'Train dataset {dataset.__class__.__name__} (length: {len(dataset)})')
90
+
91
+ if val_interval is not None:
92
+ dataset_val_args = {k[4:]: v for k,v in config.items() if k.startswith('val_') and k != 'val_interval'}
93
+ _, dataset_val_args, _ = filter_args(dataset_val_args, inspect.signature(dataset_cls).parameters)
94
+ print('val args', {**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args})
95
+
96
+ dataset_val = dataset_cls(**{**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args})
97
+
98
+ # optimizer
99
+ opt_cls = get_attribute(config.optimizer)
100
+ if config.optimize == 'torch.optim.SGD':
101
+ opt_args = {'momentum': config.momentum if 'momentum' in config else 0}
102
+ else:
103
+ opt_args = {}
104
+ opt = opt_cls(model.parameters(), lr=config.lr, **opt_args)
105
+
106
+ if config.lr_scheduler == 'cosine':
107
+ assert config.T_max is not None and config.eta_min is not None
108
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, config.T_max, config.eta_min)
109
+ elif config.lr_scheduler == 'warmup_cosine':
110
+ lr_scheduler = LambdaLR(opt, partial(cosine_warmup_lr, max_iter=(config.max_iterations), warmup=config.warmup))
111
+ else:
112
+ lr_scheduler = None
113
+
114
+ batch_size, max_iterations = config.batch_size, config.max_iterations
115
+
116
+ loss_fn = get_attribute(config.loss)
117
+
118
+ if config.amp:
119
+ log.info('Using AMP')
120
+ autocast_fn = autocast
121
+ scaler = GradScaler()
122
+ else:
123
+ autocast_fn, scaler = nullcontext, None
124
+
125
+
126
+ save_only_trainable = True
127
+ data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=4)
128
+
129
+ # disable config when hyperparam. opt. to avoid writing logs.
130
+ tracker_config = config if not config.hyperparameter_optimization else None
131
+
132
+ with TrainingLogger(log_dir=config.name, model=model, config=tracker_config) as logger:
133
+
134
+ i = 0
135
+ while True:
136
+ for data_x, data_y in data_loader:
137
+
138
+ # between caption and output feature.
139
+ # 1. Sample random captions
140
+ # 2. Check alignment with CLIP
141
+
142
+ # randomly mix text and visual support conditionals
143
+ if config.mix:
144
+
145
+ assert config.mask.startswith('text_and')
146
+
147
+ with autocast_fn():
148
+ # data_x[1] = text label
149
+ prompts = model.sample_prompts(data_x[1])
150
+
151
+ # model.clip_model()
152
+
153
+ text_cond = model.compute_conditional(prompts)
154
+ if model.__class__.__name__ == 'CLIPDensePredTMasked':
155
+ # when mask=='separate'
156
+ visual_s_cond, _, _ = model.visual_forward_masked(data_x[2].cuda(), data_x[3].cuda())
157
+ else:
158
+ # data_x[2] = visual prompt
159
+ visual_s_cond, _, _ = model.visual_forward(data_x[2].cuda())
160
+
161
+ max_txt = config.mix_text_max if config.mix_text_max is not None else 1
162
+ batch_size = text_cond.shape[0]
163
+
164
+ # sample weights for each element in batch
165
+ text_weights = torch.distributions.Uniform(config.mix_text_min, max_txt).sample((batch_size,))[:, None]
166
+ text_weights = text_weights.cuda()
167
+
168
+ if dataset.__class__.__name__ == 'PhraseCut':
169
+ # give full weight to text where support_image is invalid
170
+ visual_is_valid = data_x[4] if model.__class__.__name__ == 'CLIPDensePredTMasked' else data_x[3]
171
+ text_weights = torch.max(text_weights[:,0], 1 - visual_is_valid.float().cuda()).unsqueeze(1)
172
+
173
+ cond = text_cond * text_weights + visual_s_cond * (1 - text_weights)
174
+
175
+ else:
176
+ # no mix
177
+
178
+ if model.__class__.__name__ == 'CLIPDensePredTMasked':
179
+ # compute conditional vector using CLIP masking
180
+ with autocast_fn():
181
+ assert config.mask == 'separate'
182
+ cond, _, _ = model.visual_forward_masked(data_x[1].cuda(), data_x[2].cuda())
183
+ else:
184
+ cond = data_x[1]
185
+ if isinstance(cond, torch.Tensor):
186
+ cond = cond.cuda()
187
+
188
+ with autocast_fn():
189
+ visual_q = None
190
+
191
+ pred, visual_q, _, _ = model(data_x[0].cuda(), cond, return_features=True)
192
+
193
+ loss = loss_fn(pred, data_y[0].cuda())
194
+
195
+ if torch.isnan(loss) or torch.isinf(loss):
196
+ # skip if loss is nan
197
+ log.warning('Training stopped due to inf/nan loss.')
198
+ sys.exit(-1)
199
+
200
+ extra_loss = 0
201
+ loss += extra_loss
202
+
203
+ opt.zero_grad()
204
+
205
+ if scaler is None:
206
+ loss.backward()
207
+ opt.step()
208
+ else:
209
+ scaler.scale(loss).backward()
210
+ scaler.step(opt)
211
+ scaler.update()
212
+
213
+ if lr_scheduler is not None:
214
+ lr_scheduler.step()
215
+ if i % 2000 == 0:
216
+ current_lr = [g['lr'] for g in opt.param_groups][0]
217
+ log.info(f'current lr: {current_lr:.5f} ({len(opt.param_groups)} parameter groups)')
218
+
219
+ logger.iter(i=i, loss=loss)
220
+ i += 1
221
+
222
+ if i >= max_iterations:
223
+
224
+ if not isfile(join(logger.base_path, 'weights.pth')):
225
+ # only write if no weights were already written
226
+ logger.save_weights(only_trainable=save_only_trainable)
227
+
228
+ sys.exit(0)
229
+
230
+
231
+ if config.checkpoint_iterations is not None and i in config.checkpoint_iterations:
232
+ logger.save_weights(only_trainable=save_only_trainable, weight_file=f'weights_{i}.pth')
233
+
234
+
235
+ if val_interval is not None and i % val_interval == val_interval - 1:
236
+
237
+ val_loss, val_scores, maximize = validate(model, dataset_val, config)
238
+
239
+ if len(val_scores) > 0:
240
+
241
+ score_str = f', scores: ' + ', '.join(f'{k}: {v}' for k, v in val_scores.items())
242
+
243
+ if maximize and val_scores[config.use_val_metric] > best_val_score:
244
+ logger.save_weights(only_trainable=save_only_trainable)
245
+ best_val_score = val_scores[config.use_val_metric]
246
+
247
+ elif not maximize and val_scores[config.use_val_metric] < best_val_score:
248
+ logger.save_weights(only_trainable=save_only_trainable)
249
+ best_val_score = val_scores[config.use_val_metric]
250
+
251
+ else:
252
+ score_str = ''
253
+ # if no score is used, fall back to loss
254
+ if val_loss < best_val_loss:
255
+ logger.save_weights(only_trainable=save_only_trainable)
256
+ best_val_loss = val_loss
257
+
258
+ log.info(f'Validation loss: {val_loss}' + score_str)
259
+ logger.iter(i=i, val_loss=val_loss, extra_loss=float(extra_loss), **val_scores)
260
+ model.train()
261
+
262
+ print('epoch complete')
263
+
264
+
265
+ if __name__ == '__main__':
266
+ main()
clipseg/weights/rd64-uni.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13845f6cee4d54ca46f62ee19dd354822094a26e0efccc64e606be93d6a7e26f
3
+ size 4306645