akhaliq HF staff commited on
Commit
1ff2b57
1 Parent(s): 6ed0225

add examples

Browse files
app.py CHANGED
@@ -42,7 +42,8 @@ def predict(dict, prompt=""):
42
  images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"]
43
  return images[0]
44
 
45
- # examples = [[dict(image="init_image.png", mask="mask_image.png"), "A panda sitting on a bench"]]
 
46
  css = '''
47
  .container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
48
  #image_upload{min-height:400px}
@@ -98,10 +99,11 @@ with image_blocks as demo:
98
  rounded=(False, True, True, False),
99
  full_width=False,
100
  )
101
-
 
102
  btn.click(fn=predict, inputs=[image, prompt], outputs=image)
103
 
104
-
105
 
106
  gr.HTML(
107
  """
 
42
  images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"]
43
  return images[0]
44
 
45
+ examples = [[dict(image="init_image.png", mask="mask_image.png"), "A panda sitting on a bench"]]
46
+
47
  css = '''
48
  .container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
49
  #image_upload{min-height:400px}
 
99
  rounded=(False, True, True, False),
100
  full_width=False,
101
  )
102
+ ex = gr.Examples(fn=predict, inputs=[image, prompt], outputs=image, cache_examples=False)
103
+ ex.dataset.headers = [""]
104
  btn.click(fn=predict, inputs=[image, prompt], outputs=image)
105
 
106
+
107
 
