Spaces:
Runtime error
Runtime error
files
Browse files- app.py +170 -50
- clipseg/LICENSE +21 -0
- clipseg/Quickstart.ipynb +107 -0
- clipseg/Readme.md +84 -0
- clipseg/Tables.ipynb +349 -0
- clipseg/Visual_Feature_Engineering.ipynb +366 -0
- clipseg/datasets/coco_wrapper.py +99 -0
- clipseg/datasets/pascal_classes.json +1 -0
- clipseg/datasets/pascal_zeroshot.py +60 -0
- clipseg/datasets/pfe_dataset.py +129 -0
- clipseg/datasets/phrasecut.py +335 -0
- clipseg/datasets/utils.py +68 -0
- clipseg/environment.yml +15 -0
- clipseg/evaluation_utils.py +292 -0
- clipseg/example_image.jpg +0 -0
- clipseg/experiments/ablation.yaml +84 -0
- clipseg/experiments/coco.yaml +101 -0
- clipseg/experiments/pascal_1shot.yaml +101 -0
- clipseg/experiments/phrasecut.yaml +80 -0
- clipseg/general_utils.py +272 -0
- clipseg/metrics.py +271 -0
- clipseg/models/clipseg.py +552 -0
- clipseg/models/vitseg.py +286 -0
- clipseg/overview.png +0 -0
- clipseg/score.py +453 -0
- clipseg/setup.py +30 -0
- clipseg/training.py +266 -0
- clipseg/weights/rd64-uni.pth +3 -0
- init_image.png +0 -0
- inpainting.py +194 -0
- mask_image.png +0 -0
app.py
CHANGED
@@ -1,54 +1,174 @@
|
|
1 |
-
from diffusers import StableDiffusionInpaintPipeline
|
2 |
import gradio as gr
|
3 |
-
|
4 |
-
import imageio
|
5 |
-
from PIL import Image
|
6 |
from io import BytesIO
|
|
|
|
|
|
|
|
|
7 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
|
|
|
|
|
3 |
from io import BytesIO
|
4 |
+
import requests
|
5 |
+
import PIL
|
6 |
+
from PIL import Image
|
7 |
+
import numpy as np
|
8 |
import os
|
9 |
+
import uuid
|
10 |
+
import torch
|
11 |
+
from torch import autocast
|
12 |
+
import cv2
|
13 |
+
from matplotlib import pyplot as plt
|
14 |
+
from inpainting import StableDiffusionInpaintingPipeline
|
15 |
+
from torchvision import transforms
|
16 |
+
from clipseg.models.clipseg import CLIPDensePredT
|
17 |
+
|
18 |
+
auth_token = os.environ.get("API_TOKEN") or True
|
19 |
+
|
20 |
+
def download_image(url):
|
21 |
+
response = requests.get(url)
|
22 |
+
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
|
23 |
+
|
24 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
25 |
+
pipe = StableDiffusionInpaintingPipeline.from_pretrained(
|
26 |
+
"CompVis/stable-diffusion-v1-4",
|
27 |
+
revision="fp16",
|
28 |
+
torch_dtype=torch.float16,
|
29 |
+
use_auth_token=auth_token,
|
30 |
+
).to(device)
|
31 |
+
|
32 |
+
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)
|
33 |
+
model.eval()
|
34 |
+
model.load_state_dict(torch.load('./clipseg/weights/rd64-uni.pth', map_location=torch.device('cuda')), strict=False)
|
35 |
+
|
36 |
+
transform = transforms.Compose([
|
37 |
+
transforms.ToTensor(),
|
38 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
39 |
+
transforms.Resize((512, 512)),
|
40 |
+
])
|
41 |
+
|
42 |
+
def predict(radio, dict, word_mask, prompt=""):
|
43 |
+
if(radio == "draw a mask above"):
|
44 |
+
with autocast("cuda"):
|
45 |
+
init_image = dict["image"].convert("RGB").resize((512, 512))
|
46 |
+
mask = dict["mask"].convert("RGB").resize((512, 512))
|
47 |
+
else:
|
48 |
+
img = transform(dict["image"]).unsqueeze(0)
|
49 |
+
word_masks = [word_mask]
|
50 |
+
with torch.no_grad():
|
51 |
+
preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
|
52 |
+
init_image = dict['image'].convert('RGB').resize((512, 512))
|
53 |
+
filename = f"{uuid.uuid4()}.png"
|
54 |
+
plt.imsave(filename,torch.sigmoid(preds[0][0]))
|
55 |
+
img2 = cv2.imread(filename)
|
56 |
+
gray_image = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
|
57 |
+
(thresh, bw_image) = cv2.threshold(gray_image, 100, 255, cv2.THRESH_BINARY)
|
58 |
+
cv2.cvtColor(bw_image, cv2.COLOR_BGR2RGB)
|
59 |
+
mask = Image.fromarray(np.uint8(bw_image)).convert('RGB')
|
60 |
+
os.remove(filename)
|
61 |
+
with autocast("cuda"):
|
62 |
+
images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"]
|
63 |
+
return images[0]
|
64 |
+
|
65 |
+
# examples = [[dict(image="init_image.png", mask="mask_image.png"), "A panda sitting on a bench"]]
|
66 |
+
css = '''
|
67 |
+
.container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
|
68 |
+
#image_upload{min-height:400px}
|
69 |
+
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 400px}
|
70 |
+
#mask_radio .gr-form{background:transparent; border: none}
|
71 |
+
#word_mask{margin-top: .75em !important}
|
72 |
+
#word_mask textarea:disabled{opacity: 0.3}
|
73 |
+
.footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
|
74 |
+
.footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
|
75 |
+
.dark .footer {border-color: #303030}
|
76 |
+
.dark .footer>p {background: #0b0f19}
|
77 |
+
.acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
|
78 |
+
#image_upload .touch-none{display: flex}
|
79 |
+
'''
|
80 |
+
def swap_word_mask(radio_option):
|
81 |
+
if(radio_option == "type what to mask below"):
|
82 |
+
return gr.update(interactive=True, placeholder="A cat")
|
83 |
+
else:
|
84 |
+
return gr.update(interactive=False, placeholder="Disabled")
|
85 |
|
86 |
+
image_blocks = gr.Blocks(css=css)
|
87 |
+
with image_blocks as demo:
|
88 |
+
gr.HTML(
|
89 |
+
"""
|
90 |
+
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
|
91 |
+
<div
|
92 |
+
style="
|
93 |
+
display: inline-flex;
|
94 |
+
align-items: center;
|
95 |
+
gap: 0.8rem;
|
96 |
+
font-size: 1.75rem;
|
97 |
+
"
|
98 |
+
>
|
99 |
+
<svg
|
100 |
+
width="0.65em"
|
101 |
+
height="0.65em"
|
102 |
+
viewBox="0 0 115 115"
|
103 |
+
fill="none"
|
104 |
+
xmlns="http://www.w3.org/2000/svg"
|
105 |
+
>
|
106 |
+
<rect width="23" height="23" fill="white"></rect>
|
107 |
+
<rect y="69" width="23" height="23" fill="white"></rect>
|
108 |
+
<rect x="23" width="23" height="23" fill="#AEAEAE"></rect>
|
109 |
+
<rect x="23" y="69" width="23" height="23" fill="#AEAEAE"></rect>
|
110 |
+
<rect x="46" width="23" height="23" fill="white"></rect>
|
111 |
+
<rect x="46" y="69" width="23" height="23" fill="white"></rect>
|
112 |
+
<rect x="69" width="23" height="23" fill="black"></rect>
|
113 |
+
<rect x="69" y="69" width="23" height="23" fill="black"></rect>
|
114 |
+
<rect x="92" width="23" height="23" fill="#D9D9D9"></rect>
|
115 |
+
<rect x="92" y="69" width="23" height="23" fill="#AEAEAE"></rect>
|
116 |
+
<rect x="115" y="46" width="23" height="23" fill="white"></rect>
|
117 |
+
<rect x="115" y="115" width="23" height="23" fill="white"></rect>
|
118 |
+
<rect x="115" y="69" width="23" height="23" fill="#D9D9D9"></rect>
|
119 |
+
<rect x="92" y="46" width="23" height="23" fill="#AEAEAE"></rect>
|
120 |
+
<rect x="92" y="115" width="23" height="23" fill="#AEAEAE"></rect>
|
121 |
+
<rect x="92" y="69" width="23" height="23" fill="white"></rect>
|
122 |
+
<rect x="69" y="46" width="23" height="23" fill="white"></rect>
|
123 |
+
<rect x="69" y="115" width="23" height="23" fill="white"></rect>
|
124 |
+
<rect x="69" y="69" width="23" height="23" fill="#D9D9D9"></rect>
|
125 |
+
<rect x="46" y="46" width="23" height="23" fill="black"></rect>
|
126 |
+
<rect x="46" y="115" width="23" height="23" fill="black"></rect>
|
127 |
+
<rect x="46" y="69" width="23" height="23" fill="black"></rect>
|
128 |
+
<rect x="23" y="46" width="23" height="23" fill="#D9D9D9"></rect>
|
129 |
+
<rect x="23" y="115" width="23" height="23" fill="#AEAEAE"></rect>
|
130 |
+
<rect x="23" y="69" width="23" height="23" fill="black"></rect>
|
131 |
+
</svg>
|
132 |
+
<h1 style="font-weight: 900; margin-bottom: 7px;">
|
133 |
+
Stable Diffusion Multi Inpainting
|
134 |
+
</h1>
|
135 |
+
</div>
|
136 |
+
<p style="margin-bottom: 10px; font-size: 94%">
|
137 |
+
Inpaint Stable Diffusion by either drawing a mask or typing what to replace
|
138 |
+
</p>
|
139 |
+
</div>
|
140 |
+
"""
|
141 |
+
)
|
142 |
+
with gr.Row():
|
143 |
+
with gr.Column():
|
144 |
+
image = gr.Image(source='upload', tool='sketch', elem_id="image_upload", type="pil", label="Upload").style(height=400)
|
145 |
+
with gr.Box(elem_id="mask_radio").style(border=False):
|
146 |
+
radio = gr.Radio(["draw a mask above", "type what to mask below"], value="draw a mask above", show_label=False, interactive=True).style(container=False)
|
147 |
+
word_mask = gr.Textbox(label = "What to find in your image", interactive=False, elem_id="word_mask", placeholder="Disabled").style(container=False)
|
148 |
+
prompt = gr.Textbox(label = 'Your prompt (what you want to add in place of what you are removing)')
|
149 |
+
radio.change(fn=swap_word_mask, inputs=radio, outputs=word_mask,show_progress=False)
|
150 |
+
radio.change(None, inputs=[], outputs=image_blocks, _js = """
|
151 |
+
() => {
|
152 |
+
css_style = document.styleSheets[document.styleSheets.length - 1]
|
153 |
+
last_item = css_style.cssRules[css_style.cssRules.length - 1]
|
154 |
+
last_item.style.display = ["flex", ""].includes(last_item.style.display) ? "none" : "flex";
|
155 |
+
}""")
|
156 |
+
btn = gr.Button("Run")
|
157 |
+
with gr.Column():
|
158 |
+
result = gr.Image(label="Result")
|
159 |
+
btn.click(fn=predict, inputs=[radio, image, word_mask, prompt], outputs=result)
|
160 |
+
gr.HTML(
|
161 |
+
"""
|
162 |
+
<div class="footer">
|
163 |
+
<p>Model by <a href="https://huggingface.co/CompVis" style="text-decoration: underline;" target="_blank">CompVis</a> and <a href="https://huggingface.co/stabilityai" style="text-decoration: underline;" target="_blank">Stability AI</a> - Inpainting by <a href="https://github.com/nagolinc" style="text-decoration: underline;" target="_blank">nagolinc</a> and <a href="https://github.com/patil-suraj" style="text-decoration: underline;">patil-suraj</a>, inpainting with words by <a href="https://twitter.com/yvrjsharma/" style="text-decoration: underline;" target="_blank">@yvrjsharma</a> and <a href="https://twitter.com/1littlecoder" style="text-decoration: underline;">@1littlecoder</a> - Gradio Demo by π€ Hugging Face
|
164 |
+
</p>
|
165 |
+
</div>
|
166 |
+
<div class="acknowledgments">
|
167 |
+
<p><h4>LICENSE</h4>
|
168 |
+
The model is licensed with a <a href="https://huggingface.co/spaces/CompVis/stable-diffusion-license" style="text-decoration: underline;" target="_blank">CreativeML Open RAIL-M</a> license. The authors claim no rights on the outputs you generate, you are free to use them and are accountable for their use which must not go against the provisions set in this license. The license forbids you from sharing any content that violates any laws, produce any harm to a person, disseminate any personal information that would be meant for harm, spread misinformation and target vulnerable groups. For the full list of restrictions please <a href="https://huggingface.co/spaces/CompVis/stable-diffusion-license" target="_blank" style="text-decoration: underline;" target="_blank">read the license</a></p>
|
169 |
+
<p><h4>Biases and content acknowledgment</h4>
|
170 |
+
Despite how impressive being able to turn text into image is, beware to the fact that this model may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography and violence. The model was trained on the <a href="https://laion.ai/blog/laion-5b/" style="text-decoration: underline;" target="_blank">LAION-5B dataset</a>, which scraped non-curated image-text-pairs from the internet (the exception being the removal of illegal content) and is meant for research purposes. You can read more in the <a href="https://huggingface.co/CompVis/stable-diffusion-v1-4" style="text-decoration: underline;" target="_blank">model card</a></p>
|
171 |
+
</div>
|
172 |
+
"""
|
173 |
+
)
|
174 |
+
demo.launch()
|
clipseg/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
This license does not apply to the model weights.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
clipseg/Quickstart.ipynb
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import torch\n",
|
10 |
+
"import requests\n",
|
11 |
+
"\n",
|
12 |
+
"! wget https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -O weights.zip\n",
|
13 |
+
"! unzip -d weights -j weights.zip\n",
|
14 |
+
"from models.clipseg import CLIPDensePredT\n",
|
15 |
+
"from PIL import Image\n",
|
16 |
+
"from torchvision import transforms\n",
|
17 |
+
"from matplotlib import pyplot as plt\n",
|
18 |
+
"\n",
|
19 |
+
"# load model\n",
|
20 |
+
"model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)\n",
|
21 |
+
"model.eval();\n",
|
22 |
+
"\n",
|
23 |
+
"# non-strict, because we only stored decoder weights (not CLIP weights)\n",
|
24 |
+
"model.load_state_dict(torch.load('weights/rd64-uni.pth', map_location=torch.device('cpu')), strict=False);"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"cell_type": "markdown",
|
29 |
+
"metadata": {},
|
30 |
+
"source": [
|
31 |
+
"Load and normalize `example_image.jpg`. You can also load through an URL."
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "code",
|
36 |
+
"execution_count": null,
|
37 |
+
"metadata": {},
|
38 |
+
"outputs": [],
|
39 |
+
"source": [
|
40 |
+
"# load and normalize image\n",
|
41 |
+
"input_image = Image.open('example_image.jpg')\n",
|
42 |
+
"\n",
|
43 |
+
"# or load from URL...\n",
|
44 |
+
"# image_url = 'https://farm5.staticflickr.com/4141/4856248695_03475782dc_z.jpg'\n",
|
45 |
+
"# input_image = Image.open(requests.get(image_url, stream=True).raw)\n",
|
46 |
+
"\n",
|
47 |
+
"transform = transforms.Compose([\n",
|
48 |
+
" transforms.ToTensor(),\n",
|
49 |
+
" transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
|
50 |
+
" transforms.Resize((352, 352)),\n",
|
51 |
+
"])\n",
|
52 |
+
"img = transform(input_image).unsqueeze(0)"
|
53 |
+
]
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"cell_type": "markdown",
|
57 |
+
"metadata": {},
|
58 |
+
"source": [
|
59 |
+
"Predict and visualize (this might take a few seconds if running without GPU support)"
|
60 |
+
]
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"cell_type": "code",
|
64 |
+
"execution_count": null,
|
65 |
+
"metadata": {},
|
66 |
+
"outputs": [],
|
67 |
+
"source": [
|
68 |
+
"prompts = ['a glass', 'something to fill', 'wood', 'a jar']\n",
|
69 |
+
"\n",
|
70 |
+
"# predict\n",
|
71 |
+
"with torch.no_grad():\n",
|
72 |
+
" preds = model(img.repeat(4,1,1,1), prompts)[0]\n",
|
73 |
+
"\n",
|
74 |
+
"# visualize prediction\n",
|
75 |
+
"_, ax = plt.subplots(1, 5, figsize=(15, 4))\n",
|
76 |
+
"[a.axis('off') for a in ax.flatten()]\n",
|
77 |
+
"ax[0].imshow(input_image)\n",
|
78 |
+
"[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(4)];\n",
|
79 |
+
"[ax[i+1].text(0, -15, prompts[i]) for i in range(4)];"
|
80 |
+
]
|
81 |
+
}
|
82 |
+
],
|
83 |
+
"metadata": {
|
84 |
+
"interpreter": {
|
85 |
+
"hash": "800ed241f7db2bd3aa6942aa3be6809cdb30ee6b0a9e773dfecfa9fef1f4c586"
|
86 |
+
},
|
87 |
+
"kernelspec": {
|
88 |
+
"display_name": "Python 3",
|
89 |
+
"language": "python",
|
90 |
+
"name": "python3"
|
91 |
+
},
|
92 |
+
"language_info": {
|
93 |
+
"codemirror_mode": {
|
94 |
+
"name": "ipython",
|
95 |
+
"version": 3
|
96 |
+
},
|
97 |
+
"file_extension": ".py",
|
98 |
+
"mimetype": "text/x-python",
|
99 |
+
"name": "python",
|
100 |
+
"nbconvert_exporter": "python",
|
101 |
+
"pygments_lexer": "ipython3",
|
102 |
+
"version": "3.8.10"
|
103 |
+
}
|
104 |
+
},
|
105 |
+
"nbformat": 4,
|
106 |
+
"nbformat_minor": 4
|
107 |
+
}
|
clipseg/Readme.md
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Image Segmentation Using Text and Image Prompts
|
2 |
+
This repository contains the code used in the paper ["Image Segmentation Using Text and Image Prompts"](https://arxiv.org/abs/2112.10003).
|
3 |
+
|
4 |
+
**The Paper has been accepted to CVPR 2022!**
|
5 |
+
|
6 |
+
<img src="overview.png" alt="drawing" height="200em"/>
|
7 |
+
|
8 |
+
The systems allows to create segmentation models without training based on:
|
9 |
+
- An arbitrary text query
|
10 |
+
- Or an image with a mask highlighting stuff or an object.
|
11 |
+
|
12 |
+
### Quick Start
|
13 |
+
|
14 |
+
In the `Quickstart.ipynb` notebook we provide the code for using a pre-trained CLIPSeg model. If you run the notebook locally, make sure you downloaded the `rd64-uni.pth` weights, either manually or via git lfs extension.
|
15 |
+
It can also be used interactively using [MyBinder](https://mybinder.org/v2/gh/timojl/clipseg/HEAD?labpath=Quickstart.ipynb)
|
16 |
+
(please note that the VM does not use a GPU, thus inference takes a few seconds).
|
17 |
+
|
18 |
+
|
19 |
+
### Dependencies
|
20 |
+
This code base depends on pytorch, torchvision and clip (`pip install git+https://github.com/openai/CLIP.git`).
|
21 |
+
Additional dependencies are hidden for double blind review.
|
22 |
+
|
23 |
+
|
24 |
+
### Datasets
|
25 |
+
|
26 |
+
* `PhraseCut` and `PhraseCutPlus`: Referring expression dataset
|
27 |
+
* `PFEPascalWrapper`: Wrapper class for PFENet's Pascal-5i implementation
|
28 |
+
* `PascalZeroShot`: Wrapper class for PascalZeroShot
|
29 |
+
* `COCOWrapper`: Wrapper class for COCO.
|
30 |
+
|
31 |
+
### Models
|
32 |
+
|
33 |
+
* `CLIPDensePredT`: CLIPSeg model with transformer-based decoder.
|
34 |
+
* `ViTDensePredT`: CLIPSeg model with transformer-based decoder.
|
35 |
+
|
36 |
+
### Third Party Dependencies
|
37 |
+
For some of the datasets third party dependencies are required. Run the following commands in the `third_party` folder.
|
38 |
+
```bash
|
39 |
+
git clone https://github.com/cvlab-yonsei/JoEm
|
40 |
+
git clone https://github.com/Jia-Research-Lab/PFENet.git
|
41 |
+
git clone https://github.com/ChenyunWu/PhraseCutDataset.git
|
42 |
+
git clone https://github.com/juhongm999/hsnet.git
|
43 |
+
```
|
44 |
+
|
45 |
+
### Weights
|
46 |
+
|
47 |
+
The MIT license does not apply to these weights.
|
48 |
+
|
49 |
+
We provide two model weights, for D=64 (4.1MB) and D=16 (1.1MB).
|
50 |
+
```
|
51 |
+
wget https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -O weights.zip
|
52 |
+
unzip -d weights -j weights.zip
|
53 |
+
```
|
54 |
+
|
55 |
+
|
56 |
+
### Training and Evaluation
|
57 |
+
|
58 |
+
To train use the `training.py` script with experiment file and experiment id parameters. E.g. `python training.py phrasecut.yaml 0` will train the first phrasecut experiment which is defined by the `configuration` and first `individual_configurations` parameters. Model weights will be written in `logs/`.
|
59 |
+
|
60 |
+
For evaluation use `score.py`. E.g. `python score.py phrasecut.yaml 0 0` will train the first phrasecut experiment of `test_configuration` and the first configuration in `individual_configurations`.
|
61 |
+
|
62 |
+
|
63 |
+
### Usage of PFENet Wrappers
|
64 |
+
|
65 |
+
In order to use the dataset and model wrappers for PFENet, the PFENet repository needs to be cloned to the root folder.
|
66 |
+
`git clone https://github.com/Jia-Research-Lab/PFENet.git `
|
67 |
+
|
68 |
+
|
69 |
+
### License
|
70 |
+
|
71 |
+
The source code files in this repository (excluding model weights) are released under MIT license.
|
72 |
+
|
73 |
+
### Citation
|
74 |
+
```
|
75 |
+
@InProceedings{lueddecke22_cvpr,
|
76 |
+
author = {L\"uddecke, Timo and Ecker, Alexander},
|
77 |
+
title = {Image Segmentation Using Text and Image Prompts},
|
78 |
+
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
79 |
+
month = {June},
|
80 |
+
year = {2022},
|
81 |
+
pages = {7086-7096}
|
82 |
+
}
|
83 |
+
|
84 |
+
```
|
clipseg/Tables.ipynb
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"%load_ext autoreload\n",
|
10 |
+
"%autoreload 2\n",
|
11 |
+
"\n",
|
12 |
+
"import clip\n",
|
13 |
+
"from evaluation_utils import norm, denorm\n",
|
14 |
+
"from general_utils import *\n",
|
15 |
+
"from datasets.lvis_oneshot3 import LVIS_OneShot3, LVIS_OneShot"
|
16 |
+
]
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"cell_type": "markdown",
|
20 |
+
"metadata": {},
|
21 |
+
"source": [
|
22 |
+
"# PhraseCut"
|
23 |
+
]
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"cell_type": "code",
|
27 |
+
"execution_count": null,
|
28 |
+
"metadata": {},
|
29 |
+
"outputs": [],
|
30 |
+
"source": [
|
31 |
+
"pc = experiment('experiments/phrasecut.yaml', nums=':6').dataframe()"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "code",
|
36 |
+
"execution_count": null,
|
37 |
+
"metadata": {},
|
38 |
+
"outputs": [],
|
39 |
+
"source": [
|
40 |
+
"tab1 = pc[['name', 'pc_miou_best', 'pc_fgiou_best', 'pc_ap']]"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "code",
|
45 |
+
"execution_count": null,
|
46 |
+
"metadata": {},
|
47 |
+
"outputs": [],
|
48 |
+
"source": [
|
49 |
+
"cols = ['pc_miou_0.3', 'pc_fgiou_0.3', 'pc_ap']\n",
|
50 |
+
"tab1 = pc[['name'] + cols]\n",
|
51 |
+
"for k in cols:\n",
|
52 |
+
" tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n",
|
53 |
+
"tab1.loc[:, 'name'] = ['CLIPSeg (PC+)', 'CLIPSeg (PC, $D=128$)', 'CLIPSeg (PC)', 'CLIP-Deconv', 'ViTSeg (PC+)', 'ViTSeg (PC)']\n",
|
54 |
+
"tab1.insert(1, 't', [0.3]*tab1.shape[0])\n",
|
55 |
+
"print(tab1.to_latex(header=False, index=False))"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "markdown",
|
60 |
+
"metadata": {},
|
61 |
+
"source": [
|
62 |
+
"For 0.1 threshold"
|
63 |
+
]
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"cell_type": "code",
|
67 |
+
"execution_count": null,
|
68 |
+
"metadata": {},
|
69 |
+
"outputs": [],
|
70 |
+
"source": [
|
71 |
+
"cols = ['pc_miou_0.1', 'pc_fgiou_0.1', 'pc_ap']\n",
|
72 |
+
"tab1 = pc[['name'] + cols]\n",
|
73 |
+
"for k in cols:\n",
|
74 |
+
" tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n",
|
75 |
+
"tab1.loc[:, 'name'] = ['CLIPSeg (PC+)', 'CLIPSeg (PC, $D=128$)', 'CLIPSeg (PC)', 'CLIP-Deconv', 'ViTSeg (PC+)', 'ViTSeg (PC)']\n",
|
76 |
+
"tab1.insert(1, 't', [0.1]*tab1.shape[0])\n",
|
77 |
+
"print(tab1.to_latex(header=False, index=False))"
|
78 |
+
]
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"cell_type": "markdown",
|
82 |
+
"metadata": {},
|
83 |
+
"source": [
|
84 |
+
"# One-shot"
|
85 |
+
]
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"cell_type": "markdown",
|
89 |
+
"metadata": {},
|
90 |
+
"source": [
|
91 |
+
"### Pascal"
|
92 |
+
]
|
93 |
+
},
|
94 |
+
{
|
95 |
+
"cell_type": "code",
|
96 |
+
"execution_count": null,
|
97 |
+
"metadata": {},
|
98 |
+
"outputs": [],
|
99 |
+
"source": [
|
100 |
+
"pas = experiment('experiments/pascal_1shot.yaml', nums=':19').dataframe()"
|
101 |
+
]
|
102 |
+
},
|
103 |
+
{
|
104 |
+
"cell_type": "code",
|
105 |
+
"execution_count": null,
|
106 |
+
"metadata": {},
|
107 |
+
"outputs": [],
|
108 |
+
"source": [
|
109 |
+
"pas[['name', 'pas_h2_miou_0.3', 'pas_h2_biniou_0.3', 'pas_h2_ap', 'pas_h2_fgiou_ct']]"
|
110 |
+
]
|
111 |
+
},
|
112 |
+
{
|
113 |
+
"cell_type": "code",
|
114 |
+
"execution_count": null,
|
115 |
+
"metadata": {},
|
116 |
+
"outputs": [],
|
117 |
+
"source": [
|
118 |
+
"pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n",
|
119 |
+
"tab1 = pas[['pas_h2_miou_0.3', 'pas_h2_biniou_0.3', 'pas_h2_ap']]\n",
|
120 |
+
"print('CLIPSeg (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
|
121 |
+
"print('CLIPSeg (PC) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
|
122 |
+
"\n",
|
123 |
+
"pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n",
|
124 |
+
"tab1 = pas[['pas_h2_miou_0.2', 'pas_h2_biniou_0.2', 'pas_h2_ap']]\n",
|
125 |
+
"print('CLIP-Deconv (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
|
126 |
+
"\n",
|
127 |
+
"pas = experiment('experiments/pascal_1shot.yaml', nums='16:20').dataframe()\n",
|
128 |
+
"tab1 = pas[['pas_t_miou_0.2', 'pas_t_biniou_0.2', 'pas_t_ap']]\n",
|
129 |
+
"print('ViTSeg (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')"
|
130 |
+
]
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"cell_type": "markdown",
|
134 |
+
"metadata": {},
|
135 |
+
"source": [
|
136 |
+
"#### Pascal Zero-shot (in one-shot setting)\n",
|
137 |
+
"\n",
|
138 |
+
"Using the same setting as one-shot (hence different from the other zero-shot benchmark)"
|
139 |
+
]
|
140 |
+
},
|
141 |
+
{
|
142 |
+
"cell_type": "code",
|
143 |
+
"execution_count": null,
|
144 |
+
"metadata": {},
|
145 |
+
"outputs": [],
|
146 |
+
"source": [
|
147 |
+
"pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n",
|
148 |
+
"tab1 = pas[['pas_t_miou_0.3', 'pas_t_biniou_0.3', 'pas_t_ap']]\n",
|
149 |
+
"print('CLIPSeg (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
|
150 |
+
"print('CLIPSeg (PC) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
|
151 |
+
"\n",
|
152 |
+
"pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n",
|
153 |
+
"tab1 = pas[['pas_t_miou_0.3', 'pas_t_biniou_0.3', 'pas_t_ap']]\n",
|
154 |
+
"print('CLIP-Deconv (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
|
155 |
+
"\n",
|
156 |
+
"pas = experiment('experiments/pascal_1shot.yaml', nums='16:20').dataframe()\n",
|
157 |
+
"tab1 = pas[['pas_t_miou_0.2', 'pas_t_biniou_0.2', 'pas_t_ap']]\n",
|
158 |
+
"print('ViTSeg (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')"
|
159 |
+
]
|
160 |
+
},
|
161 |
+
{
|
162 |
+
"cell_type": "code",
|
163 |
+
"execution_count": null,
|
164 |
+
"metadata": {},
|
165 |
+
"outputs": [],
|
166 |
+
"source": [
|
167 |
+
"# without fixed thresholds...\n",
|
168 |
+
"\n",
|
169 |
+
"pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n",
|
170 |
+
"tab1 = pas[['pas_t_best_miou', 'pas_t_best_biniou', 'pas_t_ap']]\n",
|
171 |
+
"print('CLIPSeg (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
|
172 |
+
"print('CLIPSeg (PC) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
|
173 |
+
"\n",
|
174 |
+
"pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n",
|
175 |
+
"tab1 = pas[['pas_t_best_miou', 'pas_t_best_biniou', 'pas_t_ap']]\n",
|
176 |
+
"print('CLIP-Deconv (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')"
|
177 |
+
]
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"cell_type": "markdown",
|
181 |
+
"metadata": {},
|
182 |
+
"source": [
|
183 |
+
"### COCO"
|
184 |
+
]
|
185 |
+
},
|
186 |
+
{
|
187 |
+
"cell_type": "code",
|
188 |
+
"execution_count": null,
|
189 |
+
"metadata": {},
|
190 |
+
"outputs": [],
|
191 |
+
"source": [
|
192 |
+
"coco = experiment('experiments/coco.yaml', nums=':29').dataframe()"
|
193 |
+
]
|
194 |
+
},
|
195 |
+
{
|
196 |
+
"cell_type": "code",
|
197 |
+
"execution_count": null,
|
198 |
+
"metadata": {},
|
199 |
+
"outputs": [],
|
200 |
+
"source": [
|
201 |
+
"tab1 = coco[['coco_h2_miou_0.1', 'coco_h2_biniou_0.1', 'coco_h2_ap']]\n",
|
202 |
+
"tab2 = coco[['coco_h2_miou_0.2', 'coco_h2_biniou_0.2', 'coco_h2_ap']]\n",
|
203 |
+
"tab3 = coco[['coco_h2_miou_best', 'coco_h2_biniou_best', 'coco_h2_ap']]\n",
|
204 |
+
"print('CLIPSeg (COCO) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[:4].mean(0).values), '\\\\\\\\')\n",
|
205 |
+
"print('CLIPSeg (COCO+N) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
|
206 |
+
"print('CLIP-Deconv (COCO+N) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[12:16].mean(0).values), '\\\\\\\\')\n",
|
207 |
+
"print('ViTSeg (COCO) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[8:12].mean(0).values), '\\\\\\\\')"
|
208 |
+
]
|
209 |
+
},
|
210 |
+
{
|
211 |
+
"cell_type": "markdown",
|
212 |
+
"metadata": {},
|
213 |
+
"source": [
|
214 |
+
"# Zero-shot"
|
215 |
+
]
|
216 |
+
},
|
217 |
+
{
|
218 |
+
"cell_type": "code",
|
219 |
+
"execution_count": null,
|
220 |
+
"metadata": {},
|
221 |
+
"outputs": [],
|
222 |
+
"source": [
|
223 |
+
"zs = experiment('experiments/pascal_0shot.yaml', nums=':11').dataframe()"
|
224 |
+
]
|
225 |
+
},
|
226 |
+
{
|
227 |
+
"cell_type": "code",
|
228 |
+
"execution_count": null,
|
229 |
+
"metadata": {},
|
230 |
+
"outputs": [],
|
231 |
+
"source": [
|
232 |
+
"\n",
|
233 |
+
"tab1 = zs[['pas_zs_seen', 'pas_zs_unseen']]\n",
|
234 |
+
"print('CLIPSeg (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[8:9].values[0].tolist() + tab1[10:11].values[0].tolist()), '\\\\\\\\')\n",
|
235 |
+
"print('CLIP-Deconv & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[2:3].values[0].tolist() + tab1[3:4].values[0].tolist()), '\\\\\\\\')\n",
|
236 |
+
"print('ViTSeg & ImageNet-1K & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:5].values[0].tolist() + tab1[5:6].values[0].tolist()), '\\\\\\\\')"
|
237 |
+
]
|
238 |
+
},
|
239 |
+
{
|
240 |
+
"cell_type": "markdown",
|
241 |
+
"metadata": {},
|
242 |
+
"source": [
|
243 |
+
"# Ablation"
|
244 |
+
]
|
245 |
+
},
|
246 |
+
{
|
247 |
+
"cell_type": "code",
|
248 |
+
"execution_count": null,
|
249 |
+
"metadata": {},
|
250 |
+
"outputs": [],
|
251 |
+
"source": [
|
252 |
+
"ablation = experiment('experiments/ablation.yaml', nums=':8').dataframe()"
|
253 |
+
]
|
254 |
+
},
|
255 |
+
{
|
256 |
+
"cell_type": "code",
|
257 |
+
"execution_count": null,
|
258 |
+
"metadata": {},
|
259 |
+
"outputs": [],
|
260 |
+
"source": [
|
261 |
+
"tab1 = ablation[['name', 'pc_miou_best', 'pc_ap', 'pc-vis_miou_best', 'pc-vis_ap']]\n",
|
262 |
+
"for k in ['pc_miou_best', 'pc_ap', 'pc-vis_miou_best', 'pc-vis_ap']:\n",
|
263 |
+
" tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n",
|
264 |
+
"tab1.loc[:, 'name'] = ['CLIPSeg', 'no CLIP pre-training', 'no-negatives', '50% negatives', 'no visual', '$D=16$', 'only layer 3', 'highlight mask']"
|
265 |
+
]
|
266 |
+
},
|
267 |
+
{
|
268 |
+
"cell_type": "code",
|
269 |
+
"execution_count": null,
|
270 |
+
"metadata": {},
|
271 |
+
"outputs": [],
|
272 |
+
"source": [
|
273 |
+
"print(tab1.loc[[0,1,4,5,6,7],:].to_latex(header=False, index=False))"
|
274 |
+
]
|
275 |
+
},
|
276 |
+
{
|
277 |
+
"cell_type": "code",
|
278 |
+
"execution_count": null,
|
279 |
+
"metadata": {},
|
280 |
+
"outputs": [],
|
281 |
+
"source": [
|
282 |
+
"print(tab1.loc[[0,1,4,5,6,7],:].to_latex(header=False, index=False))"
|
283 |
+
]
|
284 |
+
},
|
285 |
+
{
|
286 |
+
"cell_type": "markdown",
|
287 |
+
"metadata": {},
|
288 |
+
"source": [
|
289 |
+
"# Generalization"
|
290 |
+
]
|
291 |
+
},
|
292 |
+
{
|
293 |
+
"cell_type": "code",
|
294 |
+
"execution_count": null,
|
295 |
+
"metadata": {},
|
296 |
+
"outputs": [],
|
297 |
+
"source": [
|
298 |
+
"generalization = experiment('experiments/generalize.yaml').dataframe()"
|
299 |
+
]
|
300 |
+
},
|
301 |
+
{
|
302 |
+
"cell_type": "code",
|
303 |
+
"execution_count": null,
|
304 |
+
"metadata": {},
|
305 |
+
"outputs": [],
|
306 |
+
"source": [
|
307 |
+
"gen = generalization[['aff_best_fgiou', 'aff_ap', 'ability_best_fgiou', 'ability_ap', 'part_best_fgiou', 'part_ap']].values"
|
308 |
+
]
|
309 |
+
},
|
310 |
+
{
|
311 |
+
"cell_type": "code",
|
312 |
+
"execution_count": null,
|
313 |
+
"metadata": {},
|
314 |
+
"outputs": [],
|
315 |
+
"source": [
|
316 |
+
"print(\n",
|
317 |
+
" 'CLIPSeg (PC+) & ' + ' & '.join(f'{x*100:.1f}' for x in gen[1]) + ' \\\\\\\\ \\n' + \\\n",
|
318 |
+
" 'CLIPSeg (LVIS) & ' + ' & '.join(f'{x*100:.1f}' for x in gen[0]) + ' \\\\\\\\ \\n' + \\\n",
|
319 |
+
" 'CLIP-Deconv & ' + ' & '.join(f'{x*100:.1f}' for x in gen[2]) + ' \\\\\\\\ \\n' + \\\n",
|
320 |
+
" 'VITSeg & ' + ' & '.join(f'{x*100:.1f}' for x in gen[3]) + ' \\\\\\\\'\n",
|
321 |
+
")"
|
322 |
+
]
|
323 |
+
}
|
324 |
+
],
|
325 |
+
"metadata": {
|
326 |
+
"interpreter": {
|
327 |
+
"hash": "800ed241f7db2bd3aa6942aa3be6809cdb30ee6b0a9e773dfecfa9fef1f4c586"
|
328 |
+
},
|
329 |
+
"kernelspec": {
|
330 |
+
"display_name": "env2",
|
331 |
+
"language": "python",
|
332 |
+
"name": "env2"
|
333 |
+
},
|
334 |
+
"language_info": {
|
335 |
+
"codemirror_mode": {
|
336 |
+
"name": "ipython",
|
337 |
+
"version": 3
|
338 |
+
},
|
339 |
+
"file_extension": ".py",
|
340 |
+
"mimetype": "text/x-python",
|
341 |
+
"name": "python",
|
342 |
+
"nbconvert_exporter": "python",
|
343 |
+
"pygments_lexer": "ipython3",
|
344 |
+
"version": "3.8.8"
|
345 |
+
}
|
346 |
+
},
|
347 |
+
"nbformat": 4,
|
348 |
+
"nbformat_minor": 4
|
349 |
+
}
|
clipseg/Visual_Feature_Engineering.ipynb
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# Systematic"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "code",
|
12 |
+
"execution_count": null,
|
13 |
+
"metadata": {},
|
14 |
+
"outputs": [],
|
15 |
+
"source": [
|
16 |
+
"%load_ext autoreload\n",
|
17 |
+
"%autoreload 2\n",
|
18 |
+
"\n",
|
19 |
+
"import clip\n",
|
20 |
+
"from evaluation_utils import norm, denorm\n",
|
21 |
+
"from general_utils import *\n",
|
22 |
+
"from datasets.lvis_oneshot3 import LVIS_OneShot3\n",
|
23 |
+
"\n",
|
24 |
+
"clip_device = 'cuda'\n",
|
25 |
+
"clip_model, preprocess = clip.load(\"ViT-B/16\", device=clip_device)\n",
|
26 |
+
"clip_model.eval();\n",
|
27 |
+
"\n",
|
28 |
+
"from models.clipseg import CLIPDensePredTMasked\n",
|
29 |
+
"\n",
|
30 |
+
"clip_mask_model = CLIPDensePredTMasked(version='ViT-B/16').to(clip_device)\n",
|
31 |
+
"clip_mask_model.eval();"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "code",
|
36 |
+
"execution_count": null,
|
37 |
+
"metadata": {},
|
38 |
+
"outputs": [],
|
39 |
+
"source": [
|
40 |
+
"lvis = LVIS_OneShot3('train_fixed', mask='separate', normalize=True, with_class_label=True, add_bar=False, \n",
|
41 |
+
" text_class_labels=True, image_size=352, min_area=0.1,\n",
|
42 |
+
" min_frac_s=0.05, min_frac_q=0.05, fix_find_crop=True)"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "code",
|
47 |
+
"execution_count": null,
|
48 |
+
"metadata": {},
|
49 |
+
"outputs": [],
|
50 |
+
"source": [
|
51 |
+
"plot_data(lvis)"
|
52 |
+
]
|
53 |
+
},
|
54 |
+
{
|
55 |
+
"cell_type": "code",
|
56 |
+
"execution_count": null,
|
57 |
+
"metadata": {},
|
58 |
+
"outputs": [],
|
59 |
+
"source": [
|
60 |
+
"from collections import defaultdict\n",
|
61 |
+
"import json\n",
|
62 |
+
"\n",
|
63 |
+
"lvis_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_train.json')))\n",
|
64 |
+
"lvis_val_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_val.json')))\n",
|
65 |
+
"\n",
|
66 |
+
"objects_per_image = defaultdict(lambda : set())\n",
|
67 |
+
"for ann in lvis_raw['annotations']:\n",
|
68 |
+
" objects_per_image[ann['image_id']].add(ann['category_id'])\n",
|
69 |
+
" \n",
|
70 |
+
"for ann in lvis_val_raw['annotations']:\n",
|
71 |
+
" objects_per_image[ann['image_id']].add(ann['category_id']) \n",
|
72 |
+
" \n",
|
73 |
+
"objects_per_image = {o: [lvis.category_names[o] for o in v] for o, v in objects_per_image.items()}\n",
|
74 |
+
"\n",
|
75 |
+
"del lvis_raw, lvis_val_raw"
|
76 |
+
]
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"cell_type": "code",
|
80 |
+
"execution_count": null,
|
81 |
+
"metadata": {},
|
82 |
+
"outputs": [],
|
83 |
+
"source": [
|
84 |
+
"#bs = 32\n",
|
85 |
+
"#batches = [get_batch(lvis, i*bs, (i+1)*bs, cuda=True) for i in range(10)]"
|
86 |
+
]
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"cell_type": "code",
|
90 |
+
"execution_count": null,
|
91 |
+
"metadata": {},
|
92 |
+
"outputs": [],
|
93 |
+
"source": [
|
94 |
+
"from general_utils import get_batch\n",
|
95 |
+
"from functools import partial\n",
|
96 |
+
"from evaluation_utils import img_preprocess\n",
|
97 |
+
"import torch\n",
|
98 |
+
"\n",
|
99 |
+
"def get_similarities(batches_or_dataset, process, mask=lambda x: None, clipmask=False):\n",
|
100 |
+
"\n",
|
101 |
+
" # base_words = [f'a photo of {x}' for x in ['a person', 'an animal', 'a knife', 'a cup']]\n",
|
102 |
+
"\n",
|
103 |
+
" all_prompts = []\n",
|
104 |
+
" \n",
|
105 |
+
" with torch.no_grad():\n",
|
106 |
+
" valid_sims = []\n",
|
107 |
+
" torch.manual_seed(571)\n",
|
108 |
+
" \n",
|
109 |
+
" if type(batches_or_dataset) == list:\n",
|
110 |
+
" loader = batches_or_dataset # already loaded\n",
|
111 |
+
" max_iter = float('inf')\n",
|
112 |
+
" else:\n",
|
113 |
+
" loader = DataLoader(batches_or_dataset, shuffle=False, batch_size=32)\n",
|
114 |
+
" max_iter = 50\n",
|
115 |
+
" \n",
|
116 |
+
" global batch\n",
|
117 |
+
" for i_batch, (batch, batch_y) in enumerate(loader):\n",
|
118 |
+
" \n",
|
119 |
+
" if i_batch >= max_iter: break\n",
|
120 |
+
" \n",
|
121 |
+
" processed_batch = process(batch)\n",
|
122 |
+
" if type(processed_batch) == dict:\n",
|
123 |
+
" \n",
|
124 |
+
" # processed_batch = {k: v.to(clip_device) for k, v in processed_batch.items()}\n",
|
125 |
+
" image_features = clip_mask_model.visual_forward(**processed_batch)[0].to(clip_device).half()\n",
|
126 |
+
" else:\n",
|
127 |
+
" processed_batch = process(batch).to(clip_device)\n",
|
128 |
+
" processed_batch = nnf.interpolate(processed_batch, (224, 224), mode='bilinear')\n",
|
129 |
+
" #image_features = clip_model.encode_image(processed_batch.to(clip_device)) \n",
|
130 |
+
" image_features = clip_mask_model.visual_forward(processed_batch)[0].to(clip_device).half()\n",
|
131 |
+
" \n",
|
132 |
+
" image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n",
|
133 |
+
" bs = len(batch[0])\n",
|
134 |
+
" for j in range(bs):\n",
|
135 |
+
" \n",
|
136 |
+
" c, _, sid, qid = lvis.sample_ids[bs * i_batch + j]\n",
|
137 |
+
" support_image = basename(lvis.samples[c][sid])\n",
|
138 |
+
" \n",
|
139 |
+
" img_objs = [o for o in objects_per_image[int(support_image)]]\n",
|
140 |
+
" img_objs = [o.replace('_', ' ') for o in img_objs]\n",
|
141 |
+
" \n",
|
142 |
+
" other_words = [f'a photo of a {o.replace(\"_\", \" \")}' for o in img_objs \n",
|
143 |
+
" if o != batch_y[2][j]]\n",
|
144 |
+
" \n",
|
145 |
+
" prompts = [f'a photo of a {batch_y[2][j]}'] + other_words\n",
|
146 |
+
" all_prompts += [prompts]\n",
|
147 |
+
" \n",
|
148 |
+
" text_cond = clip_model.encode_text(clip.tokenize(prompts).to(clip_device))\n",
|
149 |
+
" text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True) \n",
|
150 |
+
"\n",
|
151 |
+
" global logits\n",
|
152 |
+
" logits = clip_model.logit_scale.exp() * image_features[j] @ text_cond.T\n",
|
153 |
+
"\n",
|
154 |
+
" global sim\n",
|
155 |
+
" sim = torch.softmax(logits, dim=-1)\n",
|
156 |
+
" \n",
|
157 |
+
" valid_sims += [sim]\n",
|
158 |
+
" \n",
|
159 |
+
" #valid_sims = torch.stack(valid_sims)\n",
|
160 |
+
" return valid_sims, all_prompts\n",
|
161 |
+
" \n",
|
162 |
+
"\n",
|
163 |
+
"def new_img_preprocess(x):\n",
|
164 |
+
" return {'x_inp': x[1], 'mask': (11, 'cls_token', x[2])}\n",
|
165 |
+
" \n",
|
166 |
+
"#get_similarities(lvis, partial(img_preprocess, center_context=0.5));\n",
|
167 |
+
"get_similarities(lvis, lambda x: x[1]);"
|
168 |
+
]
|
169 |
+
},
|
170 |
+
{
|
171 |
+
"cell_type": "code",
|
172 |
+
"execution_count": null,
|
173 |
+
"metadata": {},
|
174 |
+
"outputs": [],
|
175 |
+
"source": [
|
176 |
+
"preprocessing_functions = [\n",
|
177 |
+
"# ['clip mask CLS L11', lambda x: {'x_inp': x[1].cuda(), 'mask': (11, 'cls_token', x[2].cuda())}],\n",
|
178 |
+
"# ['clip mask CLS all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'cls_token', x[2].cuda())}],\n",
|
179 |
+
"# ['clip mask all all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'all', x[2].cuda())}],\n",
|
180 |
+
"# ['colorize object red', partial(img_preprocess, colorize=True)],\n",
|
181 |
+
"# ['add red outline', partial(img_preprocess, outline=True)],\n",
|
182 |
+
" \n",
|
183 |
+
"# ['BG brightness 50%', partial(img_preprocess, bg_fac=0.5)],\n",
|
184 |
+
"# ['BG brightness 10%', partial(img_preprocess, bg_fac=0.1)],\n",
|
185 |
+
"# ['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)],\n",
|
186 |
+
"# ['BG blur', partial(img_preprocess, blur=3)],\n",
|
187 |
+
"# ['BG blur & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n",
|
188 |
+
" \n",
|
189 |
+
"# ['crop large context', partial(img_preprocess, center_context=0.5)],\n",
|
190 |
+
"# ['crop small context', partial(img_preprocess, center_context=0.1)],\n",
|
191 |
+
" ['crop & background blur', partial(img_preprocess, blur=3, center_context=0.5)],\n",
|
192 |
+
" ['crop & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n",
|
193 |
+
"# ['crop & background blur & intensity 10%', partial(img_preprocess, blur=3, center_context=0.1, bg_fac=0.1)],\n",
|
194 |
+
"]\n",
|
195 |
+
"\n",
|
196 |
+
"preprocessing_functions = preprocessing_functions\n",
|
197 |
+
"\n",
|
198 |
+
"base, base_p = get_similarities(lvis, lambda x: x[1])\n",
|
199 |
+
"outs = [get_similarities(lvis, fun) for _, fun in preprocessing_functions]"
|
200 |
+
]
|
201 |
+
},
|
202 |
+
{
|
203 |
+
"cell_type": "code",
|
204 |
+
"execution_count": null,
|
205 |
+
"metadata": {},
|
206 |
+
"outputs": [],
|
207 |
+
"source": [
|
208 |
+
"outs2 = [get_similarities(lvis, fun) for _, fun in [['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)]]]"
|
209 |
+
]
|
210 |
+
},
|
211 |
+
{
|
212 |
+
"cell_type": "code",
|
213 |
+
"execution_count": null,
|
214 |
+
"metadata": {},
|
215 |
+
"outputs": [],
|
216 |
+
"source": [
|
217 |
+
"for j in range(1):\n",
|
218 |
+
" print(np.mean([outs2[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3]))"
|
219 |
+
]
|
220 |
+
},
|
221 |
+
{
|
222 |
+
"cell_type": "code",
|
223 |
+
"execution_count": null,
|
224 |
+
"metadata": {},
|
225 |
+
"outputs": [],
|
226 |
+
"source": [
|
227 |
+
"from pandas import DataFrame\n",
|
228 |
+
"tab = dict()\n",
|
229 |
+
"for j, (name, _) in enumerate(preprocessing_functions):\n",
|
230 |
+
" tab[name] = np.mean([outs[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3])\n",
|
231 |
+
" \n",
|
232 |
+
" \n",
|
233 |
+
"print('\\n'.join(f'{k} & {v*100:.2f} \\\\\\\\' for k,v in tab.items())) "
|
234 |
+
]
|
235 |
+
},
|
236 |
+
{
|
237 |
+
"cell_type": "markdown",
|
238 |
+
"metadata": {},
|
239 |
+
"source": [
|
240 |
+
"# Visual"
|
241 |
+
]
|
242 |
+
},
|
243 |
+
{
|
244 |
+
"cell_type": "code",
|
245 |
+
"execution_count": null,
|
246 |
+
"metadata": {},
|
247 |
+
"outputs": [],
|
248 |
+
"source": [
|
249 |
+
"from evaluation_utils import denorm, norm"
|
250 |
+
]
|
251 |
+
},
|
252 |
+
{
|
253 |
+
"cell_type": "code",
|
254 |
+
"execution_count": null,
|
255 |
+
"metadata": {},
|
256 |
+
"outputs": [],
|
257 |
+
"source": [
|
258 |
+
"def load_sample(filename, filename2):\n",
|
259 |
+
" from os.path import join\n",
|
260 |
+
" bp = expanduser('~/cloud/resources/sample_images')\n",
|
261 |
+
" tf = transforms.Compose([\n",
|
262 |
+
" transforms.ToTensor(),\n",
|
263 |
+
" transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
|
264 |
+
" transforms.Resize(224),\n",
|
265 |
+
" transforms.CenterCrop(224)\n",
|
266 |
+
" ])\n",
|
267 |
+
" tf2 = transforms.Compose([\n",
|
268 |
+
" transforms.ToTensor(),\n",
|
269 |
+
" transforms.Resize(224),\n",
|
270 |
+
" transforms.CenterCrop(224)\n",
|
271 |
+
" ])\n",
|
272 |
+
" inp1 = [None, tf(Image.open(join(bp, filename))), tf2(Image.open(join(bp, filename2)))]\n",
|
273 |
+
" inp1[1] = inp1[1].unsqueeze(0)\n",
|
274 |
+
" inp1[2] = inp1[2][:1] \n",
|
275 |
+
" return inp1\n",
|
276 |
+
"\n",
|
277 |
+
"def all_preprocessing(inp1):\n",
|
278 |
+
" return [\n",
|
279 |
+
" img_preprocess(inp1),\n",
|
280 |
+
" img_preprocess(inp1, colorize=True),\n",
|
281 |
+
" img_preprocess(inp1, outline=True), \n",
|
282 |
+
" img_preprocess(inp1, blur=3),\n",
|
283 |
+
" img_preprocess(inp1, bg_fac=0.1),\n",
|
284 |
+
" #img_preprocess(inp1, bg_fac=0.5),\n",
|
285 |
+
" #img_preprocess(inp1, blur=3, bg_fac=0.5), \n",
|
286 |
+
" img_preprocess(inp1, blur=3, bg_fac=0.5, center_context=0.5),\n",
|
287 |
+
" ]\n",
|
288 |
+
"\n"
|
289 |
+
]
|
290 |
+
},
|
291 |
+
{
|
292 |
+
"cell_type": "code",
|
293 |
+
"execution_count": null,
|
294 |
+
"metadata": {},
|
295 |
+
"outputs": [],
|
296 |
+
"source": [
|
297 |
+
"from torchvision import transforms\n",
|
298 |
+
"from PIL import Image\n",
|
299 |
+
"from matplotlib import pyplot as plt\n",
|
300 |
+
"from evaluation_utils import img_preprocess\n",
|
301 |
+
"import clip\n",
|
302 |
+
"\n",
|
303 |
+
"images_queries = [\n",
|
304 |
+
" [load_sample('things1.jpg', 'things1_jar.png'), ['jug', 'knife', 'car', 'animal', 'sieve', 'nothing']],\n",
|
305 |
+
" [load_sample('own_photos/IMG_2017s_square.jpg', 'own_photos/IMG_2017s_square_trash_can.png'), ['trash bin', 'house', 'car', 'bike', 'window', 'nothing']],\n",
|
306 |
+
"]\n",
|
307 |
+
"\n",
|
308 |
+
"\n",
|
309 |
+
"_, ax = plt.subplots(2 * len(images_queries), 6, figsize=(14, 4.5 * len(images_queries)))\n",
|
310 |
+
"\n",
|
311 |
+
"for j, (images, objects) in enumerate(images_queries):\n",
|
312 |
+
" \n",
|
313 |
+
" joint_image = all_preprocessing(images)\n",
|
314 |
+
" \n",
|
315 |
+
" joint_image = torch.stack(joint_image)[:,0]\n",
|
316 |
+
" clip_model, preprocess = clip.load(\"ViT-B/16\", device='cpu')\n",
|
317 |
+
" image_features = clip_model.encode_image(joint_image)\n",
|
318 |
+
" image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n",
|
319 |
+
" \n",
|
320 |
+
" prompts = [f'a photo of a {obj}'for obj in objects]\n",
|
321 |
+
" text_cond = clip_model.encode_text(clip.tokenize(prompts))\n",
|
322 |
+
" text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True)\n",
|
323 |
+
" logits = clip_model.logit_scale.exp() * image_features @ text_cond.T\n",
|
324 |
+
" sim = torch.softmax(logits, dim=-1).detach().cpu()\n",
|
325 |
+
"\n",
|
326 |
+
" for i, img in enumerate(joint_image):\n",
|
327 |
+
" ax[2*j, i].axis('off')\n",
|
328 |
+
" \n",
|
329 |
+
" ax[2*j, i].imshow(torch.clamp(denorm(joint_image[i]).permute(1,2,0), 0, 1))\n",
|
330 |
+
" ax[2*j+ 1, i].grid(True)\n",
|
331 |
+
" \n",
|
332 |
+
" ax[2*j + 1, i].set_ylim(0,1)\n",
|
333 |
+
" ax[2*j + 1, i].set_yticklabels([])\n",
|
334 |
+
" ax[2*j + 1, i].set_xticks([]) # set_xticks(range(len(prompts)))\n",
|
335 |
+
"# ax[1, i].set_xticklabels(objects, rotation=90)\n",
|
336 |
+
" for k in range(len(sim[i])):\n",
|
337 |
+
" ax[2*j + 1, i].bar(k, sim[i][k], color=plt.cm.tab20(1) if k!=0 else plt.cm.tab20(3))\n",
|
338 |
+
" ax[2*j + 1, i].text(k, 0.07, objects[k], rotation=90, ha='center', fontsize=15)\n",
|
339 |
+
"\n",
|
340 |
+
"plt.tight_layout()\n",
|
341 |
+
"plt.savefig('figures/prompt_engineering.pdf', bbox_inches='tight')"
|
342 |
+
]
|
343 |
+
}
|
344 |
+
],
|
345 |
+
"metadata": {
|
346 |
+
"kernelspec": {
|
347 |
+
"display_name": "env2",
|
348 |
+
"language": "python",
|
349 |
+
"name": "env2"
|
350 |
+
},
|
351 |
+
"language_info": {
|
352 |
+
"codemirror_mode": {
|
353 |
+
"name": "ipython",
|
354 |
+
"version": 3
|
355 |
+
},
|
356 |
+
"file_extension": ".py",
|
357 |
+
"mimetype": "text/x-python",
|
358 |
+
"name": "python",
|
359 |
+
"nbconvert_exporter": "python",
|
360 |
+
"pygments_lexer": "ipython3",
|
361 |
+
"version": "3.8.8"
|
362 |
+
}
|
363 |
+
},
|
364 |
+
"nbformat": 4,
|
365 |
+
"nbformat_minor": 4
|
366 |
+
}
|
clipseg/datasets/coco_wrapper.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
from types import new_class
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
|
8 |
+
from os.path import join, dirname, isdir, isfile, expanduser, realpath, basename
|
9 |
+
from random import shuffle, seed as set_seed
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
from itertools import combinations
|
13 |
+
from torchvision import transforms
|
14 |
+
from torchvision.transforms.transforms import Resize
|
15 |
+
|
16 |
+
from datasets.utils import blend_image_segmentation
|
17 |
+
from general_utils import get_from_repository
|
18 |
+
|
19 |
+
COCO_CLASSES = {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'}
|
20 |
+
|
21 |
+
class COCOWrapper(object):
|
22 |
+
|
23 |
+
def __init__(self, split, fold=0, image_size=400, aug=None, mask='separate', negative_prob=0,
|
24 |
+
with_class_label=False):
|
25 |
+
super().__init__()
|
26 |
+
|
27 |
+
self.mask = mask
|
28 |
+
self.with_class_label = with_class_label
|
29 |
+
self.negative_prob = negative_prob
|
30 |
+
|
31 |
+
from third_party.hsnet.data.coco import DatasetCOCO
|
32 |
+
|
33 |
+
get_from_repository('COCO-20i', ['COCO-20i.tar'])
|
34 |
+
|
35 |
+
foldpath = join(dirname(__file__), '../third_party/hsnet/data/splits/coco/%s/fold%d.pkl')
|
36 |
+
|
37 |
+
def build_img_metadata_classwise(self):
|
38 |
+
with open(foldpath % (self.split, self.fold), 'rb') as f:
|
39 |
+
img_metadata_classwise = pickle.load(f)
|
40 |
+
return img_metadata_classwise
|
41 |
+
|
42 |
+
|
43 |
+
DatasetCOCO.build_img_metadata_classwise = build_img_metadata_classwise
|
44 |
+
# DatasetCOCO.read_mask = read_mask
|
45 |
+
|
46 |
+
mean = [0.485, 0.456, 0.406]
|
47 |
+
std = [0.229, 0.224, 0.225]
|
48 |
+
transform = transforms.Compose([
|
49 |
+
transforms.Resize((image_size, image_size)),
|
50 |
+
transforms.ToTensor(),
|
51 |
+
transforms.Normalize(mean, std)
|
52 |
+
])
|
53 |
+
|
54 |
+
self.coco = DatasetCOCO(expanduser('~/datasets/COCO-20i/'), fold, transform, split, 1, False)
|
55 |
+
|
56 |
+
self.all_classes = [self.coco.class_ids]
|
57 |
+
self.coco.base_path = join(expanduser('~/datasets/COCO-20i'))
|
58 |
+
|
59 |
+
def __len__(self):
|
60 |
+
return len(self.coco)
|
61 |
+
|
62 |
+
def __getitem__(self, i):
|
63 |
+
sample = self.coco[i]
|
64 |
+
|
65 |
+
label_name = COCO_CLASSES[int(sample['class_id'])]
|
66 |
+
|
67 |
+
img_s, seg_s = sample['support_imgs'][0], sample['support_masks'][0]
|
68 |
+
|
69 |
+
if self.negative_prob > 0 and torch.rand(1).item() < self.negative_prob:
|
70 |
+
new_class_id = sample['class_id']
|
71 |
+
while new_class_id == sample['class_id']:
|
72 |
+
sample2 = self.coco[torch.randint(0, len(self), (1,)).item()]
|
73 |
+
new_class_id = sample2['class_id']
|
74 |
+
img_s = sample2['support_imgs'][0]
|
75 |
+
seg_s = torch.zeros_like(seg_s)
|
76 |
+
|
77 |
+
mask = self.mask
|
78 |
+
if mask == 'separate':
|
79 |
+
supp = (img_s, seg_s)
|
80 |
+
elif mask == 'text_label':
|
81 |
+
# DEPRECATED
|
82 |
+
supp = [int(sample['class_id'])]
|
83 |
+
elif mask == 'text':
|
84 |
+
supp = [label_name]
|
85 |
+
else:
|
86 |
+
if mask.startswith('text_and_'):
|
87 |
+
mask = mask[9:]
|
88 |
+
label_add = [label_name]
|
89 |
+
else:
|
90 |
+
label_add = []
|
91 |
+
|
92 |
+
supp = label_add + blend_image_segmentation(img_s, seg_s, mode=mask)
|
93 |
+
|
94 |
+
if self.with_class_label:
|
95 |
+
label = (torch.zeros(0), sample['class_id'],)
|
96 |
+
else:
|
97 |
+
label = (torch.zeros(0), )
|
98 |
+
|
99 |
+
return (sample['query_img'],) + tuple(supp), (sample['query_mask'].unsqueeze(0),) + label
|
clipseg/datasets/pascal_classes.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
[{"id": 1, "synonyms": ["aeroplane"]}, {"id": 2, "synonyms": ["bicycle"]}, {"id": 3, "synonyms": ["bird"]}, {"id": 4, "synonyms": ["boat"]}, {"id": 5, "synonyms": ["bottle"]}, {"id": 6, "synonyms": ["bus"]}, {"id": 7, "synonyms": ["car"]}, {"id": 8, "synonyms": ["cat"]}, {"id": 9, "synonyms": ["chair"]}, {"id": 10, "synonyms": ["cow"]}, {"id": 11, "synonyms": ["diningtable"]}, {"id": 12, "synonyms": ["dog"]}, {"id": 13, "synonyms": ["horse"]}, {"id": 14, "synonyms": ["motorbike"]}, {"id": 15, "synonyms": ["person"]}, {"id": 16, "synonyms": ["pottedplant"]}, {"id": 17, "synonyms": ["sheep"]}, {"id": 18, "synonyms": ["sofa"]}, {"id": 19, "synonyms": ["train"]}, {"id": 20, "synonyms": ["tvmonitor"]}]
|
clipseg/datasets/pascal_zeroshot.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os.path import expanduser
|
2 |
+
import torch
|
3 |
+
import json
|
4 |
+
import torchvision
|
5 |
+
from general_utils import get_from_repository
|
6 |
+
from general_utils import log
|
7 |
+
from torchvision import transforms
|
8 |
+
|
9 |
+
PASCAL_VOC_CLASSES_ZS = [['cattle.n.01', 'motorcycle.n.01'], ['aeroplane.n.01', 'sofa.n.01'],
|
10 |
+
['cat.n.01', 'television.n.03'], ['train.n.01', 'bottle.n.01'],
|
11 |
+
['chair.n.01', 'pot_plant.n.01']]
|
12 |
+
|
13 |
+
|
14 |
+
class PascalZeroShot(object):
|
15 |
+
|
16 |
+
def __init__(self, split, n_unseen, image_size=224) -> None:
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
import sys
|
20 |
+
sys.path.append('third_party/JoEm')
|
21 |
+
from third_party.JoEm.data_loader.dataset import VOCSegmentation
|
22 |
+
from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC
|
23 |
+
|
24 |
+
self.pascal_classes = VOC
|
25 |
+
self.image_size = image_size
|
26 |
+
|
27 |
+
self.transform = transforms.Compose([
|
28 |
+
transforms.Resize((image_size, image_size)),
|
29 |
+
])
|
30 |
+
|
31 |
+
if split == 'train':
|
32 |
+
self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen),
|
33 |
+
split=split, transform=True, transform_args=dict(base_size=312, crop_size=312),
|
34 |
+
ignore_bg=False, ignore_unseen=False, remv_unseen_img=True)
|
35 |
+
elif split == 'val':
|
36 |
+
self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen),
|
37 |
+
split=split, transform=False,
|
38 |
+
ignore_bg=False, ignore_unseen=False)
|
39 |
+
|
40 |
+
self.unseen_idx = get_unseen_idx(n_unseen)
|
41 |
+
|
42 |
+
def __len__(self):
|
43 |
+
return len(self.voc)
|
44 |
+
|
45 |
+
def __getitem__(self, i):
|
46 |
+
|
47 |
+
sample = self.voc[i]
|
48 |
+
label = sample['label'].long()
|
49 |
+
all_labels = [l for l in torch.where(torch.bincount(label.flatten())>0)[0].numpy().tolist() if l != 255]
|
50 |
+
class_indices = [l for l in all_labels]
|
51 |
+
class_names = [self.pascal_classes[l] for l in all_labels]
|
52 |
+
|
53 |
+
image = self.transform(sample['image'])
|
54 |
+
|
55 |
+
label = transforms.Resize((self.image_size, self.image_size),
|
56 |
+
interpolation=torchvision.transforms.InterpolationMode.NEAREST)(label.unsqueeze(0))[0]
|
57 |
+
|
58 |
+
return (image,), (label, )
|
59 |
+
|
60 |
+
|
clipseg/datasets/pfe_dataset.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os.path import expanduser
|
2 |
+
import torch
|
3 |
+
import json
|
4 |
+
from general_utils import get_from_repository
|
5 |
+
from datasets.lvis_oneshot3 import blend_image_segmentation
|
6 |
+
from general_utils import log
|
7 |
+
|
8 |
+
PASCAL_CLASSES = {a['id']: a['synonyms'] for a in json.load(open('datasets/pascal_classes.json'))}
|
9 |
+
|
10 |
+
|
11 |
+
class PFEPascalWrapper(object):
|
12 |
+
|
13 |
+
def __init__(self, mode, split, mask='separate', image_size=473, label_support=None, size=None, p_negative=0, aug=None):
|
14 |
+
import sys
|
15 |
+
# sys.path.append(expanduser('~/projects/new_one_shot'))
|
16 |
+
from third_party.PFENet.util.dataset import SemData
|
17 |
+
|
18 |
+
get_from_repository('PascalVOC2012', ['Pascal5i.tar'])
|
19 |
+
|
20 |
+
self.p_negative = p_negative
|
21 |
+
self.size = size
|
22 |
+
self.mode = mode
|
23 |
+
self.image_size = image_size
|
24 |
+
|
25 |
+
if label_support in {True, False}:
|
26 |
+
log.warning('label_support argument is deprecated. Use mask instead.')
|
27 |
+
#raise ValueError()
|
28 |
+
|
29 |
+
self.mask = mask
|
30 |
+
|
31 |
+
value_scale = 255
|
32 |
+
mean = [0.485, 0.456, 0.406]
|
33 |
+
mean = [item * value_scale for item in mean]
|
34 |
+
std = [0.229, 0.224, 0.225]
|
35 |
+
std = [item * value_scale for item in std]
|
36 |
+
|
37 |
+
import third_party.PFENet.util.transform as transform
|
38 |
+
|
39 |
+
if mode == 'val':
|
40 |
+
data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/val.txt')
|
41 |
+
|
42 |
+
data_transform = [transform.test_Resize(size=image_size)] if image_size != 'original' else []
|
43 |
+
data_transform += [
|
44 |
+
transform.ToTensor(),
|
45 |
+
transform.Normalize(mean=mean, std=std)
|
46 |
+
]
|
47 |
+
|
48 |
+
|
49 |
+
elif mode == 'train':
|
50 |
+
data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/voc_sbd_merge_noduplicate.txt')
|
51 |
+
|
52 |
+
assert image_size != 'original'
|
53 |
+
|
54 |
+
data_transform = [
|
55 |
+
transform.RandScale([0.9, 1.1]),
|
56 |
+
transform.RandRotate([-10, 10], padding=mean, ignore_label=255),
|
57 |
+
transform.RandomGaussianBlur(),
|
58 |
+
transform.RandomHorizontalFlip(),
|
59 |
+
transform.Crop((image_size, image_size), crop_type='rand', padding=mean, ignore_label=255),
|
60 |
+
transform.ToTensor(),
|
61 |
+
transform.Normalize(mean=mean, std=std)
|
62 |
+
]
|
63 |
+
|
64 |
+
data_transform = transform.Compose(data_transform)
|
65 |
+
|
66 |
+
self.dataset = SemData(split=split, mode=mode, data_root=expanduser('~/datasets/PascalVOC2012/VOC2012'),
|
67 |
+
data_list=data_list, shot=1, transform=data_transform, use_coco=False, use_split_coco=False)
|
68 |
+
|
69 |
+
self.class_list = self.dataset.sub_val_list if mode == 'val' else self.dataset.sub_list
|
70 |
+
|
71 |
+
# verify that subcls_list always has length 1
|
72 |
+
# assert len(set([len(d[4]) for d in self.dataset])) == 1
|
73 |
+
|
74 |
+
print('actual length', len(self.dataset.data_list))
|
75 |
+
|
76 |
+
def __len__(self):
|
77 |
+
if self.mode == 'val':
|
78 |
+
return len(self.dataset.data_list)
|
79 |
+
else:
|
80 |
+
return len(self.dataset.data_list)
|
81 |
+
|
82 |
+
def __getitem__(self, index):
|
83 |
+
if self.dataset.mode == 'train':
|
84 |
+
image, label, s_x, s_y, subcls_list = self.dataset[index % len(self.dataset.data_list)]
|
85 |
+
elif self.dataset.mode == 'val':
|
86 |
+
image, label, s_x, s_y, subcls_list, ori_label = self.dataset[index % len(self.dataset.data_list)]
|
87 |
+
ori_label = torch.from_numpy(ori_label).unsqueeze(0)
|
88 |
+
|
89 |
+
if self.image_size != 'original':
|
90 |
+
longerside = max(ori_label.size(1), ori_label.size(2))
|
91 |
+
backmask = torch.ones(ori_label.size(0), longerside, longerside).cuda()*255
|
92 |
+
backmask[0, :ori_label.size(1), :ori_label.size(2)] = ori_label
|
93 |
+
label = backmask.clone().long()
|
94 |
+
else:
|
95 |
+
label = label.unsqueeze(0)
|
96 |
+
|
97 |
+
# assert label.shape == (473, 473)
|
98 |
+
|
99 |
+
if self.p_negative > 0:
|
100 |
+
if torch.rand(1).item() < self.p_negative:
|
101 |
+
while True:
|
102 |
+
idx = torch.randint(0, len(self.dataset.data_list), (1,)).item()
|
103 |
+
_, _, s_x, s_y, subcls_list_tmp, _ = self.dataset[idx]
|
104 |
+
if subcls_list[0] != subcls_list_tmp[0]:
|
105 |
+
break
|
106 |
+
|
107 |
+
s_x = s_x[0]
|
108 |
+
s_y = (s_y == 1)[0]
|
109 |
+
label_fg = (label == 1).float()
|
110 |
+
val_mask = (label != 255).float()
|
111 |
+
|
112 |
+
class_id = self.class_list[subcls_list[0]]
|
113 |
+
|
114 |
+
label_name = PASCAL_CLASSES[class_id][0]
|
115 |
+
label_add = ()
|
116 |
+
mask = self.mask
|
117 |
+
|
118 |
+
if mask == 'text':
|
119 |
+
support = ('a photo of a ' + label_name + '.',)
|
120 |
+
elif mask == 'separate':
|
121 |
+
support = (s_x, s_y)
|
122 |
+
else:
|
123 |
+
if mask.startswith('text_and_'):
|
124 |
+
label_add = (label_name,)
|
125 |
+
mask = mask[9:]
|
126 |
+
|
127 |
+
support = (blend_image_segmentation(s_x, s_y.float(), mask)[0],)
|
128 |
+
|
129 |
+
return (image,) + label_add + support, (label_fg.unsqueeze(0), val_mask.unsqueeze(0), subcls_list[0])
|
clipseg/datasets/phrasecut.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
|
6 |
+
from os.path import join, isdir, isfile, expanduser
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from torchvision import transforms
|
10 |
+
from torchvision.transforms.transforms import Resize
|
11 |
+
|
12 |
+
from torch.nn import functional as nnf
|
13 |
+
from general_utils import get_from_repository
|
14 |
+
|
15 |
+
from skimage.draw import polygon2mask
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
def random_crop_slices(origin_size, target_size):
|
20 |
+
"""Gets slices of a random crop. """
|
21 |
+
assert origin_size[0] >= target_size[0] and origin_size[1] >= target_size[1], f'actual size: {origin_size}, target size: {target_size}'
|
22 |
+
|
23 |
+
offset_y = torch.randint(0, origin_size[0] - target_size[0] + 1, (1,)).item() # range: 0 <= value < high
|
24 |
+
offset_x = torch.randint(0, origin_size[1] - target_size[1] + 1, (1,)).item()
|
25 |
+
|
26 |
+
return slice(offset_y, offset_y + target_size[0]), slice(offset_x, offset_x + target_size[1])
|
27 |
+
|
28 |
+
|
29 |
+
def find_crop(seg, image_size, iterations=1000, min_frac=None, best_of=None):
|
30 |
+
|
31 |
+
|
32 |
+
best_crops = []
|
33 |
+
best_crop_not_ok = float('-inf'), None, None
|
34 |
+
min_sum = 0
|
35 |
+
|
36 |
+
seg = seg.astype('bool')
|
37 |
+
|
38 |
+
if min_frac is not None:
|
39 |
+
#min_sum = seg.sum() * min_frac
|
40 |
+
min_sum = seg.shape[0] * seg.shape[1] * min_frac
|
41 |
+
|
42 |
+
for iteration in range(iterations):
|
43 |
+
sl_y, sl_x = random_crop_slices(seg.shape, image_size)
|
44 |
+
seg_ = seg[sl_y, sl_x]
|
45 |
+
sum_seg_ = seg_.sum()
|
46 |
+
|
47 |
+
if sum_seg_ > min_sum:
|
48 |
+
|
49 |
+
if best_of is None:
|
50 |
+
return sl_y, sl_x, False
|
51 |
+
else:
|
52 |
+
best_crops += [(sum_seg_, sl_y, sl_x)]
|
53 |
+
if len(best_crops) >= best_of:
|
54 |
+
best_crops.sort(key=lambda x:x[0], reverse=True)
|
55 |
+
sl_y, sl_x = best_crops[0][1:]
|
56 |
+
|
57 |
+
return sl_y, sl_x, False
|
58 |
+
|
59 |
+
else:
|
60 |
+
if sum_seg_ > best_crop_not_ok[0]:
|
61 |
+
best_crop_not_ok = sum_seg_, sl_y, sl_x
|
62 |
+
|
63 |
+
else:
|
64 |
+
# return best segmentation found
|
65 |
+
return best_crop_not_ok[1:] + (best_crop_not_ok[0] <= min_sum,)
|
66 |
+
|
67 |
+
|
68 |
+
class PhraseCut(object):
|
69 |
+
|
70 |
+
def __init__(self, split, image_size=400, negative_prob=0, aug=None, aug_color=False, aug_crop=True,
|
71 |
+
min_size=0, remove_classes=None, with_visual=False, only_visual=False, mask=None):
|
72 |
+
super().__init__()
|
73 |
+
|
74 |
+
self.negative_prob = negative_prob
|
75 |
+
self.image_size = image_size
|
76 |
+
self.with_visual = with_visual
|
77 |
+
self.only_visual = only_visual
|
78 |
+
self.phrase_form = '{}'
|
79 |
+
self.mask = mask
|
80 |
+
self.aug_crop = aug_crop
|
81 |
+
|
82 |
+
if aug_color:
|
83 |
+
self.aug_color = transforms.Compose([
|
84 |
+
transforms.ColorJitter(0.5, 0.5, 0.2, 0.05),
|
85 |
+
])
|
86 |
+
else:
|
87 |
+
self.aug_color = None
|
88 |
+
|
89 |
+
get_from_repository('PhraseCut', ['PhraseCut.tar'], integrity_check=lambda local_dir: all([
|
90 |
+
isdir(join(local_dir, 'VGPhraseCut_v0')),
|
91 |
+
isdir(join(local_dir, 'VGPhraseCut_v0', 'images')),
|
92 |
+
isfile(join(local_dir, 'VGPhraseCut_v0', 'refer_train.json')),
|
93 |
+
len(os.listdir(join(local_dir, 'VGPhraseCut_v0', 'images'))) in {108250, 108249}
|
94 |
+
]))
|
95 |
+
|
96 |
+
from third_party.PhraseCutDataset.utils.refvg_loader import RefVGLoader
|
97 |
+
self.refvg_loader = RefVGLoader(split=split)
|
98 |
+
|
99 |
+
# img_ids where the size in the annotations does not match actual size
|
100 |
+
invalid_img_ids = set([150417, 285665, 498246, 61564, 285743, 498269, 498010, 150516, 150344, 286093, 61530,
|
101 |
+
150333, 286065, 285814, 498187, 285761, 498042])
|
102 |
+
|
103 |
+
mean = [0.485, 0.456, 0.406]
|
104 |
+
std = [0.229, 0.224, 0.225]
|
105 |
+
self.normalize = transforms.Normalize(mean, std)
|
106 |
+
|
107 |
+
self.sample_ids = [(i, j)
|
108 |
+
for i in self.refvg_loader.img_ids
|
109 |
+
for j in range(len(self.refvg_loader.get_img_ref_data(i)['phrases']))
|
110 |
+
if i not in invalid_img_ids]
|
111 |
+
|
112 |
+
|
113 |
+
# self.all_phrases = list(set([p for i in self.refvg_loader.img_ids for p in self.refvg_loader.get_img_ref_data(i)['phrases']]))
|
114 |
+
|
115 |
+
from nltk.stem import WordNetLemmatizer
|
116 |
+
wnl = WordNetLemmatizer()
|
117 |
+
|
118 |
+
# Filter by class (if remove_classes is set)
|
119 |
+
if remove_classes is None:
|
120 |
+
pass
|
121 |
+
else:
|
122 |
+
from datasets.generate_lvis_oneshot import PASCAL_SYNSETS, traverse_lemmas, traverse_lemmas_hypo
|
123 |
+
from nltk.corpus import wordnet
|
124 |
+
|
125 |
+
print('remove pascal classes...')
|
126 |
+
|
127 |
+
get_data = self.refvg_loader.get_img_ref_data # shortcut
|
128 |
+
keep_sids = None
|
129 |
+
|
130 |
+
if remove_classes[0] == 'pas5i':
|
131 |
+
subset_id = remove_classes[1]
|
132 |
+
from datasets.generate_lvis_oneshot import PASCAL_5I_SYNSETS_ORDERED, PASCAL_5I_CLASS_IDS
|
133 |
+
avoid = [PASCAL_5I_SYNSETS_ORDERED[i] for i in range(20) if i+1 not in PASCAL_5I_CLASS_IDS[subset_id]]
|
134 |
+
|
135 |
+
|
136 |
+
elif remove_classes[0] == 'zs':
|
137 |
+
stop = remove_classes[1]
|
138 |
+
|
139 |
+
from datasets.pascal_zeroshot import PASCAL_VOC_CLASSES_ZS
|
140 |
+
|
141 |
+
avoid = [c for class_set in PASCAL_VOC_CLASSES_ZS[:stop] for c in class_set]
|
142 |
+
print(avoid)
|
143 |
+
|
144 |
+
elif remove_classes[0] == 'aff':
|
145 |
+
# avoid = ['drink.v.01', 'sit.v.01', 'ride.v.02']
|
146 |
+
# all_lemmas = set(['drink', 'sit', 'ride'])
|
147 |
+
avoid = ['drink', 'drinks', 'drinking', 'sit', 'sits', 'sitting',
|
148 |
+
'ride', 'rides', 'riding',
|
149 |
+
'fly', 'flies', 'flying', 'drive', 'drives', 'driving', 'driven',
|
150 |
+
'swim', 'swims', 'swimming',
|
151 |
+
'wheels', 'wheel', 'legs', 'leg', 'ear', 'ears']
|
152 |
+
keep_sids = [(i, j) for i, j in self.sample_ids if
|
153 |
+
all(x not in avoid for x in get_data(i)['phrases'][j].split(' '))]
|
154 |
+
|
155 |
+
print('avoid classes:', avoid)
|
156 |
+
|
157 |
+
|
158 |
+
if keep_sids is None:
|
159 |
+
all_lemmas = [s for ps in avoid for s in traverse_lemmas_hypo(wordnet.synset(ps), max_depth=None)]
|
160 |
+
all_lemmas = list(set(all_lemmas))
|
161 |
+
all_lemmas = [h.replace('_', ' ').lower() for h in all_lemmas]
|
162 |
+
all_lemmas = set(all_lemmas)
|
163 |
+
|
164 |
+
# divide into multi word and single word
|
165 |
+
all_lemmas_s = set(l for l in all_lemmas if ' ' not in l)
|
166 |
+
all_lemmas_m = set(l for l in all_lemmas if l not in all_lemmas_s)
|
167 |
+
|
168 |
+
# new3
|
169 |
+
phrases = [get_data(i)['phrases'][j] for i, j in self.sample_ids]
|
170 |
+
remove_sids = set((i,j) for (i,j), phrase in zip(self.sample_ids, phrases)
|
171 |
+
if any(l in phrase for l in all_lemmas_m) or
|
172 |
+
len(set(wnl.lemmatize(w) for w in phrase.split(' ')).intersection(all_lemmas_s)) > 0
|
173 |
+
)
|
174 |
+
keep_sids = [(i, j) for i, j in self.sample_ids if (i,j) not in remove_sids]
|
175 |
+
|
176 |
+
print(f'Reduced to {len(keep_sids) / len(self.sample_ids):.3f}')
|
177 |
+
removed_ids = set(self.sample_ids) - set(keep_sids)
|
178 |
+
|
179 |
+
print('Examples of removed', len(removed_ids))
|
180 |
+
for i, j in list(removed_ids)[:20]:
|
181 |
+
print(i, get_data(i)['phrases'][j])
|
182 |
+
|
183 |
+
self.sample_ids = keep_sids
|
184 |
+
|
185 |
+
from itertools import groupby
|
186 |
+
samples_by_phrase = [(self.refvg_loader.get_img_ref_data(i)['phrases'][j], (i, j))
|
187 |
+
for i, j in self.sample_ids]
|
188 |
+
samples_by_phrase = sorted(samples_by_phrase)
|
189 |
+
samples_by_phrase = groupby(samples_by_phrase, key=lambda x: x[0])
|
190 |
+
|
191 |
+
self.samples_by_phrase = {prompt: [s[1] for s in prompt_sample_ids] for prompt, prompt_sample_ids in samples_by_phrase}
|
192 |
+
|
193 |
+
self.all_phrases = list(set(self.samples_by_phrase.keys()))
|
194 |
+
|
195 |
+
|
196 |
+
if self.only_visual:
|
197 |
+
assert self.with_visual
|
198 |
+
self.sample_ids = [(i, j) for i, j in self.sample_ids
|
199 |
+
if len(self.samples_by_phrase[self.refvg_loader.get_img_ref_data(i)['phrases'][j]]) > 1]
|
200 |
+
|
201 |
+
# Filter by size (if min_size is set)
|
202 |
+
sizes = [self.refvg_loader.get_img_ref_data(i)['gt_boxes'][j] for i, j in self.sample_ids]
|
203 |
+
image_sizes = [self.refvg_loader.get_img_ref_data(i)['width'] * self.refvg_loader.get_img_ref_data(i)['height'] for i, j in self.sample_ids]
|
204 |
+
#self.sizes = [sum([(s[2] - s[0]) * (s[3] - s[1]) for s in size]) for size in sizes]
|
205 |
+
self.sizes = [sum([s[2] * s[3] for s in size]) / img_size for size, img_size in zip(sizes, image_sizes)]
|
206 |
+
|
207 |
+
if min_size:
|
208 |
+
print('filter by size')
|
209 |
+
|
210 |
+
self.sample_ids = [self.sample_ids[i] for i in range(len(self.sample_ids)) if self.sizes[i] > min_size]
|
211 |
+
|
212 |
+
self.base_path = join(expanduser('~/datasets/PhraseCut/VGPhraseCut_v0/images/'))
|
213 |
+
|
214 |
+
def __len__(self):
|
215 |
+
return len(self.sample_ids)
|
216 |
+
|
217 |
+
|
218 |
+
def load_sample(self, sample_i, j):
|
219 |
+
|
220 |
+
img_ref_data = self.refvg_loader.get_img_ref_data(sample_i)
|
221 |
+
|
222 |
+
polys_phrase0 = img_ref_data['gt_Polygons'][j]
|
223 |
+
phrase = img_ref_data['phrases'][j]
|
224 |
+
phrase = self.phrase_form.format(phrase)
|
225 |
+
|
226 |
+
masks = []
|
227 |
+
for polys in polys_phrase0:
|
228 |
+
for poly in polys:
|
229 |
+
poly = [p[::-1] for p in poly] # swap x,y
|
230 |
+
masks += [polygon2mask((img_ref_data['height'], img_ref_data['width']), poly)]
|
231 |
+
|
232 |
+
seg = np.stack(masks).max(0)
|
233 |
+
img = np.array(Image.open(join(self.base_path, str(img_ref_data['image_id']) + '.jpg')))
|
234 |
+
|
235 |
+
min_shape = min(img.shape[:2])
|
236 |
+
|
237 |
+
if self.aug_crop:
|
238 |
+
sly, slx, exceed = find_crop(seg, (min_shape, min_shape), iterations=50, min_frac=0.05)
|
239 |
+
else:
|
240 |
+
sly, slx = slice(0, None), slice(0, None)
|
241 |
+
|
242 |
+
seg = seg[sly, slx]
|
243 |
+
img = img[sly, slx]
|
244 |
+
|
245 |
+
seg = seg.astype('uint8')
|
246 |
+
seg = torch.from_numpy(seg).view(1, 1, *seg.shape)
|
247 |
+
|
248 |
+
if img.ndim == 2:
|
249 |
+
img = np.dstack([img] * 3)
|
250 |
+
|
251 |
+
img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).float()
|
252 |
+
|
253 |
+
seg = nnf.interpolate(seg, (self.image_size, self.image_size), mode='nearest')[0,0]
|
254 |
+
img = nnf.interpolate(img, (self.image_size, self.image_size), mode='bilinear', align_corners=True)[0]
|
255 |
+
|
256 |
+
# img = img.permute([2,0, 1])
|
257 |
+
img = img / 255.0
|
258 |
+
|
259 |
+
if self.aug_color is not None:
|
260 |
+
img = self.aug_color(img)
|
261 |
+
|
262 |
+
img = self.normalize(img)
|
263 |
+
|
264 |
+
|
265 |
+
|
266 |
+
return img, seg, phrase
|
267 |
+
|
268 |
+
def __getitem__(self, i):
|
269 |
+
|
270 |
+
sample_i, j = self.sample_ids[i]
|
271 |
+
|
272 |
+
img, seg, phrase = self.load_sample(sample_i, j)
|
273 |
+
|
274 |
+
if self.negative_prob > 0:
|
275 |
+
if torch.rand((1,)).item() < self.negative_prob:
|
276 |
+
|
277 |
+
new_phrase = None
|
278 |
+
while new_phrase is None or new_phrase == phrase:
|
279 |
+
idx = torch.randint(0, len(self.all_phrases), (1,)).item()
|
280 |
+
new_phrase = self.all_phrases[idx]
|
281 |
+
phrase = new_phrase
|
282 |
+
seg = torch.zeros_like(seg)
|
283 |
+
|
284 |
+
if self.with_visual:
|
285 |
+
# find a corresponding visual image
|
286 |
+
if phrase in self.samples_by_phrase and len(self.samples_by_phrase[phrase]) > 1:
|
287 |
+
idx = torch.randint(0, len(self.samples_by_phrase[phrase]), (1,)).item()
|
288 |
+
other_sample = self.samples_by_phrase[phrase][idx]
|
289 |
+
#print(other_sample)
|
290 |
+
img_s, seg_s, _ = self.load_sample(*other_sample)
|
291 |
+
|
292 |
+
from datasets.utils import blend_image_segmentation
|
293 |
+
|
294 |
+
if self.mask in {'separate', 'text_and_separate'}:
|
295 |
+
# assert img.shape[1:] == img_s.shape[1:] == seg_s.shape == seg.shape[1:]
|
296 |
+
add_phrase = [phrase] if self.mask == 'text_and_separate' else []
|
297 |
+
vis_s = add_phrase + [img_s, seg_s, True]
|
298 |
+
else:
|
299 |
+
if self.mask.startswith('text_and_'):
|
300 |
+
mask_mode = self.mask[9:]
|
301 |
+
label_add = [phrase]
|
302 |
+
else:
|
303 |
+
mask_mode = self.mask
|
304 |
+
label_add = []
|
305 |
+
|
306 |
+
masked_img_s = torch.from_numpy(blend_image_segmentation(img_s, seg_s, mode=mask_mode, image_size=self.image_size)[0])
|
307 |
+
vis_s = label_add + [masked_img_s, True]
|
308 |
+
|
309 |
+
else:
|
310 |
+
# phrase is unique
|
311 |
+
vis_s = torch.zeros_like(img)
|
312 |
+
|
313 |
+
if self.mask in {'separate', 'text_and_separate'}:
|
314 |
+
add_phrase = [phrase] if self.mask == 'text_and_separate' else []
|
315 |
+
vis_s = add_phrase + [vis_s, torch.zeros(*vis_s.shape[1:], dtype=torch.uint8), False]
|
316 |
+
elif self.mask.startswith('text_and_'):
|
317 |
+
vis_s = [phrase, vis_s, False]
|
318 |
+
else:
|
319 |
+
vis_s = [vis_s, False]
|
320 |
+
else:
|
321 |
+
assert self.mask == 'text'
|
322 |
+
vis_s = [phrase]
|
323 |
+
|
324 |
+
seg = seg.unsqueeze(0).float()
|
325 |
+
|
326 |
+
data_x = (img,) + tuple(vis_s)
|
327 |
+
|
328 |
+
return data_x, (seg, torch.zeros(0), i)
|
329 |
+
|
330 |
+
|
331 |
+
class PhraseCutPlus(PhraseCut):
|
332 |
+
|
333 |
+
def __init__(self, split, image_size=400, aug=None, aug_color=False, aug_crop=True, min_size=0, remove_classes=None, only_visual=False, mask=None):
|
334 |
+
super().__init__(split, image_size=image_size, negative_prob=0.2, aug=aug, aug_color=aug_color, aug_crop=aug_crop, min_size=min_size,
|
335 |
+
remove_classes=remove_classes, with_visual=True, only_visual=only_visual, mask=mask)
|
clipseg/datasets/utils.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def blend_image_segmentation(img, seg, mode, image_size=224):
|
7 |
+
|
8 |
+
|
9 |
+
if mode in {'blur_highlight', 'blur3_highlight', 'blur3_highlight01', 'blur_highlight_random', 'crop'}:
|
10 |
+
if isinstance(img, np.ndarray):
|
11 |
+
img = torch.from_numpy(img)
|
12 |
+
|
13 |
+
if isinstance(seg, np.ndarray):
|
14 |
+
seg = torch.from_numpy(seg)
|
15 |
+
|
16 |
+
if mode == 'overlay':
|
17 |
+
out = img * seg
|
18 |
+
out = [out.astype('float32')]
|
19 |
+
elif mode == 'highlight':
|
20 |
+
out = img * seg[None, :, :] * 0.85 + 0.15 * img
|
21 |
+
out = [out.astype('float32')]
|
22 |
+
elif mode == 'highlight2':
|
23 |
+
img = img / 2
|
24 |
+
out = (img+0.1) * seg[None, :, :] + 0.3 * img
|
25 |
+
out = [out.astype('float32')]
|
26 |
+
elif mode == 'blur_highlight':
|
27 |
+
from evaluation_utils import img_preprocess
|
28 |
+
out = [img_preprocess((None, [img], [seg]), blur=1, bg_fac=0.5).numpy()[0] - 0.01]
|
29 |
+
elif mode == 'blur3_highlight':
|
30 |
+
from evaluation_utils import img_preprocess
|
31 |
+
out = [img_preprocess((None, [img], [seg]), blur=3, bg_fac=0.5).numpy()[0] - 0.01]
|
32 |
+
elif mode == 'blur3_highlight01':
|
33 |
+
from evaluation_utils import img_preprocess
|
34 |
+
out = [img_preprocess((None, [img], [seg]), blur=3, bg_fac=0.1).numpy()[0] - 0.01]
|
35 |
+
elif mode == 'blur_highlight_random':
|
36 |
+
from evaluation_utils import img_preprocess
|
37 |
+
out = [img_preprocess((None, [img], [seg]), blur=0 + torch.randint(0, 3, (1,)).item(), bg_fac=0.1 + 0.8*torch.rand(1).item()).numpy()[0] - 0.01]
|
38 |
+
elif mode == 'crop':
|
39 |
+
from evaluation_utils import img_preprocess
|
40 |
+
out = [img_preprocess((None, [img], [seg]), blur=1, center_context=0.1, image_size=image_size)[0].numpy()]
|
41 |
+
elif mode == 'crop_blur_highlight':
|
42 |
+
from evaluation_utils import img_preprocess
|
43 |
+
out = [img_preprocess((None, [img], [seg]), blur=3, center_context=0.1, bg_fac=0.1, image_size=image_size)[0].numpy()]
|
44 |
+
elif mode == 'crop_blur_highlight352':
|
45 |
+
from evaluation_utils import img_preprocess
|
46 |
+
out = [img_preprocess((None, [img], [seg]), blur=3, center_context=0.1, bg_fac=0.1, image_size=352)[0].numpy()]
|
47 |
+
elif mode == 'shape':
|
48 |
+
out = [np.stack([seg[:, :]]*3).astype('float32')]
|
49 |
+
elif mode == 'concat':
|
50 |
+
out = [np.concatenate([img, seg[None, :, :]]).astype('float32')]
|
51 |
+
elif mode == 'image_only':
|
52 |
+
out = [img.astype('float32')]
|
53 |
+
elif mode == 'image_black':
|
54 |
+
out = [img.astype('float32')*0]
|
55 |
+
elif mode is None:
|
56 |
+
out = [img.astype('float32')]
|
57 |
+
elif mode == 'separate':
|
58 |
+
out = [img.astype('float32'), seg.astype('int64')]
|
59 |
+
elif mode == 'separate_img_black':
|
60 |
+
out = [img.astype('float32')*0, seg.astype('int64')]
|
61 |
+
elif mode == 'separate_seg_ones':
|
62 |
+
out = [img.astype('float32'), np.ones_like(seg).astype('int64')]
|
63 |
+
elif mode == 'separate_both_black':
|
64 |
+
out = [img.astype('float32')*0, seg.astype('int64')*0]
|
65 |
+
else:
|
66 |
+
raise ValueError(f'invalid mode: {mode}')
|
67 |
+
|
68 |
+
return out
|
clipseg/environment.yml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: clipseg-environment
|
2 |
+
channels:
|
3 |
+
- conda-forge
|
4 |
+
- pytorch
|
5 |
+
dependencies:
|
6 |
+
- numpy
|
7 |
+
- scipy
|
8 |
+
- matplotlib-base
|
9 |
+
- pip
|
10 |
+
- pip:
|
11 |
+
- --find-links https://download.pytorch.org/whl/torch_stable.html
|
12 |
+
- torch==1.10.0+cpu
|
13 |
+
- torchvision==0.11.1+cpu
|
14 |
+
- opencv-python
|
15 |
+
- git+https://github.com/openai/CLIP.git
|
clipseg/evaluation_utils.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.functional import Tensor
|
2 |
+
from general_utils import load_model
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
def denorm(img):
|
8 |
+
|
9 |
+
np_input = False
|
10 |
+
if isinstance(img, np.ndarray):
|
11 |
+
img = torch.from_numpy(img)
|
12 |
+
np_input = True
|
13 |
+
|
14 |
+
mean = torch.Tensor([0.485, 0.456, 0.406])
|
15 |
+
std = torch.Tensor([0.229, 0.224, 0.225])
|
16 |
+
|
17 |
+
img_denorm = (img*std[:,None,None]) + mean[:,None,None]
|
18 |
+
|
19 |
+
if np_input:
|
20 |
+
img_denorm = np.clip(img_denorm.numpy(), 0, 1)
|
21 |
+
else:
|
22 |
+
img_denorm = torch.clamp(img_denorm, 0, 1)
|
23 |
+
|
24 |
+
return img_denorm
|
25 |
+
|
26 |
+
|
27 |
+
def norm(img):
|
28 |
+
mean = torch.Tensor([0.485, 0.456, 0.406])
|
29 |
+
std = torch.Tensor([0.229, 0.224, 0.225])
|
30 |
+
return (img - mean[:,None,None]) / std[:,None,None]
|
31 |
+
|
32 |
+
|
33 |
+
def fast_iou_curve(p, g):
|
34 |
+
|
35 |
+
g = g[p.sort().indices]
|
36 |
+
p = torch.sigmoid(p.sort().values)
|
37 |
+
|
38 |
+
scores = []
|
39 |
+
vals = np.linspace(0, 1, 50)
|
40 |
+
|
41 |
+
for q in vals:
|
42 |
+
|
43 |
+
n = int(len(g) * q)
|
44 |
+
|
45 |
+
valid = torch.where(p > q)[0]
|
46 |
+
if len(valid) > 0:
|
47 |
+
n = int(valid[0])
|
48 |
+
else:
|
49 |
+
n = len(g)
|
50 |
+
|
51 |
+
fn = g[:n].sum()
|
52 |
+
tn = n - fn
|
53 |
+
tp = g[n:].sum()
|
54 |
+
fp = len(g) - n - tp
|
55 |
+
|
56 |
+
iou = tp / (tp + fn + fp)
|
57 |
+
|
58 |
+
precision = tp / (tp + fp)
|
59 |
+
recall = tp / (tp + fn)
|
60 |
+
|
61 |
+
scores += [iou]
|
62 |
+
|
63 |
+
return vals, scores
|
64 |
+
|
65 |
+
|
66 |
+
def fast_rp_curve(p, g):
|
67 |
+
|
68 |
+
g = g[p.sort().indices]
|
69 |
+
p = torch.sigmoid(p.sort().values)
|
70 |
+
|
71 |
+
precisions, recalls = [], []
|
72 |
+
vals = np.linspace(p.min(), p.max(), 250)
|
73 |
+
|
74 |
+
for q in p[::100000]:
|
75 |
+
|
76 |
+
n = int(len(g) * q)
|
77 |
+
|
78 |
+
valid = torch.where(p > q)[0]
|
79 |
+
if len(valid) > 0:
|
80 |
+
n = int(valid[0])
|
81 |
+
else:
|
82 |
+
n = len(g)
|
83 |
+
|
84 |
+
fn = g[:n].sum()
|
85 |
+
tn = n - fn
|
86 |
+
tp = g[n:].sum()
|
87 |
+
fp = len(g) - n - tp
|
88 |
+
|
89 |
+
iou = tp / (tp + fn + fp)
|
90 |
+
|
91 |
+
precision = tp / (tp + fp)
|
92 |
+
recall = tp / (tp + fn)
|
93 |
+
|
94 |
+
precisions += [precision]
|
95 |
+
recalls += [recall]
|
96 |
+
|
97 |
+
return recalls, precisions
|
98 |
+
|
99 |
+
|
100 |
+
# Image processing
|
101 |
+
|
102 |
+
def img_preprocess(batch, blur=0, grayscale=False, center_context=None, rect=False, rect_color=(255,0,0), rect_width=2,
|
103 |
+
brightness=1.0, bg_fac=1, colorize=False, outline=False, image_size=224):
|
104 |
+
import cv2
|
105 |
+
|
106 |
+
rw = rect_width
|
107 |
+
|
108 |
+
out = []
|
109 |
+
for img, mask in zip(batch[1], batch[2]):
|
110 |
+
|
111 |
+
img = img.cpu() if isinstance(img, torch.Tensor) else torch.from_numpy(img)
|
112 |
+
mask = mask.cpu() if isinstance(mask, torch.Tensor) else torch.from_numpy(mask)
|
113 |
+
|
114 |
+
img *= brightness
|
115 |
+
img_bl = img
|
116 |
+
if blur > 0: # best 5
|
117 |
+
img_bl = torch.from_numpy(cv2.GaussianBlur(img.permute(1,2,0).numpy(), (15, 15), blur)).permute(2,0,1)
|
118 |
+
|
119 |
+
if grayscale:
|
120 |
+
img_bl = img_bl[1][None]
|
121 |
+
|
122 |
+
#img_inp = img_ratio*img*mask + (1-img_ratio)*img_bl
|
123 |
+
# img_inp = img_ratio*img*mask + (1-img_ratio)*img_bl * (1-mask)
|
124 |
+
img_inp = img*mask + (bg_fac) * img_bl * (1-mask)
|
125 |
+
|
126 |
+
if rect:
|
127 |
+
_, bbox = crop_mask(img, mask, context=0.1)
|
128 |
+
img_inp[:, bbox[2]: bbox[3], max(0, bbox[0]-rw):bbox[0]+rw] = torch.tensor(rect_color)[:,None,None]
|
129 |
+
img_inp[:, bbox[2]: bbox[3], max(0, bbox[1]-rw):bbox[1]+rw] = torch.tensor(rect_color)[:,None,None]
|
130 |
+
img_inp[:, max(0, bbox[2]-1): bbox[2]+rw, bbox[0]:bbox[1]] = torch.tensor(rect_color)[:,None,None]
|
131 |
+
img_inp[:, max(0, bbox[3]-1): bbox[3]+rw, bbox[0]:bbox[1]] = torch.tensor(rect_color)[:,None,None]
|
132 |
+
|
133 |
+
|
134 |
+
if center_context is not None:
|
135 |
+
img_inp = object_crop(img_inp, mask, context=center_context, image_size=image_size)
|
136 |
+
|
137 |
+
if colorize:
|
138 |
+
img_gray = denorm(img)
|
139 |
+
img_gray = cv2.cvtColor(img_gray.permute(1,2,0).numpy(), cv2.COLOR_RGB2GRAY)
|
140 |
+
img_gray = torch.stack([torch.from_numpy(img_gray)]*3)
|
141 |
+
img_inp = torch.tensor([1,0.2,0.2])[:,None,None] * img_gray * mask + bg_fac * img_gray * (1-mask)
|
142 |
+
img_inp = norm(img_inp)
|
143 |
+
|
144 |
+
if outline:
|
145 |
+
cont = cv2.findContours(mask.byte().numpy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
146 |
+
outline_img = np.zeros(mask.shape, dtype=np.uint8)
|
147 |
+
cv2.drawContours(outline_img, cont[0], -1, thickness=5, color=(255, 255, 255))
|
148 |
+
outline_img = torch.stack([torch.from_numpy(outline_img)]*3).float() / 255.
|
149 |
+
img_inp = torch.tensor([1,0,0])[:,None,None] * outline_img + denorm(img_inp) * (1- outline_img)
|
150 |
+
img_inp = norm(img_inp)
|
151 |
+
|
152 |
+
out += [img_inp]
|
153 |
+
|
154 |
+
return torch.stack(out)
|
155 |
+
|
156 |
+
|
157 |
+
def object_crop(img, mask, context=0.0, square=False, image_size=224):
|
158 |
+
img_crop, bbox = crop_mask(img, mask, context=context, square=square)
|
159 |
+
img_crop = pad_to_square(img_crop, channel_dim=0)
|
160 |
+
img_crop = torch.nn.functional.interpolate(img_crop.unsqueeze(0), (image_size, image_size)).squeeze(0)
|
161 |
+
return img_crop
|
162 |
+
|
163 |
+
|
164 |
+
def crop_mask(img, mask, context=0.0, square=False):
|
165 |
+
|
166 |
+
assert img.shape[1:] == mask.shape
|
167 |
+
|
168 |
+
bbox = [mask.max(0).values.argmax(), mask.size(0) - mask.max(0).values.flip(0).argmax()]
|
169 |
+
bbox += [mask.max(1).values.argmax(), mask.size(1) - mask.max(1).values.flip(0).argmax()]
|
170 |
+
bbox = [int(x) for x in bbox]
|
171 |
+
|
172 |
+
width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0])
|
173 |
+
|
174 |
+
# square mask
|
175 |
+
if square:
|
176 |
+
bbox[0] = int(max(0, bbox[0] - context * height))
|
177 |
+
bbox[1] = int(min(mask.size(0), bbox[1] + context * height))
|
178 |
+
bbox[2] = int(max(0, bbox[2] - context * width))
|
179 |
+
bbox[3] = int(min(mask.size(1), bbox[3] + context * width))
|
180 |
+
|
181 |
+
width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0])
|
182 |
+
if height > width:
|
183 |
+
bbox[2] = int(max(0, (bbox[2] - 0.5*height)))
|
184 |
+
bbox[3] = bbox[2] + height
|
185 |
+
else:
|
186 |
+
bbox[0] = int(max(0, (bbox[0] - 0.5*width)))
|
187 |
+
bbox[1] = bbox[0] + width
|
188 |
+
else:
|
189 |
+
bbox[0] = int(max(0, bbox[0] - context * height))
|
190 |
+
bbox[1] = int(min(mask.size(0), bbox[1] + context * height))
|
191 |
+
bbox[2] = int(max(0, bbox[2] - context * width))
|
192 |
+
bbox[3] = int(min(mask.size(1), bbox[3] + context * width))
|
193 |
+
|
194 |
+
width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0])
|
195 |
+
img_crop = img[:, bbox[2]: bbox[3], bbox[0]: bbox[1]]
|
196 |
+
return img_crop, bbox
|
197 |
+
|
198 |
+
|
199 |
+
def pad_to_square(img, channel_dim=2, fill=0):
|
200 |
+
"""
|
201 |
+
|
202 |
+
|
203 |
+
add padding such that a squared image is returned """
|
204 |
+
|
205 |
+
from torchvision.transforms.functional import pad
|
206 |
+
|
207 |
+
if channel_dim == 2:
|
208 |
+
img = img.permute(2, 0, 1)
|
209 |
+
elif channel_dim == 0:
|
210 |
+
pass
|
211 |
+
else:
|
212 |
+
raise ValueError('invalid channel_dim')
|
213 |
+
|
214 |
+
h, w = img.shape[1:]
|
215 |
+
pady1 = pady2 = padx1 = padx2 = 0
|
216 |
+
|
217 |
+
if h > w:
|
218 |
+
padx1 = (h - w) // 2
|
219 |
+
padx2 = h - w - padx1
|
220 |
+
elif w > h:
|
221 |
+
pady1 = (w - h) // 2
|
222 |
+
pady2 = w - h - pady1
|
223 |
+
|
224 |
+
img_padded = pad(img, padding=(padx1, pady1, padx2, pady2), padding_mode='constant')
|
225 |
+
|
226 |
+
if channel_dim == 2:
|
227 |
+
img_padded = img_padded.permute(1, 2, 0)
|
228 |
+
|
229 |
+
return img_padded
|
230 |
+
|
231 |
+
|
232 |
+
# qualitative
|
233 |
+
|
234 |
+
def split_sentence(inp, limit=9):
|
235 |
+
t_new, current_len = [], 0
|
236 |
+
for k, t in enumerate(inp.split(' ')):
|
237 |
+
current_len += len(t) + 1
|
238 |
+
t_new += [t+' ']
|
239 |
+
# not last
|
240 |
+
if current_len > limit and k != len(inp.split(' ')) - 1:
|
241 |
+
current_len = 0
|
242 |
+
t_new += ['\n']
|
243 |
+
|
244 |
+
t_new = ''.join(t_new)
|
245 |
+
return t_new
|
246 |
+
|
247 |
+
|
248 |
+
from matplotlib import pyplot as plt
|
249 |
+
|
250 |
+
|
251 |
+
def plot(imgs, *preds, labels=None, scale=1, cmap=plt.cm.magma, aps=None, gt_labels=None, vmax=None):
|
252 |
+
|
253 |
+
row_off = 0 if labels is None else 1
|
254 |
+
_, ax = plt.subplots(len(imgs) + row_off, 1 + len(preds), figsize=(scale * float(1 + 2*len(preds)), scale * float(len(imgs)*2)))
|
255 |
+
[a.axis('off') for a in ax.flatten()]
|
256 |
+
|
257 |
+
if labels is not None:
|
258 |
+
for j in range(len(labels)):
|
259 |
+
t_new = split_sentence(labels[j], limit=6)
|
260 |
+
ax[0, 1+ j].text(0.5, 0.1, t_new, ha='center', fontsize=3+ 10*scale)
|
261 |
+
|
262 |
+
|
263 |
+
for i in range(len(imgs)):
|
264 |
+
ax[i + row_off,0].imshow(imgs[i])
|
265 |
+
for j in range(len(preds)):
|
266 |
+
img = preds[j][i][0].detach().cpu().numpy()
|
267 |
+
|
268 |
+
if gt_labels is not None and labels[j] == gt_labels[i]:
|
269 |
+
print(j, labels[j], gt_labels[i])
|
270 |
+
edgecolor = 'red'
|
271 |
+
if aps is not None:
|
272 |
+
ax[i + row_off, 1 + j].text(30, 70, f'AP: {aps[i]:.3f}', color='red', fontsize=8)
|
273 |
+
else:
|
274 |
+
edgecolor = 'k'
|
275 |
+
|
276 |
+
rect = plt.Rectangle([0,0], img.shape[0], img.shape[1], facecolor="none",
|
277 |
+
edgecolor=edgecolor, linewidth=3)
|
278 |
+
ax[i + row_off,1 + j].add_patch(rect)
|
279 |
+
|
280 |
+
if vmax is None:
|
281 |
+
this_vmax = 1
|
282 |
+
elif vmax == 'per_prompt':
|
283 |
+
this_vmax = max([preds[j][_i][0].max() for _i in range(len(imgs))])
|
284 |
+
elif vmax == 'per_image':
|
285 |
+
this_vmax = max([preds[_j][i][0].max() for _j in range(len(preds))])
|
286 |
+
|
287 |
+
ax[i + row_off,1 + j].imshow(img, vmin=0, vmax=this_vmax, cmap=cmap)
|
288 |
+
|
289 |
+
|
290 |
+
# ax[i,1 + j].imshow(preds[j][i][0].detach().cpu().numpy(), vmin=preds[j].min(), vmax=preds[j].max())
|
291 |
+
plt.tight_layout()
|
292 |
+
plt.subplots_adjust(wspace=0.05, hspace=0.05)
|
clipseg/example_image.jpg
ADDED
clipseg/experiments/ablation.yaml
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
configuration:
|
2 |
+
batch_size: 64
|
3 |
+
optimizer: torch.optim.AdamW
|
4 |
+
|
5 |
+
lr: 0.001
|
6 |
+
|
7 |
+
trainer: experiment_setup.train_loop
|
8 |
+
scorer: experiment_setup.score
|
9 |
+
model: models.clipseg.CLIPDensePredT
|
10 |
+
|
11 |
+
lr_scheduler: cosine
|
12 |
+
T_max: 20000
|
13 |
+
eta_min: 0.0001
|
14 |
+
|
15 |
+
max_iterations: 20000 # <-##########################################
|
16 |
+
val_interval: null
|
17 |
+
|
18 |
+
# dataset
|
19 |
+
dataset: datasets.phrasecut.PhraseCut # <-----------------
|
20 |
+
split_mode: pascal_test
|
21 |
+
split: train
|
22 |
+
mask: text_and_crop_blur_highlight352
|
23 |
+
image_size: 352
|
24 |
+
negative_prob: 0.2
|
25 |
+
mix_text_max: 0.5
|
26 |
+
|
27 |
+
# general
|
28 |
+
mix: True # <-----------------
|
29 |
+
prompt: shuffle+
|
30 |
+
norm_cond: True
|
31 |
+
mix_text_min: 0.0
|
32 |
+
with_visual: True
|
33 |
+
|
34 |
+
# model
|
35 |
+
version: 'ViT-B/16'
|
36 |
+
extract_layers: [3, 7, 9]
|
37 |
+
reduce_dim: 64
|
38 |
+
depth: 3
|
39 |
+
fix_shift: False # <-##########################################
|
40 |
+
|
41 |
+
loss: torch.nn.functional.binary_cross_entropy_with_logits
|
42 |
+
amp: True
|
43 |
+
|
44 |
+
test_configuration_common:
|
45 |
+
normalize: True
|
46 |
+
image_size: 352
|
47 |
+
batch_size: 32
|
48 |
+
sigmoid: True
|
49 |
+
split: test
|
50 |
+
label_support: True
|
51 |
+
|
52 |
+
test_configuration:
|
53 |
+
|
54 |
+
-
|
55 |
+
name: pc
|
56 |
+
metric: metrics.FixedIntervalMetrics
|
57 |
+
test_dataset: phrasecut
|
58 |
+
mask: text
|
59 |
+
|
60 |
+
-
|
61 |
+
name: pc-vis
|
62 |
+
metric: metrics.FixedIntervalMetrics
|
63 |
+
test_dataset: phrasecut
|
64 |
+
mask: crop_blur_highlight352
|
65 |
+
with_visual: True
|
66 |
+
visual_only: True
|
67 |
+
|
68 |
+
|
69 |
+
columns: [name,
|
70 |
+
pc_fgiou_best, pc_miou_best, pc_fgiou_0.5,
|
71 |
+
pc-vis_fgiou_best, pc-vis_miou_best, pc-vis_fgiou_0.5,
|
72 |
+
duration]
|
73 |
+
|
74 |
+
|
75 |
+
individual_configurations:
|
76 |
+
|
77 |
+
- {name: rd64-uni}
|
78 |
+
- {name: rd64-no-pretrain, not_pretrained: True, lr: 0.0003}
|
79 |
+
- {name: rd64-no-negatives, negative_prob: 0.0}
|
80 |
+
- {name: rd64-neg0.5, negative_prob: 0.5}
|
81 |
+
- {name: rd64-no-visual, with_visual: False, mix: False}
|
82 |
+
- {name: rd16-uni, reduce_dim: 16}
|
83 |
+
- {name: rd64-layer3, extract_layers: [3], depth: 1}
|
84 |
+
- {name: rd64-blur-highlight, mask: text_and_blur_highlight, test_configuration: {mask: blur_highlight}}
|
clipseg/experiments/coco.yaml
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
configuration:
|
2 |
+
batch_size: 64
|
3 |
+
optimizer: torch.optim.AdamW
|
4 |
+
|
5 |
+
lr: 0.001
|
6 |
+
|
7 |
+
trainer: experiment_setup.train_loop
|
8 |
+
scorer: experiment_setup.score
|
9 |
+
model: models.clipseg.CLIPDensePredT
|
10 |
+
|
11 |
+
lr_scheduler: cosine
|
12 |
+
T_max: 20000
|
13 |
+
eta_min: 0.0001
|
14 |
+
|
15 |
+
max_iterations: 20000
|
16 |
+
val_interval: null
|
17 |
+
|
18 |
+
# dataset
|
19 |
+
dataset: datasets.coco_wrapper.COCOWrapper
|
20 |
+
# split_mode: pascal_test
|
21 |
+
split: train
|
22 |
+
mask: text_and_blur3_highlight01
|
23 |
+
image_size: 352
|
24 |
+
normalize: True
|
25 |
+
pre_crop_image_size: [sample, 1, 1.5]
|
26 |
+
aug: 1new
|
27 |
+
|
28 |
+
# general
|
29 |
+
mix: True
|
30 |
+
prompt: shuffle+
|
31 |
+
norm_cond: True
|
32 |
+
mix_text_min: 0.0
|
33 |
+
|
34 |
+
# model
|
35 |
+
out: 1
|
36 |
+
extract_layers: [3, 7, 9]
|
37 |
+
reduce_dim: 64
|
38 |
+
depth: 3
|
39 |
+
fix_shift: False
|
40 |
+
|
41 |
+
loss: torch.nn.functional.binary_cross_entropy_with_logits
|
42 |
+
amp: True
|
43 |
+
|
44 |
+
test_configuration_common:
|
45 |
+
normalize: True
|
46 |
+
image_size: 352
|
47 |
+
# max_iterations: 10
|
48 |
+
batch_size: 8
|
49 |
+
sigmoid: True
|
50 |
+
test_dataset: coco
|
51 |
+
metric: metrics.FixedIntervalMetrics
|
52 |
+
|
53 |
+
test_configuration:
|
54 |
+
|
55 |
+
-
|
56 |
+
name: coco_t
|
57 |
+
mask: text
|
58 |
+
|
59 |
+
-
|
60 |
+
name: coco_h
|
61 |
+
mask: blur3_highlight01
|
62 |
+
|
63 |
+
-
|
64 |
+
name: coco_h2
|
65 |
+
mask: crop_blur_highlight352
|
66 |
+
|
67 |
+
|
68 |
+
columns: [i, name,
|
69 |
+
coco_t_fgiou_best, coco_t_miou_best, coco_t_fgiou_0.5,
|
70 |
+
coco_h_fgiou_best, coco_h_miou_best, coco_h_fgiou_0.5,
|
71 |
+
coco_h2_fgiou_best, coco_h2_miou_best, coco_h2_fgiou_0.5, coco_h2_fgiou_best_t,
|
72 |
+
train_loss, duration, date
|
73 |
+
]
|
74 |
+
|
75 |
+
individual_configurations:
|
76 |
+
|
77 |
+
|
78 |
+
- {name: rd64-7K-vit16-cbh-coco-0, version: 'ViT-B/16', fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
|
79 |
+
- {name: rd64-7K-vit16-cbh-coco-1, version: 'ViT-B/16', fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
|
80 |
+
- {name: rd64-7K-vit16-cbh-coco-2, version: 'ViT-B/16', fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
|
81 |
+
- {name: rd64-7K-vit16-cbh-coco-3, version: 'ViT-B/16', fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
|
82 |
+
|
83 |
+
|
84 |
+
- {name: rd64-7K-vit16-cbh-neg0.2-coco-0, version: 'ViT-B/16', negative_prob: 0.2, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
|
85 |
+
- {name: rd64-7K-vit16-cbh-neg0.2-coco-1, version: 'ViT-B/16', negative_prob: 0.2, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
|
86 |
+
- {name: rd64-7K-vit16-cbh-neg0.2-coco-2, version: 'ViT-B/16', negative_prob: 0.2, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
|
87 |
+
- {name: rd64-7K-vit16-cbh-neg0.2-coco-3, version: 'ViT-B/16', negative_prob: 0.2, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
|
88 |
+
|
89 |
+
|
90 |
+
# ViT
|
91 |
+
- {name: vit64-7K-vit16-cbh-coco-0, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001}
|
92 |
+
- {name: vit64-7K-vit16-cbh-coco-1, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001}
|
93 |
+
- {name: vit64-7K-vit16-cbh-coco-2, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001}
|
94 |
+
- {name: vit64-7K-vit16-cbh-coco-3, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001}
|
95 |
+
|
96 |
+
|
97 |
+
# BASELINE
|
98 |
+
- {name: bl64-7K-vit16-cbh-neg0.2-coco-0, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
|
99 |
+
- {name: bl64-7K-vit16-cbh-neg0.2-coco-1, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
|
100 |
+
- {name: bl64-7K-vit16-cbh-neg0.2-coco-2, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
|
101 |
+
- {name: bl64-7K-vit16-cbh-neg0.2-coco-3, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
|
clipseg/experiments/pascal_1shot.yaml
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
configuration:
|
2 |
+
batch_size: 64
|
3 |
+
optimizer: torch.optim.AdamW
|
4 |
+
|
5 |
+
lr: 0.001
|
6 |
+
|
7 |
+
trainer: experiment_setup.train_loop
|
8 |
+
scorer: experiment_setup.score
|
9 |
+
model: models.clipseg.CLIPDensePredT
|
10 |
+
|
11 |
+
lr_scheduler: cosine
|
12 |
+
T_max: 20000
|
13 |
+
eta_min: 0.0001
|
14 |
+
|
15 |
+
max_iterations: 20000 # <-##########################################
|
16 |
+
val_interval: null
|
17 |
+
|
18 |
+
# dataset
|
19 |
+
dataset: datasets.phrasecut.PhraseCut
|
20 |
+
split_mode: pascal_test
|
21 |
+
mode: train
|
22 |
+
mask: text_and_crop_blur_highlight352
|
23 |
+
image_size: 352
|
24 |
+
normalize: True
|
25 |
+
pre_crop_image_size: [sample, 1, 1.5]
|
26 |
+
aug: 1new
|
27 |
+
with_visual: True
|
28 |
+
split: train
|
29 |
+
|
30 |
+
# general
|
31 |
+
mix: True
|
32 |
+
prompt: shuffle+
|
33 |
+
norm_cond: True
|
34 |
+
mix_text_min: 0.0
|
35 |
+
|
36 |
+
# model
|
37 |
+
out: 1
|
38 |
+
version: 'ViT-B/16'
|
39 |
+
extract_layers: [3, 7, 9]
|
40 |
+
reduce_dim: 64
|
41 |
+
depth: 3
|
42 |
+
|
43 |
+
loss: torch.nn.functional.binary_cross_entropy_with_logits
|
44 |
+
amp: True
|
45 |
+
|
46 |
+
test_configuration_common:
|
47 |
+
normalize: True
|
48 |
+
image_size: 352
|
49 |
+
metric: metrics.FixedIntervalMetrics
|
50 |
+
batch_size: 1
|
51 |
+
test_dataset: pascal
|
52 |
+
sigmoid: True
|
53 |
+
# max_iterations: 250
|
54 |
+
|
55 |
+
test_configuration:
|
56 |
+
|
57 |
+
-
|
58 |
+
name: pas_t
|
59 |
+
mask: text
|
60 |
+
|
61 |
+
-
|
62 |
+
name: pas_h
|
63 |
+
mask: blur3_highlight01
|
64 |
+
|
65 |
+
-
|
66 |
+
name: pas_h2
|
67 |
+
mask: crop_blur_highlight352
|
68 |
+
|
69 |
+
|
70 |
+
columns: [name,
|
71 |
+
pas_t_fgiou_best, pas_t_miou_best, pas_t_fgiou_ct,
|
72 |
+
pas_h_fgiou_best, pas_h_miou_best, pas_h_fgiou_ct,
|
73 |
+
pas_h2_fgiou_best, pas_h2_miou_best, pas_h2_fgiou_ct, pas_h2_fgiou_best_t,
|
74 |
+
train_loss, duration, date
|
75 |
+
]
|
76 |
+
|
77 |
+
individual_configurations:
|
78 |
+
|
79 |
+
- {name: rd64-uni-phrasepas5i-0, remove_classes: [pas5i, 0], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [0], custom_threshold: 0.24}}
|
80 |
+
- {name: rd64-uni-phrasepas5i-1, remove_classes: [pas5i, 1], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [1], custom_threshold: 0.24}}
|
81 |
+
- {name: rd64-uni-phrasepas5i-2, remove_classes: [pas5i, 2], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [2], custom_threshold: 0.24}}
|
82 |
+
- {name: rd64-uni-phrasepas5i-3, remove_classes: [pas5i, 3], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [3], custom_threshold: 0.24}}
|
83 |
+
|
84 |
+
|
85 |
+
- {name: rd64-phrasepas5i-0, remove_classes: [pas5i, 0], negative_prob: 0.0, test_configuration: {splits: [0], custom_threshold: 0.28}}
|
86 |
+
- {name: rd64-phrasepas5i-1, remove_classes: [pas5i, 1], negative_prob: 0.0, test_configuration: {splits: [1], custom_threshold: 0.28}}
|
87 |
+
- {name: rd64-phrasepas5i-2, remove_classes: [pas5i, 2], negative_prob: 0.0, test_configuration: {splits: [2], custom_threshold: 0.28}}
|
88 |
+
- {name: rd64-phrasepas5i-3, remove_classes: [pas5i, 3], negative_prob: 0.0, test_configuration: {splits: [3], custom_threshold: 0.28}}
|
89 |
+
|
90 |
+
|
91 |
+
# baseline
|
92 |
+
- {name: bl64-phrasepas5i-0, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 0], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [0], custom_threshold: 0.24}}
|
93 |
+
- {name: bl64-phrasepas5i-1, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 1], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [1], custom_threshold: 0.24}}
|
94 |
+
- {name: bl64-phrasepas5i-2, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 2], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [2], custom_threshold: 0.24}}
|
95 |
+
- {name: bl64-phrasepas5i-3, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 3], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [3], custom_threshold: 0.24}}
|
96 |
+
|
97 |
+
# ViT
|
98 |
+
- {name: vit64-uni-phrasepas5i-0, remove_classes: [pas5i, 0], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [0], custom_threshold: 0.02}}
|
99 |
+
- {name: vit64-uni-phrasepas5i-1, remove_classes: [pas5i, 1], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [1], custom_threshold: 0.02}}
|
100 |
+
- {name: vit64-uni-phrasepas5i-2, remove_classes: [pas5i, 2], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [2], custom_threshold: 0.02}}
|
101 |
+
- {name: vit64-uni-phrasepas5i-3, remove_classes: [pas5i, 3], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [3], custom_threshold: 0.02}}
|
clipseg/experiments/phrasecut.yaml
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
configuration:
|
2 |
+
batch_size: 64
|
3 |
+
optimizer: torch.optim.AdamW
|
4 |
+
|
5 |
+
lr: 0.001
|
6 |
+
|
7 |
+
trainer: experiment_setup.train_loop
|
8 |
+
scorer: experiment_setup.score
|
9 |
+
model: models.clipseg.CLIPDensePredT
|
10 |
+
|
11 |
+
lr_scheduler: cosine
|
12 |
+
T_max: 20000
|
13 |
+
eta_min: 0.0001
|
14 |
+
|
15 |
+
max_iterations: 20000
|
16 |
+
val_interval: null
|
17 |
+
|
18 |
+
# dataset
|
19 |
+
dataset: datasets.phrasecut.PhraseCut # <-----------------
|
20 |
+
split_mode: pascal_test
|
21 |
+
split: train
|
22 |
+
mask: text_and_crop_blur_highlight352
|
23 |
+
image_size: 352
|
24 |
+
normalize: True
|
25 |
+
pre_crop_image_size: [sample, 1, 1.5]
|
26 |
+
aug: 1new
|
27 |
+
|
28 |
+
# general
|
29 |
+
mix: False # <-----------------
|
30 |
+
prompt: shuffle+
|
31 |
+
norm_cond: True
|
32 |
+
mix_text_min: 0.0
|
33 |
+
|
34 |
+
# model
|
35 |
+
out: 1
|
36 |
+
extract_layers: [3, 7, 9]
|
37 |
+
reduce_dim: 64
|
38 |
+
depth: 3
|
39 |
+
fix_shift: False
|
40 |
+
|
41 |
+
loss: torch.nn.functional.binary_cross_entropy_with_logits
|
42 |
+
amp: True
|
43 |
+
|
44 |
+
test_configuration_common:
|
45 |
+
normalize: True
|
46 |
+
image_size: 352
|
47 |
+
batch_size: 32
|
48 |
+
# max_iterations: 5
|
49 |
+
# max_iterations: 150
|
50 |
+
|
51 |
+
test_configuration:
|
52 |
+
|
53 |
+
-
|
54 |
+
name: pc # old: phrasecut
|
55 |
+
metric: metrics.FixedIntervalMetrics
|
56 |
+
test_dataset: phrasecut
|
57 |
+
split: test
|
58 |
+
mask: text
|
59 |
+
label_support: True
|
60 |
+
sigmoid: True
|
61 |
+
|
62 |
+
|
63 |
+
columns: [i, name, pc_miou_0.3, pc_fgiou_0.3, pc_fgiou_0.5, pc_ap, duration, date]
|
64 |
+
|
65 |
+
|
66 |
+
individual_configurations:
|
67 |
+
|
68 |
+
# important ones
|
69 |
+
|
70 |
+
|
71 |
+
- {name: rd64-uni, version: 'ViT-B/16', reduce_dim: 64, with_visual: True, negative_prob: 0.2, mix: True, mix_text_max: 0.5}
|
72 |
+
|
73 |
+
# this was accedentally trained using old mask
|
74 |
+
- {name: rd128-vit16-phrasecut, version: 'ViT-B/16', reduce_dim: 128, mask: text_and_blur3_highlight01}
|
75 |
+
- {name: rd64-uni-novis, version: 'ViT-B/16', reduce_dim: 64, with_visual: False, negative_prob: 0.2, mix: False}
|
76 |
+
# this was accedentally trained using old mask
|
77 |
+
- {name: baseline3-vit16-phrasecut, model: models.clipseg.CLIPDenseBaseline, version: 'ViT-B/16', reduce_dim: 64, reduce2_dim: 64, mask: text_and_blur3_highlight01}
|
78 |
+
|
79 |
+
- {name: vit64-uni, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, reduce_dim: 64, with_visual: True, only_visual: True, negative_prob: 0.2, mask: crop_blur_highlight352, lr: 0.0003}
|
80 |
+
- {name: vit64-uni-novis, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, with_visual: False, reduce_dim: 64, lr: 0.0001}
|
clipseg/general_utils.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import inspect
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import yaml
|
7 |
+
from shutil import copy, copytree
|
8 |
+
from os.path import join, dirname, realpath, expanduser, isfile, isdir, basename
|
9 |
+
|
10 |
+
|
11 |
+
class Logger(object):
|
12 |
+
|
13 |
+
def __getattr__(self, k):
|
14 |
+
return print
|
15 |
+
|
16 |
+
log = Logger()
|
17 |
+
|
18 |
+
def training_config_from_cli_args():
|
19 |
+
experiment_name = sys.argv[1]
|
20 |
+
experiment_id = int(sys.argv[2])
|
21 |
+
|
22 |
+
yaml_config = yaml.load(open(f'experiments/{experiment_name}'), Loader=yaml.SafeLoader)
|
23 |
+
|
24 |
+
config = yaml_config['configuration']
|
25 |
+
config = {**config, **yaml_config['individual_configurations'][experiment_id]}
|
26 |
+
config = AttributeDict(config)
|
27 |
+
return config
|
28 |
+
|
29 |
+
|
30 |
+
def score_config_from_cli_args():
|
31 |
+
experiment_name = sys.argv[1]
|
32 |
+
experiment_id = int(sys.argv[2])
|
33 |
+
|
34 |
+
|
35 |
+
yaml_config = yaml.load(open(f'experiments/{experiment_name}'), Loader=yaml.SafeLoader)
|
36 |
+
|
37 |
+
config = yaml_config['test_configuration_common']
|
38 |
+
|
39 |
+
if type(yaml_config['test_configuration']) == list:
|
40 |
+
test_id = int(sys.argv[3])
|
41 |
+
config = {**config, **yaml_config['test_configuration'][test_id]}
|
42 |
+
else:
|
43 |
+
config = {**config, **yaml_config['test_configuration']}
|
44 |
+
|
45 |
+
if 'test_configuration' in yaml_config['individual_configurations'][experiment_id]:
|
46 |
+
config = {**config, **yaml_config['individual_configurations'][experiment_id]['test_configuration']}
|
47 |
+
|
48 |
+
train_checkpoint_id = yaml_config['individual_configurations'][experiment_id]['name']
|
49 |
+
|
50 |
+
config = AttributeDict(config)
|
51 |
+
return config, train_checkpoint_id
|
52 |
+
|
53 |
+
|
54 |
+
def get_from_repository(local_name, repo_files, integrity_check=None, repo_dir='~/dataset_repository',
|
55 |
+
local_dir='~/datasets'):
|
56 |
+
""" copies files from repository to local folder.
|
57 |
+
|
58 |
+
repo_files: list of filenames or list of tuples [filename, target path]
|
59 |
+
|
60 |
+
e.g. get_from_repository('MyDataset', [['data/dataset1.tar', 'other/path/ds03.tar'])
|
61 |
+
will create a folder 'MyDataset' in local_dir, and extract the content of
|
62 |
+
'<repo_dir>/data/dataset1.tar' to <local_dir>/MyDataset/other/path.
|
63 |
+
"""
|
64 |
+
|
65 |
+
local_dir = realpath(join(expanduser(local_dir), local_name))
|
66 |
+
|
67 |
+
dataset_exists = True
|
68 |
+
|
69 |
+
# check if folder is available
|
70 |
+
if not isdir(local_dir):
|
71 |
+
dataset_exists = False
|
72 |
+
|
73 |
+
if integrity_check is not None:
|
74 |
+
try:
|
75 |
+
integrity_ok = integrity_check(local_dir)
|
76 |
+
except BaseException:
|
77 |
+
integrity_ok = False
|
78 |
+
|
79 |
+
if integrity_ok:
|
80 |
+
log.hint('Passed custom integrity check')
|
81 |
+
else:
|
82 |
+
log.hint('Custom integrity check failed')
|
83 |
+
|
84 |
+
dataset_exists = dataset_exists and integrity_ok
|
85 |
+
|
86 |
+
if not dataset_exists:
|
87 |
+
|
88 |
+
repo_dir = realpath(expanduser(repo_dir))
|
89 |
+
|
90 |
+
for i, filename in enumerate(repo_files):
|
91 |
+
|
92 |
+
if type(filename) == str:
|
93 |
+
origin, target = filename, filename
|
94 |
+
archive_target = join(local_dir, basename(origin))
|
95 |
+
extract_target = join(local_dir)
|
96 |
+
else:
|
97 |
+
origin, target = filename
|
98 |
+
archive_target = join(local_dir, dirname(target), basename(origin))
|
99 |
+
extract_target = join(local_dir, dirname(target))
|
100 |
+
|
101 |
+
archive_origin = join(repo_dir, origin)
|
102 |
+
|
103 |
+
log.hint(f'copy: {archive_origin} to {archive_target}')
|
104 |
+
|
105 |
+
# make sure the path exists
|
106 |
+
os.makedirs(dirname(archive_target), exist_ok=True)
|
107 |
+
|
108 |
+
if os.path.isfile(archive_target):
|
109 |
+
# only copy if size differs
|
110 |
+
if os.path.getsize(archive_target) != os.path.getsize(archive_origin):
|
111 |
+
log.hint(f'file exists but filesize differs: target {os.path.getsize(archive_target)} vs. origin {os.path.getsize(archive_origin)}')
|
112 |
+
copy(archive_origin, archive_target)
|
113 |
+
else:
|
114 |
+
copy(archive_origin, archive_target)
|
115 |
+
|
116 |
+
extract_archive(archive_target, extract_target, noarchive_ok=True)
|
117 |
+
|
118 |
+
# concurrent processes might have deleted the file
|
119 |
+
if os.path.isfile(archive_target):
|
120 |
+
os.remove(archive_target)
|
121 |
+
|
122 |
+
|
123 |
+
def extract_archive(filename, target_folder=None, noarchive_ok=False):
|
124 |
+
from subprocess import run, PIPE
|
125 |
+
|
126 |
+
if filename.endswith('.tgz') or filename.endswith('.tar'):
|
127 |
+
command = f'tar -xf {filename}'
|
128 |
+
command += f' -C {target_folder}' if target_folder is not None else ''
|
129 |
+
elif filename.endswith('.tar.gz'):
|
130 |
+
command = f'tar -xzf {filename}'
|
131 |
+
command += f' -C {target_folder}' if target_folder is not None else ''
|
132 |
+
elif filename.endswith('zip'):
|
133 |
+
command = f'unzip {filename}'
|
134 |
+
command += f' -d {target_folder}' if target_folder is not None else ''
|
135 |
+
else:
|
136 |
+
if noarchive_ok:
|
137 |
+
return
|
138 |
+
else:
|
139 |
+
raise ValueError(f'unsuppored file ending of {filename}')
|
140 |
+
|
141 |
+
log.hint(command)
|
142 |
+
result = run(command.split(), stdout=PIPE, stderr=PIPE)
|
143 |
+
if result.returncode != 0:
|
144 |
+
print(result.stdout, result.stderr)
|
145 |
+
|
146 |
+
|
147 |
+
class AttributeDict(dict):
|
148 |
+
"""
|
149 |
+
An extended dictionary that allows access to elements as atttributes and counts
|
150 |
+
these accesses. This way, we know if some attributes were never used.
|
151 |
+
"""
|
152 |
+
|
153 |
+
def __init__(self, *args, **kwargs):
|
154 |
+
from collections import Counter
|
155 |
+
super().__init__(*args, **kwargs)
|
156 |
+
self.__dict__['counter'] = Counter()
|
157 |
+
|
158 |
+
def __getitem__(self, k):
|
159 |
+
self.__dict__['counter'][k] += 1
|
160 |
+
return super().__getitem__(k)
|
161 |
+
|
162 |
+
def __getattr__(self, k):
|
163 |
+
self.__dict__['counter'][k] += 1
|
164 |
+
return super().get(k)
|
165 |
+
|
166 |
+
def __setattr__(self, k, v):
|
167 |
+
return super().__setitem__(k, v)
|
168 |
+
|
169 |
+
def __delattr__(self, k, v):
|
170 |
+
return super().__delitem__(k, v)
|
171 |
+
|
172 |
+
def unused_keys(self, exceptions=()):
|
173 |
+
return [k for k in super().keys() if self.__dict__['counter'][k] == 0 and k not in exceptions]
|
174 |
+
|
175 |
+
def assume_no_unused_keys(self, exceptions=()):
|
176 |
+
if len(self.unused_keys(exceptions=exceptions)) > 0:
|
177 |
+
log.warning('Unused keys:', self.unused_keys(exceptions=exceptions))
|
178 |
+
|
179 |
+
|
180 |
+
def get_attribute(name):
|
181 |
+
import importlib
|
182 |
+
|
183 |
+
if name is None:
|
184 |
+
raise ValueError('The provided attribute is None')
|
185 |
+
|
186 |
+
name_split = name.split('.')
|
187 |
+
mod = importlib.import_module('.'.join(name_split[:-1]))
|
188 |
+
return getattr(mod, name_split[-1])
|
189 |
+
|
190 |
+
|
191 |
+
|
192 |
+
def filter_args(input_args, default_args):
|
193 |
+
|
194 |
+
updated_args = {k: input_args[k] if k in input_args else v for k, v in default_args.items()}
|
195 |
+
used_args = {k: v for k, v in input_args.items() if k in default_args}
|
196 |
+
unused_args = {k: v for k, v in input_args.items() if k not in default_args}
|
197 |
+
|
198 |
+
return AttributeDict(updated_args), AttributeDict(used_args), AttributeDict(unused_args)
|
199 |
+
|
200 |
+
|
201 |
+
def load_model(checkpoint_id, weights_file=None, strict=True, model_args='from_config', with_config=False):
|
202 |
+
|
203 |
+
config = json.load(open(join('logs', checkpoint_id, 'config.json')))
|
204 |
+
|
205 |
+
if model_args != 'from_config' and type(model_args) != dict:
|
206 |
+
raise ValueError('model_args must either be "from_config" or a dictionary of values')
|
207 |
+
|
208 |
+
model_cls = get_attribute(config['model'])
|
209 |
+
|
210 |
+
# load model
|
211 |
+
if model_args == 'from_config':
|
212 |
+
_, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters)
|
213 |
+
|
214 |
+
model = model_cls(**model_args)
|
215 |
+
|
216 |
+
if weights_file is None:
|
217 |
+
weights_file = realpath(join('logs', checkpoint_id, 'weights.pth'))
|
218 |
+
else:
|
219 |
+
weights_file = realpath(join('logs', checkpoint_id, weights_file))
|
220 |
+
|
221 |
+
if isfile(weights_file):
|
222 |
+
weights = torch.load(weights_file)
|
223 |
+
for _, w in weights.items():
|
224 |
+
assert not torch.any(torch.isnan(w)), 'weights contain NaNs'
|
225 |
+
model.load_state_dict(weights, strict=strict)
|
226 |
+
else:
|
227 |
+
raise FileNotFoundError(f'model checkpoint {weights_file} was not found')
|
228 |
+
|
229 |
+
if with_config:
|
230 |
+
return model, config
|
231 |
+
|
232 |
+
return model
|
233 |
+
|
234 |
+
|
235 |
+
class TrainingLogger(object):
|
236 |
+
|
237 |
+
def __init__(self, model, log_dir, config=None, *args):
|
238 |
+
super().__init__()
|
239 |
+
self.model = model
|
240 |
+
self.base_path = join(f'logs/{log_dir}') if log_dir is not None else None
|
241 |
+
|
242 |
+
os.makedirs('logs/', exist_ok=True)
|
243 |
+
os.makedirs(self.base_path, exist_ok=True)
|
244 |
+
|
245 |
+
if config is not None:
|
246 |
+
json.dump(config, open(join(self.base_path, 'config.json'), 'w'))
|
247 |
+
|
248 |
+
def iter(self, i, **kwargs):
|
249 |
+
if i % 100 == 0 and 'loss' in kwargs:
|
250 |
+
loss = kwargs['loss']
|
251 |
+
print(f'iteration {i}: loss {loss:.4f}')
|
252 |
+
|
253 |
+
def save_weights(self, only_trainable=False, weight_file='weights.pth'):
|
254 |
+
if self.model is None:
|
255 |
+
raise AttributeError('You need to provide a model reference when initializing TrainingTracker to save weights.')
|
256 |
+
|
257 |
+
weights_path = join(self.base_path, weight_file)
|
258 |
+
|
259 |
+
weight_dict = self.model.state_dict()
|
260 |
+
|
261 |
+
if only_trainable:
|
262 |
+
weight_dict = {n: weight_dict[n] for n, p in self.model.named_parameters() if p.requires_grad}
|
263 |
+
|
264 |
+
torch.save(weight_dict, weights_path)
|
265 |
+
log.info(f'Saved weights to {weights_path}')
|
266 |
+
|
267 |
+
def __enter__(self):
|
268 |
+
return self
|
269 |
+
|
270 |
+
def __exit__(self, type, value, traceback):
|
271 |
+
""" automatically stop processes if used in a context manager """
|
272 |
+
pass
|
clipseg/metrics.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.functional import Tensor
|
2 |
+
from general_utils import log
|
3 |
+
from collections import defaultdict
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.nn import functional as nnf
|
8 |
+
|
9 |
+
|
10 |
+
class BaseMetric(object):
|
11 |
+
|
12 |
+
def __init__(self, metric_names, pred_range=None, gt_index=0, pred_index=0, eval_intermediate=True,
|
13 |
+
eval_validation=True):
|
14 |
+
self._names = tuple(metric_names)
|
15 |
+
self._eval_intermediate = eval_intermediate
|
16 |
+
self._eval_validation = eval_validation
|
17 |
+
|
18 |
+
self._pred_range = pred_range
|
19 |
+
self._pred_index = pred_index
|
20 |
+
self._gt_index = gt_index
|
21 |
+
|
22 |
+
self.predictions = []
|
23 |
+
self.ground_truths = []
|
24 |
+
|
25 |
+
def eval_intermediate(self):
|
26 |
+
return self._eval_intermediate
|
27 |
+
|
28 |
+
def eval_validation(self):
|
29 |
+
return self._eval_validation
|
30 |
+
|
31 |
+
def names(self):
|
32 |
+
return self._names
|
33 |
+
|
34 |
+
def add(self, predictions, ground_truth):
|
35 |
+
raise NotImplementedError
|
36 |
+
|
37 |
+
def value(self):
|
38 |
+
raise NotImplementedError
|
39 |
+
|
40 |
+
def scores(self):
|
41 |
+
# similar to value but returns dict
|
42 |
+
value = self.value()
|
43 |
+
if type(value) == dict:
|
44 |
+
return value
|
45 |
+
else:
|
46 |
+
assert type(value) in {list, tuple}
|
47 |
+
return list(zip(self.names(), self.value()))
|
48 |
+
|
49 |
+
def _get_pred_gt(self, predictions, ground_truth):
|
50 |
+
pred = predictions[self._pred_index]
|
51 |
+
gt = ground_truth[self._gt_index]
|
52 |
+
|
53 |
+
if self._pred_range is not None:
|
54 |
+
pred = pred[:, self._pred_range[0]: self._pred_range[1]]
|
55 |
+
|
56 |
+
return pred, gt
|
57 |
+
|
58 |
+
|
59 |
+
class FixedIntervalMetrics(BaseMetric):
|
60 |
+
|
61 |
+
def __init__(self, sigmoid=False, ignore_mask=False, resize_to=None,
|
62 |
+
resize_pred=None, n_values=51, custom_threshold=None):
|
63 |
+
|
64 |
+
|
65 |
+
super().__init__(('ap', 'best_fgiou', 'best_miou', 'fgiou0.5', 'fgiou0.1', 'mean_iou_0p5', 'mean_iou_0p1', 'best_biniou', 'biniou_0.5', 'fgiou_thresh'))
|
66 |
+
self.intersections = []
|
67 |
+
self.unions = []
|
68 |
+
# self.threshold = threshold
|
69 |
+
self.sigmoid = sigmoid
|
70 |
+
self.resize_to = resize_to
|
71 |
+
self.resize_pred = resize_pred # resize prediction to match ground truth
|
72 |
+
self.class_count = defaultdict(lambda: 0)
|
73 |
+
self.per_class = defaultdict(lambda : [0,0])
|
74 |
+
self.ignore_mask = ignore_mask
|
75 |
+
self.custom_threshold = custom_threshold
|
76 |
+
|
77 |
+
self.scores_ap = []
|
78 |
+
self.scores_iou = []
|
79 |
+
self.gts, self.preds = [], []
|
80 |
+
self.classes = []
|
81 |
+
|
82 |
+
# [1:-1] ignores 0 and 1
|
83 |
+
self.threshold_values = np.linspace(0, 1, n_values)[1:-1]
|
84 |
+
|
85 |
+
self.metrics = dict(tp=[], fp=[], fn=[], tn=[])
|
86 |
+
|
87 |
+
def add(self, pred, gt):
|
88 |
+
|
89 |
+
pred_batch = pred[0].cpu()
|
90 |
+
|
91 |
+
if self.sigmoid:
|
92 |
+
pred_batch = torch.sigmoid(pred_batch)
|
93 |
+
|
94 |
+
gt_batch = gt[0].cpu()
|
95 |
+
mask_batch = gt[1] if len(gt) > 1 and not self.ignore_mask and gt[1].numel() > 0 else ([None] * len(pred_batch))
|
96 |
+
cls_batch = gt[2] if len(gt) > 2 else [None] * len(pred_batch)
|
97 |
+
|
98 |
+
if self.resize_to is not None:
|
99 |
+
gt_batch = nnf.interpolate(gt_batch, self.resize_to, mode='nearest')
|
100 |
+
pred_batch = nnf.interpolate(pred_batch, self.resize_to, mode='bilinear', align_corners=False)
|
101 |
+
|
102 |
+
if isinstance(cls_batch, torch.Tensor):
|
103 |
+
cls_batch = cls_batch.cpu().numpy().tolist()
|
104 |
+
|
105 |
+
assert len(gt_batch) == len(pred_batch) == len(cls_batch), f'{len(gt_batch)} {len(pred_batch)} {len(cls_batch)}'
|
106 |
+
|
107 |
+
for predictions, ground_truth, mask, cls in zip(pred_batch, gt_batch, mask_batch, cls_batch):
|
108 |
+
|
109 |
+
if self.resize_pred:
|
110 |
+
predictions = nnf.interpolate(predictions.unsqueeze(0).float(), size=ground_truth.size()[-2:], mode='bilinear', align_corners=True)
|
111 |
+
|
112 |
+
p = predictions.flatten()
|
113 |
+
g = ground_truth.flatten()
|
114 |
+
|
115 |
+
assert len(p) == len(g)
|
116 |
+
|
117 |
+
if mask is not None:
|
118 |
+
m = mask.flatten().bool()
|
119 |
+
p = p[m]
|
120 |
+
g = g[m]
|
121 |
+
|
122 |
+
p_sorted = p.sort()
|
123 |
+
p = p_sorted.values
|
124 |
+
g = g[p_sorted.indices]
|
125 |
+
|
126 |
+
tps, fps, fns, tns = [], [], [], []
|
127 |
+
for thresh in self.threshold_values:
|
128 |
+
|
129 |
+
valid = torch.where(p > thresh)[0]
|
130 |
+
if len(valid) > 0:
|
131 |
+
n = int(valid[0])
|
132 |
+
else:
|
133 |
+
n = len(g)
|
134 |
+
|
135 |
+
fn = int(g[:n].sum())
|
136 |
+
tp = int(g[n:].sum())
|
137 |
+
fns += [fn]
|
138 |
+
tns += [n - fn]
|
139 |
+
tps += [tp]
|
140 |
+
fps += [len(g) - n - tp]
|
141 |
+
|
142 |
+
self.metrics['tp'] += [tps]
|
143 |
+
self.metrics['fp'] += [fps]
|
144 |
+
self.metrics['fn'] += [fns]
|
145 |
+
self.metrics['tn'] += [tns]
|
146 |
+
|
147 |
+
self.classes += [cls.item() if isinstance(cls, torch.Tensor) else cls]
|
148 |
+
|
149 |
+
def value(self):
|
150 |
+
|
151 |
+
import time
|
152 |
+
t_start = time.time()
|
153 |
+
|
154 |
+
if set(self.classes) == set([None]):
|
155 |
+
all_classes = None
|
156 |
+
log.warning('classes were not provided, cannot compute mIoU')
|
157 |
+
else:
|
158 |
+
all_classes = set(int(c) for c in self.classes)
|
159 |
+
# log.info(f'compute metrics for {len(all_classes)} classes')
|
160 |
+
|
161 |
+
summed = {k: [sum([self.metrics[k][i][j]
|
162 |
+
for i in range(len(self.metrics[k]))])
|
163 |
+
for j in range(len(self.threshold_values))]
|
164 |
+
for k in self.metrics.keys()}
|
165 |
+
|
166 |
+
if all_classes is not None:
|
167 |
+
|
168 |
+
assert len(self.classes) == len(self.metrics['tp']) == len(self.metrics['fn'])
|
169 |
+
# group by class
|
170 |
+
metrics_by_class = {c: {k: [] for k in self.metrics.keys()} for c in all_classes}
|
171 |
+
for i in range(len(self.metrics['tp'])):
|
172 |
+
for k in self.metrics.keys():
|
173 |
+
metrics_by_class[self.classes[i]][k] += [self.metrics[k][i]]
|
174 |
+
|
175 |
+
# sum over all instances within the classes
|
176 |
+
summed_by_cls = {k: {c: np.array(metrics_by_class[c][k]).sum(0).tolist() for c in all_classes} for k in self.metrics.keys()}
|
177 |
+
|
178 |
+
|
179 |
+
# Compute average precision
|
180 |
+
|
181 |
+
assert (np.array(summed['fp']) + np.array(summed['tp']) ).sum(), 'no predictions is made'
|
182 |
+
|
183 |
+
# only consider values where a prediction is made
|
184 |
+
precisions = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j]) for j in range(len(self.threshold_values))
|
185 |
+
if summed['tp'][j] + summed['fp'][j] > 0]
|
186 |
+
recalls = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fn'][j]) for j in range(len(self.threshold_values))
|
187 |
+
if summed['tp'][j] + summed['fp'][j] > 0]
|
188 |
+
|
189 |
+
# remove duplicate recall-precision-pairs (and sort by recall value)
|
190 |
+
recalls, precisions = zip(*sorted(list(set(zip(recalls, precisions))), key=lambda x: x[0]))
|
191 |
+
|
192 |
+
from scipy.integrate import simps
|
193 |
+
ap = simps(precisions, recalls)
|
194 |
+
|
195 |
+
# Compute best IoU
|
196 |
+
fgiou_scores = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j] + summed['fn'][j]) for j in range(len(self.threshold_values))]
|
197 |
+
|
198 |
+
biniou_scores = [
|
199 |
+
0.5*(summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j] + summed['fn'][j])) +
|
200 |
+
0.5*(summed['tn'][j] / (1 + summed['tn'][j] + summed['fn'][j] + summed['fp'][j]))
|
201 |
+
for j in range(len(self.threshold_values))
|
202 |
+
]
|
203 |
+
|
204 |
+
index_0p5 = self.threshold_values.tolist().index(0.5)
|
205 |
+
index_0p1 = self.threshold_values.tolist().index(0.1)
|
206 |
+
index_0p2 = self.threshold_values.tolist().index(0.2)
|
207 |
+
index_0p3 = self.threshold_values.tolist().index(0.3)
|
208 |
+
|
209 |
+
if self.custom_threshold is not None:
|
210 |
+
index_ct = self.threshold_values.tolist().index(self.custom_threshold)
|
211 |
+
|
212 |
+
if all_classes is not None:
|
213 |
+
# mean IoU
|
214 |
+
mean_ious = [np.mean([summed_by_cls['tp'][c][j] / (1 + summed_by_cls['tp'][c][j] + summed_by_cls['fp'][c][j] + summed_by_cls['fn'][c][j])
|
215 |
+
for c in all_classes])
|
216 |
+
for j in range(len(self.threshold_values))]
|
217 |
+
|
218 |
+
mean_iou_dict = {
|
219 |
+
'miou_best': max(mean_ious) if all_classes is not None else None,
|
220 |
+
'miou_0.5': mean_ious[index_0p5] if all_classes is not None else None,
|
221 |
+
'miou_0.1': mean_ious[index_0p1] if all_classes is not None else None,
|
222 |
+
'miou_0.2': mean_ious[index_0p2] if all_classes is not None else None,
|
223 |
+
'miou_0.3': mean_ious[index_0p3] if all_classes is not None else None,
|
224 |
+
'miou_best_t': self.threshold_values[np.argmax(mean_ious)],
|
225 |
+
'mean_iou_ct': mean_ious[index_ct] if all_classes is not None and self.custom_threshold is not None else None,
|
226 |
+
'mean_iou_scores': mean_ious,
|
227 |
+
}
|
228 |
+
|
229 |
+
print(f'metric computation on {(len(all_classes) if all_classes is not None else "no")} classes took {time.time() - t_start:.1f}s')
|
230 |
+
|
231 |
+
return {
|
232 |
+
'ap': ap,
|
233 |
+
|
234 |
+
# fgiou
|
235 |
+
'fgiou_best': max(fgiou_scores),
|
236 |
+
'fgiou_0.5': fgiou_scores[index_0p5],
|
237 |
+
'fgiou_0.1': fgiou_scores[index_0p1],
|
238 |
+
'fgiou_0.2': fgiou_scores[index_0p2],
|
239 |
+
'fgiou_0.3': fgiou_scores[index_0p3],
|
240 |
+
'fgiou_best_t': self.threshold_values[np.argmax(fgiou_scores)],
|
241 |
+
|
242 |
+
# mean iou
|
243 |
+
|
244 |
+
|
245 |
+
# biniou
|
246 |
+
'biniou_best': max(biniou_scores),
|
247 |
+
'biniou_0.5': biniou_scores[index_0p5],
|
248 |
+
'biniou_0.1': biniou_scores[index_0p1],
|
249 |
+
'biniou_0.2': biniou_scores[index_0p2],
|
250 |
+
'biniou_0.3': biniou_scores[index_0p3],
|
251 |
+
'biniou_best_t': self.threshold_values[np.argmax(biniou_scores)],
|
252 |
+
|
253 |
+
# custom threshold
|
254 |
+
'fgiou_ct': fgiou_scores[index_ct] if self.custom_threshold is not None else None,
|
255 |
+
'biniou_ct': biniou_scores[index_ct] if self.custom_threshold is not None else None,
|
256 |
+
'ct': self.custom_threshold,
|
257 |
+
|
258 |
+
# statistics
|
259 |
+
'fgiou_scores': fgiou_scores,
|
260 |
+
'biniou_scores': biniou_scores,
|
261 |
+
'precision_recall_curve': sorted(list(set(zip(recalls, precisions)))),
|
262 |
+
'summed_statistics': summed,
|
263 |
+
'summed_by_cls_statistics': summed_by_cls,
|
264 |
+
|
265 |
+
**mean_iou_dict
|
266 |
+
}
|
267 |
+
|
268 |
+
# ('ap', 'best_fgiou', 'best_miou', 'fgiou0.5', 'fgiou0.1', 'mean_iou_0p5', 'mean_iou_0p1', 'best_biniou', 'biniou_0.5', 'fgiou_thresh'
|
269 |
+
|
270 |
+
# return ap, best_fgiou, best_mean_iou, iou_0p5, iou_0p1, mean_iou_0p5, mean_iou_0p1, best_biniou, biniou0p5, best_fgiou_thresh, {'summed': summed, 'summed_by_cls': summed_by_cls}
|
271 |
+
|
clipseg/models/clipseg.py
ADDED
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from os.path import basename, dirname, join, isfile
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as nnf
|
6 |
+
from torch.nn.modules.activation import ReLU
|
7 |
+
|
8 |
+
|
9 |
+
def precompute_clip_vectors():
|
10 |
+
|
11 |
+
from trails.initialization import init_dataset
|
12 |
+
lvis = init_dataset('LVIS_OneShot3', split='train', mask='text_label', image_size=224, aug=1, normalize=True,
|
13 |
+
reduce_factor=None, add_bar=False, negative_prob=0.5)
|
14 |
+
|
15 |
+
all_names = list(lvis.category_names.values())
|
16 |
+
|
17 |
+
import clip
|
18 |
+
from models.clip_prompts import imagenet_templates
|
19 |
+
clip_model = clip.load("ViT-B/32", device='cuda', jit=False)[0]
|
20 |
+
prompt_vectors = {}
|
21 |
+
for name in all_names[:100]:
|
22 |
+
with torch.no_grad():
|
23 |
+
conditionals = [t.format(name).replace('_', ' ') for t in imagenet_templates]
|
24 |
+
text_tokens = clip.tokenize(conditionals).cuda()
|
25 |
+
cond = clip_model.encode_text(text_tokens).cpu()
|
26 |
+
|
27 |
+
for cond, vec in zip(conditionals, cond):
|
28 |
+
prompt_vectors[cond] = vec.cpu()
|
29 |
+
|
30 |
+
import pickle
|
31 |
+
|
32 |
+
pickle.dump(prompt_vectors, open('precomputed_prompt_vectors.pickle', 'wb'))
|
33 |
+
|
34 |
+
|
35 |
+
def get_prompt_list(prompt):
|
36 |
+
if prompt == 'plain':
|
37 |
+
return ['{}']
|
38 |
+
elif prompt == 'fixed':
|
39 |
+
return ['a photo of a {}.']
|
40 |
+
elif prompt == 'shuffle':
|
41 |
+
return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.']
|
42 |
+
elif prompt == 'shuffle+':
|
43 |
+
return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.',
|
44 |
+
'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.',
|
45 |
+
'a bad photo of a {}.', 'a photo of the {}.']
|
46 |
+
elif prompt == 'shuffle_clip':
|
47 |
+
from models.clip_prompts import imagenet_templates
|
48 |
+
return imagenet_templates
|
49 |
+
else:
|
50 |
+
raise ValueError('Invalid value for prompt')
|
51 |
+
|
52 |
+
|
53 |
+
def forward_multihead_attention(x, b, with_aff=False, attn_mask=None):
|
54 |
+
"""
|
55 |
+
Simplified version of multihead attention (taken from torch source code but without tons of if clauses).
|
56 |
+
The mlp and layer norm come from CLIP.
|
57 |
+
x: input.
|
58 |
+
b: multihead attention module.
|
59 |
+
"""
|
60 |
+
|
61 |
+
x_ = b.ln_1(x)
|
62 |
+
q, k, v = nnf.linear(x_, b.attn.in_proj_weight, b.attn.in_proj_bias).chunk(3, dim=-1)
|
63 |
+
tgt_len, bsz, embed_dim = q.size()
|
64 |
+
|
65 |
+
head_dim = embed_dim // b.attn.num_heads
|
66 |
+
scaling = float(head_dim) ** -0.5
|
67 |
+
|
68 |
+
q = q.contiguous().view(tgt_len, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1)
|
69 |
+
k = k.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1)
|
70 |
+
v = v.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1)
|
71 |
+
|
72 |
+
q = q * scaling
|
73 |
+
|
74 |
+
attn_output_weights = torch.bmm(q, k.transpose(1, 2)) # n_heads * batch_size, tokens^2, tokens^2
|
75 |
+
if attn_mask is not None:
|
76 |
+
|
77 |
+
|
78 |
+
attn_mask_type, attn_mask = attn_mask
|
79 |
+
n_heads = attn_output_weights.size(0) // attn_mask.size(0)
|
80 |
+
attn_mask = attn_mask.repeat(n_heads, 1)
|
81 |
+
|
82 |
+
if attn_mask_type == 'cls_token':
|
83 |
+
# the mask only affects similarities compared to the readout-token.
|
84 |
+
attn_output_weights[:, 0, 1:] = attn_output_weights[:, 0, 1:] * attn_mask[None,...]
|
85 |
+
# attn_output_weights[:, 0, 0] = 0*attn_output_weights[:, 0, 0]
|
86 |
+
|
87 |
+
if attn_mask_type == 'all':
|
88 |
+
# print(attn_output_weights.shape, attn_mask[:, None].shape)
|
89 |
+
attn_output_weights[:, 1:, 1:] = attn_output_weights[:, 1:, 1:] * attn_mask[:, None]
|
90 |
+
|
91 |
+
|
92 |
+
attn_output_weights = torch.softmax(attn_output_weights, dim=-1)
|
93 |
+
|
94 |
+
attn_output = torch.bmm(attn_output_weights, v)
|
95 |
+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
96 |
+
attn_output = b.attn.out_proj(attn_output)
|
97 |
+
|
98 |
+
x = x + attn_output
|
99 |
+
x = x + b.mlp(b.ln_2(x))
|
100 |
+
|
101 |
+
if with_aff:
|
102 |
+
return x, attn_output_weights
|
103 |
+
else:
|
104 |
+
return x
|
105 |
+
|
106 |
+
|
107 |
+
class CLIPDenseBase(nn.Module):
|
108 |
+
|
109 |
+
def __init__(self, version, reduce_cond, reduce_dim, prompt, n_tokens):
|
110 |
+
super().__init__()
|
111 |
+
|
112 |
+
import clip
|
113 |
+
|
114 |
+
# prec = torch.FloatTensor
|
115 |
+
self.clip_model, _ = clip.load(version, device='cpu', jit=False)
|
116 |
+
self.model = self.clip_model.visual
|
117 |
+
|
118 |
+
# if not None, scale conv weights such that we obtain n_tokens.
|
119 |
+
self.n_tokens = n_tokens
|
120 |
+
|
121 |
+
for p in self.clip_model.parameters():
|
122 |
+
p.requires_grad_(False)
|
123 |
+
|
124 |
+
# conditional
|
125 |
+
if reduce_cond is not None:
|
126 |
+
self.reduce_cond = nn.Linear(512, reduce_cond)
|
127 |
+
for p in self.reduce_cond.parameters():
|
128 |
+
p.requires_grad_(False)
|
129 |
+
else:
|
130 |
+
self.reduce_cond = None
|
131 |
+
|
132 |
+
self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
|
133 |
+
self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
|
134 |
+
|
135 |
+
self.reduce = nn.Linear(768, reduce_dim)
|
136 |
+
|
137 |
+
self.prompt_list = get_prompt_list(prompt)
|
138 |
+
|
139 |
+
# precomputed prompts
|
140 |
+
import pickle
|
141 |
+
if isfile('precomputed_prompt_vectors.pickle'):
|
142 |
+
precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb'))
|
143 |
+
self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()}
|
144 |
+
else:
|
145 |
+
self.precomputed_prompts = dict()
|
146 |
+
|
147 |
+
def rescaled_pos_emb(self, new_size):
|
148 |
+
assert len(new_size) == 2
|
149 |
+
|
150 |
+
a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape)
|
151 |
+
b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T
|
152 |
+
return torch.cat([self.model.positional_embedding[:1], b])
|
153 |
+
|
154 |
+
def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None):
|
155 |
+
|
156 |
+
|
157 |
+
with torch.no_grad():
|
158 |
+
|
159 |
+
inp_size = x_inp.shape[2:]
|
160 |
+
|
161 |
+
if self.n_tokens is not None:
|
162 |
+
stride2 = x_inp.shape[2] // self.n_tokens
|
163 |
+
conv_weight2 = nnf.interpolate(self.model.conv1.weight, (stride2, stride2), mode='bilinear', align_corners=True)
|
164 |
+
x = nnf.conv2d(x_inp, conv_weight2, bias=self.model.conv1.bias, stride=stride2, dilation=self.model.conv1.dilation)
|
165 |
+
else:
|
166 |
+
x = self.model.conv1(x_inp) # shape = [*, width, grid, grid]
|
167 |
+
|
168 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
169 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
170 |
+
|
171 |
+
x = torch.cat([self.model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
172 |
+
|
173 |
+
standard_n_tokens = 50 if self.model.conv1.kernel_size[0] == 32 else 197
|
174 |
+
|
175 |
+
if x.shape[1] != standard_n_tokens:
|
176 |
+
new_shape = int(math.sqrt(x.shape[1]-1))
|
177 |
+
x = x + self.rescaled_pos_emb((new_shape, new_shape)).to(x.dtype)[None,:,:]
|
178 |
+
else:
|
179 |
+
x = x + self.model.positional_embedding.to(x.dtype)
|
180 |
+
|
181 |
+
x = self.model.ln_pre(x)
|
182 |
+
|
183 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
184 |
+
|
185 |
+
activations, affinities = [], []
|
186 |
+
for i, res_block in enumerate(self.model.transformer.resblocks):
|
187 |
+
|
188 |
+
if mask is not None:
|
189 |
+
mask_layer, mask_type, mask_tensor = mask
|
190 |
+
if mask_layer == i or mask_layer == 'all':
|
191 |
+
# import ipdb; ipdb.set_trace()
|
192 |
+
size = int(math.sqrt(x.shape[0] - 1))
|
193 |
+
|
194 |
+
attn_mask = (mask_type, nnf.interpolate(mask_tensor.unsqueeze(1).float(), (size, size)).view(mask_tensor.shape[0], size * size))
|
195 |
+
|
196 |
+
else:
|
197 |
+
attn_mask = None
|
198 |
+
else:
|
199 |
+
attn_mask = None
|
200 |
+
|
201 |
+
x, aff_per_head = forward_multihead_attention(x, res_block, with_aff=True, attn_mask=attn_mask)
|
202 |
+
|
203 |
+
if i in extract_layers:
|
204 |
+
affinities += [aff_per_head]
|
205 |
+
|
206 |
+
#if self.n_tokens is not None:
|
207 |
+
# activations += [nnf.interpolate(x, inp_size, mode='bilinear', align_corners=True)]
|
208 |
+
#else:
|
209 |
+
activations += [x]
|
210 |
+
|
211 |
+
if len(extract_layers) > 0 and i == max(extract_layers) and skip:
|
212 |
+
print('early skip')
|
213 |
+
break
|
214 |
+
|
215 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
216 |
+
x = self.model.ln_post(x[:, 0, :])
|
217 |
+
|
218 |
+
if self.model.proj is not None:
|
219 |
+
x = x @ self.model.proj
|
220 |
+
|
221 |
+
return x, activations, affinities
|
222 |
+
|
223 |
+
def sample_prompts(self, words, prompt_list=None):
|
224 |
+
|
225 |
+
prompt_list = prompt_list if prompt_list is not None else self.prompt_list
|
226 |
+
|
227 |
+
prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
|
228 |
+
prompts = [prompt_list[i] for i in prompt_indices]
|
229 |
+
return [promt.format(w) for promt, w in zip(prompts, words)]
|
230 |
+
|
231 |
+
def get_cond_vec(self, conditional, batch_size):
|
232 |
+
# compute conditional from a single string
|
233 |
+
if conditional is not None and type(conditional) == str:
|
234 |
+
cond = self.compute_conditional(conditional)
|
235 |
+
cond = cond.repeat(batch_size, 1)
|
236 |
+
|
237 |
+
# compute conditional from string list/tuple
|
238 |
+
elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str:
|
239 |
+
assert len(conditional) == batch_size
|
240 |
+
cond = self.compute_conditional(conditional)
|
241 |
+
|
242 |
+
# use conditional directly
|
243 |
+
elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2:
|
244 |
+
cond = conditional
|
245 |
+
|
246 |
+
# compute conditional from image
|
247 |
+
elif conditional is not None and type(conditional) == torch.Tensor:
|
248 |
+
with torch.no_grad():
|
249 |
+
cond, _, _ = self.visual_forward(conditional)
|
250 |
+
else:
|
251 |
+
raise ValueError('invalid conditional')
|
252 |
+
return cond
|
253 |
+
|
254 |
+
def compute_conditional(self, conditional):
|
255 |
+
import clip
|
256 |
+
|
257 |
+
dev = next(self.parameters()).device
|
258 |
+
|
259 |
+
if type(conditional) in {list, tuple}:
|
260 |
+
text_tokens = clip.tokenize(conditional).to(dev)
|
261 |
+
cond = self.clip_model.encode_text(text_tokens)
|
262 |
+
else:
|
263 |
+
if conditional in self.precomputed_prompts:
|
264 |
+
cond = self.precomputed_prompts[conditional].float().to(dev)
|
265 |
+
else:
|
266 |
+
text_tokens = clip.tokenize([conditional]).to(dev)
|
267 |
+
cond = self.clip_model.encode_text(text_tokens)[0]
|
268 |
+
|
269 |
+
if self.shift_vector is not None:
|
270 |
+
return cond + self.shift_vector
|
271 |
+
else:
|
272 |
+
return cond
|
273 |
+
|
274 |
+
|
275 |
+
def clip_load_untrained(version):
|
276 |
+
assert version == 'ViT-B/16'
|
277 |
+
from clip.model import CLIP
|
278 |
+
from clip.clip import _MODELS, _download
|
279 |
+
model = torch.jit.load(_download(_MODELS['ViT-B/16'])).eval()
|
280 |
+
state_dict = model.state_dict()
|
281 |
+
|
282 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
283 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
284 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
285 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
286 |
+
image_resolution = vision_patch_size * grid_size
|
287 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
288 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
289 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
290 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
291 |
+
transformer_heads = transformer_width // 64
|
292 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
293 |
+
|
294 |
+
return CLIP(embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size,
|
295 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers)
|
296 |
+
|
297 |
+
|
298 |
+
class CLIPDensePredT(CLIPDenseBase):
|
299 |
+
|
300 |
+
def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed',
|
301 |
+
extra_blocks=0, reduce_cond=None, fix_shift=False,
|
302 |
+
learn_trans_conv_only=False, limit_to_clip_only=False, upsample=False,
|
303 |
+
add_calibration=False, rev_activations=False, trans_conv=None, n_tokens=None):
|
304 |
+
|
305 |
+
super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens)
|
306 |
+
# device = 'cpu'
|
307 |
+
|
308 |
+
self.extract_layers = extract_layers
|
309 |
+
self.cond_layer = cond_layer
|
310 |
+
self.limit_to_clip_only = limit_to_clip_only
|
311 |
+
self.process_cond = None
|
312 |
+
self.rev_activations = rev_activations
|
313 |
+
|
314 |
+
depth = len(extract_layers)
|
315 |
+
|
316 |
+
if add_calibration:
|
317 |
+
self.calibration_conds = 1
|
318 |
+
|
319 |
+
self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None
|
320 |
+
|
321 |
+
self.add_activation1 = True
|
322 |
+
|
323 |
+
self.version = version
|
324 |
+
|
325 |
+
self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version]
|
326 |
+
|
327 |
+
if fix_shift:
|
328 |
+
# self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'clip_text_shift_vector.pth')), requires_grad=False)
|
329 |
+
self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'shift_text_to_vis.pth')), requires_grad=False)
|
330 |
+
# self.shift_vector = nn.Parameter(-1*torch.load(join(dirname(basename(__file__)), 'shift2.pth')), requires_grad=False)
|
331 |
+
else:
|
332 |
+
self.shift_vector = None
|
333 |
+
|
334 |
+
if trans_conv is None:
|
335 |
+
trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version]
|
336 |
+
else:
|
337 |
+
# explicitly define transposed conv kernel size
|
338 |
+
trans_conv_ks = (trans_conv, trans_conv)
|
339 |
+
|
340 |
+
self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
|
341 |
+
|
342 |
+
assert len(self.extract_layers) == depth
|
343 |
+
|
344 |
+
self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)])
|
345 |
+
self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))])
|
346 |
+
self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)])
|
347 |
+
|
348 |
+
# refinement and trans conv
|
349 |
+
|
350 |
+
if learn_trans_conv_only:
|
351 |
+
for p in self.parameters():
|
352 |
+
p.requires_grad_(False)
|
353 |
+
|
354 |
+
for p in self.trans_conv.parameters():
|
355 |
+
p.requires_grad_(True)
|
356 |
+
|
357 |
+
self.prompt_list = get_prompt_list(prompt)
|
358 |
+
|
359 |
+
|
360 |
+
def forward(self, inp_image, conditional=None, return_features=False, mask=None):
|
361 |
+
|
362 |
+
assert type(return_features) == bool
|
363 |
+
|
364 |
+
inp_image = inp_image.to(self.model.positional_embedding.device)
|
365 |
+
|
366 |
+
if mask is not None:
|
367 |
+
raise ValueError('mask not supported')
|
368 |
+
|
369 |
+
# x_inp = normalize(inp_image)
|
370 |
+
x_inp = inp_image
|
371 |
+
|
372 |
+
bs, dev = inp_image.shape[0], x_inp.device
|
373 |
+
|
374 |
+
cond = self.get_cond_vec(conditional, bs)
|
375 |
+
|
376 |
+
visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers))
|
377 |
+
|
378 |
+
activation1 = activations[0]
|
379 |
+
activations = activations[1:]
|
380 |
+
|
381 |
+
_activations = activations[::-1] if not self.rev_activations else activations
|
382 |
+
|
383 |
+
a = None
|
384 |
+
for i, (activation, block, reduce) in enumerate(zip(_activations, self.blocks, self.reduces)):
|
385 |
+
|
386 |
+
if a is not None:
|
387 |
+
a = reduce(activation) + a
|
388 |
+
else:
|
389 |
+
a = reduce(activation)
|
390 |
+
|
391 |
+
if i == self.cond_layer:
|
392 |
+
if self.reduce_cond is not None:
|
393 |
+
cond = self.reduce_cond(cond)
|
394 |
+
|
395 |
+
a = self.film_mul(cond) * a + self.film_add(cond)
|
396 |
+
|
397 |
+
a = block(a)
|
398 |
+
|
399 |
+
for block in self.extra_blocks:
|
400 |
+
a = a + block(a)
|
401 |
+
|
402 |
+
a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
|
403 |
+
|
404 |
+
size = int(math.sqrt(a.shape[2]))
|
405 |
+
|
406 |
+
a = a.view(bs, a.shape[1], size, size)
|
407 |
+
|
408 |
+
a = self.trans_conv(a)
|
409 |
+
|
410 |
+
if self.n_tokens is not None:
|
411 |
+
a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear', align_corners=True)
|
412 |
+
|
413 |
+
if self.upsample_proj is not None:
|
414 |
+
a = self.upsample_proj(a)
|
415 |
+
a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear')
|
416 |
+
|
417 |
+
if return_features:
|
418 |
+
return a, visual_q, cond, [activation1] + activations
|
419 |
+
else:
|
420 |
+
return a,
|
421 |
+
|
422 |
+
|
423 |
+
|
424 |
+
class CLIPDensePredTMasked(CLIPDensePredT):
|
425 |
+
|
426 |
+
def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4,
|
427 |
+
prompt='fixed', extra_blocks=0, reduce_cond=None, fix_shift=False, learn_trans_conv_only=False,
|
428 |
+
refine=None, limit_to_clip_only=False, upsample=False, add_calibration=False, n_tokens=None):
|
429 |
+
|
430 |
+
super().__init__(version=version, extract_layers=extract_layers, cond_layer=cond_layer, reduce_dim=reduce_dim,
|
431 |
+
n_heads=n_heads, prompt=prompt, extra_blocks=extra_blocks, reduce_cond=reduce_cond,
|
432 |
+
fix_shift=fix_shift, learn_trans_conv_only=learn_trans_conv_only,
|
433 |
+
limit_to_clip_only=limit_to_clip_only, upsample=upsample, add_calibration=add_calibration,
|
434 |
+
n_tokens=n_tokens)
|
435 |
+
|
436 |
+
def visual_forward_masked(self, img_s, seg_s):
|
437 |
+
return super().visual_forward(img_s, mask=('all', 'cls_token', seg_s))
|
438 |
+
|
439 |
+
def forward(self, img_q, cond_or_img_s, seg_s=None, return_features=False):
|
440 |
+
|
441 |
+
if seg_s is None:
|
442 |
+
cond = cond_or_img_s
|
443 |
+
else:
|
444 |
+
img_s = cond_or_img_s
|
445 |
+
|
446 |
+
with torch.no_grad():
|
447 |
+
cond, _, _ = self.visual_forward_masked(img_s, seg_s)
|
448 |
+
|
449 |
+
return super().forward(img_q, cond, return_features=return_features)
|
450 |
+
|
451 |
+
|
452 |
+
|
453 |
+
class CLIPDenseBaseline(CLIPDenseBase):
|
454 |
+
|
455 |
+
def __init__(self, version='ViT-B/32', cond_layer=0,
|
456 |
+
extract_layer=9, reduce_dim=128, reduce2_dim=None, prompt='fixed',
|
457 |
+
reduce_cond=None, limit_to_clip_only=False, n_tokens=None):
|
458 |
+
|
459 |
+
super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens)
|
460 |
+
device = 'cpu'
|
461 |
+
|
462 |
+
# self.cond_layer = cond_layer
|
463 |
+
self.extract_layer = extract_layer
|
464 |
+
self.limit_to_clip_only = limit_to_clip_only
|
465 |
+
self.shift_vector = None
|
466 |
+
|
467 |
+
self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version]
|
468 |
+
|
469 |
+
assert reduce2_dim is not None
|
470 |
+
|
471 |
+
self.reduce2 = nn.Sequential(
|
472 |
+
nn.Linear(reduce_dim, reduce2_dim),
|
473 |
+
nn.ReLU(),
|
474 |
+
nn.Linear(reduce2_dim, reduce_dim)
|
475 |
+
)
|
476 |
+
|
477 |
+
trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version]
|
478 |
+
self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
|
479 |
+
|
480 |
+
|
481 |
+
def forward(self, inp_image, conditional=None, return_features=False):
|
482 |
+
|
483 |
+
inp_image = inp_image.to(self.model.positional_embedding.device)
|
484 |
+
|
485 |
+
# x_inp = normalize(inp_image)
|
486 |
+
x_inp = inp_image
|
487 |
+
|
488 |
+
bs, dev = inp_image.shape[0], x_inp.device
|
489 |
+
|
490 |
+
cond = self.get_cond_vec(conditional, bs)
|
491 |
+
|
492 |
+
visual_q, activations, affinities = self.visual_forward(x_inp, extract_layers=[self.extract_layer])
|
493 |
+
|
494 |
+
a = activations[0]
|
495 |
+
a = self.reduce(a)
|
496 |
+
a = self.film_mul(cond) * a + self.film_add(cond)
|
497 |
+
|
498 |
+
if self.reduce2 is not None:
|
499 |
+
a = self.reduce2(a)
|
500 |
+
|
501 |
+
# the original model would execute a transformer block here
|
502 |
+
|
503 |
+
a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
|
504 |
+
|
505 |
+
size = int(math.sqrt(a.shape[2]))
|
506 |
+
|
507 |
+
a = a.view(bs, a.shape[1], size, size)
|
508 |
+
a = self.trans_conv(a)
|
509 |
+
|
510 |
+
if return_features:
|
511 |
+
return a, visual_q, cond, activations
|
512 |
+
else:
|
513 |
+
return a,
|
514 |
+
|
515 |
+
|
516 |
+
class CLIPSegMultiLabel(nn.Module):
|
517 |
+
|
518 |
+
def __init__(self, model) -> None:
|
519 |
+
super().__init__()
|
520 |
+
|
521 |
+
from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC
|
522 |
+
|
523 |
+
self.pascal_classes = VOC
|
524 |
+
|
525 |
+
from models.clipseg import CLIPDensePredT
|
526 |
+
from general_utils import load_model
|
527 |
+
# self.clipseg = load_model('rd64-vit16-neg0.2-phrasecut', strict=False)
|
528 |
+
self.clipseg = load_model(model, strict=False)
|
529 |
+
|
530 |
+
self.clipseg.eval()
|
531 |
+
|
532 |
+
def forward(self, x):
|
533 |
+
|
534 |
+
bs = x.shape[0]
|
535 |
+
out = torch.ones(21, bs, 352, 352).to(x.device) * -10
|
536 |
+
|
537 |
+
for class_id, class_name in enumerate(self.pascal_classes):
|
538 |
+
|
539 |
+
fac = 3 if class_name == 'background' else 1
|
540 |
+
|
541 |
+
with torch.no_grad():
|
542 |
+
pred = torch.sigmoid(self.clipseg(x, class_name)[0][:,0]) * fac
|
543 |
+
|
544 |
+
out[class_id] += pred
|
545 |
+
|
546 |
+
|
547 |
+
out = out.permute(1, 0, 2, 3)
|
548 |
+
|
549 |
+
return out
|
550 |
+
|
551 |
+
# construct output tensor
|
552 |
+
|
clipseg/models/vitseg.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from posixpath import basename, dirname, join
|
3 |
+
# import clip
|
4 |
+
from clip.model import convert_weights
|
5 |
+
import torch
|
6 |
+
import json
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as nnf
|
9 |
+
from torch.nn.modules import activation
|
10 |
+
from torch.nn.modules.activation import ReLU
|
11 |
+
from torchvision import transforms
|
12 |
+
|
13 |
+
normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
|
14 |
+
|
15 |
+
from torchvision.models import ResNet
|
16 |
+
|
17 |
+
|
18 |
+
def process_prompts(conditional, prompt_list, conditional_map):
|
19 |
+
# DEPRECATED
|
20 |
+
|
21 |
+
# randomly sample a synonym
|
22 |
+
words = [conditional_map[int(i)] for i in conditional]
|
23 |
+
words = [syns[torch.multinomial(torch.ones(len(syns)), 1, replacement=True).item()] for syns in words]
|
24 |
+
words = [w.replace('_', ' ') for w in words]
|
25 |
+
|
26 |
+
if prompt_list is not None:
|
27 |
+
prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
|
28 |
+
prompts = [prompt_list[i] for i in prompt_indices]
|
29 |
+
else:
|
30 |
+
prompts = ['a photo of {}'] * (len(words))
|
31 |
+
|
32 |
+
return [promt.format(w) for promt, w in zip(prompts, words)]
|
33 |
+
|
34 |
+
|
35 |
+
class VITDenseBase(nn.Module):
|
36 |
+
|
37 |
+
def rescaled_pos_emb(self, new_size):
|
38 |
+
assert len(new_size) == 2
|
39 |
+
|
40 |
+
a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape)
|
41 |
+
b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T
|
42 |
+
return torch.cat([self.model.positional_embedding[:1], b])
|
43 |
+
|
44 |
+
def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None):
|
45 |
+
|
46 |
+
with torch.no_grad():
|
47 |
+
|
48 |
+
x_inp = nnf.interpolate(x_inp, (384, 384))
|
49 |
+
|
50 |
+
x = self.model.patch_embed(x_inp)
|
51 |
+
cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
52 |
+
if self.model.dist_token is None:
|
53 |
+
x = torch.cat((cls_token, x), dim=1)
|
54 |
+
else:
|
55 |
+
x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
|
56 |
+
x = self.model.pos_drop(x + self.model.pos_embed)
|
57 |
+
|
58 |
+
activations = []
|
59 |
+
for i, block in enumerate(self.model.blocks):
|
60 |
+
x = block(x)
|
61 |
+
|
62 |
+
if i in extract_layers:
|
63 |
+
# permute to be compatible with CLIP
|
64 |
+
activations += [x.permute(1,0,2)]
|
65 |
+
|
66 |
+
x = self.model.norm(x)
|
67 |
+
x = self.model.head(self.model.pre_logits(x[:, 0]))
|
68 |
+
|
69 |
+
# again for CLIP compatibility
|
70 |
+
# x = x.permute(1, 0, 2)
|
71 |
+
|
72 |
+
return x, activations, None
|
73 |
+
|
74 |
+
def sample_prompts(self, words, prompt_list=None):
|
75 |
+
|
76 |
+
prompt_list = prompt_list if prompt_list is not None else self.prompt_list
|
77 |
+
|
78 |
+
prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
|
79 |
+
prompts = [prompt_list[i] for i in prompt_indices]
|
80 |
+
return [promt.format(w) for promt, w in zip(prompts, words)]
|
81 |
+
|
82 |
+
def get_cond_vec(self, conditional, batch_size):
|
83 |
+
# compute conditional from a single string
|
84 |
+
if conditional is not None and type(conditional) == str:
|
85 |
+
cond = self.compute_conditional(conditional)
|
86 |
+
cond = cond.repeat(batch_size, 1)
|
87 |
+
|
88 |
+
# compute conditional from string list/tuple
|
89 |
+
elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str:
|
90 |
+
assert len(conditional) == batch_size
|
91 |
+
cond = self.compute_conditional(conditional)
|
92 |
+
|
93 |
+
# use conditional directly
|
94 |
+
elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2:
|
95 |
+
cond = conditional
|
96 |
+
|
97 |
+
# compute conditional from image
|
98 |
+
elif conditional is not None and type(conditional) == torch.Tensor:
|
99 |
+
with torch.no_grad():
|
100 |
+
cond, _, _ = self.visual_forward(conditional)
|
101 |
+
else:
|
102 |
+
raise ValueError('invalid conditional')
|
103 |
+
return cond
|
104 |
+
|
105 |
+
def compute_conditional(self, conditional):
|
106 |
+
import clip
|
107 |
+
|
108 |
+
dev = next(self.parameters()).device
|
109 |
+
|
110 |
+
if type(conditional) in {list, tuple}:
|
111 |
+
text_tokens = clip.tokenize(conditional).to(dev)
|
112 |
+
cond = self.clip_model.encode_text(text_tokens)
|
113 |
+
else:
|
114 |
+
if conditional in self.precomputed_prompts:
|
115 |
+
cond = self.precomputed_prompts[conditional].float().to(dev)
|
116 |
+
else:
|
117 |
+
text_tokens = clip.tokenize([conditional]).to(dev)
|
118 |
+
cond = self.clip_model.encode_text(text_tokens)[0]
|
119 |
+
|
120 |
+
return cond
|
121 |
+
|
122 |
+
|
123 |
+
class VITDensePredT(VITDenseBase):
|
124 |
+
|
125 |
+
def __init__(self, extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed',
|
126 |
+
depth=3, extra_blocks=0, reduce_cond=None, fix_shift=False,
|
127 |
+
learn_trans_conv_only=False, refine=None, limit_to_clip_only=False, upsample=False,
|
128 |
+
add_calibration=False, process_cond=None, not_pretrained=False):
|
129 |
+
super().__init__()
|
130 |
+
# device = 'cpu'
|
131 |
+
|
132 |
+
self.extract_layers = extract_layers
|
133 |
+
self.cond_layer = cond_layer
|
134 |
+
self.limit_to_clip_only = limit_to_clip_only
|
135 |
+
self.process_cond = None
|
136 |
+
|
137 |
+
if add_calibration:
|
138 |
+
self.calibration_conds = 1
|
139 |
+
|
140 |
+
self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None
|
141 |
+
|
142 |
+
self.add_activation1 = True
|
143 |
+
|
144 |
+
import timm
|
145 |
+
self.model = timm.create_model('vit_base_patch16_384', pretrained=True)
|
146 |
+
self.model.head = nn.Linear(768, 512 if reduce_cond is None else reduce_cond)
|
147 |
+
|
148 |
+
for p in self.model.parameters():
|
149 |
+
p.requires_grad_(False)
|
150 |
+
|
151 |
+
import clip
|
152 |
+
self.clip_model, _ = clip.load('ViT-B/16', device='cpu', jit=False)
|
153 |
+
# del self.clip_model.visual
|
154 |
+
|
155 |
+
|
156 |
+
self.token_shape = (14, 14)
|
157 |
+
|
158 |
+
# conditional
|
159 |
+
if reduce_cond is not None:
|
160 |
+
self.reduce_cond = nn.Linear(512, reduce_cond)
|
161 |
+
for p in self.reduce_cond.parameters():
|
162 |
+
p.requires_grad_(False)
|
163 |
+
else:
|
164 |
+
self.reduce_cond = None
|
165 |
+
|
166 |
+
# self.film = AVAILABLE_BLOCKS['film'](512, 128)
|
167 |
+
self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
|
168 |
+
self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
|
169 |
+
|
170 |
+
# DEPRECATED
|
171 |
+
# self.conditional_map = {c['id']: c['synonyms'] for c in json.load(open(cond_map))}
|
172 |
+
|
173 |
+
assert len(self.extract_layers) == depth
|
174 |
+
|
175 |
+
self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)])
|
176 |
+
self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))])
|
177 |
+
self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)])
|
178 |
+
|
179 |
+
trans_conv_ks = (16, 16)
|
180 |
+
self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
|
181 |
+
|
182 |
+
# refinement and trans conv
|
183 |
+
|
184 |
+
if learn_trans_conv_only:
|
185 |
+
for p in self.parameters():
|
186 |
+
p.requires_grad_(False)
|
187 |
+
|
188 |
+
for p in self.trans_conv.parameters():
|
189 |
+
p.requires_grad_(True)
|
190 |
+
|
191 |
+
if prompt == 'fixed':
|
192 |
+
self.prompt_list = ['a photo of a {}.']
|
193 |
+
elif prompt == 'shuffle':
|
194 |
+
self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.']
|
195 |
+
elif prompt == 'shuffle+':
|
196 |
+
self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.',
|
197 |
+
'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.',
|
198 |
+
'a bad photo of a {}.', 'a photo of the {}.']
|
199 |
+
elif prompt == 'shuffle_clip':
|
200 |
+
from models.clip_prompts import imagenet_templates
|
201 |
+
self.prompt_list = imagenet_templates
|
202 |
+
|
203 |
+
if process_cond is not None:
|
204 |
+
if process_cond == 'clamp' or process_cond[0] == 'clamp':
|
205 |
+
|
206 |
+
val = process_cond[1] if type(process_cond) in {list, tuple} else 0.2
|
207 |
+
|
208 |
+
def clamp_vec(x):
|
209 |
+
return torch.clamp(x, -val, val)
|
210 |
+
|
211 |
+
self.process_cond = clamp_vec
|
212 |
+
|
213 |
+
elif process_cond.endswith('.pth'):
|
214 |
+
|
215 |
+
shift = torch.load(process_cond)
|
216 |
+
def add_shift(x):
|
217 |
+
return x + shift.to(x.device)
|
218 |
+
|
219 |
+
self.process_cond = add_shift
|
220 |
+
|
221 |
+
import pickle
|
222 |
+
precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb'))
|
223 |
+
self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()}
|
224 |
+
|
225 |
+
|
226 |
+
def forward(self, inp_image, conditional=None, return_features=False, mask=None):
|
227 |
+
|
228 |
+
assert type(return_features) == bool
|
229 |
+
|
230 |
+
# inp_image = inp_image.to(self.model.positional_embedding.device)
|
231 |
+
|
232 |
+
if mask is not None:
|
233 |
+
raise ValueError('mask not supported')
|
234 |
+
|
235 |
+
# x_inp = normalize(inp_image)
|
236 |
+
x_inp = inp_image
|
237 |
+
|
238 |
+
bs, dev = inp_image.shape[0], x_inp.device
|
239 |
+
|
240 |
+
inp_image_size = inp_image.shape[2:]
|
241 |
+
|
242 |
+
cond = self.get_cond_vec(conditional, bs)
|
243 |
+
|
244 |
+
visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers))
|
245 |
+
|
246 |
+
activation1 = activations[0]
|
247 |
+
activations = activations[1:]
|
248 |
+
|
249 |
+
a = None
|
250 |
+
for i, (activation, block, reduce) in enumerate(zip(activations[::-1], self.blocks, self.reduces)):
|
251 |
+
|
252 |
+
if a is not None:
|
253 |
+
a = reduce(activation) + a
|
254 |
+
else:
|
255 |
+
a = reduce(activation)
|
256 |
+
|
257 |
+
if i == self.cond_layer:
|
258 |
+
if self.reduce_cond is not None:
|
259 |
+
cond = self.reduce_cond(cond)
|
260 |
+
|
261 |
+
a = self.film_mul(cond) * a + self.film_add(cond)
|
262 |
+
|
263 |
+
a = block(a)
|
264 |
+
|
265 |
+
for block in self.extra_blocks:
|
266 |
+
a = a + block(a)
|
267 |
+
|
268 |
+
a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
|
269 |
+
|
270 |
+
size = int(math.sqrt(a.shape[2]))
|
271 |
+
|
272 |
+
a = a.view(bs, a.shape[1], size, size)
|
273 |
+
|
274 |
+
if self.trans_conv is not None:
|
275 |
+
a = self.trans_conv(a)
|
276 |
+
|
277 |
+
if self.upsample_proj is not None:
|
278 |
+
a = self.upsample_proj(a)
|
279 |
+
a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear')
|
280 |
+
|
281 |
+
a = nnf.interpolate(a, inp_image_size)
|
282 |
+
|
283 |
+
if return_features:
|
284 |
+
return a, visual_q, cond, [activation1] + activations
|
285 |
+
else:
|
286 |
+
return a,
|
clipseg/overview.png
ADDED
clipseg/score.py
ADDED
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.functional import Tensor
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import inspect
|
5 |
+
import json
|
6 |
+
import yaml
|
7 |
+
import time
|
8 |
+
import sys
|
9 |
+
|
10 |
+
from general_utils import log
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
from os.path import expanduser, join, isfile, realpath
|
14 |
+
|
15 |
+
from torch.utils.data import DataLoader
|
16 |
+
|
17 |
+
from metrics import FixedIntervalMetrics
|
18 |
+
|
19 |
+
from general_utils import load_model, log, score_config_from_cli_args, AttributeDict, get_attribute, filter_args
|
20 |
+
|
21 |
+
|
22 |
+
DATASET_CACHE = dict()
|
23 |
+
|
24 |
+
def load_model(checkpoint_id, weights_file=None, strict=True, model_args='from_config', with_config=False, ignore_weights=False):
|
25 |
+
|
26 |
+
config = json.load(open(join('logs', checkpoint_id, 'config.json')))
|
27 |
+
|
28 |
+
if model_args != 'from_config' and type(model_args) != dict:
|
29 |
+
raise ValueError('model_args must either be "from_config" or a dictionary of values')
|
30 |
+
|
31 |
+
model_cls = get_attribute(config['model'])
|
32 |
+
|
33 |
+
# load model
|
34 |
+
if model_args == 'from_config':
|
35 |
+
_, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters)
|
36 |
+
|
37 |
+
model = model_cls(**model_args)
|
38 |
+
|
39 |
+
if weights_file is None:
|
40 |
+
weights_file = realpath(join('logs', checkpoint_id, 'weights.pth'))
|
41 |
+
else:
|
42 |
+
weights_file = realpath(join('logs', checkpoint_id, weights_file))
|
43 |
+
|
44 |
+
if isfile(weights_file) and not ignore_weights:
|
45 |
+
weights = torch.load(weights_file)
|
46 |
+
for _, w in weights.items():
|
47 |
+
assert not torch.any(torch.isnan(w)), 'weights contain NaNs'
|
48 |
+
model.load_state_dict(weights, strict=strict)
|
49 |
+
else:
|
50 |
+
if not ignore_weights:
|
51 |
+
raise FileNotFoundError(f'model checkpoint {weights_file} was not found')
|
52 |
+
|
53 |
+
if with_config:
|
54 |
+
return model, config
|
55 |
+
|
56 |
+
return model
|
57 |
+
|
58 |
+
|
59 |
+
def compute_shift2(model, datasets, seed=123, repetitions=1):
|
60 |
+
""" computes shift """
|
61 |
+
|
62 |
+
model.eval()
|
63 |
+
model.cuda()
|
64 |
+
|
65 |
+
import random
|
66 |
+
random.seed(seed)
|
67 |
+
|
68 |
+
preds, gts = [], []
|
69 |
+
for i_dataset, dataset in enumerate(datasets):
|
70 |
+
|
71 |
+
loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False, drop_last=False)
|
72 |
+
|
73 |
+
max_iterations = int(repetitions * len(dataset.dataset.data_list))
|
74 |
+
|
75 |
+
with torch.no_grad():
|
76 |
+
|
77 |
+
i, losses = 0, []
|
78 |
+
for i_all, (data_x, data_y) in enumerate(loader):
|
79 |
+
|
80 |
+
data_x = [v.cuda(non_blocking=True) if v is not None else v for v in data_x]
|
81 |
+
data_y = [v.cuda(non_blocking=True) if v is not None else v for v in data_y]
|
82 |
+
|
83 |
+
pred, = model(data_x[0], data_x[1], data_x[2])
|
84 |
+
preds += [pred.detach()]
|
85 |
+
gts += [data_y]
|
86 |
+
|
87 |
+
i += 1
|
88 |
+
if max_iterations and i >= max_iterations:
|
89 |
+
break
|
90 |
+
|
91 |
+
from metrics import FixedIntervalMetrics
|
92 |
+
n_values = 51
|
93 |
+
thresholds = np.linspace(0, 1, n_values)[1:-1]
|
94 |
+
metric = FixedIntervalMetrics(resize_pred=True, sigmoid=True, n_values=n_values)
|
95 |
+
|
96 |
+
for p, y in zip(preds, gts):
|
97 |
+
metric.add(p.unsqueeze(1), y)
|
98 |
+
|
99 |
+
best_idx = np.argmax(metric.value()['fgiou_scores'])
|
100 |
+
best_thresh = thresholds[best_idx]
|
101 |
+
|
102 |
+
return best_thresh
|
103 |
+
|
104 |
+
|
105 |
+
def get_cached_pascal_pfe(split, config):
|
106 |
+
from datasets.pfe_dataset import PFEPascalWrapper
|
107 |
+
try:
|
108 |
+
dataset = DATASET_CACHE[(split, config.image_size, config.label_support, config.mask)]
|
109 |
+
except KeyError:
|
110 |
+
dataset = PFEPascalWrapper(mode='val', split=split, mask=config.mask, image_size=config.image_size, label_support=config.label_support)
|
111 |
+
DATASET_CACHE[(split, config.image_size, config.label_support, config.mask)] = dataset
|
112 |
+
return dataset
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
def main():
|
118 |
+
config, train_checkpoint_id = score_config_from_cli_args()
|
119 |
+
|
120 |
+
metrics = score(config, train_checkpoint_id, None)
|
121 |
+
|
122 |
+
for dataset in metrics.keys():
|
123 |
+
for k in metrics[dataset]:
|
124 |
+
if type(metrics[dataset][k]) in {float, int}:
|
125 |
+
print(dataset, f'{k:<16} {metrics[dataset][k]:.3f}')
|
126 |
+
|
127 |
+
|
128 |
+
def score(config, train_checkpoint_id, train_config):
|
129 |
+
|
130 |
+
config = AttributeDict(config)
|
131 |
+
|
132 |
+
print(config)
|
133 |
+
|
134 |
+
# use training dataset and loss
|
135 |
+
train_config = AttributeDict(json.load(open(f'logs/{train_checkpoint_id}/config.json')))
|
136 |
+
|
137 |
+
cp_str = f'_{config.iteration_cp}' if config.iteration_cp is not None else ''
|
138 |
+
|
139 |
+
|
140 |
+
model_cls = get_attribute(train_config['model'])
|
141 |
+
|
142 |
+
_, model_args, _ = filter_args(train_config, inspect.signature(model_cls).parameters)
|
143 |
+
|
144 |
+
model_args = {**model_args, **{k: config[k] for k in ['process_cond', 'fix_shift'] if k in config}}
|
145 |
+
|
146 |
+
strict_models = {'ConditionBase4', 'PFENetWrapper'}
|
147 |
+
model = load_model(train_checkpoint_id, strict=model_cls.__name__ in strict_models, model_args=model_args,
|
148 |
+
weights_file=f'weights{cp_str}.pth', )
|
149 |
+
|
150 |
+
|
151 |
+
model.eval()
|
152 |
+
model.cuda()
|
153 |
+
|
154 |
+
metric_args = dict()
|
155 |
+
|
156 |
+
if 'threshold' in config:
|
157 |
+
if config.metric.split('.')[-1] == 'SkLearnMetrics':
|
158 |
+
metric_args['threshold'] = config.threshold
|
159 |
+
|
160 |
+
if 'resize_to' in config:
|
161 |
+
metric_args['resize_to'] = config.resize_to
|
162 |
+
|
163 |
+
if 'sigmoid' in config:
|
164 |
+
metric_args['sigmoid'] = config.sigmoid
|
165 |
+
|
166 |
+
if 'custom_threshold' in config:
|
167 |
+
metric_args['custom_threshold'] = config.custom_threshold
|
168 |
+
|
169 |
+
if config.test_dataset == 'pascal':
|
170 |
+
|
171 |
+
loss_fn = get_attribute(train_config.loss)
|
172 |
+
# assume that if no split is specified in train_config, test on all splits,
|
173 |
+
|
174 |
+
if 'splits' in config:
|
175 |
+
splits = config.splits
|
176 |
+
else:
|
177 |
+
if 'split' in train_config and type(train_config.split) == int:
|
178 |
+
# unless train_config has a split set, in that case assume train mode in training
|
179 |
+
splits = [train_config.split]
|
180 |
+
assert train_config.mode == 'train'
|
181 |
+
else:
|
182 |
+
splits = [0,1,2,3]
|
183 |
+
|
184 |
+
log.info('Test on these splits', splits)
|
185 |
+
|
186 |
+
scores = dict()
|
187 |
+
for split in splits:
|
188 |
+
|
189 |
+
shift = config.shift if 'shift' in config else 0
|
190 |
+
|
191 |
+
# automatic shift
|
192 |
+
if shift == 'auto':
|
193 |
+
shift_compute_t = time.time()
|
194 |
+
shift = compute_shift2(model, [get_cached_pascal_pfe(s, config) for s in range(4) if s != split], repetitions=config.compute_shift_fac)
|
195 |
+
log.info(f'Best threshold is {shift}, computed on splits: {[s for s in range(4) if s != split]}, took {time.time() - shift_compute_t:.1f}s')
|
196 |
+
|
197 |
+
dataset = get_cached_pascal_pfe(split, config)
|
198 |
+
|
199 |
+
eval_start_t = time.time()
|
200 |
+
|
201 |
+
loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False, drop_last=False)
|
202 |
+
|
203 |
+
assert config.batch_size is None or config.batch_size == 1, 'When PFE Dataset is used, batch size must be 1'
|
204 |
+
|
205 |
+
metric = FixedIntervalMetrics(resize_pred=True, sigmoid=True, custom_threshold=shift, **metric_args)
|
206 |
+
|
207 |
+
with torch.no_grad():
|
208 |
+
|
209 |
+
i, losses = 0, []
|
210 |
+
for i_all, (data_x, data_y) in enumerate(loader):
|
211 |
+
|
212 |
+
data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x]
|
213 |
+
data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y]
|
214 |
+
|
215 |
+
if config.mask == 'separate': # for old CondBase model
|
216 |
+
pred, = model(data_x[0], data_x[1], data_x[2])
|
217 |
+
else:
|
218 |
+
# assert config.mask in {'text', 'highlight'}
|
219 |
+
pred, _, _, _ = model(data_x[0], data_x[1], return_features=True)
|
220 |
+
|
221 |
+
# loss = loss_fn(pred, data_y[0])
|
222 |
+
metric.add(pred.unsqueeze(1) + shift, data_y)
|
223 |
+
|
224 |
+
# losses += [float(loss)]
|
225 |
+
|
226 |
+
i += 1
|
227 |
+
if config.max_iterations and i >= config.max_iterations:
|
228 |
+
break
|
229 |
+
|
230 |
+
#scores[split] = {m: s for m, s in zip(metric.names(), metric.value())}
|
231 |
+
|
232 |
+
log.info(f'Dataset length: {len(dataset)}, took {time.time() - eval_start_t:.1f}s to evaluate.')
|
233 |
+
|
234 |
+
print(metric.value()['mean_iou_scores'])
|
235 |
+
|
236 |
+
scores[split] = metric.scores()
|
237 |
+
|
238 |
+
log.info(f'Completed split {split}')
|
239 |
+
|
240 |
+
key_prefix = config['name'] if 'name' in config else 'pas'
|
241 |
+
|
242 |
+
all_keys = set.intersection(*[set(v.keys()) for v in scores.values()])
|
243 |
+
|
244 |
+
valid_keys = [k for k in all_keys if all(v[k] is not None and isinstance(v[k], (int, float, np.float)) for v in scores.values())]
|
245 |
+
|
246 |
+
return {key_prefix: {k: np.mean([s[k] for s in scores.values()]) for k in valid_keys}}
|
247 |
+
|
248 |
+
|
249 |
+
if config.test_dataset == 'coco':
|
250 |
+
from datasets.coco_wrapper import COCOWrapper
|
251 |
+
|
252 |
+
coco_dataset = COCOWrapper('test', fold=train_config.fold, image_size=train_config.image_size, mask=config.mask,
|
253 |
+
with_class_label=True)
|
254 |
+
|
255 |
+
log.info('Dataset length', len(coco_dataset))
|
256 |
+
loader = DataLoader(coco_dataset, batch_size=config.batch_size, num_workers=2, shuffle=False, drop_last=False)
|
257 |
+
|
258 |
+
metric = get_attribute(config.metric)(resize_pred=True, **metric_args)
|
259 |
+
|
260 |
+
shift = config.shift if 'shift' in config else 0
|
261 |
+
|
262 |
+
with torch.no_grad():
|
263 |
+
|
264 |
+
i, losses = 0, []
|
265 |
+
for i_all, (data_x, data_y) in enumerate(loader):
|
266 |
+
data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x]
|
267 |
+
data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y]
|
268 |
+
|
269 |
+
if config.mask == 'separate': # for old CondBase model
|
270 |
+
pred, = model(data_x[0], data_x[1], data_x[2])
|
271 |
+
else:
|
272 |
+
# assert config.mask in {'text', 'highlight'}
|
273 |
+
pred, _, _, _ = model(data_x[0], data_x[1], return_features=True)
|
274 |
+
|
275 |
+
metric.add([pred + shift], data_y)
|
276 |
+
|
277 |
+
i += 1
|
278 |
+
if config.max_iterations and i >= config.max_iterations:
|
279 |
+
break
|
280 |
+
|
281 |
+
key_prefix = config['name'] if 'name' in config else 'coco'
|
282 |
+
return {key_prefix: metric.scores()}
|
283 |
+
#return {key_prefix: {k: v for k, v in zip(metric.names(), metric.value())}}
|
284 |
+
|
285 |
+
|
286 |
+
if config.test_dataset == 'phrasecut':
|
287 |
+
from datasets.phrasecut import PhraseCut
|
288 |
+
|
289 |
+
only_visual = config.only_visual is not None and config.only_visual
|
290 |
+
with_visual = config.with_visual is not None and config.with_visual
|
291 |
+
|
292 |
+
dataset = PhraseCut('test',
|
293 |
+
image_size=train_config.image_size,
|
294 |
+
mask=config.mask,
|
295 |
+
with_visual=with_visual, only_visual=only_visual, aug_crop=False,
|
296 |
+
aug_color=False)
|
297 |
+
|
298 |
+
loader = DataLoader(dataset, batch_size=config.batch_size, num_workers=2, shuffle=False, drop_last=False)
|
299 |
+
metric = get_attribute(config.metric)(resize_pred=True, **metric_args)
|
300 |
+
|
301 |
+
shift = config.shift if 'shift' in config else 0
|
302 |
+
|
303 |
+
|
304 |
+
with torch.no_grad():
|
305 |
+
|
306 |
+
i, losses = 0, []
|
307 |
+
for i_all, (data_x, data_y) in enumerate(loader):
|
308 |
+
data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x]
|
309 |
+
data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y]
|
310 |
+
|
311 |
+
pred, _, _, _ = model(data_x[0], data_x[1], return_features=True)
|
312 |
+
metric.add([pred + shift], data_y)
|
313 |
+
|
314 |
+
i += 1
|
315 |
+
if config.max_iterations and i >= config.max_iterations:
|
316 |
+
break
|
317 |
+
|
318 |
+
key_prefix = config['name'] if 'name' in config else 'phrasecut'
|
319 |
+
return {key_prefix: metric.scores()}
|
320 |
+
#return {key_prefix: {k: v for k, v in zip(metric.names(), metric.value())}}
|
321 |
+
|
322 |
+
if config.test_dataset == 'pascal_zs':
|
323 |
+
from third_party.JoEm.model.metric import Evaluator
|
324 |
+
from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC
|
325 |
+
from datasets.pascal_zeroshot import PascalZeroShot, PASCAL_VOC_CLASSES_ZS
|
326 |
+
|
327 |
+
from models.clipseg import CLIPSegMultiLabel
|
328 |
+
|
329 |
+
n_unseen = train_config.remove_classes[1]
|
330 |
+
|
331 |
+
pz = PascalZeroShot('val', n_unseen, image_size=352)
|
332 |
+
m = CLIPSegMultiLabel(model=train_config.name).cuda()
|
333 |
+
m.eval();
|
334 |
+
|
335 |
+
print(len(pz), n_unseen)
|
336 |
+
print('training removed', [c for class_set in PASCAL_VOC_CLASSES_ZS[:n_unseen // 2] for c in class_set])
|
337 |
+
|
338 |
+
print('unseen', [VOC[i] for i in get_unseen_idx(n_unseen)])
|
339 |
+
print('seen', [VOC[i] for i in get_seen_idx(n_unseen)])
|
340 |
+
|
341 |
+
loader = DataLoader(pz, batch_size=8)
|
342 |
+
evaluator = Evaluator(21, get_unseen_idx(n_unseen), get_seen_idx(n_unseen))
|
343 |
+
|
344 |
+
for i, (data_x, data_y) in enumerate(loader):
|
345 |
+
pred = m(data_x[0].cuda())
|
346 |
+
evaluator.add_batch(data_y[0].numpy(), pred.argmax(1).cpu().detach().numpy())
|
347 |
+
|
348 |
+
if config.max_iter is not None and i > config.max_iter:
|
349 |
+
break
|
350 |
+
|
351 |
+
scores = evaluator.Mean_Intersection_over_Union()
|
352 |
+
key_prefix = config['name'] if 'name' in config else 'pas_zs'
|
353 |
+
|
354 |
+
return {key_prefix: {k: scores[k] for k in ['seen', 'unseen', 'harmonic', 'overall']}}
|
355 |
+
|
356 |
+
elif config.test_dataset in {'same_as_training', 'affordance'}:
|
357 |
+
loss_fn = get_attribute(train_config.loss)
|
358 |
+
|
359 |
+
metric_cls = get_attribute(config.metric)
|
360 |
+
metric = metric_cls(**metric_args)
|
361 |
+
|
362 |
+
if config.test_dataset == 'same_as_training':
|
363 |
+
dataset_cls = get_attribute(train_config.dataset)
|
364 |
+
elif config.test_dataset == 'affordance':
|
365 |
+
dataset_cls = get_attribute('datasets.lvis_oneshot3.LVIS_Affordance')
|
366 |
+
dataset_name = 'aff'
|
367 |
+
else:
|
368 |
+
dataset_cls = get_attribute('datasets.lvis_oneshot3.LVIS_OneShot')
|
369 |
+
dataset_name = 'lvis'
|
370 |
+
|
371 |
+
_, dataset_args, _ = filter_args(config, inspect.signature(dataset_cls).parameters)
|
372 |
+
|
373 |
+
dataset_args['image_size'] = train_config.image_size # explicitly use training image size for evaluation
|
374 |
+
|
375 |
+
if model.__class__.__name__ == 'PFENetWrapper':
|
376 |
+
dataset_args['image_size'] = config.image_size
|
377 |
+
|
378 |
+
log.info('init dataset', str(dataset_cls))
|
379 |
+
dataset = dataset_cls(**dataset_args)
|
380 |
+
|
381 |
+
log.info(f'Score on {model.__class__.__name__} on {dataset_cls.__name__}')
|
382 |
+
|
383 |
+
data_loader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size, shuffle=config.shuffle)
|
384 |
+
|
385 |
+
# explicitly set prompts
|
386 |
+
if config.prompt == 'plain':
|
387 |
+
model.prompt_list = ['{}']
|
388 |
+
elif config.prompt == 'fixed':
|
389 |
+
model.prompt_list = ['a photo of a {}.']
|
390 |
+
elif config.prompt == 'shuffle':
|
391 |
+
model.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.']
|
392 |
+
elif config.prompt == 'shuffle_clip':
|
393 |
+
from models.clip_prompts import imagenet_templates
|
394 |
+
model.prompt_list = imagenet_templates
|
395 |
+
|
396 |
+
config.assume_no_unused_keys(exceptions=['max_iterations'])
|
397 |
+
|
398 |
+
t_start = time.time()
|
399 |
+
|
400 |
+
with torch.no_grad(): # TODO: switch to inference_mode (torch 1.9)
|
401 |
+
i, losses = 0, []
|
402 |
+
for data_x, data_y in data_loader:
|
403 |
+
|
404 |
+
data_x = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_x]
|
405 |
+
data_y = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_y]
|
406 |
+
|
407 |
+
if model.__class__.__name__ in {'ConditionBase4', 'PFENetWrapper'}:
|
408 |
+
pred, = model(data_x[0], data_x[1], data_x[2])
|
409 |
+
visual_q = None
|
410 |
+
else:
|
411 |
+
pred, visual_q, _, _ = model(data_x[0], data_x[1], return_features=True)
|
412 |
+
|
413 |
+
loss = loss_fn(pred, data_y[0])
|
414 |
+
|
415 |
+
metric.add([pred], data_y)
|
416 |
+
|
417 |
+
losses += [float(loss)]
|
418 |
+
|
419 |
+
i += 1
|
420 |
+
if config.max_iterations and i >= config.max_iterations:
|
421 |
+
break
|
422 |
+
|
423 |
+
# scores = {m: s for m, s in zip(metric.names(), metric.value())}
|
424 |
+
scores = metric.scores()
|
425 |
+
|
426 |
+
keys = set(scores.keys())
|
427 |
+
if dataset.negative_prob > 0 and 'mIoU' in keys:
|
428 |
+
keys.remove('mIoU')
|
429 |
+
|
430 |
+
name_mask = dataset.mask.replace('text_label', 'txt')[:3]
|
431 |
+
name_neg = '' if dataset.negative_prob == 0 else '_' + str(dataset.negative_prob)
|
432 |
+
|
433 |
+
score_name = config.name if 'name' in config else f'{dataset_name}_{name_mask}{name_neg}'
|
434 |
+
|
435 |
+
scores = {score_name: {k: v for k,v in scores.items() if k in keys}}
|
436 |
+
scores[score_name].update({'test_loss': np.mean(losses)})
|
437 |
+
|
438 |
+
log.info(f'Evaluation took {time.time() - t_start:.1f}s')
|
439 |
+
|
440 |
+
return scores
|
441 |
+
else:
|
442 |
+
raise ValueError('invalid test dataset')
|
443 |
+
|
444 |
+
|
445 |
+
|
446 |
+
|
447 |
+
|
448 |
+
|
449 |
+
|
450 |
+
|
451 |
+
|
452 |
+
if __name__ == '__main__':
|
453 |
+
main()
|
clipseg/setup.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup
|
2 |
+
|
3 |
+
with open("README.md", "r", encoding="utf-8") as readme_file:
|
4 |
+
readme = readme_file.read()
|
5 |
+
|
6 |
+
requirements = [
|
7 |
+
"numpy",
|
8 |
+
"scipy",
|
9 |
+
"matplotlib",
|
10 |
+
"torch",
|
11 |
+
"torchvision",
|
12 |
+
"opencv-python",
|
13 |
+
"CLIP @ git+https://github.com/openai/CLIP.git"
|
14 |
+
]
|
15 |
+
|
16 |
+
setup(
|
17 |
+
name='clipseg',
|
18 |
+
packages=['clipseg'],
|
19 |
+
package_dir={'clipseg': 'models'},
|
20 |
+
package_data={'clipseg': [
|
21 |
+
"../weights/*.pth",
|
22 |
+
]},
|
23 |
+
version='0.0.1',
|
24 |
+
url='https://github.com/timojl/clipseg',
|
25 |
+
python_requires='>=3.9',
|
26 |
+
install_requires=requirements,
|
27 |
+
description='This repository contains the code used in the paper "Image Segmentation Using Text and Image Prompts".',
|
28 |
+
long_description=readme,
|
29 |
+
long_description_content_type="text/markdown",
|
30 |
+
)
|
clipseg/training.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import inspect
|
3 |
+
import json
|
4 |
+
import yaml
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
|
9 |
+
from general_utils import log
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
from functools import partial
|
13 |
+
from os.path import expanduser, join, isfile, basename
|
14 |
+
|
15 |
+
from torch.cuda.amp import autocast, GradScaler
|
16 |
+
from torch.optim.lr_scheduler import LambdaLR
|
17 |
+
from contextlib import nullcontext
|
18 |
+
from torch.utils.data import DataLoader
|
19 |
+
|
20 |
+
from general_utils import TrainingLogger, get_attribute, filter_args, log, training_config_from_cli_args
|
21 |
+
|
22 |
+
|
23 |
+
def cosine_warmup_lr(i, warmup=10, max_iter=90):
|
24 |
+
""" Cosine LR with Warmup """
|
25 |
+
if i < warmup:
|
26 |
+
return (i+1)/(warmup+1)
|
27 |
+
else:
|
28 |
+
return 0.5 + 0.5*math.cos(math.pi*(((i-warmup)/(max_iter- warmup))))
|
29 |
+
|
30 |
+
|
31 |
+
def validate(model, dataset, config):
|
32 |
+
data_loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False)
|
33 |
+
|
34 |
+
metric_class, use_metric = config.val_metric_class, config.use_val_metric
|
35 |
+
loss_fn = get_attribute(config.loss)
|
36 |
+
|
37 |
+
model.eval()
|
38 |
+
model.cuda()
|
39 |
+
|
40 |
+
if metric_class is not None:
|
41 |
+
metric = get_attribute(metric_class)()
|
42 |
+
|
43 |
+
with torch.no_grad():
|
44 |
+
|
45 |
+
i, losses = 0, []
|
46 |
+
for data_x, data_y in data_loader:
|
47 |
+
|
48 |
+
data_x = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_x]
|
49 |
+
data_y = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_y]
|
50 |
+
|
51 |
+
prompts = model.sample_prompts(data_x[1], prompt_list=('a photo of a {}',))
|
52 |
+
pred, visual_q, _, _ = model(data_x[0], prompts, return_features=True)
|
53 |
+
|
54 |
+
if metric_class is not None:
|
55 |
+
metric.add([pred], data_y)
|
56 |
+
|
57 |
+
# pred = model(data_x[0], prompts)
|
58 |
+
# loss = loss_fn(pred[0], data_y[0])
|
59 |
+
loss = loss_fn(pred, data_y[0])
|
60 |
+
losses += [float(loss)]
|
61 |
+
|
62 |
+
i += 1
|
63 |
+
|
64 |
+
if config.val_max_iterations is not None and i > config.val_max_iterations:
|
65 |
+
break
|
66 |
+
|
67 |
+
if use_metric is None:
|
68 |
+
return np.mean(losses), {}, False
|
69 |
+
else:
|
70 |
+
metric_scores = {m: s for m, s in zip(metric.names(), metric.value())} if metric is not None else {}
|
71 |
+
return np.mean(losses), metric_scores, True
|
72 |
+
|
73 |
+
|
74 |
+
def main():
|
75 |
+
|
76 |
+
config = training_config_from_cli_args()
|
77 |
+
|
78 |
+
val_interval, best_val_loss, best_val_score = config.val_interval, float('inf'), float('-inf')
|
79 |
+
|
80 |
+
model_cls = get_attribute(config.model)
|
81 |
+
_, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters)
|
82 |
+
model = model_cls(**model_args).cuda()
|
83 |
+
|
84 |
+
dataset_cls = get_attribute(config.dataset)
|
85 |
+
_, dataset_args, _ = filter_args(config, inspect.signature(dataset_cls).parameters)
|
86 |
+
|
87 |
+
dataset = dataset_cls(**dataset_args)
|
88 |
+
|
89 |
+
log.info(f'Train dataset {dataset.__class__.__name__} (length: {len(dataset)})')
|
90 |
+
|
91 |
+
if val_interval is not None:
|
92 |
+
dataset_val_args = {k[4:]: v for k,v in config.items() if k.startswith('val_') and k != 'val_interval'}
|
93 |
+
_, dataset_val_args, _ = filter_args(dataset_val_args, inspect.signature(dataset_cls).parameters)
|
94 |
+
print('val args', {**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args})
|
95 |
+
|
96 |
+
dataset_val = dataset_cls(**{**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args})
|
97 |
+
|
98 |
+
# optimizer
|
99 |
+
opt_cls = get_attribute(config.optimizer)
|
100 |
+
if config.optimize == 'torch.optim.SGD':
|
101 |
+
opt_args = {'momentum': config.momentum if 'momentum' in config else 0}
|
102 |
+
else:
|
103 |
+
opt_args = {}
|
104 |
+
opt = opt_cls(model.parameters(), lr=config.lr, **opt_args)
|
105 |
+
|
106 |
+
if config.lr_scheduler == 'cosine':
|
107 |
+
assert config.T_max is not None and config.eta_min is not None
|
108 |
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, config.T_max, config.eta_min)
|
109 |
+
elif config.lr_scheduler == 'warmup_cosine':
|
110 |
+
lr_scheduler = LambdaLR(opt, partial(cosine_warmup_lr, max_iter=(config.max_iterations), warmup=config.warmup))
|
111 |
+
else:
|
112 |
+
lr_scheduler = None
|
113 |
+
|
114 |
+
batch_size, max_iterations = config.batch_size, config.max_iterations
|
115 |
+
|
116 |
+
loss_fn = get_attribute(config.loss)
|
117 |
+
|
118 |
+
if config.amp:
|
119 |
+
log.info('Using AMP')
|
120 |
+
autocast_fn = autocast
|
121 |
+
scaler = GradScaler()
|
122 |
+
else:
|
123 |
+
autocast_fn, scaler = nullcontext, None
|
124 |
+
|
125 |
+
|
126 |
+
save_only_trainable = True
|
127 |
+
data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=4)
|
128 |
+
|
129 |
+
# disable config when hyperparam. opt. to avoid writing logs.
|
130 |
+
tracker_config = config if not config.hyperparameter_optimization else None
|
131 |
+
|
132 |
+
with TrainingLogger(log_dir=config.name, model=model, config=tracker_config) as logger:
|
133 |
+
|
134 |
+
i = 0
|
135 |
+
while True:
|
136 |
+
for data_x, data_y in data_loader:
|
137 |
+
|
138 |
+
# between caption and output feature.
|
139 |
+
# 1. Sample random captions
|
140 |
+
# 2. Check alignment with CLIP
|
141 |
+
|
142 |
+
# randomly mix text and visual support conditionals
|
143 |
+
if config.mix:
|
144 |
+
|
145 |
+
assert config.mask.startswith('text_and')
|
146 |
+
|
147 |
+
with autocast_fn():
|
148 |
+
# data_x[1] = text label
|
149 |
+
prompts = model.sample_prompts(data_x[1])
|
150 |
+
|
151 |
+
# model.clip_model()
|
152 |
+
|
153 |
+
text_cond = model.compute_conditional(prompts)
|
154 |
+
if model.__class__.__name__ == 'CLIPDensePredTMasked':
|
155 |
+
# when mask=='separate'
|
156 |
+
visual_s_cond, _, _ = model.visual_forward_masked(data_x[2].cuda(), data_x[3].cuda())
|
157 |
+
else:
|
158 |
+
# data_x[2] = visual prompt
|
159 |
+
visual_s_cond, _, _ = model.visual_forward(data_x[2].cuda())
|
160 |
+
|
161 |
+
max_txt = config.mix_text_max if config.mix_text_max is not None else 1
|
162 |
+
batch_size = text_cond.shape[0]
|
163 |
+
|
164 |
+
# sample weights for each element in batch
|
165 |
+
text_weights = torch.distributions.Uniform(config.mix_text_min, max_txt).sample((batch_size,))[:, None]
|
166 |
+
text_weights = text_weights.cuda()
|
167 |
+
|
168 |
+
if dataset.__class__.__name__ == 'PhraseCut':
|
169 |
+
# give full weight to text where support_image is invalid
|
170 |
+
visual_is_valid = data_x[4] if model.__class__.__name__ == 'CLIPDensePredTMasked' else data_x[3]
|
171 |
+
text_weights = torch.max(text_weights[:,0], 1 - visual_is_valid.float().cuda()).unsqueeze(1)
|
172 |
+
|
173 |
+
cond = text_cond * text_weights + visual_s_cond * (1 - text_weights)
|
174 |
+
|
175 |
+
else:
|
176 |
+
# no mix
|
177 |
+
|
178 |
+
if model.__class__.__name__ == 'CLIPDensePredTMasked':
|
179 |
+
# compute conditional vector using CLIP masking
|
180 |
+
with autocast_fn():
|
181 |
+
assert config.mask == 'separate'
|
182 |
+
cond, _, _ = model.visual_forward_masked(data_x[1].cuda(), data_x[2].cuda())
|
183 |
+
else:
|
184 |
+
cond = data_x[1]
|
185 |
+
if isinstance(cond, torch.Tensor):
|
186 |
+
cond = cond.cuda()
|
187 |
+
|
188 |
+
with autocast_fn():
|
189 |
+
visual_q = None
|
190 |
+
|
191 |
+
pred, visual_q, _, _ = model(data_x[0].cuda(), cond, return_features=True)
|
192 |
+
|
193 |
+
loss = loss_fn(pred, data_y[0].cuda())
|
194 |
+
|
195 |
+
if torch.isnan(loss) or torch.isinf(loss):
|
196 |
+
# skip if loss is nan
|
197 |
+
log.warning('Training stopped due to inf/nan loss.')
|
198 |
+
sys.exit(-1)
|
199 |
+
|
200 |
+
extra_loss = 0
|
201 |
+
loss += extra_loss
|
202 |
+
|
203 |
+
opt.zero_grad()
|
204 |
+
|
205 |
+
if scaler is None:
|
206 |
+
loss.backward()
|
207 |
+
opt.step()
|
208 |
+
else:
|
209 |
+
scaler.scale(loss).backward()
|
210 |
+
scaler.step(opt)
|
211 |
+
scaler.update()
|
212 |
+
|
213 |
+
if lr_scheduler is not None:
|
214 |
+
lr_scheduler.step()
|
215 |
+
if i % 2000 == 0:
|
216 |
+
current_lr = [g['lr'] for g in opt.param_groups][0]
|
217 |
+
log.info(f'current lr: {current_lr:.5f} ({len(opt.param_groups)} parameter groups)')
|
218 |
+
|
219 |
+
logger.iter(i=i, loss=loss)
|
220 |
+
i += 1
|
221 |
+
|
222 |
+
if i >= max_iterations:
|
223 |
+
|
224 |
+
if not isfile(join(logger.base_path, 'weights.pth')):
|
225 |
+
# only write if no weights were already written
|
226 |
+
logger.save_weights(only_trainable=save_only_trainable)
|
227 |
+
|
228 |
+
sys.exit(0)
|
229 |
+
|
230 |
+
|
231 |
+
if config.checkpoint_iterations is not None and i in config.checkpoint_iterations:
|
232 |
+
logger.save_weights(only_trainable=save_only_trainable, weight_file=f'weights_{i}.pth')
|
233 |
+
|
234 |
+
|
235 |
+
if val_interval is not None and i % val_interval == val_interval - 1:
|
236 |
+
|
237 |
+
val_loss, val_scores, maximize = validate(model, dataset_val, config)
|
238 |
+
|
239 |
+
if len(val_scores) > 0:
|
240 |
+
|
241 |
+
score_str = f', scores: ' + ', '.join(f'{k}: {v}' for k, v in val_scores.items())
|
242 |
+
|
243 |
+
if maximize and val_scores[config.use_val_metric] > best_val_score:
|
244 |
+
logger.save_weights(only_trainable=save_only_trainable)
|
245 |
+
best_val_score = val_scores[config.use_val_metric]
|
246 |
+
|
247 |
+
elif not maximize and val_scores[config.use_val_metric] < best_val_score:
|
248 |
+
logger.save_weights(only_trainable=save_only_trainable)
|
249 |
+
best_val_score = val_scores[config.use_val_metric]
|
250 |
+
|
251 |
+
else:
|
252 |
+
score_str = ''
|
253 |
+
# if no score is used, fall back to loss
|
254 |
+
if val_loss < best_val_loss:
|
255 |
+
logger.save_weights(only_trainable=save_only_trainable)
|
256 |
+
best_val_loss = val_loss
|
257 |
+
|
258 |
+
log.info(f'Validation loss: {val_loss}' + score_str)
|
259 |
+
logger.iter(i=i, val_loss=val_loss, extra_loss=float(extra_loss), **val_scores)
|
260 |
+
model.train()
|
261 |
+
|
262 |
+
print('epoch complete')
|
263 |
+
|
264 |
+
|
265 |
+
if __name__ == '__main__':
|
266 |
+
main()
|
clipseg/weights/rd64-uni.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:13845f6cee4d54ca46f62ee19dd354822094a26e0efccc64e606be93d6a7e26f
|
3 |
+
size 4306645
|
init_image.png
ADDED
inpainting.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from typing import List, Optional, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
import PIL
|
8 |
+
from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, PNDMScheduler, UNet2DConditionModel
|
9 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
10 |
+
from tqdm.auto import tqdm
|
11 |
+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
12 |
+
|
13 |
+
|
14 |
+
def preprocess_image(image):
|
15 |
+
w, h = image.size
|
16 |
+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
17 |
+
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
18 |
+
image = np.array(image).astype(np.float32) / 255.0
|
19 |
+
image = image[None].transpose(0, 3, 1, 2)
|
20 |
+
image = torch.from_numpy(image)
|
21 |
+
return 2.0 * image - 1.0
|
22 |
+
|
23 |
+
|
24 |
+
def preprocess_mask(mask):
|
25 |
+
mask = mask.convert("L")
|
26 |
+
w, h = mask.size
|
27 |
+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
28 |
+
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
|
29 |
+
mask = np.array(mask).astype(np.float32) / 255.0
|
30 |
+
mask = np.tile(mask, (4, 1, 1))
|
31 |
+
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
32 |
+
mask = 1 - mask # repaint white, keep black
|
33 |
+
mask = torch.from_numpy(mask)
|
34 |
+
return mask
|
35 |
+
|
36 |
+
class StableDiffusionInpaintingPipeline(DiffusionPipeline):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
vae: AutoencoderKL,
|
40 |
+
text_encoder: CLIPTextModel,
|
41 |
+
tokenizer: CLIPTokenizer,
|
42 |
+
unet: UNet2DConditionModel,
|
43 |
+
scheduler: Union[DDIMScheduler, PNDMScheduler],
|
44 |
+
safety_checker: StableDiffusionSafetyChecker,
|
45 |
+
feature_extractor: CLIPFeatureExtractor,
|
46 |
+
):
|
47 |
+
super().__init__()
|
48 |
+
scheduler = scheduler.set_format("pt")
|
49 |
+
self.register_modules(
|
50 |
+
vae=vae,
|
51 |
+
text_encoder=text_encoder,
|
52 |
+
tokenizer=tokenizer,
|
53 |
+
unet=unet,
|
54 |
+
scheduler=scheduler,
|
55 |
+
safety_checker=safety_checker,
|
56 |
+
feature_extractor=feature_extractor,
|
57 |
+
)
|
58 |
+
|
59 |
+
@torch.no_grad()
|
60 |
+
def __call__(
|
61 |
+
self,
|
62 |
+
prompt: Union[str, List[str]],
|
63 |
+
init_image: torch.FloatTensor,
|
64 |
+
mask_image: torch.FloatTensor,
|
65 |
+
strength: float = 0.8,
|
66 |
+
num_inference_steps: Optional[int] = 50,
|
67 |
+
guidance_scale: Optional[float] = 7.5,
|
68 |
+
eta: Optional[float] = 0.0,
|
69 |
+
generator: Optional[torch.Generator] = None,
|
70 |
+
output_type: Optional[str] = "pil",
|
71 |
+
):
|
72 |
+
|
73 |
+
if isinstance(prompt, str):
|
74 |
+
batch_size = 1
|
75 |
+
elif isinstance(prompt, list):
|
76 |
+
batch_size = len(prompt)
|
77 |
+
else:
|
78 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
79 |
+
|
80 |
+
if strength < 0 or strength > 1:
|
81 |
+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
82 |
+
|
83 |
+
# set timesteps
|
84 |
+
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
|
85 |
+
extra_set_kwargs = {}
|
86 |
+
offset = 0
|
87 |
+
if accepts_offset:
|
88 |
+
offset = 1
|
89 |
+
extra_set_kwargs["offset"] = 1
|
90 |
+
|
91 |
+
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
92 |
+
|
93 |
+
# preprocess image
|
94 |
+
init_image = preprocess_image(init_image).to(self.device)
|
95 |
+
|
96 |
+
# encode the init image into latents and scale the latents
|
97 |
+
init_latent_dist = self.vae.encode(init_image).latent_dist
|
98 |
+
init_latents = init_latent_dist.sample(generator=generator)
|
99 |
+
init_latents = 0.18215 * init_latents
|
100 |
+
|
101 |
+
# prepare init_latents noise to latents
|
102 |
+
init_latents = torch.cat([init_latents] * batch_size)
|
103 |
+
init_latents_orig = init_latents
|
104 |
+
|
105 |
+
# preprocess mask
|
106 |
+
mask = preprocess_mask(mask_image).to(self.device)
|
107 |
+
mask = torch.cat([mask] * batch_size)
|
108 |
+
|
109 |
+
# check sizes
|
110 |
+
if not mask.shape == init_latents.shape:
|
111 |
+
raise ValueError(f"The mask and init_image should be the same size!")
|
112 |
+
|
113 |
+
# get the original timestep using init_timestep
|
114 |
+
init_timestep = int(num_inference_steps * strength) + offset
|
115 |
+
init_timestep = min(init_timestep, num_inference_steps)
|
116 |
+
timesteps = self.scheduler.timesteps[-init_timestep]
|
117 |
+
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
|
118 |
+
|
119 |
+
# add noise to latents using the timesteps
|
120 |
+
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
|
121 |
+
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
122 |
+
|
123 |
+
# get prompt text embeddings
|
124 |
+
text_input = self.tokenizer(
|
125 |
+
prompt,
|
126 |
+
padding="max_length",
|
127 |
+
max_length=self.tokenizer.model_max_length,
|
128 |
+
truncation=True,
|
129 |
+
return_tensors="pt",
|
130 |
+
)
|
131 |
+
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
132 |
+
|
133 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
134 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
135 |
+
# corresponds to doing no classifier free guidance.
|
136 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
137 |
+
# get unconditional embeddings for classifier free guidance
|
138 |
+
if do_classifier_free_guidance:
|
139 |
+
max_length = text_input.input_ids.shape[-1]
|
140 |
+
uncond_input = self.tokenizer(
|
141 |
+
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
|
142 |
+
)
|
143 |
+
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
144 |
+
|
145 |
+
# For classifier free guidance, we need to do two forward passes.
|
146 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
147 |
+
# to avoid doing two forward passes
|
148 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
149 |
+
|
150 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
151 |
+
# eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
152 |
+
# eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502
|
153 |
+
# and should be between [0, 1]
|
154 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
155 |
+
extra_step_kwargs = {}
|
156 |
+
if accepts_eta:
|
157 |
+
extra_step_kwargs["eta"] = eta
|
158 |
+
|
159 |
+
latents = init_latents
|
160 |
+
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
161 |
+
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
|
162 |
+
# expand the latents if we are doing classifier free guidance
|
163 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
164 |
+
|
165 |
+
# predict the noise residual
|
166 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
|
167 |
+
|
168 |
+
# perform guidance
|
169 |
+
if do_classifier_free_guidance:
|
170 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
171 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
172 |
+
|
173 |
+
# compute the previous noisy sample x_t -> x_t-1
|
174 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
|
175 |
+
|
176 |
+
# masking
|
177 |
+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
|
178 |
+
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
179 |
+
|
180 |
+
# scale and decode the image latents with vae
|
181 |
+
latents = 1 / 0.18215 * latents
|
182 |
+
image = self.vae.decode(latents).sample
|
183 |
+
|
184 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
185 |
+
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
186 |
+
|
187 |
+
# run safety checker
|
188 |
+
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
189 |
+
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
|
190 |
+
|
191 |
+
if output_type == "pil":
|
192 |
+
image = self.numpy_to_pil(image)
|
193 |
+
|
194 |
+
return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
|
mask_image.png
ADDED