108
  gr.HTML(
109
  """
clipseg/LICENSE DELETED
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- This license does not apply to the model weights.
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clipseg/Quickstart.ipynb DELETED
@@ -1,107 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": null,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "import torch\n",
10
- "import requests\n",
11
- "\n",
12
- "! wget https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -O weights.zip\n",
13
- "! unzip -d weights -j weights.zip\n",
14
- "from models.clipseg import CLIPDensePredT\n",
15
- "from PIL import Image\n",
16
- "from torchvision import transforms\n",
17
- "from matplotlib import pyplot as plt\n",
18
- "\n",
19
- "# load model\n",
20
- "model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)\n",
21
- "model.eval();\n",
22
- "\n",
23
- "# non-strict, because we only stored decoder weights (not CLIP weights)\n",
24
- "model.load_state_dict(torch.load('weights/rd64-uni.pth', map_location=torch.device('cpu')), strict=False);"
25
- ]
26
- },
27
- {
28
- "cell_type": "markdown",
29
- "metadata": {},
30
- "source": [
31
- "Load and normalize `example_image.jpg`. You can also load through an URL."
32
- ]
33
- },
34
- {
35
- "cell_type": "code",
36
- "execution_count": null,
37
- "metadata": {},
38
- "outputs": [],
39
- "source": [
40
- "# load and normalize image\n",
41
- "input_image = Image.open('example_image.jpg')\n",
42
- "\n",
43
- "# or load from URL...\n",
44
- "# image_url = 'https://farm5.staticflickr.com/4141/4856248695_03475782dc_z.jpg'\n",
45
- "# input_image = Image.open(requests.get(image_url, stream=True).raw)\n",
46
- "\n",
47
- "transform = transforms.Compose([\n",
48
- " transforms.ToTensor(),\n",
49
- " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
50
- " transforms.Resize((352, 352)),\n",
51
- "])\n",
52
- "img = transform(input_image).unsqueeze(0)"
53
- ]
54
- },
55
- {
56
- "cell_type": "markdown",
57
- "metadata": {},
58
- "source": [
59
- "Predict and visualize (this might take a few seconds if running without GPU support)"
60
- ]
61
- },
62
- {
63
- "cell_type": "code",
64
- "execution_count": null,
65
- "metadata": {},
66
- "outputs": [],
67
- "source": [
68
- "prompts = ['a glass', 'something to fill', 'wood', 'a jar']\n",
69
- "\n",
70
- "# predict\n",
71
- "with torch.no_grad():\n",
72
- " preds = model(img.repeat(4,1,1,1), prompts)[0]\n",
73
- "\n",
74
- "# visualize prediction\n",
75
- "_, ax = plt.subplots(1, 5, figsize=(15, 4))\n",
76
- "[a.axis('off') for a in ax.flatten()]\n",
77
- "ax[0].imshow(input_image)\n",
78
- "[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(4)];\n",
79
- "[ax[i+1].text(0, -15, prompts[i]) for i in range(4)];"
80
- ]
81
- }
82
- ],
83
- "metadata": {
84
- "interpreter": {
85
- "hash": "800ed241f7db2bd3aa6942aa3be6809cdb30ee6b0a9e773dfecfa9fef1f4c586"
86
- },
87
- "kernelspec": {
88
- "display_name": "Python 3",
89
- "language": "python",
90
- "name": "python3"
91
- },
92
- "language_info": {
93
- "codemirror_mode": {
94
- "name": "ipython",
95
- "version": 3
96
- },
97
- "file_extension": ".py",
98
- "mimetype": "text/x-python",
99
- "name": "python",
100
- "nbconvert_exporter": "python",
101
- "pygments_lexer": "ipython3",
102
- "version": "3.8.10"
103
- }
104
- },
105
- "nbformat": 4,
106
- "nbformat_minor": 4
107
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clipseg/Readme.md DELETED
@@ -1,84 +0,0 @@
1
- # Image Segmentation Using Text and Image Prompts
2
- This repository contains the code used in the paper ["Image Segmentation Using Text and Image Prompts"](https://arxiv.org/abs/2112.10003).
3
-
4
- **The Paper has been accepted to CVPR 2022!**
5
-
6
- <img src="overview.png" alt="drawing" height="200em"/>
7
-
8
- The systems allows to create segmentation models without training based on:
9
- - An arbitrary text query
10
- - Or an image with a mask highlighting stuff or an object.
11
-
12
- ### Quick Start
13
-
14
- In the `Quickstart.ipynb` notebook we provide the code for using a pre-trained CLIPSeg model. If you run the notebook locally, make sure you downloaded the `rd64-uni.pth` weights, either manually or via git lfs extension.
15
- It can also be used interactively using [MyBinder](https://mybinder.org/v2/gh/timojl/clipseg/HEAD?labpath=Quickstart.ipynb)
16
- (please note that the VM does not use a GPU, thus inference takes a few seconds).
17
-
18
-
19
- ### Dependencies
20
- This code base depends on pytorch, torchvision and clip (`pip install git+https://github.com/openai/CLIP.git`).
21
- Additional dependencies are hidden for double blind review.
22
-
23
-
24
- ### Datasets
25
-
26
- * `PhraseCut` and `PhraseCutPlus`: Referring expression dataset
27
- * `PFEPascalWrapper`: Wrapper class for PFENet's Pascal-5i implementation
28
- * `PascalZeroShot`: Wrapper class for PascalZeroShot
29
- * `COCOWrapper`: Wrapper class for COCO.
30
-
31
- ### Models
32
-
33
- * `CLIPDensePredT`: CLIPSeg model with transformer-based decoder.
34
- * `ViTDensePredT`: CLIPSeg model with transformer-based decoder.
35
-
36
- ### Third Party Dependencies
37
- For some of the datasets third party dependencies are required. Run the following commands in the `third_party` folder.
38
- ```bash
39
- git clone https://github.com/cvlab-yonsei/JoEm
40
- git clone https://github.com/Jia-Research-Lab/PFENet.git
41
- git clone https://github.com/ChenyunWu/PhraseCutDataset.git
42
- git clone https://github.com/juhongm999/hsnet.git
43
- ```
44
-
45
- ### Weights
46
-
47
- The MIT license does not apply to these weights.
48
-
49
- We provide two model weights, for D=64 (4.1MB) and D=16 (1.1MB).
50
- ```
51
- wget https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -O weights.zip
52
- unzip -d weights -j weights.zip
53
- ```
54
-
55
-
56
- ### Training and Evaluation
57
-
58
- To train use the `training.py` script with experiment file and experiment id parameters. E.g. `python training.py phrasecut.yaml 0` will train the first phrasecut experiment which is defined by the `configuration` and first `individual_configurations` parameters. Model weights will be written in `logs/`.
59
-
60
- For evaluation use `score.py`. E.g. `python score.py phrasecut.yaml 0 0` will train the first phrasecut experiment of `test_configuration` and the first configuration in `individual_configurations`.
61
-
62
-
63
- ### Usage of PFENet Wrappers
64
-
65
- In order to use the dataset and model wrappers for PFENet, the PFENet repository needs to be cloned to the root folder.
66
- `git clone https://github.com/Jia-Research-Lab/PFENet.git `
67
-
68
-
69
- ### License
70
-
71
- The source code files in this repository (excluding model weights) are released under MIT license.
72
-
73
- ### Citation
74
- ```
75
- @InProceedings{lueddecke22_cvpr,
76
- author = {L\"uddecke, Timo and Ecker, Alexander},
77
- title = {Image Segmentation Using Text and Image Prompts},
78
- booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
79
- month = {June},
80
- year = {2022},
81
- pages = {7086-7096}
82
- }
83
-
84
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clipseg/Tables.ipynb DELETED
@@ -1,349 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": null,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "%load_ext autoreload\n",
10
- "%autoreload 2\n",
11
- "\n",
12
- "import clip\n",
13
- "from evaluation_utils import norm, denorm\n",
14
- "from general_utils import *\n",
15
- "from datasets.lvis_oneshot3 import LVIS_OneShot3, LVIS_OneShot"
16
- ]
17
- },
18
- {
19
- "cell_type": "markdown",
20
- "metadata": {},
21
- "source": [
22
- "# PhraseCut"
23
- ]
24
- },
25
- {
26
- "cell_type": "code",
27
- "execution_count": null,
28
- "metadata": {},
29
- "outputs": [],
30
- "source": [
31
- "pc = experiment('experiments/phrasecut.yaml', nums=':6').dataframe()"
32
- ]
33
- },
34
- {
35
- "cell_type": "code",
36
- "execution_count": null,
37
- "metadata": {},
38
- "outputs": [],
39
- "source": [
40
- "tab1 = pc[['name', 'pc_miou_best', 'pc_fgiou_best', 'pc_ap']]"
41
- ]
42
- },
43
- {
44
- "cell_type": "code",
45
- "execution_count": null,
46
- "metadata": {},
47
- "outputs": [],
48
- "source": [
49
- "cols = ['pc_miou_0.3', 'pc_fgiou_0.3', 'pc_ap']\n",
50
- "tab1 = pc[['name'] + cols]\n",
51
- "for k in cols:\n",
52
- " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n",
53
- "tab1.loc[:, 'name'] = ['CLIPSeg (PC+)', 'CLIPSeg (PC, $D=128$)', 'CLIPSeg (PC)', 'CLIP-Deconv', 'ViTSeg (PC+)', 'ViTSeg (PC)']\n",
54
- "tab1.insert(1, 't', [0.3]*tab1.shape[0])\n",
55
- "print(tab1.to_latex(header=False, index=False))"
56
- ]
57
- },
58
- {
59
- "cell_type": "markdown",
60
- "metadata": {},
61
- "source": [
62
- "For 0.1 threshold"
63
- ]
64
- },
65
- {
66
- "cell_type": "code",
67
- "execution_count": null,
68
- "metadata": {},
69
- "outputs": [],
70
- "source": [
71
- "cols = ['pc_miou_0.1', 'pc_fgiou_0.1', 'pc_ap']\n",
72
- "tab1 = pc[['name'] + cols]\n",
73
- "for k in cols:\n",
74
- " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n",
75
- "tab1.loc[:, 'name'] = ['CLIPSeg (PC+)', 'CLIPSeg (PC, $D=128$)', 'CLIPSeg (PC)', 'CLIP-Deconv', 'ViTSeg (PC+)', 'ViTSeg (PC)']\n",
76
- "tab1.insert(1, 't', [0.1]*tab1.shape[0])\n",
77
- "print(tab1.to_latex(header=False, index=False))"
78
- ]
79
- },
80
- {
81
- "cell_type": "markdown",
82
- "metadata": {},
83
- "source": [
84
- "# One-shot"
85
- ]
86
- },
87
- {
88
- "cell_type": "markdown",
89
- "metadata": {},
90
- "source": [
91
- "### Pascal"
92
- ]
93
- },
94
- {
95
- "cell_type": "code",
96
- "execution_count": null,
97
- "metadata": {},
98
- "outputs": [],
99
- "source": [
100
- "pas = experiment('experiments/pascal_1shot.yaml', nums=':19').dataframe()"
101
- ]
102
- },
103
- {
104
- "cell_type": "code",
105
- "execution_count": null,
106
- "metadata": {},
107
- "outputs": [],
108
- "source": [
109
- "pas[['name', 'pas_h2_miou_0.3', 'pas_h2_biniou_0.3', 'pas_h2_ap', 'pas_h2_fgiou_ct']]"
110
- ]
111
- },
112
- {
113
- "cell_type": "code",
114
- "execution_count": null,
115
- "metadata": {},
116
- "outputs": [],
117
- "source": [
118
- "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n",
119
- "tab1 = pas[['pas_h2_miou_0.3', 'pas_h2_biniou_0.3', 'pas_h2_ap']]\n",
120
- "print('CLIPSeg (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
121
- "print('CLIPSeg (PC) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
122
- "\n",
123
- "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n",
124
- "tab1 = pas[['pas_h2_miou_0.2', 'pas_h2_biniou_0.2', 'pas_h2_ap']]\n",
125
- "print('CLIP-Deconv (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
126
- "\n",
127
- "pas = experiment('experiments/pascal_1shot.yaml', nums='16:20').dataframe()\n",
128
- "tab1 = pas[['pas_t_miou_0.2', 'pas_t_biniou_0.2', 'pas_t_ap']]\n",
129
- "print('ViTSeg (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')"
130
- ]
131
- },
132
- {
133
- "cell_type": "markdown",
134
- "metadata": {},
135
- "source": [
136
- "#### Pascal Zero-shot (in one-shot setting)\n",
137
- "\n",
138
- "Using the same setting as one-shot (hence different from the other zero-shot benchmark)"
139
- ]
140
- },
141
- {
142
- "cell_type": "code",
143
- "execution_count": null,
144
- "metadata": {},
145
- "outputs": [],
146
- "source": [
147
- "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n",
148
- "tab1 = pas[['pas_t_miou_0.3', 'pas_t_biniou_0.3', 'pas_t_ap']]\n",
149
- "print('CLIPSeg (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
150
- "print('CLIPSeg (PC) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
151
- "\n",
152
- "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n",
153
- "tab1 = pas[['pas_t_miou_0.3', 'pas_t_biniou_0.3', 'pas_t_ap']]\n",
154
- "print('CLIP-Deconv (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
155
- "\n",
156
- "pas = experiment('experiments/pascal_1shot.yaml', nums='16:20').dataframe()\n",
157
- "tab1 = pas[['pas_t_miou_0.2', 'pas_t_biniou_0.2', 'pas_t_ap']]\n",
158
- "print('ViTSeg (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')"
159
- ]
160
- },
161
- {
162
- "cell_type": "code",
163
- "execution_count": null,
164
- "metadata": {},
165
- "outputs": [],
166
- "source": [
167
- "# without fixed thresholds...\n",
168
- "\n",
169
- "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n",
170
- "tab1 = pas[['pas_t_best_miou', 'pas_t_best_biniou', 'pas_t_ap']]\n",
171
- "print('CLIPSeg (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
172
- "print('CLIPSeg (PC) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
173
- "\n",
174
- "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n",
175
- "tab1 = pas[['pas_t_best_miou', 'pas_t_best_biniou', 'pas_t_ap']]\n",
176
- "print('CLIP-Deconv (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')"
177
- ]
178
- },
179
- {
180
- "cell_type": "markdown",
181
- "metadata": {},
182
- "source": [
183
- "### COCO"
184
- ]
185
- },
186
- {
187
- "cell_type": "code",
188
- "execution_count": null,
189
- "metadata": {},
190
- "outputs": [],
191
- "source": [
192
- "coco = experiment('experiments/coco.yaml', nums=':29').dataframe()"
193
- ]
194
- },
195
- {
196
- "cell_type": "code",
197
- "execution_count": null,
198
- "metadata": {},
199
- "outputs": [],
200
- "source": [
201
- "tab1 = coco[['coco_h2_miou_0.1', 'coco_h2_biniou_0.1', 'coco_h2_ap']]\n",
202
- "tab2 = coco[['coco_h2_miou_0.2', 'coco_h2_biniou_0.2', 'coco_h2_ap']]\n",
203
- "tab3 = coco[['coco_h2_miou_best', 'coco_h2_biniou_best', 'coco_h2_ap']]\n",
204
- "print('CLIPSeg (COCO) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[:4].mean(0).values), '\\\\\\\\')\n",
205
- "print('CLIPSeg (COCO+N) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
206
- "print('CLIP-Deconv (COCO+N) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[12:16].mean(0).values), '\\\\\\\\')\n",
207
- "print('ViTSeg (COCO) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[8:12].mean(0).values), '\\\\\\\\')"
208
- ]
209
- },
210
- {
211
- "cell_type": "markdown",
212
- "metadata": {},
213
- "source": [
214
- "# Zero-shot"
215
- ]
216
- },
217
- {
218
- "cell_type": "code",
219
- "execution_count": null,
220
- "metadata": {},
221
- "outputs": [],
222
- "source": [
223
- "zs = experiment('experiments/pascal_0shot.yaml', nums=':11').dataframe()"
224
- ]
225
- },
226
- {
227
- "cell_type": "code",
228
- "execution_count": null,
229
- "metadata": {},
230
- "outputs": [],
231
- "source": [
232
- "\n",
233
- "tab1 = zs[['pas_zs_seen', 'pas_zs_unseen']]\n",
234
- "print('CLIPSeg (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[8:9].values[0].tolist() + tab1[10:11].values[0].tolist()), '\\\\\\\\')\n",
235
- "print('CLIP-Deconv & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[2:3].values[0].tolist() + tab1[3:4].values[0].tolist()), '\\\\\\\\')\n",
236
- "print('ViTSeg & ImageNet-1K & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:5].values[0].tolist() + tab1[5:6].values[0].tolist()), '\\\\\\\\')"
237
- ]
238
- },
239
- {
240
- "cell_type": "markdown",
241
- "metadata": {},
242
- "source": [
243
- "# Ablation"
244
- ]
245
- },
246
- {
247
- "cell_type": "code",
248
- "execution_count": null,
249
- "metadata": {},
250
- "outputs": [],
251
- "source": [
252
- "ablation = experiment('experiments/ablation.yaml', nums=':8').dataframe()"
253
- ]
254
- },
255
- {
256
- "cell_type": "code",
257
- "execution_count": null,
258
- "metadata": {},
259
- "outputs": [],
260
- "source": [
261
- "tab1 = ablation[['name', 'pc_miou_best', 'pc_ap', 'pc-vis_miou_best', 'pc-vis_ap']]\n",
262
- "for k in ['pc_miou_best', 'pc_ap', 'pc-vis_miou_best', 'pc-vis_ap']:\n",
263
- " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n",
264
- "tab1.loc[:, 'name'] = ['CLIPSeg', 'no CLIP pre-training', 'no-negatives', '50% negatives', 'no visual', '$D=16$', 'only layer 3', 'highlight mask']"
265
- ]
266
- },
267
- {
268
- "cell_type": "code",
269
- "execution_count": null,
270
- "metadata": {},
271
- "outputs": [],
272
- "source": [
273
- "print(tab1.loc[[0,1,4,5,6,7],:].to_latex(header=False, index=False))"
274
- ]
275
- },
276
- {
277
- "cell_type": "code",
278
- "execution_count": null,
279
- "metadata": {},
280
- "outputs": [],
281
- "source": [
282
- "print(tab1.loc[[0,1,4,5,6,7],:].to_latex(header=False, index=False))"
283
- ]
284
- },
285
- {
286
- "cell_type": "markdown",
287
- "metadata": {},
288
- "source": [
289
- "# Generalization"
290
- ]
291
- },
292
- {
293
- "cell_type": "code",
294
- "execution_count": null,
295
- "metadata": {},
296
- "outputs": [],
297
- "source": [
298
- "generalization = experiment('experiments/generalize.yaml').dataframe()"
299
- ]
300
- },
301
- {
302
- "cell_type": "code",
303
- "execution_count": null,
304
- "metadata": {},
305
- "outputs": [],
306
- "source": [
307
- "gen = generalization[['aff_best_fgiou', 'aff_ap', 'ability_best_fgiou', 'ability_ap', 'part_best_fgiou', 'part_ap']].values"
308
- ]
309
- },
310
- {
311
- "cell_type": "code",
312
- "execution_count": null,
313
- "metadata": {},
314
- "outputs": [],
315
- "source": [
316
- "print(\n",
317
- " 'CLIPSeg (PC+) & ' + ' & '.join(f'{x*100:.1f}' for x in gen[1]) + ' \\\\\\\\ \\n' + \\\n",
318
- " 'CLIPSeg (LVIS) & ' + ' & '.join(f'{x*100:.1f}' for x in gen[0]) + ' \\\\\\\\ \\n' + \\\n",
319
- " 'CLIP-Deconv & ' + ' & '.join(f'{x*100:.1f}' for x in gen[2]) + ' \\\\\\\\ \\n' + \\\n",
320
- " 'VITSeg & ' + ' & '.join(f'{x*100:.1f}' for x in gen[3]) + ' \\\\\\\\'\n",
321
- ")"
322
- ]
323
- }
324
- ],
325
- "metadata": {
326
- "interpreter": {
327
- "hash": "800ed241f7db2bd3aa6942aa3be6809cdb30ee6b0a9e773dfecfa9fef1f4c586"
328
- },
329
- "kernelspec": {
330
- "display_name": "env2",
331
- "language": "python",
332
- "name": "env2"
333
- },
334
- "language_info": {
335
- "codemirror_mode": {
336
- "name": "ipython",
337
- "version": 3
338
- },
339
- "file_extension": ".py",
340
- "mimetype": "text/x-python",
341
- "name": "python",
342
- "nbconvert_exporter": "python",
343
- "pygments_lexer": "ipython3",
344
- "version": "3.8.8"
345
- }
346
- },
347
- "nbformat": 4,
348
- "nbformat_minor": 4
349
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clipseg/Visual_Feature_Engineering.ipynb DELETED
@@ -1,366 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {},
6
- "source": [
7
- "# Systematic"
8
- ]
9
- },
10
- {
11
- "cell_type": "code",
12
- "execution_count": null,
13
- "metadata": {},
14
- "outputs": [],
15
- "source": [
16
- "%load_ext autoreload\n",
17
- "%autoreload 2\n",
18
- "\n",
19
- "import clip\n",
20
- "from evaluation_utils import norm, denorm\n",
21
- "from general_utils import *\n",
22
- "from datasets.lvis_oneshot3 import LVIS_OneShot3\n",
23
- "\n",
24
- "clip_device = 'cuda'\n",
25
- "clip_model, preprocess = clip.load(\"ViT-B/16\", device=clip_device)\n",
26
- "clip_model.eval();\n",
27
- "\n",
28
- "from models.clipseg import CLIPDensePredTMasked\n",
29
- "\n",
30
- "clip_mask_model = CLIPDensePredTMasked(version='ViT-B/16').to(clip_device)\n",
31
- "clip_mask_model.eval();"
32
- ]
33
- },
34
- {
35
- "cell_type": "code",
36
- "execution_count": null,
37
- "metadata": {},
38
- "outputs": [],
39
- "source": [
40
- "lvis = LVIS_OneShot3('train_fixed', mask='separate', normalize=True, with_class_label=True, add_bar=False, \n",
41
- " text_class_labels=True, image_size=352, min_area=0.1,\n",
42
- " min_frac_s=0.05, min_frac_q=0.05, fix_find_crop=True)"
43
- ]
44
- },
45
- {
46
- "cell_type": "code",
47
- "execution_count": null,
48
- "metadata": {},
49
- "outputs": [],
50
- "source": [
51
- "plot_data(lvis)"
52
- ]
53
- },
54
- {
55
- "cell_type": "code",
56
- "execution_count": null,
57
- "metadata": {},
58
- "outputs": [],
59
- "source": [
60
- "from collections import defaultdict\n",
61
- "import json\n",
62
- "\n",
63
- "lvis_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_train.json')))\n",
64
- "lvis_val_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_val.json')))\n",
65
- "\n",
66
- "objects_per_image = defaultdict(lambda : set())\n",
67
- "for ann in lvis_raw['annotations']:\n",
68
- " objects_per_image[ann['image_id']].add(ann['category_id'])\n",
69
- " \n",
70
- "for ann in lvis_val_raw['annotations']:\n",
71
- " objects_per_image[ann['image_id']].add(ann['category_id']) \n",
72
- " \n",
73
- "objects_per_image = {o: [lvis.category_names[o] for o in v] for o, v in objects_per_image.items()}\n",
74
- "\n",
75
- "del lvis_raw, lvis_val_raw"
76
- ]
77
- },
78
- {
79
- "cell_type": "code",
80
- "execution_count": null,
81
- "metadata": {},
82
- "outputs": [],
83
- "source": [
84
- "#bs = 32\n",
85
- "#batches = [get_batch(lvis, i*bs, (i+1)*bs, cuda=True) for i in range(10)]"
86
- ]
87
- },
88
- {
89
- "cell_type": "code",
90
- "execution_count": null,
91
- "metadata": {},
92
- "outputs": [],
93
- "source": [
94
- "from general_utils import get_batch\n",
95
- "from functools import partial\n",
96
- "from evaluation_utils import img_preprocess\n",
97
- "import torch\n",
98
- "\n",
99
- "def get_similarities(batches_or_dataset, process, mask=lambda x: None, clipmask=False):\n",
100
- "\n",
101
- " # base_words = [f'a photo of {x}' for x in ['a person', 'an animal', 'a knife', 'a cup']]\n",
102
- "\n",
103
- " all_prompts = []\n",
104
- " \n",
105
- " with torch.no_grad():\n",
106
- " valid_sims = []\n",
107
- " torch.manual_seed(571)\n",
108
- " \n",
109
- " if type(batches_or_dataset) == list:\n",
110
- " loader = batches_or_dataset # already loaded\n",
111
- " max_iter = float('inf')\n",
112
- " else:\n",
113
- " loader = DataLoader(batches_or_dataset, shuffle=False, batch_size=32)\n",
114
- " max_iter = 50\n",
115
- " \n",
116
- " global batch\n",
117
- " for i_batch, (batch, batch_y) in enumerate(loader):\n",
118
- " \n",
119
- " if i_batch >= max_iter: break\n",
120
- " \n",
121
- " processed_batch = process(batch)\n",
122
- " if type(processed_batch) == dict:\n",
123
- " \n",
124
- " # processed_batch = {k: v.to(clip_device) for k, v in processed_batch.items()}\n",
125
- " image_features = clip_mask_model.visual_forward(**processed_batch)[0].to(clip_device).half()\n",
126
- " else:\n",
127
- " processed_batch = process(batch).to(clip_device)\n",
128
- " processed_batch = nnf.interpolate(processed_batch, (224, 224), mode='bilinear')\n",
129
- " #image_features = clip_model.encode_image(processed_batch.to(clip_device)) \n",
130
- " image_features = clip_mask_model.visual_forward(processed_batch)[0].to(clip_device).half()\n",
131
- " \n",
132
- " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n",
133
- " bs = len(batch[0])\n",
134
- " for j in range(bs):\n",
135
- " \n",
136
- " c, _, sid, qid = lvis.sample_ids[bs * i_batch + j]\n",
137
- " support_image = basename(lvis.samples[c][sid])\n",
138
- " \n",
139
- " img_objs = [o for o in objects_per_image[int(support_image)]]\n",
140
- " img_objs = [o.replace('_', ' ') for o in img_objs]\n",
141
- " \n",
142
- " other_words = [f'a photo of a {o.replace(\"_\", \" \")}' for o in img_objs \n",
143
- " if o != batch_y[2][j]]\n",
144
- " \n",
145
- " prompts = [f'a photo of a {batch_y[2][j]}'] + other_words\n",
146
- " all_prompts += [prompts]\n",
147
- " \n",
148
- " text_cond = clip_model.encode_text(clip.tokenize(prompts).to(clip_device))\n",
149
- " text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True) \n",
150
- "\n",
151
- " global logits\n",
152
- " logits = clip_model.logit_scale.exp() * image_features[j] @ text_cond.T\n",
153
- "\n",
154
- " global sim\n",
155
- " sim = torch.softmax(logits, dim=-1)\n",
156
- " \n",
157
- " valid_sims += [sim]\n",
158
- " \n",
159
- " #valid_sims = torch.stack(valid_sims)\n",
160
- " return valid_sims, all_prompts\n",
161
- " \n",
162
- "\n",
163
- "def new_img_preprocess(x):\n",
164
- " return {'x_inp': x[1], 'mask': (11, 'cls_token', x[2])}\n",
165
- " \n",
166
- "#get_similarities(lvis, partial(img_preprocess, center_context=0.5));\n",
167
- "get_similarities(lvis, lambda x: x[1]);"
168
- ]
169
- },
170
- {
171
- "cell_type": "code",
172
- "execution_count": null,
173
- "metadata": {},
174
- "outputs": [],
175
- "source": [
176
- "preprocessing_functions = [\n",
177
- "# ['clip mask CLS L11', lambda x: {'x_inp': x[1].cuda(), 'mask': (11, 'cls_token', x[2].cuda())}],\n",
178
- "# ['clip mask CLS all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'cls_token', x[2].cuda())}],\n",
179
- "# ['clip mask all all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'all', x[2].cuda())}],\n",
180
- "# ['colorize object red', partial(img_preprocess, colorize=True)],\n",
181
- "# ['add red outline', partial(img_preprocess, outline=True)],\n",
182
- " \n",
183
- "# ['BG brightness 50%', partial(img_preprocess, bg_fac=0.5)],\n",
184
- "# ['BG brightness 10%', partial(img_preprocess, bg_fac=0.1)],\n",
185
- "# ['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)],\n",
186
- "# ['BG blur', partial(img_preprocess, blur=3)],\n",
187
- "# ['BG blur & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n",
188
- " \n",
189
- "# ['crop large context', partial(img_preprocess, center_context=0.5)],\n",
190
- "# ['crop small context', partial(img_preprocess, center_context=0.1)],\n",
191
- " ['crop & background blur', partial(img_preprocess, blur=3, center_context=0.5)],\n",
192
- " ['crop & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n",
193
- "# ['crop & background blur & intensity 10%', partial(img_preprocess, blur=3, center_context=0.1, bg_fac=0.1)],\n",
194
- "]\n",
195
- "\n",
196
- "preprocessing_functions = preprocessing_functions\n",
197
- "\n",
198
- "base, base_p = get_similarities(lvis, lambda x: x[1])\n",
199
- "outs = [get_similarities(lvis, fun) for _, fun in preprocessing_functions]"
200
- ]
201
- },
202
- {
203
- "cell_type": "code",
204
- "execution_count": null,
205
- "metadata": {},
206
- "outputs": [],
207
- "source": [
208
- "outs2 = [get_similarities(lvis, fun) for _, fun in [['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)]]]"
209
- ]
210
- },
211
- {
212
- "cell_type": "code",
213
- "execution_count": null,
214
- "metadata": {},
215
- "outputs": [],
216
- "source": [
217
- "for j in range(1):\n",
218
- " print(np.mean([outs2[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3]))"
219
- ]
220
- },
221
- {
222
- "cell_type": "code",
223
- "execution_count": null,
224
- "metadata": {},
225
- "outputs": [],
226
- "source": [
227
- "from pandas import DataFrame\n",
228
- "tab = dict()\n",
229
- "for j, (name, _) in enumerate(preprocessing_functions):\n",
230
- " tab[name] = np.mean([outs[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3])\n",
231
- " \n",
232
- " \n",
233
- "print('\\n'.join(f'{k} & {v*100:.2f} \\\\\\\\' for k,v in tab.items())) "
234
- ]
235
- },
236
- {
237
- "cell_type": "markdown",
238
- "metadata": {},
239
- "source": [
240
- "# Visual"
241
- ]
242
- },
243
- {
244
- "cell_type": "code",
245
- "execution_count": null,
246
- "metadata": {},
247
- "outputs": [],
248
- "source": [
249
- "from evaluation_utils import denorm, norm"
250
- ]
251
- },
252
- {
253
- "cell_type": "code",
254
- "execution_count": null,
255
- "metadata": {},
256
- "outputs": [],
257
- "source": [
258
- "def load_sample(filename, filename2):\n",
259
- " from os.path import join\n",
260
- " bp = expanduser('~/cloud/resources/sample_images')\n",
261
- " tf = transforms.Compose([\n",
262
- " transforms.ToTensor(),\n",
263
- " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
264
- " transforms.Resize(224),\n",
265
- " transforms.CenterCrop(224)\n",
266
- " ])\n",
267
- " tf2 = transforms.Compose([\n",
268
- " transforms.ToTensor(),\n",
269
- " transforms.Resize(224),\n",
270
- " transforms.CenterCrop(224)\n",
271
- " ])\n",
272
- " inp1 = [None, tf(Image.open(join(bp, filename))), tf2(Image.open(join(bp, filename2)))]\n",
273
- " inp1[1] = inp1[1].unsqueeze(0)\n",
274
- " inp1[2] = inp1[2][:1] \n",
275
- " return inp1\n",
276
- "\n",
277
- "def all_preprocessing(inp1):\n",
278
- " return [\n",
279
- " img_preprocess(inp1),\n",
280
- " img_preprocess(inp1, colorize=True),\n",
281
- " img_preprocess(inp1, outline=True), \n",
282
- " img_preprocess(inp1, blur=3),\n",
283
- " img_preprocess(inp1, bg_fac=0.1),\n",
284
- " #img_preprocess(inp1, bg_fac=0.5),\n",
285
- " #img_preprocess(inp1, blur=3, bg_fac=0.5), \n",
286
- " img_preprocess(inp1, blur=3, bg_fac=0.5, center_context=0.5),\n",
287
- " ]\n",
288
- "\n"
289
- ]
290
- },
291
- {
292
- "cell_type": "code",
293
- "execution_count": null,
294
- "metadata": {},
295
- "outputs": [],
296
- "source": [
297
- "from torchvision import transforms\n",
298
- "from PIL import Image\n",
299
- "from matplotlib import pyplot as plt\n",
300
- "from evaluation_utils import img_preprocess\n",
301
- "import clip\n",
302
- "\n",
303
- "images_queries = [\n",
304
- " [load_sample('things1.jpg', 'things1_jar.png'), ['jug', 'knife', 'car', 'animal', 'sieve', 'nothing']],\n",
305
- " [load_sample('own_photos/IMG_2017s_square.jpg', 'own_photos/IMG_2017s_square_trash_can.png'), ['trash bin', 'house', 'car', 'bike', 'window', 'nothing']],\n",
306
- "]\n",
307
- "\n",
308
- "\n",
309
- "_, ax = plt.subplots(2 * len(images_queries), 6, figsize=(14, 4.5 * len(images_queries)))\n",
310
- "\n",
311
- "for j, (images, objects) in enumerate(images_queries):\n",
312
- " \n",
313
- " joint_image = all_preprocessing(images)\n",
314
- " \n",
315
- " joint_image = torch.stack(joint_image)[:,0]\n",
316
- " clip_model, preprocess = clip.load(\"ViT-B/16\", device='cpu')\n",
317
- " image_features = clip_model.encode_image(joint_image)\n",
318
- " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n",
319
- " \n",
320
- " prompts = [f'a photo of a {obj}'for obj in objects]\n",
321
- " text_cond = clip_model.encode_text(clip.tokenize(prompts))\n",
322
- " text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True)\n",
323
- " logits = clip_model.logit_scale.exp() * image_features @ text_cond.T\n",
324
- " sim = torch.softmax(logits, dim=-1).detach().cpu()\n",
325
- "\n",
326
- " for i, img in enumerate(joint_image):\n",
327
- " ax[2*j, i].axis('off')\n",
328
- " \n",
329
- " ax[2*j, i].imshow(torch.clamp(denorm(joint_image[i]).permute(1,2,0), 0, 1))\n",
330
- " ax[2*j+ 1, i].grid(True)\n",
331
- " \n",
332
- " ax[2*j + 1, i].set_ylim(0,1)\n",
333
- " ax[2*j + 1, i].set_yticklabels([])\n",
334
- " ax[2*j + 1, i].set_xticks([]) # set_xticks(range(len(prompts)))\n",
335
- "# ax[1, i].set_xticklabels(objects, rotation=90)\n",
336
- " for k in range(len(sim[i])):\n",
337
- " ax[2*j + 1, i].bar(k, sim[i][k], color=plt.cm.tab20(1) if k!=0 else plt.cm.tab20(3))\n",
338
- " ax[2*j + 1, i].text(k, 0.07, objects[k], rotation=90, ha='center', fontsize=15)\n",
339
- "\n",
340
- "plt.tight_layout()\n",
341
- "plt.savefig('figures/prompt_engineering.pdf', bbox_inches='tight')"
342
- ]
343
- }
344
- ],
345
- "metadata": {
346
- "kernelspec": {
347
- "display_name": "env2",
348
- "language": "python",
349
- "name": "env2"
350
- },
351
- "language_info": {
352
- "codemirror_mode": {
353
- "name": "ipython",
354
- "version": 3
355
- },
356
- "file_extension": ".py",
357
- "mimetype": "text/x-python",
358
- "name": "python",
359
- "nbconvert_exporter": "python",
360
- "pygments_lexer": "ipython3",
361
- "version": "3.8.8"
362
- }
363
- },
364
- "nbformat": 4,
365
- "nbformat_minor": 4
366
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clipseg/datasets/coco_wrapper.py DELETED
@@ -1,99 +0,0 @@
1
- import pickle
2
- from types import new_class
3
- import torch
4
- import numpy as np
5
- import os
6
- import json
7
-
8
- from os.path import join, dirname, isdir, isfile, expanduser, realpath, basename
9
- from random import shuffle, seed as set_seed
10
- from PIL import Image
11
-
12
- from itertools import combinations
13
- from torchvision import transforms
14
- from torchvision.transforms.transforms import Resize
15
-
16
- from datasets.utils import blend_image_segmentation
17
- from general_utils import get_from_repository
18
-
19
- COCO_CLASSES = {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'}
20
-
21
- class COCOWrapper(object):
22
-
23
- def __init__(self, split, fold=0, image_size=400, aug=None, mask='separate', negative_prob=0,
24
- with_class_label=False):
25
- super().__init__()
26
-
27
- self.mask = mask
28
- self.with_class_label = with_class_label
29
- self.negative_prob = negative_prob
30
-
31
- from third_party.hsnet.data.coco import DatasetCOCO
32
-
33
- get_from_repository('COCO-20i', ['COCO-20i.tar'])
34
-
35
- foldpath = join(dirname(__file__), '../third_party/hsnet/data/splits/coco/%s/fold%d.pkl')
36
-
37
- def build_img_metadata_classwise(self):
38
- with open(foldpath % (self.split, self.fold), 'rb') as f:
39
- img_metadata_classwise = pickle.load(f)
40
- return img_metadata_classwise
41
-
42
-
43
- DatasetCOCO.build_img_metadata_classwise = build_img_metadata_classwise
44
- # DatasetCOCO.read_mask = read_mask
45
-
46
- mean = [0.485, 0.456, 0.406]
47
- std = [0.229, 0.224, 0.225]
48
- transform = transforms.Compose([
49
- transforms.Resize((image_size, image_size)),
50
- transforms.ToTensor(),
51
- transforms.Normalize(mean, std)
52
- ])
53
-
54
- self.coco = DatasetCOCO(expanduser('~/datasets/COCO-20i/'), fold, transform, split, 1, False)
55
-
56
- self.all_classes = [self.coco.class_ids]
57
- self.coco.base_path = join(expanduser('~/datasets/COCO-20i'))
58
-
59
- def __len__(self):
60
- return len(self.coco)
61
-
62
- def __getitem__(self, i):
63
- sample = self.coco[i]
64
-
65
- label_name = COCO_CLASSES[int(sample['class_id'])]
66
-
67
- img_s, seg_s = sample['support_imgs'][0], sample['support_masks'][0]
68
-
69
- if self.negative_prob > 0 and torch.rand(1).item() < self.negative_prob:
70
- new_class_id = sample['class_id']
71
- while new_class_id == sample['class_id']:
72
- sample2 = self.coco[torch.randint(0, len(self), (1,)).item()]
73
- new_class_id = sample2['class_id']
74
- img_s = sample2['support_imgs'][0]
75
- seg_s = torch.zeros_like(seg_s)
76
-
77
- mask = self.mask
78
- if mask == 'separate':
79
- supp = (img_s, seg_s)
80
- elif mask == 'text_label':
81
- # DEPRECATED
82
- supp = [int(sample['class_id'])]
83
- elif mask == 'text':
84
- supp = [label_name]
85
- else:
86
- if mask.startswith('text_and_'):
87
- mask = mask[9:]
88
- label_add = [label_name]
89
- else:
90
- label_add = []
91
-
92
- supp = label_add + blend_image_segmentation(img_s, seg_s, mode=mask)
93
-
94
- if self.with_class_label:
95
- label = (torch.zeros(0), sample['class_id'],)
96
- else:
97
- label = (torch.zeros(0), )
98
-
99
- return (sample['query_img'],) + tuple(supp), (sample['query_mask'].unsqueeze(0),) + label
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clipseg/datasets/pascal_classes.json DELETED
@@ -1 +0,0 @@
1
- [{"id": 1, "synonyms": ["aeroplane"]}, {"id": 2, "synonyms": ["bicycle"]}, {"id": 3, "synonyms": ["bird"]}, {"id": 4, "synonyms": ["boat"]}, {"id": 5, "synonyms": ["bottle"]}, {"id": 6, "synonyms": ["bus"]}, {"id": 7, "synonyms": ["car"]}, {"id": 8, "synonyms": ["cat"]}, {"id": 9, "synonyms": ["chair"]}, {"id": 10, "synonyms": ["cow"]}, {"id": 11, "synonyms": ["diningtable"]}, {"id": 12, "synonyms": ["dog"]}, {"id": 13, "synonyms": ["horse"]}, {"id": 14, "synonyms": ["motorbike"]}, {"id": 15, "synonyms": ["person"]}, {"id": 16, "synonyms": ["pottedplant"]}, {"id": 17, "synonyms": ["sheep"]}, {"id": 18, "synonyms": ["sofa"]}, {"id": 19, "synonyms": ["train"]}, {"id": 20, "synonyms": ["tvmonitor"]}]
 
 
clipseg/datasets/pascal_zeroshot.py DELETED
@@ -1,60 +0,0 @@
1
- from os.path import expanduser
2
- import torch
3
- import json
4
- import torchvision
5
- from general_utils import get_from_repository
6
- from general_utils import log
7
- from torchvision import transforms
8
-
9
- PASCAL_VOC_CLASSES_ZS = [['cattle.n.01', 'motorcycle.n.01'], ['aeroplane.n.01', 'sofa.n.01'],
10
- ['cat.n.01', 'television.n.03'], ['train.n.01', 'bottle.n.01'],
11
- ['chair.n.01', 'pot_plant.n.01']]
12
-
13
-
14
- class PascalZeroShot(object):
15
-
16
- def __init__(self, split, n_unseen, image_size=224) -> None:
17
- super().__init__()
18
-
19
- import sys
20
- sys.path.append('third_party/JoEm')
21
- from third_party.JoEm.data_loader.dataset import VOCSegmentation
22
- from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC
23
-
24
- self.pascal_classes = VOC
25
- self.image_size = image_size
26
-
27
- self.transform = transforms.Compose([
28
- transforms.Resize((image_size, image_size)),
29
- ])
30
-
31
- if split == 'train':
32
- self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen),
33
- split=split, transform=True, transform_args=dict(base_size=312, crop_size=312),
34
- ignore_bg=False, ignore_unseen=False, remv_unseen_img=True)
35
- elif split == 'val':
36
- self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen),
37
- split=split, transform=False,
38
- ignore_bg=False, ignore_unseen=False)
39
-
40
- self.unseen_idx = get_unseen_idx(n_unseen)
41
-
42
- def __len__(self):
43
- return len(self.voc)
44
-
45
- def __getitem__(self, i):
46
-
47
- sample = self.voc[i]
48
- label = sample['label'].long()
49
- all_labels = [l for l in torch.where(torch.bincount(label.flatten())>0)[0].numpy().tolist() if l != 255]
50
- class_indices = [l for l in all_labels]
51
- class_names = [self.pascal_classes[l] for l in all_labels]
52
-
53
- image = self.transform(sample['image'])
54
-
55
- label = transforms.Resize((self.image_size, self.image_size),
56
- interpolation=torchvision.transforms.InterpolationMode.NEAREST)(label.unsqueeze(0))[0]
57
-
58
- return (image,), (label, )
59
-
60
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clipseg/datasets/pfe_dataset.py DELETED
@@ -1,129 +0,0 @@
1
- from os.path import expanduser
2
- import torch
3
- import json
4
- from general_utils import get_from_repository
5
- from datasets.lvis_oneshot3 import blend_image_segmentation
6
- from general_utils import log
7
-
8
- PASCAL_CLASSES = {a['id']: a['synonyms'] for a in json.load(open('datasets/pascal_classes.json'))}
9
-
10
-
11
- class PFEPascalWrapper(object):
12
-
13
- def __init__(self, mode, split, mask='separate', image_size=473, label_support=None, size=None, p_negative=0, aug=None):
14
- import sys
15
- # sys.path.append(expanduser('~/projects/new_one_shot'))
16
- from third_party.PFENet.util.dataset import SemData
17
-
18
- get_from_repository('PascalVOC2012', ['Pascal5i.tar'])
19
-
20
- self.p_negative = p_negative
21
- self.size = size
22
- self.mode = mode
23
- self.image_size = image_size
24
-
25
- if label_support in {True, False}:
26
- log.warning('label_support argument is deprecated. Use mask instead.')
27
- #raise ValueError()
28
-
29
- self.mask = mask
30
-
31
- value_scale = 255
32
- mean = [0.485, 0.456, 0.406]
33
- mean = [item * value_scale for item in mean]
34
- std = [0.229, 0.224, 0.225]
35
- std = [item * value_scale for item in std]
36
-
37
- import third_party.PFENet.util.transform as transform
38
-
39
- if mode == 'val':
40
- data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/val.txt')
41
-
42
- data_transform = [transform.test_Resize(size=image_size)] if image_size != 'original' else []
43
- data_transform += [
44
- transform.ToTensor(),
45
- transform.Normalize(mean=mean, std=std)
46
- ]
47
-
48
-
49
- elif mode == 'train':
50
- data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/voc_sbd_merge_noduplicate.txt')
51
-
52
- assert image_size != 'original'
53
-
54
- data_transform = [
55
- transform.RandScale([0.9, 1.1]),
56
- transform.RandRotate([-10, 10], padding=mean, ignore_label=255),
57
- transform.RandomGaussianBlur(),
58
- transform.RandomHorizontalFlip(),
59
- transform.Crop((image_size, image_size), crop_type='rand', padding=mean, ignore_label=255),
60
- transform.ToTensor(),
61
- transform.Normalize(mean=mean, std=std)
62
- ]
63
-
64
- data_transform = transform.Compose(data_transform)
65
-
66
- self.dataset = SemData(split=split, mode=mode, data_root=expanduser('~/datasets/PascalVOC2012/VOC2012'),
67
- data_list=data_list, shot=1, transform=data_transform, use_coco=False, use_split_coco=False)
68
-
69
- self.class_list = self.dataset.sub_val_list if mode == 'val' else self.dataset.sub_list
70
-
71
- # verify that subcls_list always has length 1
72
- # assert len(set([len(d[4]) for d in self.dataset])) == 1
73
-
74
- print('actual length', len(self.dataset.data_list))
75
-
76
- def __len__(self):
77
- if self.mode == 'val':
78
- return len(self.dataset.data_list)
79
- else:
80
- return len(self.dataset.data_list)
81
-
82
- def __getitem__(self, index):
83
- if self.dataset.mode == 'train':
84
- image, label, s_x, s_y, subcls_list = self.dataset[index % len(self.dataset.data_list)]
85
- elif self.dataset.mode == 'val':
86
- image, label, s_x, s_y, subcls_list, ori_label = self.dataset[index % len(self.dataset.data_list)]
87
- ori_label = torch.from_numpy(ori_label).unsqueeze(0)
88
-
89
- if self.image_size != 'original':
90
- longerside = max(ori_label.size(1), ori_label.size(2))
91
- backmask = torch.ones(ori_label.size(0), longerside, longerside).cuda()*255
92
- backmask[0, :ori_label.size(1), :ori_label.size(2)] = ori_label
93
- label = backmask.clone().long()
94
- else:
95
- label = label.unsqueeze(0)
96
-
97
- # assert label.shape == (473, 473)
98
-
99
- if self.p_negative > 0:
100
- if torch.rand(1).item() < self.p_negative:
101
- while True:
102
- idx = torch.randint(0, len(self.dataset.data_list), (1,)).item()
103
- _, _, s_x, s_y, subcls_list_tmp, _ = self.dataset[idx]
104
- if subcls_list[0] != subcls_list_tmp[0]:
105
- break
106
-
107
- s_x = s_x[0]
108
- s_y = (s_y == 1)[0]
109
- label_fg = (label == 1).float()
110
- val_mask = (label != 255).float()
111
-
112
- class_id = self.class_list[subcls_list[0]]
113
-
114
- label_name = PASCAL_CLASSES[class_id][0]
115
- label_add = ()
116
- mask = self.mask
117
-
118
- if mask == 'text':
119
- support = ('a photo of a ' + label_name + '.',)
120
- elif mask == 'separate':
121
- support = (s_x, s_y)
122
- else:
123
- if mask.startswith('text_and_'):
124
- label_add = (label_name,)
125
- mask = mask[9:]
126
-
127
- support = (blend_image_segmentation(s_x, s_y.float(), mask)[0],)
128
-
129
- return (image,) + label_add + support, (label_fg.unsqueeze(0), val_mask.unsqueeze(0), subcls_list[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clipseg/datasets/phrasecut.py DELETED
@@ -1,335 +0,0 @@
1
-
2
- import torch
3
- import numpy as np
4
- import os
5
-
6
- from os.path import join, isdir, isfile, expanduser
7
- from PIL import Image
8
-
9
- from torchvision import transforms
10
- from torchvision.transforms.transforms import Resize
11
-
12
- from torch.nn import functional as nnf
13
- from general_utils import get_from_repository
14
-
15
- from skimage.draw import polygon2mask
16
-
17
-
18
-
19
- def random_crop_slices(origin_size, target_size):
20
- """Gets slices of a random crop. """
21
- assert origin_size[0] >= target_size[0] and origin_size[1] >= target_size[1], f'actual size: {origin_size}, target size: {target_size}'
22
-
23
- offset_y = torch.randint(0, origin_size[0] - target_size[0] + 1, (1,)).item() # range: 0 <= value < high
24
- offset_x = torch.randint(0, origin_size[1] - target_size[1] + 1, (1,)).item()
25
-
26
- return slice(offset_y, offset_y + target_size[0]), slice(offset_x, offset_x + target_size[1])
27
-
28
-
29
- def find_crop(seg, image_size, iterations=1000, min_frac=None, best_of=None):
30
-
31
-
32
- best_crops = []
33
- best_crop_not_ok = float('-inf'), None, None
34
- min_sum = 0
35
-
36
- seg = seg.astype('bool')
37
-
38
- if min_frac is not None:
39
- #min_sum = seg.sum() * min_frac
40
- min_sum = seg.shape[0] * seg.shape[1] * min_frac
41
-
42
- for iteration in range(iterations):
43
- sl_y, sl_x = random_crop_slices(seg.shape, image_size)
44
- seg_ = seg[sl_y, sl_x]
45
- sum_seg_ = seg_.sum()
46
-
47
- if sum_seg_ > min_sum:
48
-
49
- if best_of is None:
50
- return sl_y, sl_x, False
51
- else:
52
- best_crops += [(sum_seg_, sl_y, sl_x)]
53
- if len(best_crops) >= best_of:
54
- best_crops.sort(key=lambda x:x[0], reverse=True)
55
- sl_y, sl_x = best_crops[0][1:]
56
-
57
- return sl_y, sl_x, False
58
-
59
- else:
60
- if sum_seg_ > best_crop_not_ok[0]:
61
- best_crop_not_ok = sum_seg_, sl_y, sl_x
62
-
63
- else:
64
- # return best segmentation found
65
- return best_crop_not_ok[1:] + (best_crop_not_ok[0] <= min_sum,)
66
-
67
-
68
- class PhraseCut(object):
69
-
70
- def __init__(self, split, image_size=400, negative_prob=0, aug=None, aug_color=False, aug_crop=True,
71
- min_size=0, remove_classes=None, with_visual=False, only_visual=False, mask=None):
72
- super().__init__()
73
-
74
- self.negative_prob = negative_prob
75
- self.image_size = image_size
76
- self.with_visual = with_visual
77
- self.only_visual = only_visual
78
- self.phrase_form = '{}'
79
- self.mask = mask
80
- self.aug_crop = aug_crop
81
-
82
- if aug_color:
83
- self.aug_color = transforms.Compose([
84
- transforms.ColorJitter(0.5, 0.5, 0.2, 0.05),
85
- ])
86
- else:
87
- self.aug_color = None
88
-
89
- get_from_repository('PhraseCut', ['PhraseCut.tar'], integrity_check=lambda local_dir: all([
90
- isdir(join(local_dir, 'VGPhraseCut_v0')),
91
- isdir(join(local_dir, 'VGPhraseCut_v0', 'images')),
92
- isfile(join(local_dir, 'VGPhraseCut_v0', 'refer_train.json')),
93
- len(os.listdir(join(local_dir, 'VGPhraseCut_v0', 'images'))) in {108250, 108249}
94
- ]))
95
-
96
- from third_party.PhraseCutDataset.utils.refvg_loader import RefVGLoader
97
- self.refvg_loader = RefVGLoader(split=split)
98
-
99
- # img_ids where the size in the annotations does not match actual size
100
- invalid_img_ids = set([150417, 285665, 498246, 61564, 285743, 498269, 498010, 150516, 150344, 286093, 61530,
101
- 150333, 286065, 285814, 498187, 285761, 498042])
102
-
103
- mean = [0.485, 0.456, 0.406]
104
- std = [0.229, 0.224, 0.225]
105
- self.normalize = transforms.Normalize(mean, std)
106
-
107
- self.sample_ids = [(i, j)
108
- for i in self.refvg_loader.img_ids
109
- for j in range(len(self.refvg_loader.get_img_ref_data(i)['phrases']))
110
- if i not in invalid_img_ids]
111
-
112
-
113
- # self.all_phrases = list(set([p for i in self.refvg_loader.img_ids for p in self.refvg_loader.get_img_ref_data(i)['phrases']]))
114
-
115
- from nltk.stem import WordNetLemmatizer
116
- wnl = WordNetLemmatizer()
117
-
118
- # Filter by class (if remove_classes is set)
119
- if remove_classes is None:
120
- pass
121
- else:
122
- from datasets.generate_lvis_oneshot import PASCAL_SYNSETS, traverse_lemmas, traverse_lemmas_hypo
123
- from nltk.corpus import wordnet
124
-
125
- print('remove pascal classes...')
126
-
127
- get_data = self.refvg_loader.get_img_ref_data # shortcut
128
- keep_sids = None
129
-
130
- if remove_classes[0] == 'pas5i':
131
- subset_id = remove_classes[1]
132
- from datasets.generate_lvis_oneshot import PASCAL_5I_SYNSETS_ORDERED, PASCAL_5I_CLASS_IDS
133
- avoid = [PASCAL_5I_SYNSETS_ORDERED[i] for i in range(20) if i+1 not in PASCAL_5I_CLASS_IDS[subset_id]]
134
-
135
-
136
- elif remove_classes[0] == 'zs':
137
- stop = remove_classes[1]
138
-
139
- from datasets.pascal_zeroshot import PASCAL_VOC_CLASSES_ZS
140
-
141
- avoid = [c for class_set in PASCAL_VOC_CLASSES_ZS[:stop] for c in class_set]
142
- print(avoid)
143
-
144
- elif remove_classes[0] == 'aff':
145
- # avoid = ['drink.v.01', 'sit.v.01', 'ride.v.02']
146
- # all_lemmas = set(['drink', 'sit', 'ride'])
147
- avoid = ['drink', 'drinks', 'drinking', 'sit', 'sits', 'sitting',
148
- 'ride', 'rides', 'riding',
149
- 'fly', 'flies', 'flying', 'drive', 'drives', 'driving', 'driven',
150
- 'swim', 'swims', 'swimming',
151
- 'wheels', 'wheel', 'legs', 'leg', 'ear', 'ears']
152
- keep_sids = [(i, j) for i, j in self.sample_ids if
153
- all(x not in avoid for x in get_data(i)['phrases'][j].split(' '))]
154
-
155
- print('avoid classes:', avoid)
156
-
157
-
158
- if keep_sids is None:
159
- all_lemmas = [s for ps in avoid for s in traverse_lemmas_hypo(wordnet.synset(ps), max_depth=None)]
160
- all_lemmas = list(set(all_lemmas))
161
- all_lemmas = [h.replace('_', ' ').lower() for h in all_lemmas]
162
- all_lemmas = set(all_lemmas)
163
-
164
- # divide into multi word and single word
165
- all_lemmas_s = set(l for l in all_lemmas if ' ' not in l)
166
- all_lemmas_m = set(l for l in all_lemmas if l not in all_lemmas_s)
167
-
168
- # new3
169
- phrases = [get_data(i)['phrases'][j] for i, j in self.sample_ids]
170
- remove_sids = set((i,j) for (i,j), phrase in zip(self.sample_ids, phrases)
171
- if any(l in phrase for l in all_lemmas_m) or
172
- len(set(wnl.lemmatize(w) for w in phrase.split(' ')).intersection(all_lemmas_s)) > 0
173
- )
174
- keep_sids = [(i, j) for i, j in self.sample_ids if (i,j) not in remove_sids]
175
-
176
- print(f'Reduced to {len(keep_sids) / len(self.sample_ids):.3f}')
177
- removed_ids = set(self.sample_ids) - set(keep_sids)
178
-
179
- print('Examples of removed', len(removed_ids))
180
- for i, j in list(removed_ids)[:20]:
181
- print(i, get_data(i)['phrases'][j])
182
-
183
- self.sample_ids = keep_sids
184
-
185
- from itertools import groupby
186
- samples_by_phrase = [(self.refvg_loader.get_img_ref_data(i)['phrases'][j], (i, j))
187
- for i, j in self.sample_ids]
188
- samples_by_phrase = sorted(samples_by_phrase)
189
- samples_by_phrase = groupby(samples_by_phrase, key=lambda x: x[0])
190
-
191
- self.samples_by_phrase = {prompt: [s[1] for s in prompt_sample_ids] for prompt, prompt_sample_ids in samples_by_phrase}
192
-
193
- self.all_phrases = list(set(self.samples_by_phrase.keys()))
194
-
195
-
196
- if self.only_visual:
197
- assert self.with_visual
198
- self.sample_ids = [(i, j) for i, j in self.sample_ids
199
- if len(self.samples_by_phrase[self.refvg_loader.get_img_ref_data(i)['phrases'][j]]) > 1]
200
-
201
- # Filter by size (if min_size is set)
202
- sizes = [self.refvg_loader.get_img_ref_data(i)['gt_boxes'][j] for i, j in self.sample_ids]
203
- image_sizes = [self.refvg_loader.get_img_ref_data(i)['width'] * self.refvg_loader.get_img_ref_data(i)['height'] for i, j in self.sample_ids]
204
- #self.sizes = [sum([(s[2] - s[0]) * (s[3] - s[1]) for s in size]) for size in sizes]
205
- self.sizes = [sum([s[2] * s[3] for s in size]) / img_size for size, img_size in zip(sizes, image_sizes)]
206
-
207
- if min_size:
208
- print('filter by size')
209
-
210
- self.sample_ids = [self.sample_ids[i] for i in range(len(self.sample_ids)) if self.sizes[i] > min_size]
211
-
212
- self.base_path = join(expanduser('~/datasets/PhraseCut/VGPhraseCut_v0/images/'))
213
-
214
- def __len__(self):
215
- return len(self.sample_ids)
216
-
217
-
218
- def load_sample(self, sample_i, j):
219
-
220
- img_ref_data = self.refvg_loader.get_img_ref_data(sample_i)
221
-
222
- polys_phrase0 = img_ref_data['gt_Polygons'][j]
223
- phrase = img_ref_data['phrases'][j]
224
- phrase = self.phrase_form.format(phrase)
225
-
226
- masks = []
227
- for polys in polys_phrase0:
228
- for poly in polys:
229
- poly = [p[::-1] for p in poly] # swap x,y
230
- masks += [polygon2mask((img_ref_data['height'], img_ref_data['width']), poly)]
231
-
232
- seg = np.stack(masks).max(0)
233
- img = np.array(Image.open(join(self.base_path, str(img_ref_data['image_id']) + '.jpg')))
234
-
235
- min_shape = min(img.shape[:2])
236
-
237
- if self.aug_crop:
238
- sly, slx, exceed = find_crop(seg, (min_shape, min_shape), iterations=50, min_frac=0.05)
239
- else:
240
- sly, slx = slice(0, None), slice(0, None)
241
-
242
- seg = seg[sly, slx]
243
- img = img[sly, slx]
244
-
245
- seg = seg.astype('uint8')
246
- seg = torch.from_numpy(seg).view(1, 1, *seg.shape)
247
-
248
- if img.ndim == 2:
249
- img = np.dstack([img] * 3)
250
-
251
- img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).float()
252
-
253
- seg = nnf.interpolate(seg, (self.image_size, self.image_size), mode='nearest')[0,0]
254
- img = nnf.interpolate(img, (self.image_size, self.image_size), mode='bilinear', align_corners=True)[0]
255
-
256
- # img = img.permute([2,0, 1])
257
- img = img / 255.0
258
-
259
- if self.aug_color is not None:
260
- img = self.aug_color(img)
261
-
262
- img = self.normalize(img)
263
-
264
-
265
-
266
- return img, seg, phrase
267
-
268
- def __getitem__(self, i):
269
-
270
- sample_i, j = self.sample_ids[i]
271
-
272
- img, seg, phrase = self.load_sample(sample_i, j)
273
-
274
- if self.negative_prob > 0:
275
- if torch.rand((1,)).item() < self.negative_prob:
276
-
277
- new_phrase = None
278
- while new_phrase is None or new_phrase == phrase:
279
- idx = torch.randint(0, len(self.all_phrases), (1,)).item()
280
- new_phrase = self.all_phrases[idx]
281
- phrase = new_phrase
282
- seg = torch.zeros_like(seg)
283
-
284
- if self.with_visual:
285
- # find a corresponding visual image
286
- if phrase in self.samples_by_phrase and len(self.samples_by_phrase[phrase]) > 1:
287
- idx = torch.randint(0, len(self.samples_by_phrase[phrase]), (1,)).item()
288
- other_sample = self.samples_by_phrase[phrase][idx]
289
- #print(other_sample)
290
- img_s, seg_s, _ = self.load_sample(*other_sample)
291
-
292
- from datasets.utils import blend_image_segmentation
293
-
294
- if self.mask in {'separate', 'text_and_separate'}:
295
- # assert img.shape[1:] == img_s.shape[1:] == seg_s.shape == seg.shape[1:]
296
- add_phrase = [phrase] if self.mask == 'text_and_separate' else []
297
- vis_s = add_phrase + [img_s, seg_s, True]
298
- else:
299
- if self.mask.startswith('text_and_'):
300
- mask_mode = self.mask[9:]
301
- label_add = [phrase]
302
- else:
303
- mask_mode = self.mask
304
- label_add = []
305
-
306
- masked_img_s = torch.from_numpy(blend_image_segmentation(img_s, seg_s, mode=mask_mode, image_size=self.image_size)[0])
307
- vis_s = label_add + [masked_img_s, True]
308
-
309
- else:
310
- # phrase is unique
311
- vis_s = torch.zeros_like(img)
312
-
313
- if self.mask in {'separate', 'text_and_separate'}:
314
- add_phrase = [phrase] if self.mask == 'text_and_separate' else []
315
- vis_s = add_phrase + [vis_s, torch.zeros(*vis_s.shape[1:], dtype=torch.uint8), False]
316
- elif self.mask.startswith('text_and_'):
317
- vis_s = [phrase, vis_s, False]
318
- else:
319
- vis_s = [vis_s, False]
320
- else:
321
- assert self.mask == 'text'
322
- vis_s = [phrase]
323
-
324
- seg = seg.unsqueeze(0).float()
325
-
326
- data_x = (img,) + tuple(vis_s)
327
-
328
- return data_x, (seg, torch.zeros(0), i)
329
-
330
-
331
- class PhraseCutPlus(PhraseCut):
332
-
333
- def __init__(self, split, image_size=400, aug=None, aug_color=False, aug_crop=True, min_size=0, remove_classes=None, only_visual=False, mask=None):
334
- super().__init__(split, image_size=image_size, negative_prob=0.2, aug=aug, aug_color=aug_color, aug_crop=aug_crop, min_size=min_size,
335
- remove_classes=remove_classes, with_visual=True, only_visual=only_visual, mask=mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clipseg/datasets/utils.py DELETED
@@ -1,68 +0,0 @@
1
-
2
- import numpy as np
3
- import torch
4
-
5
-
6
- def blend_image_segmentation(img, seg, mode, image_size=224):
7
-
8
-
9
- if mode in {'blur_highlight', 'blur3_highlight', 'blur3_highlight01', 'blur_highlight_random', 'crop'}:
10
- if isinstance(img, np.ndarray):
11
- img = torch.from_numpy(img)
12
-
13
- if isinstance(seg, np.ndarray):
14
- seg = torch.from_numpy(seg)
15
-
16
- if mode == 'overlay':
17
- out = img * seg
18
- out = [out.astype('float32')]
19
- elif mode == 'highlight':
20
- out = img * seg[None, :, :] * 0.85 + 0.15 * img
21
- out = [out.astype('float32')]
22
- elif mode == 'highlight2':
23
- img = img / 2
24
- out = (img+0.1) * seg[None, :, :] + 0.3 * img
25
- out = [out.astype('float32')]
26
- elif mode == 'blur_highlight':
27
- from evaluation_utils import img_preprocess
28
- out = [img_preprocess((None, [img], [seg]), blur=1, bg_fac=0.5).numpy()[0] - 0.01]
29
- elif mode == 'blur3_highlight':
30
- from evaluation_utils import img_preprocess
31
- out = [img_preprocess((None, [img], [seg]), blur=3, bg_fac=0.5).numpy()[0] - 0.01]
32
- elif mode == 'blur3_highlight01':
33
- from evaluation_utils import img_preprocess
34
- out = [img_preprocess((None, [img], [seg]), blur=3, bg_fac=0.1).numpy()[0] - 0.01]
35
- elif mode == 'blur_highlight_random':
36
- from evaluation_utils import img_preprocess
37
- out = [img_preprocess((None, [img], [seg]), blur=0 + torch.randint(0, 3, (1,)).item(), bg_fac=0.1 + 0.8*torch.rand(1).item()).numpy()[0] - 0.01]
38
- elif mode == 'crop':
39
- from evaluation_utils import img_preprocess
40
- out = [img_preprocess((None, [img], [seg]), blur=1, center_context=0.1, image_size=image_size)[0].numpy()]
41
- elif mode == 'crop_blur_highlight':
42
- from evaluation_utils import img_preprocess
43
- out = [img_preprocess((None, [img], [seg]), blur=3, center_context=0.1, bg_fac=0.1, image_size=image_size)[0].numpy()]
44
- elif mode == 'crop_blur_highlight352':
45
- from evaluation_utils import img_preprocess
46
- out = [img_preprocess((None, [img], [seg]), blur=3, center_context=0.1, bg_fac=0.1, image_size=352)[0].numpy()]
47
- elif mode == 'shape':
48
- out = [np.stack([seg[:, :]]*3).astype('float32')]
49
- elif mode == 'concat':
50
- out = [np.concatenate([img, seg[None, :, :]]).astype('float32')]
51
- elif mode == 'image_only':
52
- out = [img.astype('float32')]
53
- elif mode == 'image_black':
54
- out = [img.astype('float32')*0]
55
- elif mode is None:
56
- out = [img.astype('float32')]
57
- elif mode == 'separate':
58
- out = [img.astype('float32'), seg.astype('int64')]
59
- elif mode == 'separate_img_black':
60
- out = [img.astype('float32')*0, seg.astype('int64')]
61
- elif mode == 'separate_seg_ones':
62
- out = [img.astype('float32'), np.ones_like(seg).astype('int64')]
63
- elif mode == 'separate_both_black':
64
- out = [img.astype('float32')*0, seg.astype('int64')*0]
65
- else:
66
- raise ValueError(f'invalid mode: {mode}')
67
-
68
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clipseg/environment.yml DELETED
@@ -1,15 +0,0 @@
1
- name: clipseg-environment
2
- channels:
3
- - conda-forge
4
- - pytorch
5
- dependencies:
6
- - numpy
7
- - scipy
8
- - matplotlib-base
9
- - pip
10
- - pip:
11
- - --find-links https://download.pytorch.org/whl/torch_stable.html
12
- - torch==1.10.0+cpu
13
- - torchvision==0.11.1+cpu
14
- - opencv-python
15
- - git+https://github.com/openai/CLIP.git
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clipseg/evaluation_utils.py DELETED
@@ -1,292 +0,0 @@
1
- from torch.functional import Tensor
2
- from general_utils import load_model
3
- from torch.utils.data import DataLoader
4
- import torch
5
- import numpy as np
6
-
7
- def denorm(img):
8
-
9
- np_input = False
10
- if isinstance(img, np.ndarray):
11
- img = torch.from_numpy(img)
12
- np_input = True
13
-
14
- mean = torch.Tensor([0.485, 0.456, 0.406])
15
- std = torch.Tensor([0.229, 0.224, 0.225])
16
-
17
- img_denorm = (img*std[:,None,None]) + mean[:,None,None]
18
-
19
- if np_input:
20
- img_denorm = np.clip(img_denorm.numpy(), 0, 1)
21
- else:
22
- img_denorm = torch.clamp(img_denorm, 0, 1)
23
-
24
- return img_denorm
25
-
26
-
27
- def norm(img):
28
- mean = torch.Tensor([0.485, 0.456, 0.406])
29
- std = torch.Tensor([0.229, 0.224, 0.225])
30
- return (img - mean[:,None,None]) / std[:,None,None]
31
-
32
-
33
- def fast_iou_curve(p, g):
34
-
35
- g = g[p.sort().indices]
36
- p = torch.sigmoid(p.sort().values)
37
-
38
- scores = []
39
- vals = np.linspace(0, 1, 50)
40
-
41
- for q in vals:
42
-
43
- n = int(len(g) * q)
44
-
45
- valid = torch.where(p > q)[0]
46
- if len(valid) > 0:
47
- n = int(valid[0])
48
- else:
49
- n = len(g)
50
-
51
- fn = g[:n].sum()
52
- tn = n - fn
53
- tp = g[n:].sum()
54
- fp = len(g) - n - tp
55
-
56
- iou = tp / (tp + fn + fp)
57
-
58
- precision = tp / (tp + fp)
59
- recall = tp / (tp + fn)
60
-
61
- scores += [iou]
62
-
63
- return vals, scores
64
-
65
-
66
- def fast_rp_curve(p, g):
67
-
68
- g = g[p.sort().indices]
69
- p = torch.sigmoid(p.sort().values)
70
-
71
- precisions, recalls = [], []
72
- vals = np.linspace(p.min(), p.max(), 250)
73
-
74
- for q in p[::100000]:
75
-
76
- n = int(len(g) * q)
77
-
78
- valid = torch.where(p > q)[0]
79
- if len(valid) > 0:
80
- n = int(valid[0])
81
- else:
82
- n = len(g)
83
-
84
- fn = g[:n].sum()
85
- tn = n - fn
86
- tp = g[n:].sum()
87
- fp = len(g) - n - tp
88
-
89
- iou = tp / (tp + fn + fp)
90
-
91
- precision = tp / (tp + fp)
92
- recall = tp / (tp + fn)
93
-
94
- precisions += [precision]
95
- recalls += [recall]
96
-
97
- return recalls, precisions
98
-
99
-
100
- # Image processing
101
-
102
- def img_preprocess(batch, blur=0, grayscale=False, center_context=None, rect=False, rect_color=(255,0,0), rect_width=2,
103
- brightness=1.0, bg_fac=1, colorize=False, outline=False, image_size=224):
104
- import cv2
105
-
106
- rw = rect_width
107
-
108
- out = []
109
- for img, mask in zip(batch[1], batch[2]):
110
-
111
- img = img.cpu() if isinstance(img, torch.Tensor) else torch.from_numpy(img)
112
- mask = mask.cpu() if isinstance(mask, torch.Tensor) else torch.from_numpy(mask)
113
-
114
- img *= brightness
115
- img_bl = img
116
- if blur > 0: # best 5
117
- img_bl = torch.from_numpy(cv2.GaussianBlur(img.permute(1,2,0).numpy(), (15, 15), blur)).permute(2,0,1)
118
-
119
- if grayscale:
120
- img_bl = img_bl[1][None]
121
-
122
- #img_inp = img_ratio*img*mask + (1-img_ratio)*img_bl
123
- # img_inp = img_ratio*img*mask + (1-img_ratio)*img_bl * (1-mask)
124
- img_inp = img*mask + (bg_fac) * img_bl * (1-mask)
125
-
126
- if rect:
127
- _, bbox = crop_mask(img, mask, context=0.1)
128
- img_inp[:, bbox[2]: bbox[3], max(0, bbox[0]-rw):bbox[0]+rw] = torch.tensor(rect_color)[:,None,None]
129
- img_inp[:, bbox[2]: bbox[3], max(0, bbox[1]-rw):bbox[1]+rw] = torch.tensor(rect_color)[:,None,None]
130
- img_inp[:, max(0, bbox[2]-1): bbox[2]+rw, bbox[0]:bbox[1]] = torch.tensor(rect_color)[:,None,None]
131
- img_inp[:, max(0, bbox[3]-1): bbox[3]+rw, bbox[0]:bbox[1]] = torch.tensor(rect_color)[:,None,None]
132
-
133
-
134
- if center_context is not None:
135
- img_inp = object_crop(img_inp, mask, context=center_context, image_size=image_size)
136
-
137
- if colorize:
138
- img_gray = denorm(img)
139
- img_gray = cv2.cvtColor(img_gray.permute(1,2,0).numpy(), cv2.COLOR_RGB2GRAY)
140
- img_gray = torch.stack([torch.from_numpy(img_gray)]*3)
141
- img_inp = torch.tensor([1,0.2,0.2])[:,None,None] * img_gray * mask + bg_fac * img_gray * (1-mask)
142
- img_inp = norm(img_inp)
143
-
144
- if outline:
145
- cont = cv2.findContours(mask.byte().numpy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
146
- outline_img = np.zeros(mask.shape, dtype=np.uint8)
147
- cv2.drawContours(outline_img, cont[0], -1, thickness=5, color=(255, 255, 255))
148
- outline_img = torch.stack([torch.from_numpy(outline_img)]*3).float() / 255.
149
- img_inp = torch.tensor([1,0,0])[:,None,None] * outline_img + denorm(img_inp) * (1- outline_img)
150
- img_inp = norm(img_inp)
151
-
152
- out += [img_inp]
153
-
154
- return torch.stack(out)
155
-
156
-
157
- def object_crop(img, mask, context=0.0, square=False, image_size=224):
158
- img_crop, bbox = crop_mask(img, mask, context=context, square=square)
159
- img_crop = pad_to_square(img_crop, channel_dim=0)
160
- img_crop = torch.nn.functional.interpolate(img_crop.unsqueeze(0), (image_size, image_size)).squeeze(0)
161
- return img_crop
162
-
163
-
164
- def crop_mask(img, mask, context=0.0, square=False):
165
-
166
- assert img.shape[1:] == mask.shape
167
-
168
- bbox = [mask.max(0).values.argmax(), mask.size(0) - mask.max(0).values.flip(0).argmax()]
169
- bbox += [mask.max(1).values.argmax(), mask.size(1) - mask.max(1).values.flip(0).argmax()]
170
- bbox = [int(x) for x in bbox]
171
-
172
- width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0])
173
-
174
- # square mask
175
- if square:
176
- bbox[0] = int(max(0, bbox[0] - context * height))
177
- bbox[1] = int(min(mask.size(0), bbox[1] + context * height))
178
- bbox[2] = int(max(0, bbox[2] - context * width))
179
- bbox[3] = int(min(mask.size(1), bbox[3] + context * width))
180
-
181
- width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0])
182
- if height > width:
183
- bbox[2] = int(max(0, (bbox[2] - 0.5*height)))
184
- bbox[3] = bbox[2] + height
185
- else:
186
- bbox[0] = int(max(0, (bbox[0] - 0.5*width)))
187
- bbox[1] = bbox[0] + width
188
- else:
189
- bbox[0] = int(max(0, bbox[0] - context * height))
190
- bbox[1] = int(min(mask.size(0), bbox[1] + context * height))
191
- bbox[2] = int(max(0, bbox[2] - context * width))
192
- bbox[3] = int(min(mask.size(1), bbox[3] + context * width))
193
-
194
- width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0])
195
- img_crop = img[:, bbox[2]: bbox[3], bbox[0]: bbox[1]]
196
- return img_crop, bbox
197
-
198
-
199
- def pad_to_square(img, channel_dim=2, fill=0):
200
- """
201
-
202
-
203
- add padding such that a squared image is returned """
204
-
205
- from torchvision.transforms.functional import pad
206
-
207
- if channel_dim == 2:
208
- img = img.permute(2, 0, 1)
209
- elif channel_dim == 0:
210
- pass
211
- else:
212
- raise ValueError('invalid channel_dim')
213
-
214
- h, w = img.shape[1:]
215
- pady1 = pady2 = padx1 = padx2 = 0
216
-
217
- if h > w:
218
- padx1 = (h - w) // 2
219
- padx2 = h - w - padx1
220
- elif w > h:
221
- pady1 = (w - h) // 2
222
- pady2 = w - h - pady1
223
-
224
- img_padded = pad(img, padding=(padx1, pady1, padx2, pady2), padding_mode='constant')
225
-
226
- if channel_dim == 2:
227
- img_padded = img_padded.permute(1, 2, 0)
228
-
229
- return img_padded
230
-
231
-
232
- # qualitative
233
-
234
- def split_sentence(inp, limit=9):
235
- t_new, current_len = [], 0
236
- for k, t in enumerate(inp.split(' ')):
237
- current_len += len(t) + 1
238
- t_new += [t+' ']
239
- # not last
240
- if current_len > limit and k != len(inp.split(' ')) - 1:
241
- current_len = 0
242
- t_new += ['\n']
243
-
244
- t_new = ''.join(t_new)
245
- return t_new
246
-
247
-
248
- from matplotlib import pyplot as plt
249
-
250
-
251
- def plot(imgs, *preds, labels=None, scale=1, cmap=plt.cm.magma, aps=None, gt_labels=None, vmax=None):
252
-
253
- row_off = 0 if labels is None else 1
254
- _, ax = plt.subplots(len(imgs) + row_off, 1 + len(preds), figsize=(scale * float(1 + 2*len(preds)), scale * float(len(imgs)*2)))
255
- [a.axis('off') for a in ax.flatten()]
256
-
257
- if labels is not None:
258
- for j in range(len(labels)):
259
- t_new = split_sentence(labels[j], limit=6)
260
- ax[0, 1+ j].text(0.5, 0.1, t_new, ha='center', fontsize=3+ 10*scale)
261
-
262
-
263
- for i in range(len(imgs)):
264
- ax[i + row_off,0].imshow(imgs[i])
265
- for j in range(len(preds)):
266
- img = preds[j][i][0].detach().cpu().numpy()
267
-
268
- if gt_labels is not None and labels[j] == gt_labels[i]:
269
- print(j, labels[j], gt_labels[i])
270
- edgecolor = 'red'
271
- if aps is not None:
272
- ax[i + row_off, 1 + j].text(30, 70, f'AP: {aps[i]:.3f}', color='red', fontsize=8)
273
- else:
274
- edgecolor = 'k'
275
-
276
- rect = plt.Rectangle([0,0], img.shape[0], img.shape[1], facecolor="none",
277
- edgecolor=edgecolor, linewidth=3)
278
- ax[i + row_off,1 + j].add_patch(rect)
279
-
280
- if vmax is None:
281
- this_vmax = 1
282
- elif vmax == 'per_prompt':
283
- this_vmax = max([preds[j][_i][0].max() for _i in range(len(imgs))])
284
- elif vmax == 'per_image':
285
- this_vmax = max([preds[_j][i][0].max() for _j in range(len(preds))])
286
-
287
- ax[i + row_off,1 + j].imshow(img, vmin=0, vmax=this_vmax, cmap=cmap)
288
-
289
-
290
- # ax[i,1 + j].imshow(preds[j][i][0].detach().cpu().numpy(), vmin=preds[j].min(), vmax=preds[j].max())
291
- plt.tight_layout()
292
- plt.subplots_adjust(wspace=0.05, hspace=0.05)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clipseg/example_image.jpg DELETED
Binary file (91.5 kB)
 
clipseg/experiments/ablation.yaml DELETED
@@ -1,84 +0,0 @@
1
- configuration:
2
- batch_size: 64
3
- optimizer: torch.optim.AdamW
4
-
5
- lr: 0.001
6
-
7
- trainer: experiment_setup.train_loop
8
- scorer: experiment_setup.score
9
- model: models.clipseg.CLIPDensePredT
10
-
11
- lr_scheduler: cosine
12
- T_max: 20000
13
- eta_min: 0.0001
14
-
15
- max_iterations: 20000 # <-##########################################
16
- val_interval: null
17
-
18
- # dataset
19
- dataset: datasets.phrasecut.PhraseCut # <-----------------
20
- split_mode: pascal_test
21
- split: train
22
- mask: text_and_crop_blur_highlight352
23
- image_size: 352
24
- negative_prob: 0.2
25
- mix_text_max: 0.5
26
-
27
- # general
28
- mix: True # <-----------------
29
- prompt: shuffle+
30
- norm_cond: True
31
- mix_text_min: 0.0
32
- with_visual: True
33
-
34
- # model
35
- version: 'ViT-B/16'
36
- extract_layers: [3, 7, 9]
37
- reduce_dim: 64
38
- depth: 3
39
- fix_shift: False # <-##########################################
40
-
41
- loss: torch.nn.functional.binary_cross_entropy_with_logits
42
- amp: True
43
-
44
- test_configuration_common:
45
- normalize: True
46
- image_size: 352
47
- batch_size: 32
48
- sigmoid: True
49
- split: test
50
- label_support: True
51
-
52
- test_configuration:
53
-
54
- -
55
- name: pc
56
- metric: metrics.FixedIntervalMetrics
57
- test_dataset: phrasecut
58
- mask: text
59
-
60
- -
61
- name: pc-vis
62
- metric: metrics.FixedIntervalMetrics
63
- test_dataset: phrasecut
64
- mask: crop_blur_highlight352
65
- with_visual: True
66
- visual_only: True
67
-
68
-
69
- columns: [name,
70
- pc_fgiou_best, pc_miou_best, pc_fgiou_0.5,
71
- pc-vis_fgiou_best, pc-vis_miou_best, pc-vis_fgiou_0.5,
72
- duration]
73
-
74
-
75
- individual_configurations:
76
-
77
- - {name: rd64-uni}
78
- - {name: rd64-no-pretrain, not_pretrained: True, lr: 0.0003}
79
- - {name: rd64-no-negatives, negative_prob: 0.0}
80
- - {name: rd64-neg0.5, negative_prob: 0.5}
81
- - {name: rd64-no-visual, with_visual: False, mix: False}
82
- - {name: rd16-uni, reduce_dim: 16}
83
- - {name: rd64-layer3, extract_layers: [3], depth: 1}
84
- - {name: rd64-blur-highlight, mask: text_and_blur_highlight, test_configuration: {mask: blur_highlight}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clipseg/experiments/coco.yaml DELETED
@@ -1,101 +0,0 @@
1
- configuration:
2
- batch_size: 64
3
- optimizer: torch.optim.AdamW
4
-
5
- lr: 0.001
6
-
7
- trainer: experiment_setup.train_loop
8
- scorer: experiment_setup.score
9
- model: models.clipseg.CLIPDensePredT
10
-
11
- lr_scheduler: cosine
12
- T_max: 20000
13
- eta_min: 0.0001
14
-
15
- max_iterations: 20000
16
- val_interval: null
17
-
18
- # dataset
19
- dataset: datasets.coco_wrapper.COCOWrapper
20
- # split_mode: pascal_test
21
- split: train
22
- mask: text_and_blur3_highlight01
23
- image_size: 352
24
- normalize: True
25
- pre_crop_image_size: [sample, 1, 1.5]
26
- aug: 1new
27
-
28
- # general
29
- mix: True
30
- prompt: shuffle+
31
- norm_cond: True
32
- mix_text_min: 0.0
33
-
34
- # model
35
- out: 1
36
- extract_layers: [3, 7, 9]
37
- reduce_dim: 64
38
- depth: 3
39
- fix_shift: False
40
-
41
- loss: torch.nn.functional.binary_cross_entropy_with_logits
42
- amp: True
43
-
44
- test_configuration_common:
45
- normalize: True
46
- image_size: 352
47
- # max_iterations: 10
48
- batch_size: 8
49
- sigmoid: True
50
- test_dataset: coco
51
- metric: metrics.FixedIntervalMetrics
52
-
53
- test_configuration:
54
-
55
- -
56
- name: coco_t
57
- mask: text
58
-
59
- -
60
- name: coco_h
61
- mask: blur3_highlight01
62
-
63
- -
64
- name: coco_h2
65
- mask: crop_blur_highlight352
66
-
67
-
68
- columns: [i, name,
69
- coco_t_fgiou_best, coco_t_miou_best, coco_t_fgiou_0.5,
70
- coco_h_fgiou_best, coco_h_miou_best, coco_h_fgiou_0.5,
71
- coco_h2_fgiou_best, coco_h2_miou_best, coco_h2_fgiou_0.5, coco_h2_fgiou_best_t,
72
- train_loss, duration, date
73
- ]
74
-
75
- individual_configurations:
76
-
77
-
78
- - {name: rd64-7K-vit16-cbh-coco-0, version: 'ViT-B/16', fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
79
- - {name: rd64-7K-vit16-cbh-coco-1, version: 'ViT-B/16', fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
80
- - {name: rd64-7K-vit16-cbh-coco-2, version: 'ViT-B/16', fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
81
- - {name: rd64-7K-vit16-cbh-coco-3, version: 'ViT-B/16', fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
82
-
83
-
84
- - {name: rd64-7K-vit16-cbh-neg0.2-coco-0, version: 'ViT-B/16', negative_prob: 0.2, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
85
- - {name: rd64-7K-vit16-cbh-neg0.2-coco-1, version: 'ViT-B/16', negative_prob: 0.2, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
86
- - {name: rd64-7K-vit16-cbh-neg0.2-coco-2, version: 'ViT-B/16', negative_prob: 0.2, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
87
- - {name: rd64-7K-vit16-cbh-neg0.2-coco-3, version: 'ViT-B/16', negative_prob: 0.2, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
88
-
89
-
90
- # ViT
91
- - {name: vit64-7K-vit16-cbh-coco-0, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001}
92
- - {name: vit64-7K-vit16-cbh-coco-1, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001}
93
- - {name: vit64-7K-vit16-cbh-coco-2, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001}
94
- - {name: vit64-7K-vit16-cbh-coco-3, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001}
95
-
96
-
97
- # BASELINE
98
- - {name: bl64-7K-vit16-cbh-neg0.2-coco-0, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
99
- - {name: bl64-7K-vit16-cbh-neg0.2-coco-1, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
100
- - {name: bl64-7K-vit16-cbh-neg0.2-coco-2, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
101
- - {name: bl64-7K-vit16-cbh-neg0.2-coco-3, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clipseg/experiments/pascal_1shot.yaml DELETED
@@ -1,101 +0,0 @@
1
- configuration:
2
- batch_size: 64
3
- optimizer: torch.optim.AdamW
4
-
5
- lr: 0.001
6
-
7
- trainer: experiment_setup.train_loop
8
- scorer: experiment_setup.score
9
- model: models.clipseg.CLIPDensePredT
10
-
11
- lr_scheduler: cosine
12
- T_max: 20000
13
- eta_min: 0.0001
14
-
15
- max_iterations: 20000 # <-##########################################
16
- val_interval: null
17
-
18
- # dataset
19
- dataset: datasets.phrasecut.PhraseCut
20
- split_mode: pascal_test
21
- mode: train
22
- mask: text_and_crop_blur_highlight352
23
- image_size: 352
24
- normalize: True
25
- pre_crop_image_size: [sample, 1, 1.5]
26
- aug: 1new
27
- with_visual: True
28
- split: train
29
-
30
- # general
31
- mix: True
32
- prompt: shuffle+
33
- norm_cond: True
34
- mix_text_min: 0.0
35
-
36
- # model
37
- out: 1
38
- version: 'ViT-B/16'
39
- extract_layers: [3, 7, 9]
40
- reduce_dim: 64
41
- depth: 3
42
-
43
- loss: torch.nn.functional.binary_cross_entropy_with_logits
44
- amp: True
45
-
46
- test_configuration_common:
47
- normalize: True
48
- image_size: 352
49
- metric: metrics.FixedIntervalMetrics
50
- batch_size: 1
51
- test_dataset: pascal
52
- sigmoid: True
53
- # max_iterations: 250
54
-
55
- test_configuration:
56
-
57
- -
58
- name: pas_t
59
- mask: text
60
-
61
- -
62
- name: pas_h
63
- mask: blur3_highlight01
64
-
65
- -
66
- name: pas_h2
67
- mask: crop_blur_highlight352
68
-
69
-
70
- columns: [name,
71
- pas_t_fgiou_best, pas_t_miou_best, pas_t_fgiou_ct,
72
- pas_h_fgiou_best, pas_h_miou_best, pas_h_fgiou_ct,
73
- pas_h2_fgiou_best, pas_h2_miou_best, pas_h2_fgiou_ct, pas_h2_fgiou_best_t,
74
- train_loss, duration, date
75
- ]
76
-
77
- individual_configurations:
78
-
79
- - {name: rd64-uni-phrasepas5i-0, remove_classes: [pas5i, 0], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [0], custom_threshold: 0.24}}
80
- - {name: rd64-uni-phrasepas5i-1, remove_classes: [pas5i, 1], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [1], custom_threshold: 0.24}}
81
- - {name: rd64-uni-phrasepas5i-2, remove_classes: [pas5i, 2], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [2], custom_threshold: 0.24}}
82
- - {name: rd64-uni-phrasepas5i-3, remove_classes: [pas5i, 3], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [3], custom_threshold: 0.24}}
83
-
84
-
85
- - {name: rd64-phrasepas5i-0, remove_classes: [pas5i, 0], negative_prob: 0.0, test_configuration: {splits: [0], custom_threshold: 0.28}}
86
- - {name: rd64-phrasepas5i-1, remove_classes: [pas5i, 1], negative_prob: 0.0, test_configuration: {splits: [1], custom_threshold: 0.28}}
87
- - {name: rd64-phrasepas5i-2, remove_classes: [pas5i, 2], negative_prob: 0.0, test_configuration: {splits: [2], custom_threshold: 0.28}}
88
- - {name: rd64-phrasepas5i-3, remove_classes: [pas5i, 3], negative_prob: 0.0, test_configuration: {splits: [3], custom_threshold: 0.28}}
89
-
90
-
91
- # baseline
92
- - {name: bl64-phrasepas5i-0, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 0], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [0], custom_threshold: 0.24}}
93
- - {name: bl64-phrasepas5i-1, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 1], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [1], custom_threshold: 0.24}}
94
- - {name: bl64-phrasepas5i-2, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 2], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [2], custom_threshold: 0.24}}
95
- - {name: bl64-phrasepas5i-3, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 3], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [3], custom_threshold: 0.24}}
96
-
97
- # ViT
98
- - {name: vit64-uni-phrasepas5i-0, remove_classes: [pas5i, 0], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [0], custom_threshold: 0.02}}
99
- - {name: vit64-uni-phrasepas5i-1, remove_classes: [pas5i, 1], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [1], custom_threshold: 0.02}}
100
- - {name: vit64-uni-phrasepas5i-2, remove_classes: [pas5i, 2], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [2], custom_threshold: 0.02}}
101
- - {name: vit64-uni-phrasepas5i-3, remove_classes: [pas5i, 3], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [3], custom_threshold: 0.02}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clipseg/experiments/phrasecut.yaml DELETED
@@ -1,80 +0,0 @@
1
- configuration:
2
- batch_size: 64
3
- optimizer: torch.optim.AdamW
4
-
5
- lr: 0.001
6
-
7
- trainer: experiment_setup.train_loop
8
- scorer: experiment_setup.score
9
- model: models.clipseg.CLIPDensePredT
10
-
11
- lr_scheduler: cosine
12
- T_max: 20000
13
- eta_min: 0.0001
14
-
15
- max_iterations: 20000
16
- val_interval: null
17
-
18
- # dataset
19
- dataset: datasets.phrasecut.PhraseCut # <-----------------
20
- split_mode: pascal_test
21
- split: train
22
- mask: text_and_crop_blur_highlight352
23
- image_size: 352
24
- normalize: True
25
- pre_crop_image_size: [sample, 1, 1.5]
26
- aug: 1new
27
-
28
- # general
29
- mix: False # <-----------------
30
- prompt: shuffle+
31
- norm_cond: True
32
- mix_text_min: 0.0
33
-
34
- # model
35
- out: 1
36
- extract_layers: [3, 7, 9]
37
- reduce_dim: 64
38
- depth: 3
39
- fix_shift: False
40
-
41
- loss: torch.nn.functional.binary_cross_entropy_with_logits
42
- amp: True
43
-
44
- test_configuration_common:
45
- normalize: True
46
- image_size: 352
47
- batch_size: 32
48
- # max_iterations: 5
49
- # max_iterations: 150
50
-
51
- test_configuration:
52
-
53
- -
54
- name: pc # old: phrasecut
55
- metric: metrics.FixedIntervalMetrics
56
- test_dataset: phrasecut
57
- split: test
58
- mask: text
59
- label_support: True
60
- sigmoid: True
61
-
62
-
63
- columns: [i, name, pc_miou_0.3, pc_fgiou_0.3, pc_fgiou_0.5, pc_ap, duration, date]
64
-
65
-
66
- individual_configurations:
67
-
68
- # important ones
69
-
70
-
71
- - {name: rd64-uni, version: 'ViT-B/16', reduce_dim: 64, with_visual: True, negative_prob: 0.2, mix: True, mix_text_max: 0.5}
72
-
73
- # this was accedentally trained using old mask
74
- - {name: rd128-vit16-phrasecut, version: 'ViT-B/16', reduce_dim: 128, mask: text_and_blur3_highlight01}
75
- - {name: rd64-uni-novis, version: 'ViT-B/16', reduce_dim: 64, with_visual: False, negative_prob: 0.2, mix: False}
76
- # this was accedentally trained using old mask
77
- - {name: baseline3-vit16-phrasecut, model: models.clipseg.CLIPDenseBaseline, version: 'ViT-B/16', reduce_dim: 64, reduce2_dim: 64, mask: text_and_blur3_highlight01}
78
-
79
- - {name: vit64-uni, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, reduce_dim: 64, with_visual: True, only_visual: True, negative_prob: 0.2, mask: crop_blur_highlight352, lr: 0.0003}
80
- - {name: vit64-uni-novis, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, with_visual: False, reduce_dim: 64, lr: 0.0001}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clipseg/general_utils.py DELETED
@@ -1,272 +0,0 @@
1
- import json
2
- import inspect
3
- import torch
4
- import os
5
- import sys
6
- import yaml
7
- from shutil import copy, copytree
8
- from os.path import join, dirname, realpath, expanduser, isfile, isdir, basename
9
-
10
-
11
- class Logger(object):
12
-
13
- def __getattr__(self, k):
14
- return print
15
-
16
- log = Logger()
17
-
18
- def training_config_from_cli_args():
19
- experiment_name = sys.argv[1]
20
- experiment_id = int(sys.argv[2])
21
-
22
- yaml_config = yaml.load(open(f'experiments/{experiment_name}'), Loader=yaml.SafeLoader)
23
-
24
- config = yaml_config['configuration']
25
- config = {**config, **yaml_config['individual_configurations'][experiment_id]}
26
- config = AttributeDict(config)
27
- return config
28
-
29
-
30
- def score_config_from_cli_args():
31
- experiment_name = sys.argv[1]
32
- experiment_id = int(sys.argv[2])
33
-
34
-
35
- yaml_config = yaml.load(open(f'experiments/{experiment_name}'), Loader=yaml.SafeLoader)
36
-
37
- config = yaml_config['test_configuration_common']
38
-
39
- if type(yaml_config['test_configuration']) == list:
40
- test_id = int(sys.argv[3])
41
- config = {**config, **yaml_config['test_configuration'][test_id]}
42
- else:
43
- config = {**config, **yaml_config['test_configuration']}
44
-
45
- if 'test_configuration' in yaml_config['individual_configurations'][experiment_id]:
46
- config = {**config, **yaml_config['individual_configurations'][experiment_id]['test_configuration']}
47
-
48
- train_checkpoint_id = yaml_config['individual_configurations'][experiment_id]['name']
49
-
50
- config = AttributeDict(config)
51
- return config, train_checkpoint_id
52
-
53
-
54
- def get_from_repository(local_name, repo_files, integrity_check=None, repo_dir='~/dataset_repository',
55
- local_dir='~/datasets'):
56
- """ copies files from repository to local folder.
57
-
58
- repo_files: list of filenames or list of tuples [filename, target path]
59
-
60
- e.g. get_from_repository('MyDataset', [['data/dataset1.tar', 'other/path/ds03.tar'])
61
- will create a folder 'MyDataset' in local_dir, and extract the content of
62
- '<repo_dir>/data/dataset1.tar' to <local_dir>/MyDataset/other/path.
63
- """
64
-
65
- local_dir = realpath(join(expanduser(local_dir), local_name))
66
-
67
- dataset_exists = True
68
-
69
- # check if folder is available
70
- if not isdir(local_dir):
71
- dataset_exists = False
72
-
73
- if integrity_check is not None:
74
- try:
75
- integrity_ok = integrity_check(local_dir)
76
- except BaseException:
77
- integrity_ok = False
78
-
79
- if integrity_ok:
80
- log.hint('Passed custom integrity check')
81
- else:
82
- log.hint('Custom integrity check failed')
83
-
84
- dataset_exists = dataset_exists and integrity_ok
85
-
86
- if not dataset_exists:
87
-
88
- repo_dir = realpath(expanduser(repo_dir))
89
-
90
- for i, filename in enumerate(repo_files):
91
-
92
- if type(filename) == str:
93
- origin, target = filename, filename
94
- archive_target = join(local_dir, basename(origin))
95
- extract_target = join(local_dir)
96
- else:
97
- origin, target = filename
98
- archive_target = join(local_dir, dirname(target), basename(origin))
99
- extract_target = join(local_dir, dirname(target))
100
-
101
- archive_origin = join(repo_dir, origin)
102
-
103
- log.hint(f'copy: {archive_origin} to {archive_target}')
104
-
105
- # make sure the path exists
106
- os.makedirs(dirname(archive_target), exist_ok=True)
107
-
108
- if os.path.isfile(archive_target):
109
- # only copy if size differs
110
- if os.path.getsize(archive_target) != os.path.getsize(archive_origin):
111
- log.hint(f'file exists but filesize differs: target {os.path.getsize(archive_target)} vs. origin {os.path.getsize(archive_origin)}')
112
- copy(archive_origin, archive_target)
113
- else:
114
- copy(archive_origin, archive_target)
115
-
116
- extract_archive(archive_target, extract_target, noarchive_ok=True)
117
-
118
- # concurrent processes might have deleted the file
119
- if os.path.isfile(archive_target):
120
- os.remove(archive_target)
121
-
122
-
123
- def extract_archive(filename, target_folder=None, noarchive_ok=False):
124
- from subprocess import run, PIPE
125
-
126
- if filename.endswith('.tgz') or filename.endswith('.tar'):
127
- command = f'tar -xf {filename}'
128
- command += f' -C {target_folder}' if target_folder is not None else ''
129
- elif filename.endswith('.tar.gz'):
130
- command = f'tar -xzf {filename}'
131
- command += f' -C {target_folder}' if target_folder is not None else ''
132
- elif filename.endswith('zip'):
133
- command = f'unzip {filename}'
134
- command += f' -d {target_folder}' if target_folder is not None else ''
135
- else:
136
- if noarchive_ok:
137
- return
138
- else:
139
- raise ValueError(f'unsuppored file ending of {filename}')
140
-
141
- log.hint(command)
142
- result = run(command.split(), stdout=PIPE, stderr=PIPE)
143
- if result.returncode != 0:
144
- print(result.stdout, result.stderr)
145
-
146
-
147
- class AttributeDict(dict):
148
- """
149
- An extended dictionary that allows access to elements as atttributes and counts
150
- these accesses. This way, we know if some attributes were never used.
151
- """
152
-
153
- def __init__(self, *args, **kwargs):
154
- from collections import Counter
155
- super().__init__(*args, **kwargs)
156
- self.__dict__['counter'] = Counter()
157
-
158
- def __getitem__(self, k):
159
- self.__dict__['counter'][k] += 1
160
- return super().__getitem__(k)
161
-
162
- def __getattr__(self, k):
163
- self.__dict__['counter'][k] += 1
164
- return super().get(k)
165
-
166
- def __setattr__(self, k, v):
167
- return super().__setitem__(k, v)
168
-
169
- def __delattr__(self, k, v):
170
- return super().__delitem__(k, v)
171
-
172
- def unused_keys(self, exceptions=()):
173
- return [k for k in super().keys() if self.__dict__['counter'][k] == 0 and k not in exceptions]
174
-
175
- def assume_no_unused_keys(self, exceptions=()):
176
- if len(self.unused_keys(exceptions=exceptions)) > 0:
177
- log.warning('Unused keys:', self.unused_keys(exceptions=exceptions))
178
-
179
-
180
- def get_attribute(name):
181
- import importlib
182
-
183
- if name is None:
184
- raise ValueError('The provided attribute is None')
185
-
186
- name_split = name.split('.')
187
- mod = importlib.import_module('.'.join(name_split[:-1]))
188
- return getattr(mod, name_split[-1])
189
-
190
-
191
-
192
- def filter_args(input_args, default_args):
193
-
194
- updated_args = {k: input_args[k] if k in input_args else v for k, v in default_args.items()}
195
- used_args = {k: v for k, v in input_args.items() if k in default_args}
196
- unused_args = {k: v for k, v in input_args.items() if k not in default_args}
197
-
198
- return AttributeDict(updated_args), AttributeDict(used_args), AttributeDict(unused_args)
199
-
200
-
201
- def load_model(checkpoint_id, weights_file=None, strict=True, model_args='from_config', with_config=False):
202
-
203
- config = json.load(open(join('logs', checkpoint_id, 'config.json')))
204
-
205
- if model_args != 'from_config' and type(model_args) != dict:
206
- raise ValueError('model_args must either be "from_config" or a dictionary of values')
207
-
208
- model_cls = get_attribute(config['model'])
209
-
210
- # load model
211
- if model_args == 'from_config':
212
- _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters)
213
-
214
- model = model_cls(**model_args)
215
-
216
- if weights_file is None:
217
- weights_file = realpath(join('logs', checkpoint_id, 'weights.pth'))
218
- else:
219
- weights_file = realpath(join('logs', checkpoint_id, weights_file))
220
-
221
- if isfile(weights_file):
222
- weights = torch.load(weights_file)
223
- for _, w in weights.items():
224
- assert not torch.any(torch.isnan(w)), 'weights contain NaNs'
225
- model.load_state_dict(weights, strict=strict)
226
- else:
227
- raise FileNotFoundError(f'model checkpoint {weights_file} was not found')
228
-
229
- if with_config:
230
- return model, config
231
-
232
- return model
233
-
234
-
235
- class TrainingLogger(object):
236
-
237
- def __init__(self, model, log_dir, config=None, *args):
238
- super().__init__()
239
- self.model = model
240
- self.base_path = join(f'logs/{log_dir}') if log_dir is not None else None
241
-
242
- os.makedirs('logs/', exist_ok=True)
243
- os.makedirs(self.base_path, exist_ok=True)
244
-
245
- if config is not None:
246
- json.dump(config, open(join(self.base_path, 'config.json'), 'w'))
247
-
248
- def iter(self, i, **kwargs):
249
- if i % 100 == 0 and 'loss' in kwargs:
250
- loss = kwargs['loss']
251
- print(f'iteration {i}: loss {loss:.4f}')
252
-
253
- def save_weights(self, only_trainable=False, weight_file='weights.pth'):
254
- if self.model is None:
255
- raise AttributeError('You need to provide a model reference when initializing TrainingTracker to save weights.')
256
-
257
- weights_path = join(self.base_path, weight_file)
258
-
259
- weight_dict = self.model.state_dict()
260
-
261
- if only_trainable:
262
- weight_dict = {n: weight_dict[n] for n, p in self.model.named_parameters() if p.requires_grad}
263
-
264
- torch.save(weight_dict, weights_path)
265
- log.info(f'Saved weights to {weights_path}')
266
-
267
- def __enter__(self):
268
- return self
269
-
270
- def __exit__(self, type, value, traceback):
271
- """ automatically stop processes if used in a context manager """
272
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clipseg/metrics.py DELETED
@@ -1,271 +0,0 @@
1
- from torch.functional import Tensor
2
- from general_utils import log
3
- from collections import defaultdict
4
- import numpy as np
5
-
6
- import torch
7
- from torch.nn import functional as nnf
8
-
9
-
10
- class BaseMetric(object):
11
-
12
- def __init__(self, metric_names, pred_range=None, gt_index=0, pred_index=0, eval_intermediate=True,
13
- eval_validation=True):
14
- self._names = tuple(metric_names)
15
- self._eval_intermediate = eval_intermediate
16
- self._eval_validation = eval_validation
17
-
18
- self._pred_range = pred_range
19
- self._pred_index = pred_index
20
- self._gt_index = gt_index
21
-
22
- self.predictions = []
23
- self.ground_truths = []
24
-
25
- def eval_intermediate(self):
26
- return self._eval_intermediate
27
-
28
- def eval_validation(self):
29
- return self._eval_validation
30
-
31
- def names(self):
32
- return self._names
33
-
34
- def add(self, predictions, ground_truth):
35
- raise NotImplementedError
36
-
37
- def value(self):
38
- raise NotImplementedError
39
-
40
- def scores(self):
41
- # similar to value but returns dict
42
- value = self.value()
43
- if type(value) == dict:
44
- return value
45
- else:
46
- assert type(value) in {list, tuple}
47
- return list(zip(self.names(), self.value()))
48
-
49
- def _get_pred_gt(self, predictions, ground_truth):
50
- pred = predictions[self._pred_index]
51
- gt = ground_truth[self._gt_index]
52
-
53
- if self._pred_range is not None:
54
- pred = pred[:, self._pred_range[0]: self._pred_range[1]]
55
-
56
- return pred, gt
57
-
58
-
59
- class FixedIntervalMetrics(BaseMetric):
60
-
61
- def __init__(self, sigmoid=False, ignore_mask=False, resize_to=None,
62
- resize_pred=None, n_values=51, custom_threshold=None):
63
-
64
-
65
- super().__init__(('ap', 'best_fgiou', 'best_miou', 'fgiou0.5', 'fgiou0.1', 'mean_iou_0p5', 'mean_iou_0p1', 'best_biniou', 'biniou_0.5', 'fgiou_thresh'))
66
- self.intersections = []
67
- self.unions = []
68
- # self.threshold = threshold
69
- self.sigmoid = sigmoid
70
- self.resize_to = resize_to
71
- self.resize_pred = resize_pred # resize prediction to match ground truth
72
- self.class_count = defaultdict(lambda: 0)
73
- self.per_class = defaultdict(lambda : [0,0])
74
- self.ignore_mask = ignore_mask
75
- self.custom_threshold = custom_threshold
76
-
77
- self.scores_ap = []
78
- self.scores_iou = []
79
- self.gts, self.preds = [], []
80
- self.classes = []
81
-
82
- # [1:-1] ignores 0 and 1
83
- self.threshold_values = np.linspace(0, 1, n_values)[1:-1]
84
-
85
- self.metrics = dict(tp=[], fp=[], fn=[], tn=[])
86
-
87
- def add(self, pred, gt):
88
-
89
- pred_batch = pred[0].cpu()
90
-
91
- if self.sigmoid:
92
- pred_batch = torch.sigmoid(pred_batch)
93
-
94
- gt_batch = gt[0].cpu()
95
- mask_batch = gt[1] if len(gt) > 1 and not self.ignore_mask and gt[1].numel() > 0 else ([None] * len(pred_batch))
96
- cls_batch = gt[2] if len(gt) > 2 else [None] * len(pred_batch)
97
-
98
- if self.resize_to is not None:
99
- gt_batch = nnf.interpolate(gt_batch, self.resize_to, mode='nearest')
100
- pred_batch = nnf.interpolate(pred_batch, self.resize_to, mode='bilinear', align_corners=False)
101
-
102
- if isinstance(cls_batch, torch.Tensor):
103
- cls_batch = cls_batch.cpu().numpy().tolist()
104
-
105
- assert len(gt_batch) == len(pred_batch) == len(cls_batch), f'{len(gt_batch)} {len(pred_batch)} {len(cls_batch)}'
106
-
107
- for predictions, ground_truth, mask, cls in zip(pred_batch, gt_batch, mask_batch, cls_batch):
108
-
109
- if self.resize_pred:
110
- predictions = nnf.interpolate(predictions.unsqueeze(0).float(), size=ground_truth.size()[-2:], mode='bilinear', align_corners=True)
111
-
112
- p = predictions.flatten()
113
- g = ground_truth.flatten()
114
-
115
- assert len(p) == len(g)
116
-
117
- if mask is not None:
118
- m = mask.flatten().bool()
119
- p = p[m]
120
- g = g[m]
121
-
122
- p_sorted = p.sort()
123
- p = p_sorted.values
124
- g = g[p_sorted.indices]
125
-
126
- tps, fps, fns, tns = [], [], [], []
127
- for thresh in self.threshold_values:
128
-
129
- valid = torch.where(p > thresh)[0]
130
- if len(valid) > 0:
131
- n = int(valid[0])
132
- else:
133
- n = len(g)
134
-
135
- fn = int(g[:n].sum())
136
- tp = int(g[n:].sum())
137
- fns += [fn]
138
- tns += [n - fn]
139
- tps += [tp]
140
- fps += [len(g) - n - tp]
141
-
142
- self.metrics['tp'] += [tps]
143
- self.metrics['fp'] += [fps]
144
- self.metrics['fn'] += [fns]
145
- self.metrics['tn'] += [tns]
146
-
147
- self.classes += [cls.item() if isinstance(cls, torch.Tensor) else cls]
148
-
149
- def value(self):
150
-
151
- import time
152
- t_start = time.time()
153
-
154
- if set(self.classes) == set([None]):
155
- all_classes = None
156
- log.warning('classes were not provided, cannot compute mIoU')
157
- else:
158
- all_classes = set(int(c) for c in self.classes)
159
- # log.info(f'compute metrics for {len(all_classes)} classes')
160
-
161
- summed = {k: [sum([self.metrics[k][i][j]
162
- for i in range(len(self.metrics[k]))])
163
- for j in range(len(self.threshold_values))]
164
- for k in self.metrics.keys()}
165
-
166
- if all_classes is not None:
167
-
168
- assert len(self.classes) == len(self.metrics['tp']) == len(self.metrics['fn'])
169
- # group by class
170
- metrics_by_class = {c: {k: [] for k in self.metrics.keys()} for c in all_classes}
171
- for i in range(len(self.metrics['tp'])):
172
- for k in self.metrics.keys():
173
- metrics_by_class[self.classes[i]][k] += [self.metrics[k][i]]
174
-
175
- # sum over all instances within the classes
176
- summed_by_cls = {k: {c: np.array(metrics_by_class[c][k]).sum(0).tolist() for c in all_classes} for k in self.metrics.keys()}
177
-
178
-
179
- # Compute average precision
180
-
181
- assert (np.array(summed['fp']) + np.array(summed['tp']) ).sum(), 'no predictions is made'
182
-
183
- # only consider values where a prediction is made
184
- precisions = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j]) for j in range(len(self.threshold_values))
185
- if summed['tp'][j] + summed['fp'][j] > 0]
186
- recalls = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fn'][j]) for j in range(len(self.threshold_values))
187
- if summed['tp'][j] + summed['fp'][j] > 0]
188
-
189
- # remove duplicate recall-precision-pairs (and sort by recall value)
190
- recalls, precisions = zip(*sorted(list(set(zip(recalls, precisions))), key=lambda x: x[0]))
191
-
192
- from scipy.integrate import simps
193
- ap = simps(precisions, recalls)
194
-
195
- # Compute best IoU
196
- fgiou_scores = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j] + summed['fn'][j]) for j in range(len(self.threshold_values))]
197
-
198
- biniou_scores = [
199
- 0.5*(summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j] + summed['fn'][j])) +
200
- 0.5*(summed['tn'][j] / (1 + summed['tn'][j] + summed['fn'][j] + summed['fp'][j]))
201
- for j in range(len(self.threshold_values))
202
- ]
203
-
204
- index_0p5 = self.threshold_values.tolist().index(0.5)
205
- index_0p1 = self.threshold_values.tolist().index(0.1)
206
- index_0p2 = self.threshold_values.tolist().index(0.2)
207
- index_0p3 = self.threshold_values.tolist().index(0.3)
208
-
209
- if self.custom_threshold is not None:
210
- index_ct = self.threshold_values.tolist().index(self.custom_threshold)
211
-
212
- if all_classes is not None:
213
- # mean IoU
214
- mean_ious = [np.mean([summed_by_cls['tp'][c][j] / (1 + summed_by_cls['tp'][c][j] + summed_by_cls['fp'][c][j] + summed_by_cls['fn'][c][j])
215
- for c in all_classes])
216
- for j in range(len(self.threshold_values))]
217
-
218
- mean_iou_dict = {
219
- 'miou_best': max(mean_ious) if all_classes is not None else None,
220
- 'miou_0.5': mean_ious[index_0p5] if all_classes is not None else None,
221
- 'miou_0.1': mean_ious[index_0p1] if all_classes is not None else None,
222
- 'miou_0.2': mean_ious[index_0p2] if all_classes is not None else None,
223
- 'miou_0.3': mean_ious[index_0p3] if all_classes is not None else None,
224
- 'miou_best_t': self.threshold_values[np.argmax(mean_ious)],
225
- 'mean_iou_ct': mean_ious[index_ct] if all_classes is not None and self.custom_threshold is not None else None,
226
- 'mean_iou_scores': mean_ious,
227
- }
228
-
229
- print(f'metric computation on {(len(all_classes) if all_classes is not None else "no")} classes took {time.time() - t_start:.1f}s')
230
-
231
- return {
232
- 'ap': ap,
233
-
234
- # fgiou
235
- 'fgiou_best': max(fgiou_scores),
236
- 'fgiou_0.5': fgiou_scores[index_0p5],
237
- 'fgiou_0.1': fgiou_scores[index_0p1],
238
- 'fgiou_0.2': fgiou_scores[index_0p2],
239
- 'fgiou_0.3': fgiou_scores[index_0p3],
240
- 'fgiou_best_t': self.threshold_values[np.argmax(fgiou_scores)],
241
-
242
- # mean iou
243
-
244
-
245
- # biniou
246
- 'biniou_best': max(biniou_scores),
247
- 'biniou_0.5': biniou_scores[index_0p5],
248
- 'biniou_0.1': biniou_scores[index_0p1],
249
- 'biniou_0.2': biniou_scores[index_0p2],
250
- 'biniou_0.3': biniou_scores[index_0p3],
251
- 'biniou_best_t': self.threshold_values[np.argmax(biniou_scores)],
252
-
253
- # custom threshold
254
- 'fgiou_ct': fgiou_scores[index_ct] if self.custom_threshold is not None else None,
255
- 'biniou_ct': biniou_scores[index_ct] if self.custom_threshold is not None else None,
256
- 'ct': self.custom_threshold,
257
-
258
- # statistics
259
- 'fgiou_scores': fgiou_scores,
260
- 'biniou_scores': biniou_scores,
261
- 'precision_recall_curve': sorted(list(set(zip(recalls, precisions)))),
262
- 'summed_statistics': summed,
263
- 'summed_by_cls_statistics': summed_by_cls,
264
-
265
- **mean_iou_dict
266
- }
267
-
268
- # ('ap', 'best_fgiou', 'best_miou', 'fgiou0.5', 'fgiou0.1', 'mean_iou_0p5', 'mean_iou_0p1', 'best_biniou', 'biniou_0.5', 'fgiou_thresh'
269
-
270
- # return ap, best_fgiou, best_mean_iou, iou_0p5, iou_0p1, mean_iou_0p5, mean_iou_0p1, best_biniou, biniou0p5, best_fgiou_thresh, {'summed': summed, 'summed_by_cls': summed_by_cls}
271
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clipseg/models/clipseg.py DELETED
@@ -1,552 +0,0 @@
1
- import math
2
- from os.path import basename, dirname, join, isfile
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as nnf
6
- from torch.nn.modules.activation import ReLU
7
-
8
-
9
- def precompute_clip_vectors():
10
-
11
- from trails.initialization import init_dataset
12
- lvis = init_dataset('LVIS_OneShot3', split='train', mask='text_label', image_size=224, aug=1, normalize=True,
13
- reduce_factor=None, add_bar=False, negative_prob=0.5)
14
-
15
- all_names = list(lvis.category_names.values())
16
-
17
- import clip
18
- from models.clip_prompts import imagenet_templates
19
- clip_model = clip.load("ViT-B/32", device='cuda', jit=False)[0]
20
- prompt_vectors = {}
21
- for name in all_names[:100]:
22
- with torch.no_grad():
23
- conditionals = [t.format(name).replace('_', ' ') for t in imagenet_templates]
24
- text_tokens = clip.tokenize(conditionals).cuda()
25
- cond = clip_model.encode_text(text_tokens).cpu()
26
-
27
- for cond, vec in zip(conditionals, cond):
28
- prompt_vectors[cond] = vec.cpu()
29
-
30
- import pickle
31
-
32
- pickle.dump(prompt_vectors, open('precomputed_prompt_vectors.pickle', 'wb'))
33
-
34
-
35
- def get_prompt_list(prompt):
36
- if prompt == 'plain':
37
- return ['{}']
38
- elif prompt == 'fixed':
39
- return ['a photo of a {}.']
40
- elif prompt == 'shuffle':
41
- return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.']
42
- elif prompt == 'shuffle+':
43
- return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.',
44
- 'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.',
45
- 'a bad photo of a {}.', 'a photo of the {}.']
46
- elif prompt == 'shuffle_clip':
47
- from models.clip_prompts import imagenet_templates
48
- return imagenet_templates
49
- else:
50
- raise ValueError('Invalid value for prompt')
51
-
52
-
53
- def forward_multihead_attention(x, b, with_aff=False, attn_mask=None):
54
- """
55
- Simplified version of multihead attention (taken from torch source code but without tons of if clauses).
56
- The mlp and layer norm come from CLIP.
57
- x: input.
58
- b: multihead attention module.
59
- """
60
-
61
- x_ = b.ln_1(x)
62
- q, k, v = nnf.linear(x_, b.attn.in_proj_weight, b.attn.in_proj_bias).chunk(3, dim=-1)
63
- tgt_len, bsz, embed_dim = q.size()
64
-
65
- head_dim = embed_dim // b.attn.num_heads
66
- scaling = float(head_dim) ** -0.5
67
-
68
- q = q.contiguous().view(tgt_len, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1)
69
- k = k.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1)
70
- v = v.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1)
71
-
72
- q = q * scaling
73
-
74
- attn_output_weights = torch.bmm(q, k.transpose(1, 2)) # n_heads * batch_size, tokens^2, tokens^2
75
- if attn_mask is not None:
76
-
77
-
78
- attn_mask_type, attn_mask = attn_mask
79
- n_heads = attn_output_weights.size(0) // attn_mask.size(0)
80
- attn_mask = attn_mask.repeat(n_heads, 1)
81
-
82
- if attn_mask_type == 'cls_token':
83
- # the mask only affects similarities compared to the readout-token.
84
- attn_output_weights[:, 0, 1:] = attn_output_weights[:, 0, 1:] * attn_mask[None,...]
85
- # attn_output_weights[:, 0, 0] = 0*attn_output_weights[:, 0, 0]
86
-
87
- if attn_mask_type == 'all':
88
- # print(attn_output_weights.shape, attn_mask[:, None].shape)
89
- attn_output_weights[:, 1:, 1:] = attn_output_weights[:, 1:, 1:] * attn_mask[:, None]
90
-
91
-
92
- attn_output_weights = torch.softmax(attn_output_weights, dim=-1)
93
-
94
- attn_output = torch.bmm(attn_output_weights, v)
95
- attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
96
- attn_output = b.attn.out_proj(attn_output)
97
-
98
- x = x + attn_output
99
- x = x + b.mlp(b.ln_2(x))
100
-
101
- if with_aff:
102
- return x, attn_output_weights
103
- else:
104
- return x
105
-
106
-
107
- class CLIPDenseBase(nn.Module):
108
-
109
- def __init__(self, version, reduce_cond, reduce_dim, prompt, n_tokens):
110
- super().__init__()
111
-
112
- import clip
113
-
114
- # prec = torch.FloatTensor
115
- self.clip_model, _ = clip.load(version, device='cpu', jit=False)
116
- self.model = self.clip_model.visual
117
-
118
- # if not None, scale conv weights such that we obtain n_tokens.
119
- self.n_tokens = n_tokens
120
-
121
- for p in self.clip_model.parameters():
122
- p.requires_grad_(False)
123
-
124
- # conditional
125
- if reduce_cond is not None:
126
- self.reduce_cond = nn.Linear(512, reduce_cond)
127
- for p in self.reduce_cond.parameters():
128
- p.requires_grad_(False)
129
- else:
130
- self.reduce_cond = None
131
-
132
- self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
133
- self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
134
-
135
- self.reduce = nn.Linear(768, reduce_dim)
136
-
137
- self.prompt_list = get_prompt_list(prompt)
138
-
139
- # precomputed prompts
140
- import pickle
141
- if isfile('precomputed_prompt_vectors.pickle'):
142
- precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb'))
143
- self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()}
144
- else:
145
- self.precomputed_prompts = dict()
146
-
147
- def rescaled_pos_emb(self, new_size):
148
- assert len(new_size) == 2
149
-
150
- a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape)
151
- b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T
152
- return torch.cat([self.model.positional_embedding[:1], b])
153
-
154
- def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None):
155
-
156
-
157
- with torch.no_grad():
158
-
159
- inp_size = x_inp.shape[2:]
160
-
161
- if self.n_tokens is not None:
162
- stride2 = x_inp.shape[2] // self.n_tokens
163
- conv_weight2 = nnf.interpolate(self.model.conv1.weight, (stride2, stride2), mode='bilinear', align_corners=True)
164
- x = nnf.conv2d(x_inp, conv_weight2, bias=self.model.conv1.bias, stride=stride2, dilation=self.model.conv1.dilation)
165
- else:
166
- x = self.model.conv1(x_inp) # shape = [*, width, grid, grid]
167
-
168
- x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
169
- x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
170
-
171
- x = torch.cat([self.model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
172
-
173
- standard_n_tokens = 50 if self.model.conv1.kernel_size[0] == 32 else 197
174
-
175
- if x.shape[1] != standard_n_tokens:
176
- new_shape = int(math.sqrt(x.shape[1]-1))
177
- x = x + self.rescaled_pos_emb((new_shape, new_shape)).to(x.dtype)[None,:,:]
178
- else:
179
- x = x + self.model.positional_embedding.to(x.dtype)
180
-
181
- x = self.model.ln_pre(x)
182
-
183
- x = x.permute(1, 0, 2) # NLD -> LND
184
-
185
- activations, affinities = [], []
186
- for i, res_block in enumerate(self.model.transformer.resblocks):
187
-
188
- if mask is not None:
189
- mask_layer, mask_type, mask_tensor = mask
190
- if mask_layer == i or mask_layer == 'all':
191
- # import ipdb; ipdb.set_trace()
192
- size = int(math.sqrt(x.shape[0] - 1))
193
-
194
- attn_mask = (mask_type, nnf.interpolate(mask_tensor.unsqueeze(1).float(), (size, size)).view(mask_tensor.shape[0], size * size))
195
-
196
- else:
197
- attn_mask = None
198
- else:
199
- attn_mask = None
200
-
201
- x, aff_per_head = forward_multihead_attention(x, res_block, with_aff=True, attn_mask=attn_mask)
202
-
203
- if i in extract_layers:
204
- affinities += [aff_per_head]
205
-
206
- #if self.n_tokens is not None:
207
- # activations += [nnf.interpolate(x, inp_size, mode='bilinear', align_corners=True)]
208
- #else:
209
- activations += [x]
210
-
211
- if len(extract_layers) > 0 and i == max(extract_layers) and skip:
212
- print('early skip')
213
- break
214
-
215
- x = x.permute(1, 0, 2) # LND -> NLD
216
- x = self.model.ln_post(x[:, 0, :])
217
-
218
- if self.model.proj is not None:
219
- x = x @ self.model.proj
220
-
221
- return x, activations, affinities
222
-
223
- def sample_prompts(self, words, prompt_list=None):
224
-
225
- prompt_list = prompt_list if prompt_list is not None else self.prompt_list
226
-
227
- prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
228
- prompts = [prompt_list[i] for i in prompt_indices]
229
- return [promt.format(w) for promt, w in zip(prompts, words)]
230
-
231
- def get_cond_vec(self, conditional, batch_size):
232
- # compute conditional from a single string
233
- if conditional is not None and type(conditional) == str:
234
- cond = self.compute_conditional(conditional)
235
- cond = cond.repeat(batch_size, 1)
236
-
237
- # compute conditional from string list/tuple
238
- elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str:
239
- assert len(conditional) == batch_size
240
- cond = self.compute_conditional(conditional)
241
-
242
- # use conditional directly
243
- elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2:
244
- cond = conditional
245
-
246
- # compute conditional from image
247
- elif conditional is not None and type(conditional) == torch.Tensor:
248
- with torch.no_grad():
249
- cond, _, _ = self.visual_forward(conditional)
250
- else:
251
- raise ValueError('invalid conditional')
252
- return cond
253
-
254
- def compute_conditional(self, conditional):
255
- import clip
256
-
257
- dev = next(self.parameters()).device
258
-
259
- if type(conditional) in {list, tuple}:
260
- text_tokens = clip.tokenize(conditional).to(dev)
261
- cond = self.clip_model.encode_text(text_tokens)
262
- else:
263
- if conditional in self.precomputed_prompts:
264
- cond = self.precomputed_prompts[conditional].float().to(dev)
265
- else:
266
- text_tokens = clip.tokenize([conditional]).to(dev)
267
- cond = self.clip_model.encode_text(text_tokens)[0]
268
-
269
- if self.shift_vector is not None:
270
- return cond + self.shift_vector
271
- else:
272
- return cond
273
-
274
-
275
- def clip_load_untrained(version):
276
- assert version == 'ViT-B/16'
277
- from clip.model import CLIP
278
- from clip.clip import _MODELS, _download
279
- model = torch.jit.load(_download(_MODELS['ViT-B/16'])).eval()
280
- state_dict = model.state_dict()
281
-
282
- vision_width = state_dict["visual.conv1.weight"].shape[0]
283
- vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
284
- vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
285
- grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
286
- image_resolution = vision_patch_size * grid_size
287
- embed_dim = state_dict["text_projection"].shape[1]
288
- context_length = state_dict["positional_embedding"].shape[0]
289
- vocab_size = state_dict["token_embedding.weight"].shape[0]
290
- transformer_width = state_dict["ln_final.weight"].shape[0]
291
- transformer_heads = transformer_width // 64
292
- transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
293
-
294
- return CLIP(embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size,
295
- context_length, vocab_size, transformer_width, transformer_heads, transformer_layers)
296
-
297
-
298
- class CLIPDensePredT(CLIPDenseBase):
299
-
300
- def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed',
301
- extra_blocks=0, reduce_cond=None, fix_shift=False,
302
- learn_trans_conv_only=False, limit_to_clip_only=False, upsample=False,
303
- add_calibration=False, rev_activations=False, trans_conv=None, n_tokens=None):
304
-
305
- super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens)
306
- # device = 'cpu'
307
-
308
- self.extract_layers = extract_layers
309
- self.cond_layer = cond_layer
310
- self.limit_to_clip_only = limit_to_clip_only
311
- self.process_cond = None
312
- self.rev_activations = rev_activations
313
-
314
- depth = len(extract_layers)
315
-
316
- if add_calibration:
317
- self.calibration_conds = 1
318
-
319
- self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None
320
-
321
- self.add_activation1 = True
322
-
323
- self.version = version
324
-
325
- self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version]
326
-
327
- if fix_shift:
328
- # self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'clip_text_shift_vector.pth')), requires_grad=False)
329
- self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'shift_text_to_vis.pth')), requires_grad=False)
330
- # self.shift_vector = nn.Parameter(-1*torch.load(join(dirname(basename(__file__)), 'shift2.pth')), requires_grad=False)
331
- else:
332
- self.shift_vector = None
333
-
334
- if trans_conv is None:
335
- trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version]
336
- else:
337
- # explicitly define transposed conv kernel size
338
- trans_conv_ks = (trans_conv, trans_conv)
339
-
340
- self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
341
-
342
- assert len(self.extract_layers) == depth
343
-
344
- self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)])
345
- self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))])
346
- self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)])
347
-
348
- # refinement and trans conv
349
-
350
- if learn_trans_conv_only:
351
- for p in self.parameters():
352
- p.requires_grad_(False)
353
-
354
- for p in self.trans_conv.parameters():
355
- p.requires_grad_(True)
356
-
357
- self.prompt_list = get_prompt_list(prompt)
358
-
359
-
360
- def forward(self, inp_image, conditional=None, return_features=False, mask=None):
361
-
362
- assert type(return_features) == bool
363
-
364
- inp_image = inp_image.to(self.model.positional_embedding.device)
365
-
366
- if mask is not None:
367
- raise ValueError('mask not supported')
368
-
369
- # x_inp = normalize(inp_image)
370
- x_inp = inp_image
371
-
372
- bs, dev = inp_image.shape[0], x_inp.device
373
-
374
- cond = self.get_cond_vec(conditional, bs)
375
-
376
- visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers))
377
-
378
- activation1 = activations[0]
379
- activations = activations[1:]
380
-
381
- _activations = activations[::-1] if not self.rev_activations else activations
382
-
383
- a = None
384
- for i, (activation, block, reduce) in enumerate(zip(_activations, self.blocks, self.reduces)):
385
-
386
- if a is not None:
387
- a = reduce(activation) + a
388
- else:
389
- a = reduce(activation)
390
-
391
- if i == self.cond_layer:
392
- if self.reduce_cond is not None:
393
- cond = self.reduce_cond(cond)
394
-
395
- a = self.film_mul(cond) * a + self.film_add(cond)
396
-
397
- a = block(a)
398
-
399
- for block in self.extra_blocks:
400
- a = a + block(a)
401
-
402
- a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
403
-
404
- size = int(math.sqrt(a.shape[2]))
405
-
406
- a = a.view(bs, a.shape[1], size, size)
407
-
408
- a = self.trans_conv(a)
409
-
410
- if self.n_tokens is not None:
411
- a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear', align_corners=True)
412
-
413
- if self.upsample_proj is not None:
414
- a = self.upsample_proj(a)
415
- a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear')
416
-
417
- if return_features:
418
- return a, visual_q, cond, [activation1] + activations
419
- else:
420
- return a,
421
-
422
-
423
-
424
- class CLIPDensePredTMasked(CLIPDensePredT):
425
-
426
- def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4,
427
- prompt='fixed', extra_blocks=0, reduce_cond=None, fix_shift=False, learn_trans_conv_only=False,
428
- refine=None, limit_to_clip_only=False, upsample=False, add_calibration=False, n_tokens=None):
429
-
430
- super().__init__(version=version, extract_layers=extract_layers, cond_layer=cond_layer, reduce_dim=reduce_dim,
431
- n_heads=n_heads, prompt=prompt, extra_blocks=extra_blocks, reduce_cond=reduce_cond,
432
- fix_shift=fix_shift, learn_trans_conv_only=learn_trans_conv_only,
433
- limit_to_clip_only=limit_to_clip_only, upsample=upsample, add_calibration=add_calibration,
434
- n_tokens=n_tokens)
435
-
436
- def visual_forward_masked(self, img_s, seg_s):
437
- return super().visual_forward(img_s, mask=('all', 'cls_token', seg_s))
438
-
439
- def forward(self, img_q, cond_or_img_s, seg_s=None, return_features=False):
440
-
441
- if seg_s is None:
442
- cond = cond_or_img_s
443
- else:
444
- img_s = cond_or_img_s
445
-
446
- with torch.no_grad():
447
- cond, _, _ = self.visual_forward_masked(img_s, seg_s)
448
-
449
- return super().forward(img_q, cond, return_features=return_features)
450
-
451
-
452
-
453
- class CLIPDenseBaseline(CLIPDenseBase):
454
-
455
- def __init__(self, version='ViT-B/32', cond_layer=0,
456
- extract_layer=9, reduce_dim=128, reduce2_dim=None, prompt='fixed',
457
- reduce_cond=None, limit_to_clip_only=False, n_tokens=None):
458
-
459
- super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens)
460
- device = 'cpu'
461
-
462
- # self.cond_layer = cond_layer
463
- self.extract_layer = extract_layer
464
- self.limit_to_clip_only = limit_to_clip_only
465
- self.shift_vector = None
466
-
467
- self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version]
468
-
469
- assert reduce2_dim is not None
470
-
471
- self.reduce2 = nn.Sequential(
472
- nn.Linear(reduce_dim, reduce2_dim),
473
- nn.ReLU(),
474
- nn.Linear(reduce2_dim, reduce_dim)
475
- )
476
-
477
- trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version]
478
- self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
479
-
480
-
481
- def forward(self, inp_image, conditional=None, return_features=False):
482
-
483
- inp_image = inp_image.to(self.model.positional_embedding.device)
484
-
485
- # x_inp = normalize(inp_image)
486
- x_inp = inp_image
487
-
488
- bs, dev = inp_image.shape[0], x_inp.device
489
-
490
- cond = self.get_cond_vec(conditional, bs)
491
-
492
- visual_q, activations, affinities = self.visual_forward(x_inp, extract_layers=[self.extract_layer])
493
-
494
- a = activations[0]
495
- a = self.reduce(a)
496
- a = self.film_mul(cond) * a + self.film_add(cond)
497
-
498
- if self.reduce2 is not None:
499
- a = self.reduce2(a)
500
-
501
- # the original model would execute a transformer block here
502
-
503
- a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
504
-
505
- size = int(math.sqrt(a.shape[2]))
506
-
507
- a = a.view(bs, a.shape[1], size, size)
508
- a = self.trans_conv(a)
509
-
510
- if return_features:
511
- return a, visual_q, cond, activations
512
- else:
513
- return a,
514
-
515
-
516
- class CLIPSegMultiLabel(nn.Module):
517
-
518
- def __init__(self, model) -> None:
519
- super().__init__()
520
-
521
- from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC
522
-
523
- self.pascal_classes = VOC
524
-
525
- from models.clipseg import CLIPDensePredT
526
- from general_utils import load_model
527
- # self.clipseg = load_model('rd64-vit16-neg0.2-phrasecut', strict=False)
528
- self.clipseg = load_model(model, strict=False)
529
-
530
- self.clipseg.eval()
531
-
532
- def forward(self, x):
533
-
534
- bs = x.shape[0]
535
- out = torch.ones(21, bs, 352, 352).to(x.device) * -10
536
-
537
- for class_id, class_name in enumerate(self.pascal_classes):
538
-
539
- fac = 3 if class_name == 'background' else 1
540
-
541
- with torch.no_grad():
542
- pred = torch.sigmoid(self.clipseg(x, class_name)[0][:,0]) * fac
543
-
544
- out[class_id] += pred
545
-
546
-
547
- out = out.permute(1, 0, 2, 3)
548
-
549
- return out
550
-
551
- # construct output tensor
552
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clipseg/models/vitseg.py DELETED
@@ -1,286 +0,0 @@
1
- import math
2
- from posixpath import basename, dirname, join
3
- # import clip
4
- from clip.model import convert_weights
5
- import torch
6
- import json
7
- from torch import nn
8
- from torch.nn import functional as nnf
9
- from torch.nn.modules import activation
10
- from torch.nn.modules.activation import ReLU
11
- from torchvision import transforms
12
-
13
- normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
14
-
15
- from torchvision.models import ResNet
16
-
17
-
18
- def process_prompts(conditional, prompt_list, conditional_map):
19
- # DEPRECATED
20
-
21
- # randomly sample a synonym
22
- words = [conditional_map[int(i)] for i in conditional]
23
- words = [syns[torch.multinomial(torch.ones(len(syns)), 1, replacement=True).item()] for syns in words]
24
- words = [w.replace('_', ' ') for w in words]
25
-
26
- if prompt_list is not None:
27
- prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
28
- prompts = [prompt_list[i] for i in prompt_indices]
29
- else:
30
- prompts = ['a photo of {}'] * (len(words))
31
-
32
- return [promt.format(w) for promt, w in zip(prompts, words)]
33
-
34
-
35
- class VITDenseBase(nn.Module):
36
-
37
- def rescaled_pos_emb(self, new_size):
38
- assert len(new_size) == 2
39
-
40
- a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape)
41
- b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T
42
- return torch.cat([self.model.positional_embedding[:1], b])
43
-
44
- def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None):
45
-
46
- with torch.no_grad():
47
-
48
- x_inp = nnf.interpolate(x_inp, (384, 384))
49
-
50
- x = self.model.patch_embed(x_inp)
51
- cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
52
- if self.model.dist_token is None:
53
- x = torch.cat((cls_token, x), dim=1)
54
- else:
55
- x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
56
- x = self.model.pos_drop(x + self.model.pos_embed)
57
-
58
- activations = []
59
- for i, block in enumerate(self.model.blocks):
60
- x = block(x)
61
-
62
- if i in extract_layers:
63
- # permute to be compatible with CLIP
64
- activations += [x.permute(1,0,2)]
65
-
66
- x = self.model.norm(x)
67
- x = self.model.head(self.model.pre_logits(x[:, 0]))
68
-
69
- # again for CLIP compatibility
70
- # x = x.permute(1, 0, 2)
71
-
72
- return x, activations, None
73
-
74
- def sample_prompts(self, words, prompt_list=None):
75
-
76
- prompt_list = prompt_list if prompt_list is not None else self.prompt_list
77
-
78
- prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
79
- prompts = [prompt_list[i] for i in prompt_indices]
80
- return [promt.format(w) for promt, w in zip(prompts, words)]
81
-
82
- def get_cond_vec(self, conditional, batch_size):
83
- # compute conditional from a single string
84
- if conditional is not None and type(conditional) == str:
85
- cond = self.compute_conditional(conditional)
86
- cond = cond.repeat(batch_size, 1)
87
-
88
- # compute conditional from string list/tuple
89
- elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str:
90
- assert len(conditional) == batch_size
91
- cond = self.compute_conditional(conditional)
92
-
93
- # use conditional directly
94
- elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2:
95
- cond = conditional
96
-
97
- # compute conditional from image
98
- elif conditional is not None and type(conditional) == torch.Tensor:
99
- with torch.no_grad():
100
- cond, _, _ = self.visual_forward(conditional)
101
- else:
102
- raise ValueError('invalid conditional')
103
- return cond
104
-
105
- def compute_conditional(self, conditional):
106
- import clip
107
-
108
- dev = next(self.parameters()).device
109
-
110
- if type(conditional) in {list, tuple}:
111
- text_tokens = clip.tokenize(conditional).to(dev)
112
- cond = self.clip_model.encode_text(text_tokens)
113
- else:
114
- if conditional in self.precomputed_prompts:
115
- cond = self.precomputed_prompts[conditional].float().to(dev)
116
- else:
117
- text_tokens = clip.tokenize([conditional]).to(dev)
118
- cond = self.clip_model.encode_text(text_tokens)[0]
119
-
120
- return cond
121
-
122
-
123
- class VITDensePredT(VITDenseBase):
124
-
125
- def __init__(self, extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed',
126
- depth=3, extra_blocks=0, reduce_cond=None, fix_shift=False,
127
- learn_trans_conv_only=False, refine=None, limit_to_clip_only=False, upsample=False,
128
- add_calibration=False, process_cond=None, not_pretrained=False):
129
- super().__init__()
130
- # device = 'cpu'
131
-
132
- self.extract_layers = extract_layers
133
- self.cond_layer = cond_layer
134
- self.limit_to_clip_only = limit_to_clip_only
135
- self.process_cond = None
136
-
137
- if add_calibration:
138
- self.calibration_conds = 1
139
-
140
- self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None
141
-
142
- self.add_activation1 = True
143
-
144
- import timm
145
- self.model = timm.create_model('vit_base_patch16_384', pretrained=True)
146
- self.model.head = nn.Linear(768, 512 if reduce_cond is None else reduce_cond)
147
-
148
- for p in self.model.parameters():
149
- p.requires_grad_(False)
150
-
151
- import clip
152
- self.clip_model, _ = clip.load('ViT-B/16', device='cpu', jit=False)
153
- # del self.clip_model.visual
154
-
155
-
156
- self.token_shape = (14, 14)
157
-
158
- # conditional
159
- if reduce_cond is not None:
160
- self.reduce_cond = nn.Linear(512, reduce_cond)
161
- for p in self.reduce_cond.parameters():
162
- p.requires_grad_(False)
163
- else:
164
- self.reduce_cond = None
165
-
166
- # self.film = AVAILABLE_BLOCKS['film'](512, 128)
167
- self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
168
- self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
169
-
170
- # DEPRECATED
171
- # self.conditional_map = {c['id']: c['synonyms'] for c in json.load(open(cond_map))}
172
-
173
- assert len(self.extract_layers) == depth
174
-
175
- self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)])
176
- self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))])
177
- self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)])
178
-
179
- trans_conv_ks = (16, 16)
180
- self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
181
-
182
- # refinement and trans conv
183
-
184
- if learn_trans_conv_only:
185
- for p in self.parameters():
186
- p.requires_grad_(False)
187
-
188
- for p in self.trans_conv.parameters():
189
- p.requires_grad_(True)
190
-
191
- if prompt == 'fixed':
192
- self.prompt_list = ['a photo of a {}.']
193
- elif prompt == 'shuffle':
194
- self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.']
195
- elif prompt == 'shuffle+':
196
- self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.',
197
- 'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.',
198
- 'a bad photo of a {}.', 'a photo of the {}.']
199
- elif prompt == 'shuffle_clip':
200
- from models.clip_prompts import imagenet_templates
201
- self.prompt_list = imagenet_templates
202
-
203
- if process_cond is not None:
204
- if process_cond == 'clamp' or process_cond[0] == 'clamp':
205
-
206
- val = process_cond[1] if type(process_cond) in {list, tuple} else 0.2
207
-
208
- def clamp_vec(x):
209
- return torch.clamp(x, -val, val)
210
-
211
- self.process_cond = clamp_vec
212
-
213
- elif process_cond.endswith('.pth'):
214
-
215
- shift = torch.load(process_cond)
216
- def add_shift(x):
217
- return x + shift.to(x.device)
218
-
219
- self.process_cond = add_shift
220
-
221
- import pickle
222
- precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb'))
223
- self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()}
224
-
225
-
226
- def forward(self, inp_image, conditional=None, return_features=False, mask=None):
227
-
228
- assert type(return_features) == bool
229
-
230
- # inp_image = inp_image.to(self.model.positional_embedding.device)
231
-
232
- if mask is not None:
233
- raise ValueError('mask not supported')
234
-
235
- # x_inp = normalize(inp_image)
236
- x_inp = inp_image
237
-
238
- bs, dev = inp_image.shape[0], x_inp.device
239
-
240
- inp_image_size = inp_image.shape[2:]
241
-
242
- cond = self.get_cond_vec(conditional, bs)
243
-
244
- visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers))
245
-
246
- activation1 = activations[0]
247
- activations = activations[1:]
248
-
249
- a = None
250
- for i, (activation, block, reduce) in enumerate(zip(activations[::-1], self.blocks, self.reduces)):
251
-
252
- if a is not None:
253
- a = reduce(activation) + a
254
- else:
255
- a = reduce(activation)
256
-
257
- if i == self.cond_layer:
258
- if self.reduce_cond is not None:
259
- cond = self.reduce_cond(cond)
260
-
261
- a = self.film_mul(cond) * a + self.film_add(cond)
262
-
263
- a = block(a)
264
-
265
- for block in self.extra_blocks:
266
- a = a + block(a)
267
-
268
- a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
269
-
270
- size = int(math.sqrt(a.shape[2]))
271
-
272
- a = a.view(bs, a.shape[1], size, size)
273
-
274
- if self.trans_conv is not None:
275
- a = self.trans_conv(a)
276
-
277
- if self.upsample_proj is not None:
278
- a = self.upsample_proj(a)
279
- a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear')
280
-
281
- a = nnf.interpolate(a, inp_image_size)
282
-
283
- if return_features:
284
- return a, visual_q, cond, [activation1] + activations
285
- else:
286
- return a,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clipseg/overview.png DELETED
Binary file (54 kB)
 
clipseg/score.py DELETED
@@ -1,453 +0,0 @@
1
- from torch.functional import Tensor
2
-
3
- import torch
4
- import inspect
5
- import json
6
- import yaml
7
- import time
8
- import sys
9
-
10
- from general_utils import log
11
-
12
- import numpy as np
13
- from os.path import expanduser, join, isfile, realpath
14
-
15
- from torch.utils.data import DataLoader
16
-
17
- from metrics import FixedIntervalMetrics
18
-
19
- from general_utils import load_model, log, score_config_from_cli_args, AttributeDict, get_attribute, filter_args
20
-
21
-
22
- DATASET_CACHE = dict()
23
-
24
- def load_model(checkpoint_id, weights_file=None, strict=True, model_args='from_config', with_config=False, ignore_weights=False):
25
-
26
- config = json.load(open(join('logs', checkpoint_id, 'config.json')))
27
-
28
- if model_args != 'from_config' and type(model_args) != dict:
29
- raise ValueError('model_args must either be "from_config" or a dictionary of values')
30
-
31
- model_cls = get_attribute(config['model'])
32
-
33
- # load model
34
- if model_args == 'from_config':
35
- _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters)
36
-
37
- model = model_cls(**model_args)
38
-
39
- if weights_file is None:
40
- weights_file = realpath(join('logs', checkpoint_id, 'weights.pth'))
41
- else:
42
- weights_file = realpath(join('logs', checkpoint_id, weights_file))
43
-
44
- if isfile(weights_file) and not ignore_weights:
45
- weights = torch.load(weights_file)
46
- for _, w in weights.items():
47
- assert not torch.any(torch.isnan(w)), 'weights contain NaNs'
48
- model.load_state_dict(weights, strict=strict)
49
- else:
50
- if not ignore_weights:
51
- raise FileNotFoundError(f'model checkpoint {weights_file} was not found')
52
-
53
- if with_config:
54
- return model, config
55
-
56
- return model
57
-
58
-
59
- def compute_shift2(model, datasets, seed=123, repetitions=1):
60
- """ computes shift """
61
-
62
- model.eval()
63
- model.cuda()
64
-
65
- import random
66
- random.seed(seed)
67
-
68
- preds, gts = [], []
69
- for i_dataset, dataset in enumerate(datasets):
70
-
71
- loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False, drop_last=False)
72
-
73
- max_iterations = int(repetitions * len(dataset.dataset.data_list))
74
-
75
- with torch.no_grad():
76
-
77
- i, losses = 0, []
78
- for i_all, (data_x, data_y) in enumerate(loader):
79
-
80
- data_x = [v.cuda(non_blocking=True) if v is not None else v for v in data_x]
81
- data_y = [v.cuda(non_blocking=True) if v is not None else v for v in data_y]
82
-
83
- pred, = model(data_x[0], data_x[1], data_x[2])
84
- preds += [pred.detach()]
85
- gts += [data_y]
86
-
87
- i += 1
88
- if max_iterations and i >= max_iterations:
89
- break
90
-
91
- from metrics import FixedIntervalMetrics
92
- n_values = 51
93
- thresholds = np.linspace(0, 1, n_values)[1:-1]
94
- metric = FixedIntervalMetrics(resize_pred=True, sigmoid=True, n_values=n_values)
95
-
96
- for p, y in zip(preds, gts):
97
- metric.add(p.unsqueeze(1), y)
98
-
99
- best_idx = np.argmax(metric.value()['fgiou_scores'])
100
- best_thresh = thresholds[best_idx]
101
-
102
- return best_thresh
103
-
104
-
105
- def get_cached_pascal_pfe(split, config):
106
- from datasets.pfe_dataset import PFEPascalWrapper
107
- try:
108
- dataset = DATASET_CACHE[(split, config.image_size, config.label_support, config.mask)]
109
- except KeyError:
110
- dataset = PFEPascalWrapper(mode='val', split=split, mask=config.mask, image_size=config.image_size, label_support=config.label_support)
111
- DATASET_CACHE[(split, config.image_size, config.label_support, config.mask)] = dataset
112
- return dataset
113
-
114
-
115
-
116
-
117
- def main():
118
- config, train_checkpoint_id = score_config_from_cli_args()
119
-
120
- metrics = score(config, train_checkpoint_id, None)
121
-
122
- for dataset in metrics.keys():
123
- for k in metrics[dataset]:
124
- if type(metrics[dataset][k]) in {float, int}:
125
- print(dataset, f'{k:<16} {metrics[dataset][k]:.3f}')
126
-
127
-
128
- def score(config, train_checkpoint_id, train_config):
129
-
130
- config = AttributeDict(config)
131
-
132
- print(config)
133
-
134
- # use training dataset and loss
135
- train_config = AttributeDict(json.load(open(f'logs/{train_checkpoint_id}/config.json')))
136
-
137
- cp_str = f'_{config.iteration_cp}' if config.iteration_cp is not None else ''
138
-
139
-
140
- model_cls = get_attribute(train_config['model'])
141
-
142
- _, model_args, _ = filter_args(train_config, inspect.signature(model_cls).parameters)
143
-
144
- model_args = {**model_args, **{k: config[k] for k in ['process_cond', 'fix_shift'] if k in config}}
145
-
146
- strict_models = {'ConditionBase4', 'PFENetWrapper'}
147
- model = load_model(train_checkpoint_id, strict=model_cls.__name__ in strict_models, model_args=model_args,
148
- weights_file=f'weights{cp_str}.pth', )
149
-
150
-
151
- model.eval()
152
- model.cuda()
153
-
154
- metric_args = dict()
155
-
156
- if 'threshold' in config:
157
- if config.metric.split('.')[-1] == 'SkLearnMetrics':
158
- metric_args['threshold'] = config.threshold
159
-
160
- if 'resize_to' in config:
161
- metric_args['resize_to'] = config.resize_to
162
-
163
- if 'sigmoid' in config:
164
- metric_args['sigmoid'] = config.sigmoid
165
-
166
- if 'custom_threshold' in config:
167
- metric_args['custom_threshold'] = config.custom_threshold
168
-
169
- if config.test_dataset == 'pascal':
170
-
171
- loss_fn = get_attribute(train_config.loss)
172
- # assume that if no split is specified in train_config, test on all splits,
173
-
174
- if 'splits' in config:
175
- splits = config.splits
176
- else:
177
- if 'split' in train_config and type(train_config.split) == int:
178
- # unless train_config has a split set, in that case assume train mode in training
179
- splits = [train_config.split]
180
- assert train_config.mode == 'train'
181
- else:
182
- splits = [0,1,2,3]
183
-
184
- log.info('Test on these splits', splits)
185
-
186
- scores = dict()
187
- for split in splits:
188
-
189
- shift = config.shift if 'shift' in config else 0
190
-
191
- # automatic shift
192
- if shift == 'auto':
193
- shift_compute_t = time.time()
194
- shift = compute_shift2(model, [get_cached_pascal_pfe(s, config) for s in range(4) if s != split], repetitions=config.compute_shift_fac)
195
- log.info(f'Best threshold is {shift}, computed on splits: {[s for s in range(4) if s != split]}, took {time.time() - shift_compute_t:.1f}s')
196
-
197
- dataset = get_cached_pascal_pfe(split, config)
198
-
199
- eval_start_t = time.time()
200
-
201
- loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False, drop_last=False)
202
-
203
- assert config.batch_size is None or config.batch_size == 1, 'When PFE Dataset is used, batch size must be 1'
204
-
205
- metric = FixedIntervalMetrics(resize_pred=True, sigmoid=True, custom_threshold=shift, **metric_args)
206
-
207
- with torch.no_grad():
208
-
209
- i, losses = 0, []
210
- for i_all, (data_x, data_y) in enumerate(loader):
211
-
212
- data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x]
213
- data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y]
214
-
215
- if config.mask == 'separate': # for old CondBase model
216
- pred, = model(data_x[0], data_x[1], data_x[2])
217
- else:
218
- # assert config.mask in {'text', 'highlight'}
219
- pred, _, _, _ = model(data_x[0], data_x[1], return_features=True)
220
-
221
- # loss = loss_fn(pred, data_y[0])
222
- metric.add(pred.unsqueeze(1) + shift, data_y)
223
-
224
- # losses += [float(loss)]
225
-
226
- i += 1
227
- if config.max_iterations and i >= config.max_iterations:
228
- break
229
-
230
- #scores[split] = {m: s for m, s in zip(metric.names(), metric.value())}
231
-
232
- log.info(f'Dataset length: {len(dataset)}, took {time.time() - eval_start_t:.1f}s to evaluate.')
233
-
234
- print(metric.value()['mean_iou_scores'])
235
-
236
- scores[split] = metric.scores()
237
-
238
- log.info(f'Completed split {split}')
239
-
240
- key_prefix = config['name'] if 'name' in config else 'pas'
241
-
242
- all_keys = set.intersection(*[set(v.keys()) for v in scores.values()])
243
-
244
- valid_keys = [k for k in all_keys if all(v[k] is not None and isinstance(v[k], (int, float, np.float)) for v in scores.values())]
245
-
246
- return {key_prefix: {k: np.mean([s[k] for s in scores.values()]) for k in valid_keys}}
247
-
248
-
249
- if config.test_dataset == 'coco':
250
- from datasets.coco_wrapper import COCOWrapper
251
-
252
- coco_dataset = COCOWrapper('test', fold=train_config.fold, image_size=train_config.image_size, mask=config.mask,
253
- with_class_label=True)
254
-
255
- log.info('Dataset length', len(coco_dataset))
256
- loader = DataLoader(coco_dataset, batch_size=config.batch_size, num_workers=2, shuffle=False, drop_last=False)
257
-
258
- metric = get_attribute(config.metric)(resize_pred=True, **metric_args)
259
-
260
- shift = config.shift if 'shift' in config else 0
261
-
262
- with torch.no_grad():
263
-
264
- i, losses = 0, []
265
- for i_all, (data_x, data_y) in enumerate(loader):
266
- data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x]
267
- data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y]
268
-
269
- if config.mask == 'separate': # for old CondBase model
270
- pred, = model(data_x[0], data_x[1], data_x[2])
271
- else:
272
- # assert config.mask in {'text', 'highlight'}
273
- pred, _, _, _ = model(data_x[0], data_x[1], return_features=True)
274
-
275
- metric.add([pred + shift], data_y)
276
-
277
- i += 1
278
- if config.max_iterations and i >= config.max_iterations:
279
- break
280
-
281
- key_prefix = config['name'] if 'name' in config else 'coco'
282
- return {key_prefix: metric.scores()}
283
- #return {key_prefix: {k: v for k, v in zip(metric.names(), metric.value())}}
284
-
285
-
286
- if config.test_dataset == 'phrasecut':
287
- from datasets.phrasecut import PhraseCut
288
-
289
- only_visual = config.only_visual is not None and config.only_visual
290
- with_visual = config.with_visual is not None and config.with_visual
291
-
292
- dataset = PhraseCut('test',
293
- image_size=train_config.image_size,
294
- mask=config.mask,
295
- with_visual=with_visual, only_visual=only_visual, aug_crop=False,
296
- aug_color=False)
297
-
298
- loader = DataLoader(dataset, batch_size=config.batch_size, num_workers=2, shuffle=False, drop_last=False)
299
- metric = get_attribute(config.metric)(resize_pred=True, **metric_args)
300
-
301
- shift = config.shift if 'shift' in config else 0
302
-
303
-
304
- with torch.no_grad():
305
-
306
- i, losses = 0, []
307
- for i_all, (data_x, data_y) in enumerate(loader):
308
- data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x]
309
- data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y]
310
-
311
- pred, _, _, _ = model(data_x[0], data_x[1], return_features=True)
312
- metric.add([pred + shift], data_y)
313
-
314
- i += 1
315
- if config.max_iterations and i >= config.max_iterations:
316
- break
317
-
318
- key_prefix = config['name'] if 'name' in config else 'phrasecut'
319
- return {key_prefix: metric.scores()}
320
- #return {key_prefix: {k: v for k, v in zip(metric.names(), metric.value())}}
321
-
322
- if config.test_dataset == 'pascal_zs':
323
- from third_party.JoEm.model.metric import Evaluator
324
- from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC
325
- from datasets.pascal_zeroshot import PascalZeroShot, PASCAL_VOC_CLASSES_ZS
326
-
327
- from models.clipseg import CLIPSegMultiLabel
328
-
329
- n_unseen = train_config.remove_classes[1]
330
-
331
- pz = PascalZeroShot('val', n_unseen, image_size=352)
332
- m = CLIPSegMultiLabel(model=train_config.name).cuda()
333
- m.eval();
334
-
335
- print(len(pz), n_unseen)
336
- print('training removed', [c for class_set in PASCAL_VOC_CLASSES_ZS[:n_unseen // 2] for c in class_set])
337
-
338
- print('unseen', [VOC[i] for i in get_unseen_idx(n_unseen)])
339
- print('seen', [VOC[i] for i in get_seen_idx(n_unseen)])
340
-
341
- loader = DataLoader(pz, batch_size=8)
342
- evaluator = Evaluator(21, get_unseen_idx(n_unseen), get_seen_idx(n_unseen))
343
-
344
- for i, (data_x, data_y) in enumerate(loader):
345
- pred = m(data_x[0].cuda())
346
- evaluator.add_batch(data_y[0].numpy(), pred.argmax(1).cpu().detach().numpy())
347
-
348
- if config.max_iter is not None and i > config.max_iter:
349
- break
350
-
351
- scores = evaluator.Mean_Intersection_over_Union()
352
- key_prefix = config['name'] if 'name' in config else 'pas_zs'
353
-
354
- return {key_prefix: {k: scores[k] for k in ['seen', 'unseen', 'harmonic', 'overall']}}
355
-
356
- elif config.test_dataset in {'same_as_training', 'affordance'}:
357
- loss_fn = get_attribute(train_config.loss)
358
-
359
- metric_cls = get_attribute(config.metric)
360
- metric = metric_cls(**metric_args)
361
-
362
- if config.test_dataset == 'same_as_training':
363
- dataset_cls = get_attribute(train_config.dataset)
364
- elif config.test_dataset == 'affordance':
365
- dataset_cls = get_attribute('datasets.lvis_oneshot3.LVIS_Affordance')
366
- dataset_name = 'aff'
367
- else:
368
- dataset_cls = get_attribute('datasets.lvis_oneshot3.LVIS_OneShot')
369
- dataset_name = 'lvis'
370
-
371
- _, dataset_args, _ = filter_args(config, inspect.signature(dataset_cls).parameters)
372
-
373
- dataset_args['image_size'] = train_config.image_size # explicitly use training image size for evaluation
374
-
375
- if model.__class__.__name__ == 'PFENetWrapper':
376
- dataset_args['image_size'] = config.image_size
377
-
378
- log.info('init dataset', str(dataset_cls))
379
- dataset = dataset_cls(**dataset_args)
380
-
381
- log.info(f'Score on {model.__class__.__name__} on {dataset_cls.__name__}')
382
-
383
- data_loader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size, shuffle=config.shuffle)
384
-
385
- # explicitly set prompts
386
- if config.prompt == 'plain':
387
- model.prompt_list = ['{}']
388
- elif config.prompt == 'fixed':
389
- model.prompt_list = ['a photo of a {}.']
390
- elif config.prompt == 'shuffle':
391
- model.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.']
392
- elif config.prompt == 'shuffle_clip':
393
- from models.clip_prompts import imagenet_templates
394
- model.prompt_list = imagenet_templates
395
-
396
- config.assume_no_unused_keys(exceptions=['max_iterations'])
397
-
398
- t_start = time.time()
399
-
400
- with torch.no_grad(): # TODO: switch to inference_mode (torch 1.9)
401
- i, losses = 0, []
402
- for data_x, data_y in data_loader:
403
-
404
- data_x = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_x]
405
- data_y = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_y]
406
-
407
- if model.__class__.__name__ in {'ConditionBase4', 'PFENetWrapper'}:
408
- pred, = model(data_x[0], data_x[1], data_x[2])
409
- visual_q = None
410
- else:
411
- pred, visual_q, _, _ = model(data_x[0], data_x[1], return_features=True)
412
-
413
- loss = loss_fn(pred, data_y[0])
414
-
415
- metric.add([pred], data_y)
416
-
417
- losses += [float(loss)]
418
-
419
- i += 1
420
- if config.max_iterations and i >= config.max_iterations:
421
- break
422
-
423
- # scores = {m: s for m, s in zip(metric.names(), metric.value())}
424
- scores = metric.scores()
425
-
426
- keys = set(scores.keys())
427
- if dataset.negative_prob > 0 and 'mIoU' in keys:
428
- keys.remove('mIoU')
429
-
430
- name_mask = dataset.mask.replace('text_label', 'txt')[:3]
431
- name_neg = '' if dataset.negative_prob == 0 else '_' + str(dataset.negative_prob)
432
-
433
- score_name = config.name if 'name' in config else f'{dataset_name}_{name_mask}{name_neg}'
434
-
435
- scores = {score_name: {k: v for k,v in scores.items() if k in keys}}
436
- scores[score_name].update({'test_loss': np.mean(losses)})
437
-
438
- log.info(f'Evaluation took {time.time() - t_start:.1f}s')
439
-
440
- return scores
441
- else:
442
- raise ValueError('invalid test dataset')
443
-
444
-
445
-
446
-
447
-
448
-
449
-
450
-
451
-
452
- if __name__ == '__main__':
453
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clipseg/setup.py DELETED
@@ -1,30 +0,0 @@
1
- from setuptools import setup
2
-
3
- with open("README.md", "r", encoding="utf-8") as readme_file:
4
- readme = readme_file.read()
5
-
6
- requirements = [
7
- "numpy",
8
- "scipy",
9
- "matplotlib",
10
- "torch",
11
- "torchvision",
12
- "opencv-python",
13
- "CLIP @ git+https://github.com/openai/CLIP.git"
14
- ]
15
-
16
- setup(
17
- name='clipseg',
18
- packages=['clipseg'],
19
- package_dir={'clipseg': 'models'},
20
- package_data={'clipseg': [
21
- "../weights/*.pth",
22
- ]},
23
- version='0.0.1',
24
- url='https://github.com/timojl/clipseg',
25
- python_requires='>=3.9',
26
- install_requires=requirements,
27
- description='This repository contains the code used in the paper "Image Segmentation Using Text and Image Prompts".',
28
- long_description=readme,
29
- long_description_content_type="text/markdown",
30
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clipseg/training.py DELETED
@@ -1,266 +0,0 @@
1
- import torch
2
- import inspect
3
- import json
4
- import yaml
5
- import math
6
- import os
7
- import sys
8
-
9
- from general_utils import log
10
-
11
- import numpy as np
12
- from functools import partial
13
- from os.path import expanduser, join, isfile, basename
14
-
15
- from torch.cuda.amp import autocast, GradScaler
16
- from torch.optim.lr_scheduler import LambdaLR
17
- from contextlib import nullcontext
18
- from torch.utils.data import DataLoader
19
-
20
- from general_utils import TrainingLogger, get_attribute, filter_args, log, training_config_from_cli_args
21
-
22
-
23
- def cosine_warmup_lr(i, warmup=10, max_iter=90):
24
- """ Cosine LR with Warmup """
25
- if i < warmup:
26
- return (i+1)/(warmup+1)
27
- else:
28
- return 0.5 + 0.5*math.cos(math.pi*(((i-warmup)/(max_iter- warmup))))
29
-
30
-
31
- def validate(model, dataset, config):
32
- data_loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False)
33
-
34
- metric_class, use_metric = config.val_metric_class, config.use_val_metric
35
- loss_fn = get_attribute(config.loss)
36
-
37
- model.eval()
38
- model.cuda()
39
-
40
- if metric_class is not None:
41
- metric = get_attribute(metric_class)()
42
-
43
- with torch.no_grad():
44
-
45
- i, losses = 0, []
46
- for data_x, data_y in data_loader:
47
-
48
- data_x = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_x]
49
- data_y = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_y]
50
-
51
- prompts = model.sample_prompts(data_x[1], prompt_list=('a photo of a {}',))
52
- pred, visual_q, _, _ = model(data_x[0], prompts, return_features=True)
53
-
54
- if metric_class is not None:
55
- metric.add([pred], data_y)
56
-
57
- # pred = model(data_x[0], prompts)
58
- # loss = loss_fn(pred[0], data_y[0])
59
- loss = loss_fn(pred, data_y[0])
60
- losses += [float(loss)]
61
-
62
- i += 1
63
-
64
- if config.val_max_iterations is not None and i > config.val_max_iterations:
65
- break
66
-
67
- if use_metric is None:
68
- return np.mean(losses), {}, False
69
- else:
70
- metric_scores = {m: s for m, s in zip(metric.names(), metric.value())} if metric is not None else {}
71
- return np.mean(losses), metric_scores, True
72
-
73
-
74
- def main():
75
-
76
- config = training_config_from_cli_args()
77
-
78
- val_interval, best_val_loss, best_val_score = config.val_interval, float('inf'), float('-inf')
79
-
80
- model_cls = get_attribute(config.model)
81
- _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters)
82
- model = model_cls(**model_args).cuda()
83
-
84
- dataset_cls = get_attribute(config.dataset)
85
- _, dataset_args, _ = filter_args(config, inspect.signature(dataset_cls).parameters)
86
-
87
- dataset = dataset_cls(**dataset_args)
88
-
89
- log.info(f'Train dataset {dataset.__class__.__name__} (length: {len(dataset)})')
90
-
91
- if val_interval is not None:
92
- dataset_val_args = {k[4:]: v for k,v in config.items() if k.startswith('val_') and k != 'val_interval'}
93
- _, dataset_val_args, _ = filter_args(dataset_val_args, inspect.signature(dataset_cls).parameters)
94
- print('val args', {**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args})
95
-
96
- dataset_val = dataset_cls(**{**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args})
97
-
98
- # optimizer
99
- opt_cls = get_attribute(config.optimizer)
100
- if config.optimize == 'torch.optim.SGD':
101
- opt_args = {'momentum': config.momentum if 'momentum' in config else 0}
102
- else:
103
- opt_args = {}
104
- opt = opt_cls(model.parameters(), lr=config.lr, **opt_args)
105
-
106
- if config.lr_scheduler == 'cosine':
107
- assert config.T_max is not None and config.eta_min is not None
108
- lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, config.T_max, config.eta_min)
109
- elif config.lr_scheduler == 'warmup_cosine':
110
- lr_scheduler = LambdaLR(opt, partial(cosine_warmup_lr, max_iter=(config.max_iterations), warmup=config.warmup))
111
- else:
112
- lr_scheduler = None
113
-
114
- batch_size, max_iterations = config.batch_size, config.max_iterations
115
-
116
- loss_fn = get_attribute(config.loss)
117
-
118
- if config.amp:
119
- log.info('Using AMP')
120
- autocast_fn = autocast
121
- scaler = GradScaler()
122
- else:
123
- autocast_fn, scaler = nullcontext, None
124
-
125
-
126
- save_only_trainable = True
127
- data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=4)
128
-
129
- # disable config when hyperparam. opt. to avoid writing logs.
130
- tracker_config = config if not config.hyperparameter_optimization else None
131
-
132
- with TrainingLogger(log_dir=config.name, model=model, config=tracker_config) as logger:
133
-
134
- i = 0
135
- while True:
136
- for data_x, data_y in data_loader:
137
-
138
- # between caption and output feature.
139
- # 1. Sample random captions
140
- # 2. Check alignment with CLIP
141
-
142
- # randomly mix text and visual support conditionals
143
- if config.mix:
144
-
145
- assert config.mask.startswith('text_and')
146
-
147
- with autocast_fn():
148
- # data_x[1] = text label
149
- prompts = model.sample_prompts(data_x[1])
150
-
151
- # model.clip_model()
152
-
153
- text_cond = model.compute_conditional(prompts)
154
- if model.__class__.__name__ == 'CLIPDensePredTMasked':
155
- # when mask=='separate'
156
- visual_s_cond, _, _ = model.visual_forward_masked(data_x[2].cuda(), data_x[3].cuda())
157
- else:
158
- # data_x[2] = visual prompt
159
- visual_s_cond, _, _ = model.visual_forward(data_x[2].cuda())
160
-
161
- max_txt = config.mix_text_max if config.mix_text_max is not None else 1
162
- batch_size = text_cond.shape[0]
163
-
164
- # sample weights for each element in batch
165
- text_weights = torch.distributions.Uniform(config.mix_text_min, max_txt).sample((batch_size,))[:, None]
166
- text_weights = text_weights.cuda()
167
-
168
- if dataset.__class__.__name__ == 'PhraseCut':
169
- # give full weight to text where support_image is invalid
170
- visual_is_valid = data_x[4] if model.__class__.__name__ == 'CLIPDensePredTMasked' else data_x[3]
171
- text_weights = torch.max(text_weights[:,0], 1 - visual_is_valid.float().cuda()).unsqueeze(1)
172
-
173
- cond = text_cond * text_weights + visual_s_cond * (1 - text_weights)
174
-
175
- else:
176
- # no mix
177
-
178
- if model.__class__.__name__ == 'CLIPDensePredTMasked':
179
- # compute conditional vector using CLIP masking
180
- with autocast_fn():
181
- assert config.mask == 'separate'
182
- cond, _, _ = model.visual_forward_masked(data_x[1].cuda(), data_x[2].cuda())
183
- else:
184
- cond = data_x[1]
185
- if isinstance(cond, torch.Tensor):
186
- cond = cond.cuda()
187
-
188
- with autocast_fn():
189
- visual_q = None
190
-
191
- pred, visual_q, _, _ = model(data_x[0].cuda(), cond, return_features=True)
192
-
193
- loss = loss_fn(pred, data_y[0].cuda())
194
-
195
- if torch.isnan(loss) or torch.isinf(loss):
196
- # skip if loss is nan
197
- log.warning('Training stopped due to inf/nan loss.')
198
- sys.exit(-1)
199
-
200
- extra_loss = 0
201
- loss += extra_loss
202
-
203
- opt.zero_grad()
204
-
205
- if scaler is None:
206
- loss.backward()
207
- opt.step()
208
- else:
209
- scaler.scale(loss).backward()
210
- scaler.step(opt)
211
- scaler.update()
212
-
213
- if lr_scheduler is not None:
214
- lr_scheduler.step()
215
- if i % 2000 == 0:
216
- current_lr = [g['lr'] for g in opt.param_groups][0]
217
- log.info(f'current lr: {current_lr:.5f} ({len(opt.param_groups)} parameter groups)')
218
-
219
- logger.iter(i=i, loss=loss)
220
- i += 1
221
-
222
- if i >= max_iterations:
223
-
224
- if not isfile(join(logger.base_path, 'weights.pth')):
225
- # only write if no weights were already written
226
- logger.save_weights(only_trainable=save_only_trainable)
227
-
228
- sys.exit(0)
229
-
230
-
231
- if config.checkpoint_iterations is not None and i in config.checkpoint_iterations:
232
- logger.save_weights(only_trainable=save_only_trainable, weight_file=f'weights_{i}.pth')
233
-
234
-
235
- if val_interval is not None and i % val_interval == val_interval - 1:
236
-
237
- val_loss, val_scores, maximize = validate(model, dataset_val, config)
238
-
239
- if len(val_scores) > 0:
240
-
241
- score_str = f', scores: ' + ', '.join(f'{k}: {v}' for k, v in val_scores.items())
242
-
243
- if maximize and val_scores[config.use_val_metric] > best_val_score:
244
- logger.save_weights(only_trainable=save_only_trainable)
245
- best_val_score = val_scores[config.use_val_metric]
246
-
247
- elif not maximize and val_scores[config.use_val_metric] < best_val_score:
248
- logger.save_weights(only_trainable=save_only_trainable)
249
- best_val_score = val_scores[config.use_val_metric]
250
-
251
- else:
252
- score_str = ''
253
- # if no score is used, fall back to loss
254
- if val_loss < best_val_loss:
255
- logger.save_weights(only_trainable=save_only_trainable)
256
- best_val_loss = val_loss
257
-
258
- log.info(f'Validation loss: {val_loss}' + score_str)
259
- logger.iter(i=i, val_loss=val_loss, extra_loss=float(extra_loss), **val_scores)
260
- model.train()
261
-
262
- print('epoch complete')
263
-
264
-
265
- if __name__ == '__main__':
266
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clipseg/weights/rd64-uni.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:13845f6cee4d54ca46f62ee19dd354822094a26e0efccc64e606be93d6a7e26f
3
- size 4306645