Spaces:
Runtime error
Runtime error
Upgrade and any image size handling
#22
by
Fabrice-TIERCELIN
- opened
- .gitattributes +4 -31
- clipseg/weights/rd64-uni.pth → Examples/Example1.png +2 -2
- Examples/Example2.webp +3 -0
- Examples/Example3.jpg +0 -0
- Examples/Example4.gif +0 -0
- Examples/Example5.bmp +3 -0
- Examples/Mask1.webp +0 -0
- Examples/Mask2.png +0 -0
- Examples/Mask3.gif +0 -0
- Examples/Mask4.bmp +3 -0
- Examples/Mask5.png +0 -0
- README.md +16 -6
- app.py +339 -162
- clipseg/LICENSE +0 -21
- clipseg/Quickstart.ipynb +0 -107
- clipseg/Readme.md +0 -84
- clipseg/Tables.ipynb +0 -349
- clipseg/Visual_Feature_Engineering.ipynb +0 -366
- clipseg/datasets/coco_wrapper.py +0 -99
- clipseg/datasets/pascal_classes.json +0 -1
- clipseg/datasets/pascal_zeroshot.py +0 -60
- clipseg/datasets/pfe_dataset.py +0 -129
- clipseg/datasets/phrasecut.py +0 -335
- clipseg/datasets/utils.py +0 -68
- clipseg/environment.yml +0 -15
- clipseg/evaluation_utils.py +0 -292
- clipseg/example_image.jpg +0 -0
- clipseg/experiments/ablation.yaml +0 -84
- clipseg/experiments/coco.yaml +0 -101
- clipseg/experiments/pascal_1shot.yaml +0 -101
- clipseg/experiments/phrasecut.yaml +0 -80
- clipseg/general_utils.py +0 -272
- clipseg/metrics.py +0 -271
- clipseg/models/clipseg.py +0 -552
- clipseg/models/vitseg.py +0 -286
- clipseg/overview.png +0 -0
- clipseg/score.py +0 -453
- clipseg/setup.py +0 -30
- clipseg/training.py +0 -266
- init_image.png +0 -0
- inpainting.py +0 -194
- mask_image.png +0 -0
- requirements.txt +6 -9
.gitattributes
CHANGED
@@ -1,31 +1,4 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
23 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
26 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
Examples/Example1.png filter=lfs diff=lfs merge=lfs -text
|
2 |
+
Examples/Example2.webp filter=lfs diff=lfs merge=lfs -text
|
3 |
+
Examples/Example5.bmp filter=lfs diff=lfs merge=lfs -text
|
4 |
+
Examples/Mask4.bmp filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clipseg/weights/rd64-uni.pth → Examples/Example1.png
RENAMED
File without changes
|
Examples/Example2.webp
ADDED
Git LFS Details
|
Examples/Example3.jpg
ADDED
Examples/Example4.gif
ADDED
Examples/Example5.bmp
ADDED
Git LFS Details
|
Examples/Mask1.webp
ADDED
Examples/Mask2.png
ADDED
Examples/Mask3.gif
ADDED
Examples/Mask4.bmp
ADDED
Git LFS Details
|
Examples/Mask5.png
ADDED
README.md
CHANGED
@@ -1,13 +1,23 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
|
|
11 |
---
|
12 |
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Inpaint SDXL (any size)
|
3 |
+
emoji: ↕️
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: purple
|
6 |
+
tags:
|
7 |
+
- Image-to-Image
|
8 |
+
- Image-2-Image
|
9 |
+
- Img-to-Img
|
10 |
+
- Img-2-Img
|
11 |
+
- SDXL
|
12 |
+
- Stable Diffusion
|
13 |
+
- language models
|
14 |
+
- LLMs
|
15 |
sdk: gradio
|
16 |
+
sdk_version: 3.41.2
|
17 |
app_file: app.py
|
18 |
pinned: false
|
19 |
license: mit
|
20 |
+
short_description: Modifies one detail of your image, at any resolution, freely
|
21 |
---
|
22 |
|
23 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
@@ -1,174 +1,351 @@
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
-
#test
|
3 |
-
from io import BytesIO
|
4 |
-
import requests
|
5 |
-
import PIL
|
6 |
-
from PIL import Image
|
7 |
import numpy as np
|
8 |
-
import
|
9 |
-
import
|
|
|
|
|
10 |
import torch
|
11 |
-
from torch import autocast
|
12 |
-
import cv2
|
13 |
-
from matplotlib import pyplot as plt
|
14 |
-
from diffusers import DiffusionPipeline
|
15 |
-
from torchvision import transforms
|
16 |
-
from clipseg.models.clipseg import CLIPDensePredT
|
17 |
-
|
18 |
-
auth_token = os.environ.get("API_TOKEN") or True
|
19 |
|
20 |
-
|
21 |
-
response = requests.get(url)
|
22 |
-
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
|
23 |
|
24 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
else:
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
cv2.cvtColor(bw_image, cv2.COLOR_BGR2RGB)
|
59 |
-
mask = Image.fromarray(np.uint8(bw_image)).convert('RGB')
|
60 |
-
os.remove(filename)
|
61 |
-
#with autocast("cuda"):
|
62 |
-
output = pipe(prompt = prompt, image=init_image, mask_image=mask, strength=0.8)
|
63 |
-
return output.images[0]
|
64 |
-
|
65 |
-
# examples = [[dict(image="init_image.png", mask="mask_image.png"), "A panda sitting on a bench"]]
|
66 |
-
css = '''
|
67 |
-
.container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
|
68 |
-
#image_upload{min-height:400px}
|
69 |
-
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 400px}
|
70 |
-
#mask_radio .gr-form{background:transparent; border: none}
|
71 |
-
#word_mask{margin-top: .75em !important}
|
72 |
-
#word_mask textarea:disabled{opacity: 0.3}
|
73 |
-
.footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
|
74 |
-
.footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
|
75 |
-
.dark .footer {border-color: #303030}
|
76 |
-
.dark .footer>p {background: #0b0f19}
|
77 |
-
.acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
|
78 |
-
#image_upload .touch-none{display: flex}
|
79 |
-
'''
|
80 |
-
def swap_word_mask(radio_option):
|
81 |
-
if(radio_option == "type what to mask below"):
|
82 |
-
return gr.update(interactive=True, placeholder="A cat")
|
83 |
else:
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
"""
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
height="0.65em"
|
102 |
-
viewBox="0 0 115 115"
|
103 |
-
fill="none"
|
104 |
-
xmlns="http://www.w3.org/2000/svg"
|
105 |
-
>
|
106 |
-
<rect width="23" height="23" fill="white"></rect>
|
107 |
-
<rect y="69" width="23" height="23" fill="white"></rect>
|
108 |
-
<rect x="23" width="23" height="23" fill="#AEAEAE"></rect>
|
109 |
-
<rect x="23" y="69" width="23" height="23" fill="#AEAEAE"></rect>
|
110 |
-
<rect x="46" width="23" height="23" fill="white"></rect>
|
111 |
-
<rect x="46" y="69" width="23" height="23" fill="white"></rect>
|
112 |
-
<rect x="69" width="23" height="23" fill="black"></rect>
|
113 |
-
<rect x="69" y="69" width="23" height="23" fill="black"></rect>
|
114 |
-
<rect x="92" width="23" height="23" fill="#D9D9D9"></rect>
|
115 |
-
<rect x="92" y="69" width="23" height="23" fill="#AEAEAE"></rect>
|
116 |
-
<rect x="115" y="46" width="23" height="23" fill="white"></rect>
|
117 |
-
<rect x="115" y="115" width="23" height="23" fill="white"></rect>
|
118 |
-
<rect x="115" y="69" width="23" height="23" fill="#D9D9D9"></rect>
|
119 |
-
<rect x="92" y="46" width="23" height="23" fill="#AEAEAE"></rect>
|
120 |
-
<rect x="92" y="115" width="23" height="23" fill="#AEAEAE"></rect>
|
121 |
-
<rect x="92" y="69" width="23" height="23" fill="white"></rect>
|
122 |
-
<rect x="69" y="46" width="23" height="23" fill="white"></rect>
|
123 |
-
<rect x="69" y="115" width="23" height="23" fill="white"></rect>
|
124 |
-
<rect x="69" y="69" width="23" height="23" fill="#D9D9D9"></rect>
|
125 |
-
<rect x="46" y="46" width="23" height="23" fill="black"></rect>
|
126 |
-
<rect x="46" y="115" width="23" height="23" fill="black"></rect>
|
127 |
-
<rect x="46" y="69" width="23" height="23" fill="black"></rect>
|
128 |
-
<rect x="23" y="46" width="23" height="23" fill="#D9D9D9"></rect>
|
129 |
-
<rect x="23" y="115" width="23" height="23" fill="#AEAEAE"></rect>
|
130 |
-
<rect x="23" y="69" width="23" height="23" fill="black"></rect>
|
131 |
-
</svg>
|
132 |
-
<h1 style="font-weight: 900; margin-bottom: 7px;">
|
133 |
-
Stable Diffusion Multi Inpainting
|
134 |
-
</h1>
|
135 |
-
</div>
|
136 |
-
<p style="margin-bottom: 10px; font-size: 94%">
|
137 |
-
Inpaint Stable Diffusion by either drawing a mask or typing what to replace
|
138 |
-
</p>
|
139 |
-
</div>
|
140 |
"""
|
141 |
)
|
142 |
-
with gr.
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import StableDiffusionXLInpaintPipeline
|
2 |
+
from PIL import Image, ImageFilter
|
3 |
+
|
4 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
5 |
import numpy as np
|
6 |
+
import time
|
7 |
+
import math
|
8 |
+
import random
|
9 |
+
import imageio
|
10 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
+
max_64_bit_int = 2**63 - 1
|
|
|
|
|
13 |
|
14 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
15 |
+
floatType = torch.float16 if torch.cuda.is_available() else torch.float32
|
16 |
+
variant = "fp16" if torch.cuda.is_available() else None
|
17 |
+
pipe = StableDiffusionXLInpaintPipeline.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype = floatType, variant = variant)
|
18 |
+
pipe = pipe.to(device)
|
19 |
+
|
20 |
+
def check(
|
21 |
+
source_img,
|
22 |
+
prompt,
|
23 |
+
uploaded_mask,
|
24 |
+
negative_prompt,
|
25 |
+
denoising_steps,
|
26 |
+
num_inference_steps,
|
27 |
+
guidance_scale,
|
28 |
+
image_guidance_scale,
|
29 |
+
strength,
|
30 |
+
randomize_seed,
|
31 |
+
seed,
|
32 |
+
debug_mode,
|
33 |
+
progress = gr.Progress()
|
34 |
+
):
|
35 |
+
if source_img is None:
|
36 |
+
raise gr.Error("Please provide an image.")
|
37 |
+
|
38 |
+
if prompt is None or prompt == "":
|
39 |
+
raise gr.Error("Please provide a prompt input.")
|
40 |
+
|
41 |
+
def inpaint(
|
42 |
+
source_img,
|
43 |
+
prompt,
|
44 |
+
uploaded_mask,
|
45 |
+
negative_prompt,
|
46 |
+
denoising_steps,
|
47 |
+
num_inference_steps,
|
48 |
+
guidance_scale,
|
49 |
+
image_guidance_scale,
|
50 |
+
strength,
|
51 |
+
randomize_seed,
|
52 |
+
seed,
|
53 |
+
debug_mode,
|
54 |
+
progress = gr.Progress()
|
55 |
+
):
|
56 |
+
check(
|
57 |
+
source_img,
|
58 |
+
prompt,
|
59 |
+
uploaded_mask,
|
60 |
+
negative_prompt,
|
61 |
+
denoising_steps,
|
62 |
+
num_inference_steps,
|
63 |
+
guidance_scale,
|
64 |
+
image_guidance_scale,
|
65 |
+
strength,
|
66 |
+
randomize_seed,
|
67 |
+
seed,
|
68 |
+
debug_mode
|
69 |
+
)
|
70 |
+
start = time.time()
|
71 |
+
progress(0, desc = "Preparing data...")
|
72 |
+
|
73 |
+
if negative_prompt is None:
|
74 |
+
negative_prompt = ""
|
75 |
+
|
76 |
+
if denoising_steps is None:
|
77 |
+
denoising_steps = 1000
|
78 |
+
|
79 |
+
if num_inference_steps is None:
|
80 |
+
num_inference_steps = 25
|
81 |
+
|
82 |
+
if guidance_scale is None:
|
83 |
+
guidance_scale = 7
|
84 |
+
|
85 |
+
if image_guidance_scale is None:
|
86 |
+
image_guidance_scale = 1.1
|
87 |
+
|
88 |
+
if strength is None:
|
89 |
+
strength = 0.99
|
90 |
+
|
91 |
+
if randomize_seed:
|
92 |
+
seed = random.randint(0, max_64_bit_int)
|
93 |
+
|
94 |
+
random.seed(seed)
|
95 |
+
#pipe = pipe.manual_seed(seed)
|
96 |
+
|
97 |
+
input_image = source_img["image"].convert("RGB")
|
98 |
+
|
99 |
+
original_height, original_width, original_channel = np.array(input_image).shape
|
100 |
+
output_width = original_width
|
101 |
+
output_height = original_height
|
102 |
+
|
103 |
+
if uploaded_mask is None:
|
104 |
+
mask_image = source_img["mask"].convert("RGB")
|
105 |
else:
|
106 |
+
mask_image = uploaded_mask.convert("RGB")
|
107 |
+
mask_image = mask_image.resize((original_width, original_height))
|
108 |
+
|
109 |
+
# Limited to 1 million pixels
|
110 |
+
if 1024 * 1024 < output_width * output_height:
|
111 |
+
factor = ((1024 * 1024) / (output_width * output_height))**0.5
|
112 |
+
process_width = math.floor(output_width * factor)
|
113 |
+
process_height = math.floor(output_height * factor)
|
114 |
+
|
115 |
+
limitation = " Due to technical limitation, the image have been downscaled and then upscaled.";
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
else:
|
117 |
+
process_width = output_width
|
118 |
+
process_height = output_height
|
119 |
+
|
120 |
+
limitation = "";
|
121 |
+
|
122 |
+
# Width and height must be multiple of 8
|
123 |
+
if (process_width % 8) != 0 or (process_height % 8) != 0:
|
124 |
+
if ((process_width - (process_width % 8) + 8) * (process_height - (process_height % 8) + 8)) <= (1024 * 1024):
|
125 |
+
process_width = process_width - (process_width % 8) + 8
|
126 |
+
process_height = process_height - (process_height % 8) + 8
|
127 |
+
elif (process_height % 8) <= (process_width % 8) and ((process_width - (process_width % 8) + 8) * process_height) <= (1024 * 1024):
|
128 |
+
process_width = process_width - (process_width % 8) + 8
|
129 |
+
process_height = process_height - (process_height % 8)
|
130 |
+
elif (process_width % 8) <= (process_height % 8) and (process_width * (process_height - (process_height % 8) + 8)) <= (1024 * 1024):
|
131 |
+
process_width = process_width - (process_width % 8)
|
132 |
+
process_height = process_height - (process_height % 8) + 8
|
133 |
+
else:
|
134 |
+
process_width = process_width - (process_width % 8)
|
135 |
+
process_height = process_height - (process_height % 8)
|
136 |
+
|
137 |
+
progress(None, desc = "Processing...")
|
138 |
+
output_image = pipe(
|
139 |
+
seeds = [seed],
|
140 |
+
width = process_width,
|
141 |
+
height = process_height,
|
142 |
+
prompt = prompt,
|
143 |
+
negative_prompt = negative_prompt,
|
144 |
+
image = input_image,
|
145 |
+
mask_image = mask_image,
|
146 |
+
num_inference_steps = num_inference_steps,
|
147 |
+
guidance_scale = guidance_scale,
|
148 |
+
image_guidance_scale = image_guidance_scale,
|
149 |
+
strength = strength,
|
150 |
+
denoising_steps = denoising_steps,
|
151 |
+
show_progress_bar = True
|
152 |
+
).images[0]
|
153 |
|
154 |
+
if limitation != "":
|
155 |
+
output_image = output_image.resize((output_width, output_height))
|
156 |
+
|
157 |
+
if debug_mode == False:
|
158 |
+
input_image = None
|
159 |
+
mask_image = None
|
160 |
+
|
161 |
+
end = time.time()
|
162 |
+
secondes = int(end - start)
|
163 |
+
minutes = secondes // 60
|
164 |
+
secondes = secondes - (minutes * 60)
|
165 |
+
hours = minutes // 60
|
166 |
+
minutes = minutes - (hours * 60)
|
167 |
+
return [
|
168 |
+
output_image,
|
169 |
+
"Start again to get a different result. The new image is " + str(output_width) + " pixels large and " + str(output_height) + " pixels high, so an image of " + f'{output_width * output_height:,}' + " pixels. The image have been generated in " + str(hours) + " h, " + str(minutes) + " min, " + str(secondes) + " sec." + limitation,
|
170 |
+
input_image,
|
171 |
+
mask_image
|
172 |
+
]
|
173 |
+
|
174 |
+
def toggle_debug(is_debug_mode):
|
175 |
+
if is_debug_mode:
|
176 |
+
return [gr.update(visible = True)] * 2
|
177 |
+
else:
|
178 |
+
return [gr.update(visible = False)] * 2
|
179 |
+
|
180 |
+
with gr.Blocks() as interface:
|
181 |
+
gr.Markdown(
|
182 |
"""
|
183 |
+
<p style="text-align: center;"><b><big><big><big>Inpaint</big></big></big></b></p>
|
184 |
+
<p style="text-align: center;">Modifies one detail of your image, at any resolution, freely, without account, without watermark, without installation, which can be downloaded</p>
|
185 |
+
<br/>
|
186 |
+
<br/>
|
187 |
+
🚀 Powered by <i>SDXL 1.0</i> artificial intellingence.
|
188 |
+
<br/>
|
189 |
+
🐌 Slow process... ~1 hour.<br>You can duplicate this space on a free account, it works on CPU and should also run on CUDA.<br/>
|
190 |
+
<a href='https://huggingface.co/spaces/multimodalart/stable-diffusion-inpainting?duplicate=true'><img src='https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14'></a>
|
191 |
+
<br/>
|
192 |
+
⚖️ You can use, modify and share the generated images but not for commercial uses.
|
193 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
"""
|
195 |
)
|
196 |
+
with gr.Column():
|
197 |
+
source_img = gr.Image(label = "Your image", source = "upload", tool = "sketch", type = "pil")
|
198 |
+
prompt = gr.Textbox(label = "Prompt", info = "Describe the subject, the background and the style of image; 77 token limit", placeholder = "Describe what you want to see in the entire image")
|
199 |
+
with gr.Accordion("Upload a mask", open = False):
|
200 |
+
uploaded_mask = gr.Image(label = "Already made mask (black pixels will be preserved, white pixels will be redrawn)", source = "upload", type = "pil")
|
201 |
+
with gr.Accordion("Advanced options", open = False):
|
202 |
+
negative_prompt = gr.Textbox(label = "Negative prompt", placeholder = "Describe what you do NOT want to see in the entire image", value = "Ugly, malformed, noise, blur, watermark")
|
203 |
+
denoising_steps = gr.Slider(minimum = 0, maximum = 1000, value = 1000, step = 1, label = "Denoising", info = "lower=irrelevant result, higher=relevant result")
|
204 |
+
num_inference_steps = gr.Slider(minimum = 10, maximum = 100, value = 25, step = 1, label = "Number of inference steps", info = "lower=faster, higher=image quality")
|
205 |
+
guidance_scale = gr.Slider(minimum = 1, maximum = 13, value = 7, step = 0.1, label = "Classifier-Free Guidance Scale", info = "lower=image quality, higher=follow the prompt")
|
206 |
+
image_guidance_scale = gr.Slider(minimum = 1, value = 1.1, step = 0.1, label = "Image Guidance Scale", info = "lower=image quality, higher=follow the image")
|
207 |
+
strength = gr.Number(value = 0.99, minimum = 0.01, maximum = 1.0, step = 0.01, label = "Strength", info = "lower=follow the original area, higher=redraw from scratch")
|
208 |
+
randomize_seed = gr.Checkbox(label = "\U0001F3B2 Randomize seed (not working, always checked)", value = True, info = "If checked, result is always different")
|
209 |
+
seed = gr.Slider(minimum = 0, maximum = max_64_bit_int, step = 1, randomize = True, label = "Seed (if not randomized)")
|
210 |
+
debug_mode = gr.Checkbox(label = "Debug mode", value = False, info = "Show intermediate results")
|
211 |
+
|
212 |
+
submit = gr.Button("Inpaint", variant = "primary")
|
213 |
+
|
214 |
+
inpainted_image = gr.Image(label = "Inpainted image")
|
215 |
+
information = gr.Label(label = "Information")
|
216 |
+
original_image = gr.Image(label = "Original image", visible = False)
|
217 |
+
mask_image = gr.Image(label = "Mask image", visible = False)
|
218 |
+
|
219 |
+
submit.click(toggle_debug, debug_mode, [
|
220 |
+
original_image,
|
221 |
+
mask_image
|
222 |
+
], queue = False, show_progress = False).then(check, inputs = [
|
223 |
+
source_img,
|
224 |
+
prompt,
|
225 |
+
uploaded_mask,
|
226 |
+
negative_prompt,
|
227 |
+
denoising_steps,
|
228 |
+
num_inference_steps,
|
229 |
+
guidance_scale,
|
230 |
+
image_guidance_scale,
|
231 |
+
strength,
|
232 |
+
randomize_seed,
|
233 |
+
seed,
|
234 |
+
debug_mode
|
235 |
+
], outputs = [], queue = False, show_progress = False).success(inpaint, inputs = [
|
236 |
+
source_img,
|
237 |
+
prompt,
|
238 |
+
uploaded_mask,
|
239 |
+
negative_prompt,
|
240 |
+
denoising_steps,
|
241 |
+
num_inference_steps,
|
242 |
+
guidance_scale,
|
243 |
+
image_guidance_scale,
|
244 |
+
strength,
|
245 |
+
randomize_seed,
|
246 |
+
seed,
|
247 |
+
debug_mode
|
248 |
+
], outputs = [
|
249 |
+
inpainted_image,
|
250 |
+
information,
|
251 |
+
original_image,
|
252 |
+
mask_image
|
253 |
+
], scroll_to_output = True)
|
254 |
+
|
255 |
+
gr.Examples(
|
256 |
+
inputs = [
|
257 |
+
source_img,
|
258 |
+
prompt,
|
259 |
+
uploaded_mask,
|
260 |
+
negative_prompt,
|
261 |
+
denoising_steps,
|
262 |
+
num_inference_steps,
|
263 |
+
guidance_scale,
|
264 |
+
image_guidance_scale,
|
265 |
+
strength,
|
266 |
+
randomize_seed,
|
267 |
+
seed,
|
268 |
+
debug_mode
|
269 |
+
],
|
270 |
+
outputs = [
|
271 |
+
inpainted_image,
|
272 |
+
information,
|
273 |
+
original_image,
|
274 |
+
mask_image
|
275 |
+
],
|
276 |
+
examples = [
|
277 |
+
[
|
278 |
+
"./Examples/Example1.png",
|
279 |
+
"A deer, in a forest landscape, ultrarealistic, realistic, photorealistic, 8k",
|
280 |
+
"./Examples/Mask1.webp",
|
281 |
+
"Painting, drawing, cartoon, ugly, malformed, noise, blur, watermark",
|
282 |
+
1000,
|
283 |
+
25,
|
284 |
+
7,
|
285 |
+
1.1,
|
286 |
+
0.99,
|
287 |
+
True,
|
288 |
+
42,
|
289 |
+
False
|
290 |
+
],
|
291 |
+
[
|
292 |
+
"./Examples/Example3.jpg",
|
293 |
+
"An angry old woman, ultrarealistic, realistic, photorealistic, 8k",
|
294 |
+
"./Examples/Mask3.gif",
|
295 |
+
"Painting, drawing, cartoon, ugly, malformed, noise, blur, watermark",
|
296 |
+
1000,
|
297 |
+
25,
|
298 |
+
7,
|
299 |
+
1.5,
|
300 |
+
0.99,
|
301 |
+
True,
|
302 |
+
42,
|
303 |
+
False
|
304 |
+
],
|
305 |
+
[
|
306 |
+
"./Examples/Example4.gif",
|
307 |
+
"A laptop, ultrarealistic, realistic, photorealistic, 8k",
|
308 |
+
"./Examples/Mask4.bmp",
|
309 |
+
"Painting, drawing, cartoon, ugly, malformed, noise, blur, watermark",
|
310 |
+
1000,
|
311 |
+
25,
|
312 |
+
7,
|
313 |
+
1.1,
|
314 |
+
0.99,
|
315 |
+
True,
|
316 |
+
42,
|
317 |
+
False
|
318 |
+
],
|
319 |
+
[
|
320 |
+
"./Examples/Example5.bmp",
|
321 |
+
"A sand castle, ultrarealistic, realistic, photorealistic, 8k",
|
322 |
+
"./Examples/Mask5.png",
|
323 |
+
"Painting, drawing, cartoon, ugly, malformed, noise, blur, watermark",
|
324 |
+
1000,
|
325 |
+
50,
|
326 |
+
7,
|
327 |
+
1.5,
|
328 |
+
0.5,
|
329 |
+
True,
|
330 |
+
42,
|
331 |
+
False
|
332 |
+
],
|
333 |
+
[
|
334 |
+
"./Examples/Example2.webp",
|
335 |
+
"A cat, ultrarealistic, realistic, photorealistic, 8k",
|
336 |
+
"./Examples/Mask2.png",
|
337 |
+
"Painting, drawing, cartoon, ugly, malformed, noise, blur, watermark",
|
338 |
+
1000,
|
339 |
+
25,
|
340 |
+
7,
|
341 |
+
1.1,
|
342 |
+
0.99,
|
343 |
+
True,
|
344 |
+
42,
|
345 |
+
False
|
346 |
+
],
|
347 |
+
],
|
348 |
+
cache_examples = False,
|
349 |
+
)
|
350 |
+
|
351 |
+
interface.queue().launch()
|
clipseg/LICENSE
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
MIT License
|
2 |
-
|
3 |
-
This license does not apply to the model weights.
|
4 |
-
|
5 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
-
of this software and associated documentation files (the "Software"), to deal
|
7 |
-
in the Software without restriction, including without limitation the rights
|
8 |
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
-
copies of the Software, and to permit persons to whom the Software is
|
10 |
-
furnished to do so, subject to the following conditions:
|
11 |
-
|
12 |
-
The above copyright notice and this permission notice shall be included in all
|
13 |
-
copies or substantial portions of the Software.
|
14 |
-
|
15 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
-
SOFTWARE.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clipseg/Quickstart.ipynb
DELETED
@@ -1,107 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": null,
|
6 |
-
"metadata": {},
|
7 |
-
"outputs": [],
|
8 |
-
"source": [
|
9 |
-
"import torch\n",
|
10 |
-
"import requests\n",
|
11 |
-
"\n",
|
12 |
-
"! wget https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -O weights.zip\n",
|
13 |
-
"! unzip -d weights -j weights.zip\n",
|
14 |
-
"from models.clipseg import CLIPDensePredT\n",
|
15 |
-
"from PIL import Image\n",
|
16 |
-
"from torchvision import transforms\n",
|
17 |
-
"from matplotlib import pyplot as plt\n",
|
18 |
-
"\n",
|
19 |
-
"# load model\n",
|
20 |
-
"model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)\n",
|
21 |
-
"model.eval();\n",
|
22 |
-
"\n",
|
23 |
-
"# non-strict, because we only stored decoder weights (not CLIP weights)\n",
|
24 |
-
"model.load_state_dict(torch.load('weights/rd64-uni.pth', map_location=torch.device('cpu')), strict=False);"
|
25 |
-
]
|
26 |
-
},
|
27 |
-
{
|
28 |
-
"cell_type": "markdown",
|
29 |
-
"metadata": {},
|
30 |
-
"source": [
|
31 |
-
"Load and normalize `example_image.jpg`. You can also load through an URL."
|
32 |
-
]
|
33 |
-
},
|
34 |
-
{
|
35 |
-
"cell_type": "code",
|
36 |
-
"execution_count": null,
|
37 |
-
"metadata": {},
|
38 |
-
"outputs": [],
|
39 |
-
"source": [
|
40 |
-
"# load and normalize image\n",
|
41 |
-
"input_image = Image.open('example_image.jpg')\n",
|
42 |
-
"\n",
|
43 |
-
"# or load from URL...\n",
|
44 |
-
"# image_url = 'https://farm5.staticflickr.com/4141/4856248695_03475782dc_z.jpg'\n",
|
45 |
-
"# input_image = Image.open(requests.get(image_url, stream=True).raw)\n",
|
46 |
-
"\n",
|
47 |
-
"transform = transforms.Compose([\n",
|
48 |
-
" transforms.ToTensor(),\n",
|
49 |
-
" transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
|
50 |
-
" transforms.Resize((352, 352)),\n",
|
51 |
-
"])\n",
|
52 |
-
"img = transform(input_image).unsqueeze(0)"
|
53 |
-
]
|
54 |
-
},
|
55 |
-
{
|
56 |
-
"cell_type": "markdown",
|
57 |
-
"metadata": {},
|
58 |
-
"source": [
|
59 |
-
"Predict and visualize (this might take a few seconds if running without GPU support)"
|
60 |
-
]
|
61 |
-
},
|
62 |
-
{
|
63 |
-
"cell_type": "code",
|
64 |
-
"execution_count": null,
|
65 |
-
"metadata": {},
|
66 |
-
"outputs": [],
|
67 |
-
"source": [
|
68 |
-
"prompts = ['a glass', 'something to fill', 'wood', 'a jar']\n",
|
69 |
-
"\n",
|
70 |
-
"# predict\n",
|
71 |
-
"with torch.no_grad():\n",
|
72 |
-
" preds = model(img.repeat(4,1,1,1), prompts)[0]\n",
|
73 |
-
"\n",
|
74 |
-
"# visualize prediction\n",
|
75 |
-
"_, ax = plt.subplots(1, 5, figsize=(15, 4))\n",
|
76 |
-
"[a.axis('off') for a in ax.flatten()]\n",
|
77 |
-
"ax[0].imshow(input_image)\n",
|
78 |
-
"[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(4)];\n",
|
79 |
-
"[ax[i+1].text(0, -15, prompts[i]) for i in range(4)];"
|
80 |
-
]
|
81 |
-
}
|
82 |
-
],
|
83 |
-
"metadata": {
|
84 |
-
"interpreter": {
|
85 |
-
"hash": "800ed241f7db2bd3aa6942aa3be6809cdb30ee6b0a9e773dfecfa9fef1f4c586"
|
86 |
-
},
|
87 |
-
"kernelspec": {
|
88 |
-
"display_name": "Python 3",
|
89 |
-
"language": "python",
|
90 |
-
"name": "python3"
|
91 |
-
},
|
92 |
-
"language_info": {
|
93 |
-
"codemirror_mode": {
|
94 |
-
"name": "ipython",
|
95 |
-
"version": 3
|
96 |
-
},
|
97 |
-
"file_extension": ".py",
|
98 |
-
"mimetype": "text/x-python",
|
99 |
-
"name": "python",
|
100 |
-
"nbconvert_exporter": "python",
|
101 |
-
"pygments_lexer": "ipython3",
|
102 |
-
"version": "3.8.10"
|
103 |
-
}
|
104 |
-
},
|
105 |
-
"nbformat": 4,
|
106 |
-
"nbformat_minor": 4
|
107 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clipseg/Readme.md
DELETED
@@ -1,84 +0,0 @@
|
|
1 |
-
# Image Segmentation Using Text and Image Prompts
|
2 |
-
This repository contains the code used in the paper ["Image Segmentation Using Text and Image Prompts"](https://arxiv.org/abs/2112.10003).
|
3 |
-
|
4 |
-
**The Paper has been accepted to CVPR 2022!**
|
5 |
-
|
6 |
-
<img src="overview.png" alt="drawing" height="200em"/>
|
7 |
-
|
8 |
-
The systems allows to create segmentation models without training based on:
|
9 |
-
- An arbitrary text query
|
10 |
-
- Or an image with a mask highlighting stuff or an object.
|
11 |
-
|
12 |
-
### Quick Start
|
13 |
-
|
14 |
-
In the `Quickstart.ipynb` notebook we provide the code for using a pre-trained CLIPSeg model. If you run the notebook locally, make sure you downloaded the `rd64-uni.pth` weights, either manually or via git lfs extension.
|
15 |
-
It can also be used interactively using [MyBinder](https://mybinder.org/v2/gh/timojl/clipseg/HEAD?labpath=Quickstart.ipynb)
|
16 |
-
(please note that the VM does not use a GPU, thus inference takes a few seconds).
|
17 |
-
|
18 |
-
|
19 |
-
### Dependencies
|
20 |
-
This code base depends on pytorch, torchvision and clip (`pip install git+https://github.com/openai/CLIP.git`).
|
21 |
-
Additional dependencies are hidden for double blind review.
|
22 |
-
|
23 |
-
|
24 |
-
### Datasets
|
25 |
-
|
26 |
-
* `PhraseCut` and `PhraseCutPlus`: Referring expression dataset
|
27 |
-
* `PFEPascalWrapper`: Wrapper class for PFENet's Pascal-5i implementation
|
28 |
-
* `PascalZeroShot`: Wrapper class for PascalZeroShot
|
29 |
-
* `COCOWrapper`: Wrapper class for COCO.
|
30 |
-
|
31 |
-
### Models
|
32 |
-
|
33 |
-
* `CLIPDensePredT`: CLIPSeg model with transformer-based decoder.
|
34 |
-
* `ViTDensePredT`: CLIPSeg model with transformer-based decoder.
|
35 |
-
|
36 |
-
### Third Party Dependencies
|
37 |
-
For some of the datasets third party dependencies are required. Run the following commands in the `third_party` folder.
|
38 |
-
```bash
|
39 |
-
git clone https://github.com/cvlab-yonsei/JoEm
|
40 |
-
git clone https://github.com/Jia-Research-Lab/PFENet.git
|
41 |
-
git clone https://github.com/ChenyunWu/PhraseCutDataset.git
|
42 |
-
git clone https://github.com/juhongm999/hsnet.git
|
43 |
-
```
|
44 |
-
|
45 |
-
### Weights
|
46 |
-
|
47 |
-
The MIT license does not apply to these weights.
|
48 |
-
|
49 |
-
We provide two model weights, for D=64 (4.1MB) and D=16 (1.1MB).
|
50 |
-
```
|
51 |
-
wget https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -O weights.zip
|
52 |
-
unzip -d weights -j weights.zip
|
53 |
-
```
|
54 |
-
|
55 |
-
|
56 |
-
### Training and Evaluation
|
57 |
-
|
58 |
-
To train use the `training.py` script with experiment file and experiment id parameters. E.g. `python training.py phrasecut.yaml 0` will train the first phrasecut experiment which is defined by the `configuration` and first `individual_configurations` parameters. Model weights will be written in `logs/`.
|
59 |
-
|
60 |
-
For evaluation use `score.py`. E.g. `python score.py phrasecut.yaml 0 0` will train the first phrasecut experiment of `test_configuration` and the first configuration in `individual_configurations`.
|
61 |
-
|
62 |
-
|
63 |
-
### Usage of PFENet Wrappers
|
64 |
-
|
65 |
-
In order to use the dataset and model wrappers for PFENet, the PFENet repository needs to be cloned to the root folder.
|
66 |
-
`git clone https://github.com/Jia-Research-Lab/PFENet.git `
|
67 |
-
|
68 |
-
|
69 |
-
### License
|
70 |
-
|
71 |
-
The source code files in this repository (excluding model weights) are released under MIT license.
|
72 |
-
|
73 |
-
### Citation
|
74 |
-
```
|
75 |
-
@InProceedings{lueddecke22_cvpr,
|
76 |
-
author = {L\"uddecke, Timo and Ecker, Alexander},
|
77 |
-
title = {Image Segmentation Using Text and Image Prompts},
|
78 |
-
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
79 |
-
month = {June},
|
80 |
-
year = {2022},
|
81 |
-
pages = {7086-7096}
|
82 |
-
}
|
83 |
-
|
84 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clipseg/Tables.ipynb
DELETED
@@ -1,349 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": null,
|
6 |
-
"metadata": {},
|
7 |
-
"outputs": [],
|
8 |
-
"source": [
|
9 |
-
"%load_ext autoreload\n",
|
10 |
-
"%autoreload 2\n",
|
11 |
-
"\n",
|
12 |
-
"import clip\n",
|
13 |
-
"from evaluation_utils import norm, denorm\n",
|
14 |
-
"from general_utils import *\n",
|
15 |
-
"from datasets.lvis_oneshot3 import LVIS_OneShot3, LVIS_OneShot"
|
16 |
-
]
|
17 |
-
},
|
18 |
-
{
|
19 |
-
"cell_type": "markdown",
|
20 |
-
"metadata": {},
|
21 |
-
"source": [
|
22 |
-
"# PhraseCut"
|
23 |
-
]
|
24 |
-
},
|
25 |
-
{
|
26 |
-
"cell_type": "code",
|
27 |
-
"execution_count": null,
|
28 |
-
"metadata": {},
|
29 |
-
"outputs": [],
|
30 |
-
"source": [
|
31 |
-
"pc = experiment('experiments/phrasecut.yaml', nums=':6').dataframe()"
|
32 |
-
]
|
33 |
-
},
|
34 |
-
{
|
35 |
-
"cell_type": "code",
|
36 |
-
"execution_count": null,
|
37 |
-
"metadata": {},
|
38 |
-
"outputs": [],
|
39 |
-
"source": [
|
40 |
-
"tab1 = pc[['name', 'pc_miou_best', 'pc_fgiou_best', 'pc_ap']]"
|
41 |
-
]
|
42 |
-
},
|
43 |
-
{
|
44 |
-
"cell_type": "code",
|
45 |
-
"execution_count": null,
|
46 |
-
"metadata": {},
|
47 |
-
"outputs": [],
|
48 |
-
"source": [
|
49 |
-
"cols = ['pc_miou_0.3', 'pc_fgiou_0.3', 'pc_ap']\n",
|
50 |
-
"tab1 = pc[['name'] + cols]\n",
|
51 |
-
"for k in cols:\n",
|
52 |
-
" tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n",
|
53 |
-
"tab1.loc[:, 'name'] = ['CLIPSeg (PC+)', 'CLIPSeg (PC, $D=128$)', 'CLIPSeg (PC)', 'CLIP-Deconv', 'ViTSeg (PC+)', 'ViTSeg (PC)']\n",
|
54 |
-
"tab1.insert(1, 't', [0.3]*tab1.shape[0])\n",
|
55 |
-
"print(tab1.to_latex(header=False, index=False))"
|
56 |
-
]
|
57 |
-
},
|
58 |
-
{
|
59 |
-
"cell_type": "markdown",
|
60 |
-
"metadata": {},
|
61 |
-
"source": [
|
62 |
-
"For 0.1 threshold"
|
63 |
-
]
|
64 |
-
},
|
65 |
-
{
|
66 |
-
"cell_type": "code",
|
67 |
-
"execution_count": null,
|
68 |
-
"metadata": {},
|
69 |
-
"outputs": [],
|
70 |
-
"source": [
|
71 |
-
"cols = ['pc_miou_0.1', 'pc_fgiou_0.1', 'pc_ap']\n",
|
72 |
-
"tab1 = pc[['name'] + cols]\n",
|
73 |
-
"for k in cols:\n",
|
74 |
-
" tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n",
|
75 |
-
"tab1.loc[:, 'name'] = ['CLIPSeg (PC+)', 'CLIPSeg (PC, $D=128$)', 'CLIPSeg (PC)', 'CLIP-Deconv', 'ViTSeg (PC+)', 'ViTSeg (PC)']\n",
|
76 |
-
"tab1.insert(1, 't', [0.1]*tab1.shape[0])\n",
|
77 |
-
"print(tab1.to_latex(header=False, index=False))"
|
78 |
-
]
|
79 |
-
},
|
80 |
-
{
|
81 |
-
"cell_type": "markdown",
|
82 |
-
"metadata": {},
|
83 |
-
"source": [
|
84 |
-
"# One-shot"
|
85 |
-
]
|
86 |
-
},
|
87 |
-
{
|
88 |
-
"cell_type": "markdown",
|
89 |
-
"metadata": {},
|
90 |
-
"source": [
|
91 |
-
"### Pascal"
|
92 |
-
]
|
93 |
-
},
|
94 |
-
{
|
95 |
-
"cell_type": "code",
|
96 |
-
"execution_count": null,
|
97 |
-
"metadata": {},
|
98 |
-
"outputs": [],
|
99 |
-
"source": [
|
100 |
-
"pas = experiment('experiments/pascal_1shot.yaml', nums=':19').dataframe()"
|
101 |
-
]
|
102 |
-
},
|
103 |
-
{
|
104 |
-
"cell_type": "code",
|
105 |
-
"execution_count": null,
|
106 |
-
"metadata": {},
|
107 |
-
"outputs": [],
|
108 |
-
"source": [
|
109 |
-
"pas[['name', 'pas_h2_miou_0.3', 'pas_h2_biniou_0.3', 'pas_h2_ap', 'pas_h2_fgiou_ct']]"
|
110 |
-
]
|
111 |
-
},
|
112 |
-
{
|
113 |
-
"cell_type": "code",
|
114 |
-
"execution_count": null,
|
115 |
-
"metadata": {},
|
116 |
-
"outputs": [],
|
117 |
-
"source": [
|
118 |
-
"pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n",
|
119 |
-
"tab1 = pas[['pas_h2_miou_0.3', 'pas_h2_biniou_0.3', 'pas_h2_ap']]\n",
|
120 |
-
"print('CLIPSeg (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
|
121 |
-
"print('CLIPSeg (PC) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
|
122 |
-
"\n",
|
123 |
-
"pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n",
|
124 |
-
"tab1 = pas[['pas_h2_miou_0.2', 'pas_h2_biniou_0.2', 'pas_h2_ap']]\n",
|
125 |
-
"print('CLIP-Deconv (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
|
126 |
-
"\n",
|
127 |
-
"pas = experiment('experiments/pascal_1shot.yaml', nums='16:20').dataframe()\n",
|
128 |
-
"tab1 = pas[['pas_t_miou_0.2', 'pas_t_biniou_0.2', 'pas_t_ap']]\n",
|
129 |
-
"print('ViTSeg (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')"
|
130 |
-
]
|
131 |
-
},
|
132 |
-
{
|
133 |
-
"cell_type": "markdown",
|
134 |
-
"metadata": {},
|
135 |
-
"source": [
|
136 |
-
"#### Pascal Zero-shot (in one-shot setting)\n",
|
137 |
-
"\n",
|
138 |
-
"Using the same setting as one-shot (hence different from the other zero-shot benchmark)"
|
139 |
-
]
|
140 |
-
},
|
141 |
-
{
|
142 |
-
"cell_type": "code",
|
143 |
-
"execution_count": null,
|
144 |
-
"metadata": {},
|
145 |
-
"outputs": [],
|
146 |
-
"source": [
|
147 |
-
"pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n",
|
148 |
-
"tab1 = pas[['pas_t_miou_0.3', 'pas_t_biniou_0.3', 'pas_t_ap']]\n",
|
149 |
-
"print('CLIPSeg (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
|
150 |
-
"print('CLIPSeg (PC) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
|
151 |
-
"\n",
|
152 |
-
"pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n",
|
153 |
-
"tab1 = pas[['pas_t_miou_0.3', 'pas_t_biniou_0.3', 'pas_t_ap']]\n",
|
154 |
-
"print('CLIP-Deconv (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
|
155 |
-
"\n",
|
156 |
-
"pas = experiment('experiments/pascal_1shot.yaml', nums='16:20').dataframe()\n",
|
157 |
-
"tab1 = pas[['pas_t_miou_0.2', 'pas_t_biniou_0.2', 'pas_t_ap']]\n",
|
158 |
-
"print('ViTSeg (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')"
|
159 |
-
]
|
160 |
-
},
|
161 |
-
{
|
162 |
-
"cell_type": "code",
|
163 |
-
"execution_count": null,
|
164 |
-
"metadata": {},
|
165 |
-
"outputs": [],
|
166 |
-
"source": [
|
167 |
-
"# without fixed thresholds...\n",
|
168 |
-
"\n",
|
169 |
-
"pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n",
|
170 |
-
"tab1 = pas[['pas_t_best_miou', 'pas_t_best_biniou', 'pas_t_ap']]\n",
|
171 |
-
"print('CLIPSeg (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
|
172 |
-
"print('CLIPSeg (PC) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
|
173 |
-
"\n",
|
174 |
-
"pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n",
|
175 |
-
"tab1 = pas[['pas_t_best_miou', 'pas_t_best_biniou', 'pas_t_ap']]\n",
|
176 |
-
"print('CLIP-Deconv (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')"
|
177 |
-
]
|
178 |
-
},
|
179 |
-
{
|
180 |
-
"cell_type": "markdown",
|
181 |
-
"metadata": {},
|
182 |
-
"source": [
|
183 |
-
"### COCO"
|
184 |
-
]
|
185 |
-
},
|
186 |
-
{
|
187 |
-
"cell_type": "code",
|
188 |
-
"execution_count": null,
|
189 |
-
"metadata": {},
|
190 |
-
"outputs": [],
|
191 |
-
"source": [
|
192 |
-
"coco = experiment('experiments/coco.yaml', nums=':29').dataframe()"
|
193 |
-
]
|
194 |
-
},
|
195 |
-
{
|
196 |
-
"cell_type": "code",
|
197 |
-
"execution_count": null,
|
198 |
-
"metadata": {},
|
199 |
-
"outputs": [],
|
200 |
-
"source": [
|
201 |
-
"tab1 = coco[['coco_h2_miou_0.1', 'coco_h2_biniou_0.1', 'coco_h2_ap']]\n",
|
202 |
-
"tab2 = coco[['coco_h2_miou_0.2', 'coco_h2_biniou_0.2', 'coco_h2_ap']]\n",
|
203 |
-
"tab3 = coco[['coco_h2_miou_best', 'coco_h2_biniou_best', 'coco_h2_ap']]\n",
|
204 |
-
"print('CLIPSeg (COCO) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[:4].mean(0).values), '\\\\\\\\')\n",
|
205 |
-
"print('CLIPSeg (COCO+N) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
|
206 |
-
"print('CLIP-Deconv (COCO+N) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[12:16].mean(0).values), '\\\\\\\\')\n",
|
207 |
-
"print('ViTSeg (COCO) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[8:12].mean(0).values), '\\\\\\\\')"
|
208 |
-
]
|
209 |
-
},
|
210 |
-
{
|
211 |
-
"cell_type": "markdown",
|
212 |
-
"metadata": {},
|
213 |
-
"source": [
|
214 |
-
"# Zero-shot"
|
215 |
-
]
|
216 |
-
},
|
217 |
-
{
|
218 |
-
"cell_type": "code",
|
219 |
-
"execution_count": null,
|
220 |
-
"metadata": {},
|
221 |
-
"outputs": [],
|
222 |
-
"source": [
|
223 |
-
"zs = experiment('experiments/pascal_0shot.yaml', nums=':11').dataframe()"
|
224 |
-
]
|
225 |
-
},
|
226 |
-
{
|
227 |
-
"cell_type": "code",
|
228 |
-
"execution_count": null,
|
229 |
-
"metadata": {},
|
230 |
-
"outputs": [],
|
231 |
-
"source": [
|
232 |
-
"\n",
|
233 |
-
"tab1 = zs[['pas_zs_seen', 'pas_zs_unseen']]\n",
|
234 |
-
"print('CLIPSeg (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[8:9].values[0].tolist() + tab1[10:11].values[0].tolist()), '\\\\\\\\')\n",
|
235 |
-
"print('CLIP-Deconv & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[2:3].values[0].tolist() + tab1[3:4].values[0].tolist()), '\\\\\\\\')\n",
|
236 |
-
"print('ViTSeg & ImageNet-1K & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:5].values[0].tolist() + tab1[5:6].values[0].tolist()), '\\\\\\\\')"
|
237 |
-
]
|
238 |
-
},
|
239 |
-
{
|
240 |
-
"cell_type": "markdown",
|
241 |
-
"metadata": {},
|
242 |
-
"source": [
|
243 |
-
"# Ablation"
|
244 |
-
]
|
245 |
-
},
|
246 |
-
{
|
247 |
-
"cell_type": "code",
|
248 |
-
"execution_count": null,
|
249 |
-
"metadata": {},
|
250 |
-
"outputs": [],
|
251 |
-
"source": [
|
252 |
-
"ablation = experiment('experiments/ablation.yaml', nums=':8').dataframe()"
|
253 |
-
]
|
254 |
-
},
|
255 |
-
{
|
256 |
-
"cell_type": "code",
|
257 |
-
"execution_count": null,
|
258 |
-
"metadata": {},
|
259 |
-
"outputs": [],
|
260 |
-
"source": [
|
261 |
-
"tab1 = ablation[['name', 'pc_miou_best', 'pc_ap', 'pc-vis_miou_best', 'pc-vis_ap']]\n",
|
262 |
-
"for k in ['pc_miou_best', 'pc_ap', 'pc-vis_miou_best', 'pc-vis_ap']:\n",
|
263 |
-
" tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n",
|
264 |
-
"tab1.loc[:, 'name'] = ['CLIPSeg', 'no CLIP pre-training', 'no-negatives', '50% negatives', 'no visual', '$D=16$', 'only layer 3', 'highlight mask']"
|
265 |
-
]
|
266 |
-
},
|
267 |
-
{
|
268 |
-
"cell_type": "code",
|
269 |
-
"execution_count": null,
|
270 |
-
"metadata": {},
|
271 |
-
"outputs": [],
|
272 |
-
"source": [
|
273 |
-
"print(tab1.loc[[0,1,4,5,6,7],:].to_latex(header=False, index=False))"
|
274 |
-
]
|
275 |
-
},
|
276 |
-
{
|
277 |
-
"cell_type": "code",
|
278 |
-
"execution_count": null,
|
279 |
-
"metadata": {},
|
280 |
-
"outputs": [],
|
281 |
-
"source": [
|
282 |
-
"print(tab1.loc[[0,1,4,5,6,7],:].to_latex(header=False, index=False))"
|
283 |
-
]
|
284 |
-
},
|
285 |
-
{
|
286 |
-
"cell_type": "markdown",
|
287 |
-
"metadata": {},
|
288 |
-
"source": [
|
289 |
-
"# Generalization"
|
290 |
-
]
|
291 |
-
},
|
292 |
-
{
|
293 |
-
"cell_type": "code",
|
294 |
-
"execution_count": null,
|
295 |
-
"metadata": {},
|
296 |
-
"outputs": [],
|
297 |
-
"source": [
|
298 |
-
"generalization = experiment('experiments/generalize.yaml').dataframe()"
|
299 |
-
]
|
300 |
-
},
|
301 |
-
{
|
302 |
-
"cell_type": "code",
|
303 |
-
"execution_count": null,
|
304 |
-
"metadata": {},
|
305 |
-
"outputs": [],
|
306 |
-
"source": [
|
307 |
-
"gen = generalization[['aff_best_fgiou', 'aff_ap', 'ability_best_fgiou', 'ability_ap', 'part_best_fgiou', 'part_ap']].values"
|
308 |
-
]
|
309 |
-
},
|
310 |
-
{
|
311 |
-
"cell_type": "code",
|
312 |
-
"execution_count": null,
|
313 |
-
"metadata": {},
|
314 |
-
"outputs": [],
|
315 |
-
"source": [
|
316 |
-
"print(\n",
|
317 |
-
" 'CLIPSeg (PC+) & ' + ' & '.join(f'{x*100:.1f}' for x in gen[1]) + ' \\\\\\\\ \\n' + \\\n",
|
318 |
-
" 'CLIPSeg (LVIS) & ' + ' & '.join(f'{x*100:.1f}' for x in gen[0]) + ' \\\\\\\\ \\n' + \\\n",
|
319 |
-
" 'CLIP-Deconv & ' + ' & '.join(f'{x*100:.1f}' for x in gen[2]) + ' \\\\\\\\ \\n' + \\\n",
|
320 |
-
" 'VITSeg & ' + ' & '.join(f'{x*100:.1f}' for x in gen[3]) + ' \\\\\\\\'\n",
|
321 |
-
")"
|
322 |
-
]
|
323 |
-
}
|
324 |
-
],
|
325 |
-
"metadata": {
|
326 |
-
"interpreter": {
|
327 |
-
"hash": "800ed241f7db2bd3aa6942aa3be6809cdb30ee6b0a9e773dfecfa9fef1f4c586"
|
328 |
-
},
|
329 |
-
"kernelspec": {
|
330 |
-
"display_name": "env2",
|
331 |
-
"language": "python",
|
332 |
-
"name": "env2"
|
333 |
-
},
|
334 |
-
"language_info": {
|
335 |
-
"codemirror_mode": {
|
336 |
-
"name": "ipython",
|
337 |
-
"version": 3
|
338 |
-
},
|
339 |
-
"file_extension": ".py",
|
340 |
-
"mimetype": "text/x-python",
|
341 |
-
"name": "python",
|
342 |
-
"nbconvert_exporter": "python",
|
343 |
-
"pygments_lexer": "ipython3",
|
344 |
-
"version": "3.8.8"
|
345 |
-
}
|
346 |
-
},
|
347 |
-
"nbformat": 4,
|
348 |
-
"nbformat_minor": 4
|
349 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clipseg/Visual_Feature_Engineering.ipynb
DELETED
@@ -1,366 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "markdown",
|
5 |
-
"metadata": {},
|
6 |
-
"source": [
|
7 |
-
"# Systematic"
|
8 |
-
]
|
9 |
-
},
|
10 |
-
{
|
11 |
-
"cell_type": "code",
|
12 |
-
"execution_count": null,
|
13 |
-
"metadata": {},
|
14 |
-
"outputs": [],
|
15 |
-
"source": [
|
16 |
-
"%load_ext autoreload\n",
|
17 |
-
"%autoreload 2\n",
|
18 |
-
"\n",
|
19 |
-
"import clip\n",
|
20 |
-
"from evaluation_utils import norm, denorm\n",
|
21 |
-
"from general_utils import *\n",
|
22 |
-
"from datasets.lvis_oneshot3 import LVIS_OneShot3\n",
|
23 |
-
"\n",
|
24 |
-
"clip_device = 'cuda'\n",
|
25 |
-
"clip_model, preprocess = clip.load(\"ViT-B/16\", device=clip_device)\n",
|
26 |
-
"clip_model.eval();\n",
|
27 |
-
"\n",
|
28 |
-
"from models.clipseg import CLIPDensePredTMasked\n",
|
29 |
-
"\n",
|
30 |
-
"clip_mask_model = CLIPDensePredTMasked(version='ViT-B/16').to(clip_device)\n",
|
31 |
-
"clip_mask_model.eval();"
|
32 |
-
]
|
33 |
-
},
|
34 |
-
{
|
35 |
-
"cell_type": "code",
|
36 |
-
"execution_count": null,
|
37 |
-
"metadata": {},
|
38 |
-
"outputs": [],
|
39 |
-
"source": [
|
40 |
-
"lvis = LVIS_OneShot3('train_fixed', mask='separate', normalize=True, with_class_label=True, add_bar=False, \n",
|
41 |
-
" text_class_labels=True, image_size=352, min_area=0.1,\n",
|
42 |
-
" min_frac_s=0.05, min_frac_q=0.05, fix_find_crop=True)"
|
43 |
-
]
|
44 |
-
},
|
45 |
-
{
|
46 |
-
"cell_type": "code",
|
47 |
-
"execution_count": null,
|
48 |
-
"metadata": {},
|
49 |
-
"outputs": [],
|
50 |
-
"source": [
|
51 |
-
"plot_data(lvis)"
|
52 |
-
]
|
53 |
-
},
|
54 |
-
{
|
55 |
-
"cell_type": "code",
|
56 |
-
"execution_count": null,
|
57 |
-
"metadata": {},
|
58 |
-
"outputs": [],
|
59 |
-
"source": [
|
60 |
-
"from collections import defaultdict\n",
|
61 |
-
"import json\n",
|
62 |
-
"\n",
|
63 |
-
"lvis_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_train.json')))\n",
|
64 |
-
"lvis_val_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_val.json')))\n",
|
65 |
-
"\n",
|
66 |
-
"objects_per_image = defaultdict(lambda : set())\n",
|
67 |
-
"for ann in lvis_raw['annotations']:\n",
|
68 |
-
" objects_per_image[ann['image_id']].add(ann['category_id'])\n",
|
69 |
-
" \n",
|
70 |
-
"for ann in lvis_val_raw['annotations']:\n",
|
71 |
-
" objects_per_image[ann['image_id']].add(ann['category_id']) \n",
|
72 |
-
" \n",
|
73 |
-
"objects_per_image = {o: [lvis.category_names[o] for o in v] for o, v in objects_per_image.items()}\n",
|
74 |
-
"\n",
|
75 |
-
"del lvis_raw, lvis_val_raw"
|
76 |
-
]
|
77 |
-
},
|
78 |
-
{
|
79 |
-
"cell_type": "code",
|
80 |
-
"execution_count": null,
|
81 |
-
"metadata": {},
|
82 |
-
"outputs": [],
|
83 |
-
"source": [
|
84 |
-
"#bs = 32\n",
|
85 |
-
"#batches = [get_batch(lvis, i*bs, (i+1)*bs, cuda=True) for i in range(10)]"
|
86 |
-
]
|
87 |
-
},
|
88 |
-
{
|
89 |
-
"cell_type": "code",
|
90 |
-
"execution_count": null,
|
91 |
-
"metadata": {},
|
92 |
-
"outputs": [],
|
93 |
-
"source": [
|
94 |
-
"from general_utils import get_batch\n",
|
95 |
-
"from functools import partial\n",
|
96 |
-
"from evaluation_utils import img_preprocess\n",
|
97 |
-
"import torch\n",
|
98 |
-
"\n",
|
99 |
-
"def get_similarities(batches_or_dataset, process, mask=lambda x: None, clipmask=False):\n",
|
100 |
-
"\n",
|
101 |
-
" # base_words = [f'a photo of {x}' for x in ['a person', 'an animal', 'a knife', 'a cup']]\n",
|
102 |
-
"\n",
|
103 |
-
" all_prompts = []\n",
|
104 |
-
" \n",
|
105 |
-
" with torch.no_grad():\n",
|
106 |
-
" valid_sims = []\n",
|
107 |
-
" torch.manual_seed(571)\n",
|
108 |
-
" \n",
|
109 |
-
" if type(batches_or_dataset) == list:\n",
|
110 |
-
" loader = batches_or_dataset # already loaded\n",
|
111 |
-
" max_iter = float('inf')\n",
|
112 |
-
" else:\n",
|
113 |
-
" loader = DataLoader(batches_or_dataset, shuffle=False, batch_size=32)\n",
|
114 |
-
" max_iter = 50\n",
|
115 |
-
" \n",
|
116 |
-
" global batch\n",
|
117 |
-
" for i_batch, (batch, batch_y) in enumerate(loader):\n",
|
118 |
-
" \n",
|
119 |
-
" if i_batch >= max_iter: break\n",
|
120 |
-
" \n",
|
121 |
-
" processed_batch = process(batch)\n",
|
122 |
-
" if type(processed_batch) == dict:\n",
|
123 |
-
" \n",
|
124 |
-
" # processed_batch = {k: v.to(clip_device) for k, v in processed_batch.items()}\n",
|
125 |
-
" image_features = clip_mask_model.visual_forward(**processed_batch)[0].to(clip_device).half()\n",
|
126 |
-
" else:\n",
|
127 |
-
" processed_batch = process(batch).to(clip_device)\n",
|
128 |
-
" processed_batch = nnf.interpolate(processed_batch, (224, 224), mode='bilinear')\n",
|
129 |
-
" #image_features = clip_model.encode_image(processed_batch.to(clip_device)) \n",
|
130 |
-
" image_features = clip_mask_model.visual_forward(processed_batch)[0].to(clip_device).half()\n",
|
131 |
-
" \n",
|
132 |
-
" image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n",
|
133 |
-
" bs = len(batch[0])\n",
|
134 |
-
" for j in range(bs):\n",
|
135 |
-
" \n",
|
136 |
-
" c, _, sid, qid = lvis.sample_ids[bs * i_batch + j]\n",
|
137 |
-
" support_image = basename(lvis.samples[c][sid])\n",
|
138 |
-
" \n",
|
139 |
-
" img_objs = [o for o in objects_per_image[int(support_image)]]\n",
|
140 |
-
" img_objs = [o.replace('_', ' ') for o in img_objs]\n",
|
141 |
-
" \n",
|
142 |
-
" other_words = [f'a photo of a {o.replace(\"_\", \" \")}' for o in img_objs \n",
|
143 |
-
" if o != batch_y[2][j]]\n",
|
144 |
-
" \n",
|
145 |
-
" prompts = [f'a photo of a {batch_y[2][j]}'] + other_words\n",
|
146 |
-
" all_prompts += [prompts]\n",
|
147 |
-
" \n",
|
148 |
-
" text_cond = clip_model.encode_text(clip.tokenize(prompts).to(clip_device))\n",
|
149 |
-
" text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True) \n",
|
150 |
-
"\n",
|
151 |
-
" global logits\n",
|
152 |
-
" logits = clip_model.logit_scale.exp() * image_features[j] @ text_cond.T\n",
|
153 |
-
"\n",
|
154 |
-
" global sim\n",
|
155 |
-
" sim = torch.softmax(logits, dim=-1)\n",
|
156 |
-
" \n",
|
157 |
-
" valid_sims += [sim]\n",
|
158 |
-
" \n",
|
159 |
-
" #valid_sims = torch.stack(valid_sims)\n",
|
160 |
-
" return valid_sims, all_prompts\n",
|
161 |
-
" \n",
|
162 |
-
"\n",
|
163 |
-
"def new_img_preprocess(x):\n",
|
164 |
-
" return {'x_inp': x[1], 'mask': (11, 'cls_token', x[2])}\n",
|
165 |
-
" \n",
|
166 |
-
"#get_similarities(lvis, partial(img_preprocess, center_context=0.5));\n",
|
167 |
-
"get_similarities(lvis, lambda x: x[1]);"
|
168 |
-
]
|
169 |
-
},
|
170 |
-
{
|
171 |
-
"cell_type": "code",
|
172 |
-
"execution_count": null,
|
173 |
-
"metadata": {},
|
174 |
-
"outputs": [],
|
175 |
-
"source": [
|
176 |
-
"preprocessing_functions = [\n",
|
177 |
-
"# ['clip mask CLS L11', lambda x: {'x_inp': x[1].cuda(), 'mask': (11, 'cls_token', x[2].cuda())}],\n",
|
178 |
-
"# ['clip mask CLS all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'cls_token', x[2].cuda())}],\n",
|
179 |
-
"# ['clip mask all all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'all', x[2].cuda())}],\n",
|
180 |
-
"# ['colorize object red', partial(img_preprocess, colorize=True)],\n",
|
181 |
-
"# ['add red outline', partial(img_preprocess, outline=True)],\n",
|
182 |
-
" \n",
|
183 |
-
"# ['BG brightness 50%', partial(img_preprocess, bg_fac=0.5)],\n",
|
184 |
-
"# ['BG brightness 10%', partial(img_preprocess, bg_fac=0.1)],\n",
|
185 |
-
"# ['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)],\n",
|
186 |
-
"# ['BG blur', partial(img_preprocess, blur=3)],\n",
|
187 |
-
"# ['BG blur & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n",
|
188 |
-
" \n",
|
189 |
-
"# ['crop large context', partial(img_preprocess, center_context=0.5)],\n",
|
190 |
-
"# ['crop small context', partial(img_preprocess, center_context=0.1)],\n",
|
191 |
-
" ['crop & background blur', partial(img_preprocess, blur=3, center_context=0.5)],\n",
|
192 |
-
" ['crop & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n",
|
193 |
-
"# ['crop & background blur & intensity 10%', partial(img_preprocess, blur=3, center_context=0.1, bg_fac=0.1)],\n",
|
194 |
-
"]\n",
|
195 |
-
"\n",
|
196 |
-
"preprocessing_functions = preprocessing_functions\n",
|
197 |
-
"\n",
|
198 |
-
"base, base_p = get_similarities(lvis, lambda x: x[1])\n",
|
199 |
-
"outs = [get_similarities(lvis, fun) for _, fun in preprocessing_functions]"
|
200 |
-
]
|
201 |
-
},
|
202 |
-
{
|
203 |
-
"cell_type": "code",
|
204 |
-
"execution_count": null,
|
205 |
-
"metadata": {},
|
206 |
-
"outputs": [],
|
207 |
-
"source": [
|
208 |
-
"outs2 = [get_similarities(lvis, fun) for _, fun in [['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)]]]"
|
209 |
-
]
|
210 |
-
},
|
211 |
-
{
|
212 |
-
"cell_type": "code",
|
213 |
-
"execution_count": null,
|
214 |
-
"metadata": {},
|
215 |
-
"outputs": [],
|
216 |
-
"source": [
|
217 |
-
"for j in range(1):\n",
|
218 |
-
" print(np.mean([outs2[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3]))"
|
219 |
-
]
|
220 |
-
},
|
221 |
-
{
|
222 |
-
"cell_type": "code",
|
223 |
-
"execution_count": null,
|
224 |
-
"metadata": {},
|
225 |
-
"outputs": [],
|
226 |
-
"source": [
|
227 |
-
"from pandas import DataFrame\n",
|
228 |
-
"tab = dict()\n",
|
229 |
-
"for j, (name, _) in enumerate(preprocessing_functions):\n",
|
230 |
-
" tab[name] = np.mean([outs[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3])\n",
|
231 |
-
" \n",
|
232 |
-
" \n",
|
233 |
-
"print('\\n'.join(f'{k} & {v*100:.2f} \\\\\\\\' for k,v in tab.items())) "
|
234 |
-
]
|
235 |
-
},
|
236 |
-
{
|
237 |
-
"cell_type": "markdown",
|
238 |
-
"metadata": {},
|
239 |
-
"source": [
|
240 |
-
"# Visual"
|
241 |
-
]
|
242 |
-
},
|
243 |
-
{
|
244 |
-
"cell_type": "code",
|
245 |
-
"execution_count": null,
|
246 |
-
"metadata": {},
|
247 |
-
"outputs": [],
|
248 |
-
"source": [
|
249 |
-
"from evaluation_utils import denorm, norm"
|
250 |
-
]
|
251 |
-
},
|
252 |
-
{
|
253 |
-
"cell_type": "code",
|
254 |
-
"execution_count": null,
|
255 |
-
"metadata": {},
|
256 |
-
"outputs": [],
|
257 |
-
"source": [
|
258 |
-
"def load_sample(filename, filename2):\n",
|
259 |
-
" from os.path import join\n",
|
260 |
-
" bp = expanduser('~/cloud/resources/sample_images')\n",
|
261 |
-
" tf = transforms.Compose([\n",
|
262 |
-
" transforms.ToTensor(),\n",
|
263 |
-
" transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
|
264 |
-
" transforms.Resize(224),\n",
|
265 |
-
" transforms.CenterCrop(224)\n",
|
266 |
-
" ])\n",
|
267 |
-
" tf2 = transforms.Compose([\n",
|
268 |
-
" transforms.ToTensor(),\n",
|
269 |
-
" transforms.Resize(224),\n",
|
270 |
-
" transforms.CenterCrop(224)\n",
|
271 |
-
" ])\n",
|
272 |
-
" inp1 = [None, tf(Image.open(join(bp, filename))), tf2(Image.open(join(bp, filename2)))]\n",
|
273 |
-
" inp1[1] = inp1[1].unsqueeze(0)\n",
|
274 |
-
" inp1[2] = inp1[2][:1] \n",
|
275 |
-
" return inp1\n",
|
276 |
-
"\n",
|
277 |
-
"def all_preprocessing(inp1):\n",
|
278 |
-
" return [\n",
|
279 |
-
" img_preprocess(inp1),\n",
|
280 |
-
" img_preprocess(inp1, colorize=True),\n",
|
281 |
-
" img_preprocess(inp1, outline=True), \n",
|
282 |
-
" img_preprocess(inp1, blur=3),\n",
|
283 |
-
" img_preprocess(inp1, bg_fac=0.1),\n",
|
284 |
-
" #img_preprocess(inp1, bg_fac=0.5),\n",
|
285 |
-
" #img_preprocess(inp1, blur=3, bg_fac=0.5), \n",
|
286 |
-
" img_preprocess(inp1, blur=3, bg_fac=0.5, center_context=0.5),\n",
|
287 |
-
" ]\n",
|
288 |
-
"\n"
|
289 |
-
]
|
290 |
-
},
|
291 |
-
{
|
292 |
-
"cell_type": "code",
|
293 |
-
"execution_count": null,
|
294 |
-
"metadata": {},
|
295 |
-
"outputs": [],
|
296 |
-
"source": [
|
297 |
-
"from torchvision import transforms\n",
|
298 |
-
"from PIL import Image\n",
|
299 |
-
"from matplotlib import pyplot as plt\n",
|
300 |
-
"from evaluation_utils import img_preprocess\n",
|
301 |
-
"import clip\n",
|
302 |
-
"\n",
|
303 |
-
"images_queries = [\n",
|
304 |
-
" [load_sample('things1.jpg', 'things1_jar.png'), ['jug', 'knife', 'car', 'animal', 'sieve', 'nothing']],\n",
|
305 |
-
" [load_sample('own_photos/IMG_2017s_square.jpg', 'own_photos/IMG_2017s_square_trash_can.png'), ['trash bin', 'house', 'car', 'bike', 'window', 'nothing']],\n",
|
306 |
-
"]\n",
|
307 |
-
"\n",
|
308 |
-
"\n",
|
309 |
-
"_, ax = plt.subplots(2 * len(images_queries), 6, figsize=(14, 4.5 * len(images_queries)))\n",
|
310 |
-
"\n",
|
311 |
-
"for j, (images, objects) in enumerate(images_queries):\n",
|
312 |
-
" \n",
|
313 |
-
" joint_image = all_preprocessing(images)\n",
|
314 |
-
" \n",
|
315 |
-
" joint_image = torch.stack(joint_image)[:,0]\n",
|
316 |
-
" clip_model, preprocess = clip.load(\"ViT-B/16\", device='cpu')\n",
|
317 |
-
" image_features = clip_model.encode_image(joint_image)\n",
|
318 |
-
" image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n",
|
319 |
-
" \n",
|
320 |
-
" prompts = [f'a photo of a {obj}'for obj in objects]\n",
|
321 |
-
" text_cond = clip_model.encode_text(clip.tokenize(prompts))\n",
|
322 |
-
" text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True)\n",
|
323 |
-
" logits = clip_model.logit_scale.exp() * image_features @ text_cond.T\n",
|
324 |
-
" sim = torch.softmax(logits, dim=-1).detach().cpu()\n",
|
325 |
-
"\n",
|
326 |
-
" for i, img in enumerate(joint_image):\n",
|
327 |
-
" ax[2*j, i].axis('off')\n",
|
328 |
-
" \n",
|
329 |
-
" ax[2*j, i].imshow(torch.clamp(denorm(joint_image[i]).permute(1,2,0), 0, 1))\n",
|
330 |
-
" ax[2*j+ 1, i].grid(True)\n",
|
331 |
-
" \n",
|
332 |
-
" ax[2*j + 1, i].set_ylim(0,1)\n",
|
333 |
-
" ax[2*j + 1, i].set_yticklabels([])\n",
|
334 |
-
" ax[2*j + 1, i].set_xticks([]) # set_xticks(range(len(prompts)))\n",
|
335 |
-
"# ax[1, i].set_xticklabels(objects, rotation=90)\n",
|
336 |
-
" for k in range(len(sim[i])):\n",
|
337 |
-
" ax[2*j + 1, i].bar(k, sim[i][k], color=plt.cm.tab20(1) if k!=0 else plt.cm.tab20(3))\n",
|
338 |
-
" ax[2*j + 1, i].text(k, 0.07, objects[k], rotation=90, ha='center', fontsize=15)\n",
|
339 |
-
"\n",
|
340 |
-
"plt.tight_layout()\n",
|
341 |
-
"plt.savefig('figures/prompt_engineering.pdf', bbox_inches='tight')"
|
342 |
-
]
|
343 |
-
}
|
344 |
-
],
|
345 |
-
"metadata": {
|
346 |
-
"kernelspec": {
|
347 |
-
"display_name": "env2",
|
348 |
-
"language": "python",
|
349 |
-
"name": "env2"
|
350 |
-
},
|
351 |
-
"language_info": {
|
352 |
-
"codemirror_mode": {
|
353 |
-
"name": "ipython",
|
354 |
-
"version": 3
|
355 |
-
},
|
356 |
-
"file_extension": ".py",
|
357 |
-
"mimetype": "text/x-python",
|
358 |
-
"name": "python",
|
359 |
-
"nbconvert_exporter": "python",
|
360 |
-
"pygments_lexer": "ipython3",
|
361 |
-
"version": "3.8.8"
|
362 |
-
}
|
363 |
-
},
|
364 |
-
"nbformat": 4,
|
365 |
-
"nbformat_minor": 4
|
366 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clipseg/datasets/coco_wrapper.py
DELETED
@@ -1,99 +0,0 @@
|
|
1 |
-
import pickle
|
2 |
-
from types import new_class
|
3 |
-
import torch
|
4 |
-
import numpy as np
|
5 |
-
import os
|
6 |
-
import json
|
7 |
-
|
8 |
-
from os.path import join, dirname, isdir, isfile, expanduser, realpath, basename
|
9 |
-
from random import shuffle, seed as set_seed
|
10 |
-
from PIL import Image
|
11 |
-
|
12 |
-
from itertools import combinations
|
13 |
-
from torchvision import transforms
|
14 |
-
from torchvision.transforms.transforms import Resize
|
15 |
-
|
16 |
-
from datasets.utils import blend_image_segmentation
|
17 |
-
from general_utils import get_from_repository
|
18 |
-
|
19 |
-
COCO_CLASSES = {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'}
|
20 |
-
|
21 |
-
class COCOWrapper(object):
|
22 |
-
|
23 |
-
def __init__(self, split, fold=0, image_size=400, aug=None, mask='separate', negative_prob=0,
|
24 |
-
with_class_label=False):
|
25 |
-
super().__init__()
|
26 |
-
|
27 |
-
self.mask = mask
|
28 |
-
self.with_class_label = with_class_label
|
29 |
-
self.negative_prob = negative_prob
|
30 |
-
|
31 |
-
from third_party.hsnet.data.coco import DatasetCOCO
|
32 |
-
|
33 |
-
get_from_repository('COCO-20i', ['COCO-20i.tar'])
|
34 |
-
|
35 |
-
foldpath = join(dirname(__file__), '../third_party/hsnet/data/splits/coco/%s/fold%d.pkl')
|
36 |
-
|
37 |
-
def build_img_metadata_classwise(self):
|
38 |
-
with open(foldpath % (self.split, self.fold), 'rb') as f:
|
39 |
-
img_metadata_classwise = pickle.load(f)
|
40 |
-
return img_metadata_classwise
|
41 |
-
|
42 |
-
|
43 |
-
DatasetCOCO.build_img_metadata_classwise = build_img_metadata_classwise
|
44 |
-
# DatasetCOCO.read_mask = read_mask
|
45 |
-
|
46 |
-
mean = [0.485, 0.456, 0.406]
|
47 |
-
std = [0.229, 0.224, 0.225]
|
48 |
-
transform = transforms.Compose([
|
49 |
-
transforms.Resize((image_size, image_size)),
|
50 |
-
transforms.ToTensor(),
|
51 |
-
transforms.Normalize(mean, std)
|
52 |
-
])
|
53 |
-
|
54 |
-
self.coco = DatasetCOCO(expanduser('~/datasets/COCO-20i/'), fold, transform, split, 1, False)
|
55 |
-
|
56 |
-
self.all_classes = [self.coco.class_ids]
|
57 |
-
self.coco.base_path = join(expanduser('~/datasets/COCO-20i'))
|
58 |
-
|
59 |
-
def __len__(self):
|
60 |
-
return len(self.coco)
|
61 |
-
|
62 |
-
def __getitem__(self, i):
|
63 |
-
sample = self.coco[i]
|
64 |
-
|
65 |
-
label_name = COCO_CLASSES[int(sample['class_id'])]
|
66 |
-
|
67 |
-
img_s, seg_s = sample['support_imgs'][0], sample['support_masks'][0]
|
68 |
-
|
69 |
-
if self.negative_prob > 0 and torch.rand(1).item() < self.negative_prob:
|
70 |
-
new_class_id = sample['class_id']
|
71 |
-
while new_class_id == sample['class_id']:
|
72 |
-
sample2 = self.coco[torch.randint(0, len(self), (1,)).item()]
|
73 |
-
new_class_id = sample2['class_id']
|
74 |
-
img_s = sample2['support_imgs'][0]
|
75 |
-
seg_s = torch.zeros_like(seg_s)
|
76 |
-
|
77 |
-
mask = self.mask
|
78 |
-
if mask == 'separate':
|
79 |
-
supp = (img_s, seg_s)
|
80 |
-
elif mask == 'text_label':
|
81 |
-
# DEPRECATED
|
82 |
-
supp = [int(sample['class_id'])]
|
83 |
-
elif mask == 'text':
|
84 |
-
supp = [label_name]
|
85 |
-
else:
|
86 |
-
if mask.startswith('text_and_'):
|
87 |
-
mask = mask[9:]
|
88 |
-
label_add = [label_name]
|
89 |
-
else:
|
90 |
-
label_add = []
|
91 |
-
|
92 |
-
supp = label_add + blend_image_segmentation(img_s, seg_s, mode=mask)
|
93 |
-
|
94 |
-
if self.with_class_label:
|
95 |
-
label = (torch.zeros(0), sample['class_id'],)
|
96 |
-
else:
|
97 |
-
label = (torch.zeros(0), )
|
98 |
-
|
99 |
-
return (sample['query_img'],) + tuple(supp), (sample['query_mask'].unsqueeze(0),) + label
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clipseg/datasets/pascal_classes.json
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
[{"id": 1, "synonyms": ["aeroplane"]}, {"id": 2, "synonyms": ["bicycle"]}, {"id": 3, "synonyms": ["bird"]}, {"id": 4, "synonyms": ["boat"]}, {"id": 5, "synonyms": ["bottle"]}, {"id": 6, "synonyms": ["bus"]}, {"id": 7, "synonyms": ["car"]}, {"id": 8, "synonyms": ["cat"]}, {"id": 9, "synonyms": ["chair"]}, {"id": 10, "synonyms": ["cow"]}, {"id": 11, "synonyms": ["diningtable"]}, {"id": 12, "synonyms": ["dog"]}, {"id": 13, "synonyms": ["horse"]}, {"id": 14, "synonyms": ["motorbike"]}, {"id": 15, "synonyms": ["person"]}, {"id": 16, "synonyms": ["pottedplant"]}, {"id": 17, "synonyms": ["sheep"]}, {"id": 18, "synonyms": ["sofa"]}, {"id": 19, "synonyms": ["train"]}, {"id": 20, "synonyms": ["tvmonitor"]}]
|
|
|
|
clipseg/datasets/pascal_zeroshot.py
DELETED
@@ -1,60 +0,0 @@
|
|
1 |
-
from os.path import expanduser
|
2 |
-
import torch
|
3 |
-
import json
|
4 |
-
import torchvision
|
5 |
-
from general_utils import get_from_repository
|
6 |
-
from general_utils import log
|
7 |
-
from torchvision import transforms
|
8 |
-
|
9 |
-
PASCAL_VOC_CLASSES_ZS = [['cattle.n.01', 'motorcycle.n.01'], ['aeroplane.n.01', 'sofa.n.01'],
|
10 |
-
['cat.n.01', 'television.n.03'], ['train.n.01', 'bottle.n.01'],
|
11 |
-
['chair.n.01', 'pot_plant.n.01']]
|
12 |
-
|
13 |
-
|
14 |
-
class PascalZeroShot(object):
|
15 |
-
|
16 |
-
def __init__(self, split, n_unseen, image_size=224) -> None:
|
17 |
-
super().__init__()
|
18 |
-
|
19 |
-
import sys
|
20 |
-
sys.path.append('third_party/JoEm')
|
21 |
-
from third_party.JoEm.data_loader.dataset import VOCSegmentation
|
22 |
-
from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC
|
23 |
-
|
24 |
-
self.pascal_classes = VOC
|
25 |
-
self.image_size = image_size
|
26 |
-
|
27 |
-
self.transform = transforms.Compose([
|
28 |
-
transforms.Resize((image_size, image_size)),
|
29 |
-
])
|
30 |
-
|
31 |
-
if split == 'train':
|
32 |
-
self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen),
|
33 |
-
split=split, transform=True, transform_args=dict(base_size=312, crop_size=312),
|
34 |
-
ignore_bg=False, ignore_unseen=False, remv_unseen_img=True)
|
35 |
-
elif split == 'val':
|
36 |
-
self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen),
|
37 |
-
split=split, transform=False,
|
38 |
-
ignore_bg=False, ignore_unseen=False)
|
39 |
-
|
40 |
-
self.unseen_idx = get_unseen_idx(n_unseen)
|
41 |
-
|
42 |
-
def __len__(self):
|
43 |
-
return len(self.voc)
|
44 |
-
|
45 |
-
def __getitem__(self, i):
|
46 |
-
|
47 |
-
sample = self.voc[i]
|
48 |
-
label = sample['label'].long()
|
49 |
-
all_labels = [l for l in torch.where(torch.bincount(label.flatten())>0)[0].numpy().tolist() if l != 255]
|
50 |
-
class_indices = [l for l in all_labels]
|
51 |
-
class_names = [self.pascal_classes[l] for l in all_labels]
|
52 |
-
|
53 |
-
image = self.transform(sample['image'])
|
54 |
-
|
55 |
-
label = transforms.Resize((self.image_size, self.image_size),
|
56 |
-
interpolation=torchvision.transforms.InterpolationMode.NEAREST)(label.unsqueeze(0))[0]
|
57 |
-
|
58 |
-
return (image,), (label, )
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clipseg/datasets/pfe_dataset.py
DELETED
@@ -1,129 +0,0 @@
|
|
1 |
-
from os.path import expanduser
|
2 |
-
import torch
|
3 |
-
import json
|
4 |
-
from general_utils import get_from_repository
|
5 |
-
from datasets.lvis_oneshot3 import blend_image_segmentation
|
6 |
-
from general_utils import log
|
7 |
-
|
8 |
-
PASCAL_CLASSES = {a['id']: a['synonyms'] for a in json.load(open('datasets/pascal_classes.json'))}
|
9 |
-
|
10 |
-
|
11 |
-
class PFEPascalWrapper(object):
|
12 |
-
|
13 |
-
def __init__(self, mode, split, mask='separate', image_size=473, label_support=None, size=None, p_negative=0, aug=None):
|
14 |
-
import sys
|
15 |
-
# sys.path.append(expanduser('~/projects/new_one_shot'))
|
16 |
-
from third_party.PFENet.util.dataset import SemData
|
17 |
-
|
18 |
-
get_from_repository('PascalVOC2012', ['Pascal5i.tar'])
|
19 |
-
|
20 |
-
self.p_negative = p_negative
|
21 |
-
self.size = size
|
22 |
-
self.mode = mode
|
23 |
-
self.image_size = image_size
|
24 |
-
|
25 |
-
if label_support in {True, False}:
|
26 |
-
log.warning('label_support argument is deprecated. Use mask instead.')
|
27 |
-
#raise ValueError()
|
28 |
-
|
29 |
-
self.mask = mask
|
30 |
-
|
31 |
-
value_scale = 255
|
32 |
-
mean = [0.485, 0.456, 0.406]
|
33 |
-
mean = [item * value_scale for item in mean]
|
34 |
-
std = [0.229, 0.224, 0.225]
|
35 |
-
std = [item * value_scale for item in std]
|
36 |
-
|
37 |
-
import third_party.PFENet.util.transform as transform
|
38 |
-
|
39 |
-
if mode == 'val':
|
40 |
-
data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/val.txt')
|
41 |
-
|
42 |
-
data_transform = [transform.test_Resize(size=image_size)] if image_size != 'original' else []
|
43 |
-
data_transform += [
|
44 |
-
transform.ToTensor(),
|
45 |
-
transform.Normalize(mean=mean, std=std)
|
46 |
-
]
|
47 |
-
|
48 |
-
|
49 |
-
elif mode == 'train':
|
50 |
-
data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/voc_sbd_merge_noduplicate.txt')
|
51 |
-
|
52 |
-
assert image_size != 'original'
|
53 |
-
|
54 |
-
data_transform = [
|
55 |
-
transform.RandScale([0.9, 1.1]),
|
56 |
-
transform.RandRotate([-10, 10], padding=mean, ignore_label=255),
|
57 |
-
transform.RandomGaussianBlur(),
|
58 |
-
transform.RandomHorizontalFlip(),
|
59 |
-
transform.Crop((image_size, image_size), crop_type='rand', padding=mean, ignore_label=255),
|
60 |
-
transform.ToTensor(),
|
61 |
-
transform.Normalize(mean=mean, std=std)
|
62 |
-
]
|
63 |
-
|
64 |
-
data_transform = transform.Compose(data_transform)
|
65 |
-
|
66 |
-
self.dataset = SemData(split=split, mode=mode, data_root=expanduser('~/datasets/PascalVOC2012/VOC2012'),
|
67 |
-
data_list=data_list, shot=1, transform=data_transform, use_coco=False, use_split_coco=False)
|
68 |
-
|
69 |
-
self.class_list = self.dataset.sub_val_list if mode == 'val' else self.dataset.sub_list
|
70 |
-
|
71 |
-
# verify that subcls_list always has length 1
|
72 |
-
# assert len(set([len(d[4]) for d in self.dataset])) == 1
|
73 |
-
|
74 |
-
print('actual length', len(self.dataset.data_list))
|
75 |
-
|
76 |
-
def __len__(self):
|
77 |
-
if self.mode == 'val':
|
78 |
-
return len(self.dataset.data_list)
|
79 |
-
else:
|
80 |
-
return len(self.dataset.data_list)
|
81 |
-
|
82 |
-
def __getitem__(self, index):
|
83 |
-
if self.dataset.mode == 'train':
|
84 |
-
image, label, s_x, s_y, subcls_list = self.dataset[index % len(self.dataset.data_list)]
|
85 |
-
elif self.dataset.mode == 'val':
|
86 |
-
image, label, s_x, s_y, subcls_list, ori_label = self.dataset[index % len(self.dataset.data_list)]
|
87 |
-
ori_label = torch.from_numpy(ori_label).unsqueeze(0)
|
88 |
-
|
89 |
-
if self.image_size != 'original':
|
90 |
-
longerside = max(ori_label.size(1), ori_label.size(2))
|
91 |
-
backmask = torch.ones(ori_label.size(0), longerside, longerside).cuda()*255
|
92 |
-
backmask[0, :ori_label.size(1), :ori_label.size(2)] = ori_label
|
93 |
-
label = backmask.clone().long()
|
94 |
-
else:
|
95 |
-
label = label.unsqueeze(0)
|
96 |
-
|
97 |
-
# assert label.shape == (473, 473)
|
98 |
-
|
99 |
-
if self.p_negative > 0:
|
100 |
-
if torch.rand(1).item() < self.p_negative:
|
101 |
-
while True:
|
102 |
-
idx = torch.randint(0, len(self.dataset.data_list), (1,)).item()
|
103 |
-
_, _, s_x, s_y, subcls_list_tmp, _ = self.dataset[idx]
|
104 |
-
if subcls_list[0] != subcls_list_tmp[0]:
|
105 |
-
break
|
106 |
-
|
107 |
-
s_x = s_x[0]
|
108 |
-
s_y = (s_y == 1)[0]
|
109 |
-
label_fg = (label == 1).float()
|
110 |
-
val_mask = (label != 255).float()
|
111 |
-
|
112 |
-
class_id = self.class_list[subcls_list[0]]
|
113 |
-
|
114 |
-
label_name = PASCAL_CLASSES[class_id][0]
|
115 |
-
label_add = ()
|
116 |
-
mask = self.mask
|
117 |
-
|
118 |
-
if mask == 'text':
|
119 |
-
support = ('a photo of a ' + label_name + '.',)
|
120 |
-
elif mask == 'separate':
|
121 |
-
support = (s_x, s_y)
|
122 |
-
else:
|
123 |
-
if mask.startswith('text_and_'):
|
124 |
-
label_add = (label_name,)
|
125 |
-
mask = mask[9:]
|
126 |
-
|
127 |
-
support = (blend_image_segmentation(s_x, s_y.float(), mask)[0],)
|
128 |
-
|
129 |
-
return (image,) + label_add + support, (label_fg.unsqueeze(0), val_mask.unsqueeze(0), subcls_list[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clipseg/datasets/phrasecut.py
DELETED
@@ -1,335 +0,0 @@
|
|
1 |
-
|
2 |
-
import torch
|
3 |
-
import numpy as np
|
4 |
-
import os
|
5 |
-
|
6 |
-
from os.path import join, isdir, isfile, expanduser
|
7 |
-
from PIL import Image
|
8 |
-
|
9 |
-
from torchvision import transforms
|
10 |
-
from torchvision.transforms.transforms import Resize
|
11 |
-
|
12 |
-
from torch.nn import functional as nnf
|
13 |
-
from general_utils import get_from_repository
|
14 |
-
|
15 |
-
from skimage.draw import polygon2mask
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
def random_crop_slices(origin_size, target_size):
|
20 |
-
"""Gets slices of a random crop. """
|
21 |
-
assert origin_size[0] >= target_size[0] and origin_size[1] >= target_size[1], f'actual size: {origin_size}, target size: {target_size}'
|
22 |
-
|
23 |
-
offset_y = torch.randint(0, origin_size[0] - target_size[0] + 1, (1,)).item() # range: 0 <= value < high
|
24 |
-
offset_x = torch.randint(0, origin_size[1] - target_size[1] + 1, (1,)).item()
|
25 |
-
|
26 |
-
return slice(offset_y, offset_y + target_size[0]), slice(offset_x, offset_x + target_size[1])
|
27 |
-
|
28 |
-
|
29 |
-
def find_crop(seg, image_size, iterations=1000, min_frac=None, best_of=None):
|
30 |
-
|
31 |
-
|
32 |
-
best_crops = []
|
33 |
-
best_crop_not_ok = float('-inf'), None, None
|
34 |
-
min_sum = 0
|
35 |
-
|
36 |
-
seg = seg.astype('bool')
|
37 |
-
|
38 |
-
if min_frac is not None:
|
39 |
-
#min_sum = seg.sum() * min_frac
|
40 |
-
min_sum = seg.shape[0] * seg.shape[1] * min_frac
|
41 |
-
|
42 |
-
for iteration in range(iterations):
|
43 |
-
sl_y, sl_x = random_crop_slices(seg.shape, image_size)
|
44 |
-
seg_ = seg[sl_y, sl_x]
|
45 |
-
sum_seg_ = seg_.sum()
|
46 |
-
|
47 |
-
if sum_seg_ > min_sum:
|
48 |
-
|
49 |
-
if best_of is None:
|
50 |
-
return sl_y, sl_x, False
|
51 |
-
else:
|
52 |
-
best_crops += [(sum_seg_, sl_y, sl_x)]
|
53 |
-
if len(best_crops) >= best_of:
|
54 |
-
best_crops.sort(key=lambda x:x[0], reverse=True)
|
55 |
-
sl_y, sl_x = best_crops[0][1:]
|
56 |
-
|
57 |
-
return sl_y, sl_x, False
|
58 |
-
|
59 |
-
else:
|
60 |
-
if sum_seg_ > best_crop_not_ok[0]:
|
61 |
-
best_crop_not_ok = sum_seg_, sl_y, sl_x
|
62 |
-
|
63 |
-
else:
|
64 |
-
# return best segmentation found
|
65 |
-
return best_crop_not_ok[1:] + (best_crop_not_ok[0] <= min_sum,)
|
66 |
-
|
67 |
-
|
68 |
-
class PhraseCut(object):
|
69 |
-
|
70 |
-
def __init__(self, split, image_size=400, negative_prob=0, aug=None, aug_color=False, aug_crop=True,
|
71 |
-
min_size=0, remove_classes=None, with_visual=False, only_visual=False, mask=None):
|
72 |
-
super().__init__()
|
73 |
-
|
74 |
-
self.negative_prob = negative_prob
|
75 |
-
self.image_size = image_size
|
76 |
-
self.with_visual = with_visual
|
77 |
-
self.only_visual = only_visual
|
78 |
-
self.phrase_form = '{}'
|
79 |
-
self.mask = mask
|
80 |
-
self.aug_crop = aug_crop
|
81 |
-
|
82 |
-
if aug_color:
|
83 |
-
self.aug_color = transforms.Compose([
|
84 |
-
transforms.ColorJitter(0.5, 0.5, 0.2, 0.05),
|
85 |
-
])
|
86 |
-
else:
|
87 |
-
self.aug_color = None
|
88 |
-
|
89 |
-
get_from_repository('PhraseCut', ['PhraseCut.tar'], integrity_check=lambda local_dir: all([
|
90 |
-
isdir(join(local_dir, 'VGPhraseCut_v0')),
|
91 |
-
isdir(join(local_dir, 'VGPhraseCut_v0', 'images')),
|
92 |
-
isfile(join(local_dir, 'VGPhraseCut_v0', 'refer_train.json')),
|
93 |
-
len(os.listdir(join(local_dir, 'VGPhraseCut_v0', 'images'))) in {108250, 108249}
|
94 |
-
]))
|
95 |
-
|
96 |
-
from third_party.PhraseCutDataset.utils.refvg_loader import RefVGLoader
|
97 |
-
self.refvg_loader = RefVGLoader(split=split)
|
98 |
-
|
99 |
-
# img_ids where the size in the annotations does not match actual size
|
100 |
-
invalid_img_ids = set([150417, 285665, 498246, 61564, 285743, 498269, 498010, 150516, 150344, 286093, 61530,
|
101 |
-
150333, 286065, 285814, 498187, 285761, 498042])
|
102 |
-
|
103 |
-
mean = [0.485, 0.456, 0.406]
|
104 |
-
std = [0.229, 0.224, 0.225]
|
105 |
-
self.normalize = transforms.Normalize(mean, std)
|
106 |
-
|
107 |
-
self.sample_ids = [(i, j)
|
108 |
-
for i in self.refvg_loader.img_ids
|
109 |
-
for j in range(len(self.refvg_loader.get_img_ref_data(i)['phrases']))
|
110 |
-
if i not in invalid_img_ids]
|
111 |
-
|
112 |
-
|
113 |
-
# self.all_phrases = list(set([p for i in self.refvg_loader.img_ids for p in self.refvg_loader.get_img_ref_data(i)['phrases']]))
|
114 |
-
|
115 |
-
from nltk.stem import WordNetLemmatizer
|
116 |
-
wnl = WordNetLemmatizer()
|
117 |
-
|
118 |
-
# Filter by class (if remove_classes is set)
|
119 |
-
if remove_classes is None:
|
120 |
-
pass
|
121 |
-
else:
|
122 |
-
from datasets.generate_lvis_oneshot import PASCAL_SYNSETS, traverse_lemmas, traverse_lemmas_hypo
|
123 |
-
from nltk.corpus import wordnet
|
124 |
-
|
125 |
-
print('remove pascal classes...')
|
126 |
-
|
127 |
-
get_data = self.refvg_loader.get_img_ref_data # shortcut
|
128 |
-
keep_sids = None
|
129 |
-
|
130 |
-
if remove_classes[0] == 'pas5i':
|
131 |
-
subset_id = remove_classes[1]
|
132 |
-
from datasets.generate_lvis_oneshot import PASCAL_5I_SYNSETS_ORDERED, PASCAL_5I_CLASS_IDS
|
133 |
-
avoid = [PASCAL_5I_SYNSETS_ORDERED[i] for i in range(20) if i+1 not in PASCAL_5I_CLASS_IDS[subset_id]]
|
134 |
-
|
135 |
-
|
136 |
-
elif remove_classes[0] == 'zs':
|
137 |
-
stop = remove_classes[1]
|
138 |
-
|
139 |
-
from datasets.pascal_zeroshot import PASCAL_VOC_CLASSES_ZS
|
140 |
-
|
141 |
-
avoid = [c for class_set in PASCAL_VOC_CLASSES_ZS[:stop] for c in class_set]
|
142 |
-
print(avoid)
|
143 |
-
|
144 |
-
elif remove_classes[0] == 'aff':
|
145 |
-
# avoid = ['drink.v.01', 'sit.v.01', 'ride.v.02']
|
146 |
-
# all_lemmas = set(['drink', 'sit', 'ride'])
|
147 |
-
avoid = ['drink', 'drinks', 'drinking', 'sit', 'sits', 'sitting',
|
148 |
-
'ride', 'rides', 'riding',
|
149 |
-
'fly', 'flies', 'flying', 'drive', 'drives', 'driving', 'driven',
|
150 |
-
'swim', 'swims', 'swimming',
|
151 |
-
'wheels', 'wheel', 'legs', 'leg', 'ear', 'ears']
|
152 |
-
keep_sids = [(i, j) for i, j in self.sample_ids if
|
153 |
-
all(x not in avoid for x in get_data(i)['phrases'][j].split(' '))]
|
154 |
-
|
155 |
-
print('avoid classes:', avoid)
|
156 |
-
|
157 |
-
|
158 |
-
if keep_sids is None:
|
159 |
-
all_lemmas = [s for ps in avoid for s in traverse_lemmas_hypo(wordnet.synset(ps), max_depth=None)]
|
160 |
-
all_lemmas = list(set(all_lemmas))
|
161 |
-
all_lemmas = [h.replace('_', ' ').lower() for h in all_lemmas]
|
162 |
-
all_lemmas = set(all_lemmas)
|
163 |
-
|
164 |
-
# divide into multi word and single word
|
165 |
-
all_lemmas_s = set(l for l in all_lemmas if ' ' not in l)
|
166 |
-
all_lemmas_m = set(l for l in all_lemmas if l not in all_lemmas_s)
|
167 |
-
|
168 |
-
# new3
|
169 |
-
phrases = [get_data(i)['phrases'][j] for i, j in self.sample_ids]
|
170 |
-
remove_sids = set((i,j) for (i,j), phrase in zip(self.sample_ids, phrases)
|
171 |
-
if any(l in phrase for l in all_lemmas_m) or
|
172 |
-
len(set(wnl.lemmatize(w) for w in phrase.split(' ')).intersection(all_lemmas_s)) > 0
|
173 |
-
)
|
174 |
-
keep_sids = [(i, j) for i, j in self.sample_ids if (i,j) not in remove_sids]
|
175 |
-
|
176 |
-
print(f'Reduced to {len(keep_sids) / len(self.sample_ids):.3f}')
|
177 |
-
removed_ids = set(self.sample_ids) - set(keep_sids)
|
178 |
-
|
179 |
-
print('Examples of removed', len(removed_ids))
|
180 |
-
for i, j in list(removed_ids)[:20]:
|
181 |
-
print(i, get_data(i)['phrases'][j])
|
182 |
-
|
183 |
-
self.sample_ids = keep_sids
|
184 |
-
|
185 |
-
from itertools import groupby
|
186 |
-
samples_by_phrase = [(self.refvg_loader.get_img_ref_data(i)['phrases'][j], (i, j))
|
187 |
-
for i, j in self.sample_ids]
|
188 |
-
samples_by_phrase = sorted(samples_by_phrase)
|
189 |
-
samples_by_phrase = groupby(samples_by_phrase, key=lambda x: x[0])
|
190 |
-
|
191 |
-
self.samples_by_phrase = {prompt: [s[1] for s in prompt_sample_ids] for prompt, prompt_sample_ids in samples_by_phrase}
|
192 |
-
|
193 |
-
self.all_phrases = list(set(self.samples_by_phrase.keys()))
|
194 |
-
|
195 |
-
|
196 |
-
if self.only_visual:
|
197 |
-
assert self.with_visual
|
198 |
-
self.sample_ids = [(i, j) for i, j in self.sample_ids
|
199 |
-
if len(self.samples_by_phrase[self.refvg_loader.get_img_ref_data(i)['phrases'][j]]) > 1]
|
200 |
-
|
201 |
-
# Filter by size (if min_size is set)
|
202 |
-
sizes = [self.refvg_loader.get_img_ref_data(i)['gt_boxes'][j] for i, j in self.sample_ids]
|
203 |
-
image_sizes = [self.refvg_loader.get_img_ref_data(i)['width'] * self.refvg_loader.get_img_ref_data(i)['height'] for i, j in self.sample_ids]
|
204 |
-
#self.sizes = [sum([(s[2] - s[0]) * (s[3] - s[1]) for s in size]) for size in sizes]
|
205 |
-
self.sizes = [sum([s[2] * s[3] for s in size]) / img_size for size, img_size in zip(sizes, image_sizes)]
|
206 |
-
|
207 |
-
if min_size:
|
208 |
-
print('filter by size')
|
209 |
-
|
210 |
-
self.sample_ids = [self.sample_ids[i] for i in range(len(self.sample_ids)) if self.sizes[i] > min_size]
|
211 |
-
|
212 |
-
self.base_path = join(expanduser('~/datasets/PhraseCut/VGPhraseCut_v0/images/'))
|
213 |
-
|
214 |
-
def __len__(self):
|
215 |
-
return len(self.sample_ids)
|
216 |
-
|
217 |
-
|
218 |
-
def load_sample(self, sample_i, j):
|
219 |
-
|
220 |
-
img_ref_data = self.refvg_loader.get_img_ref_data(sample_i)
|
221 |
-
|
222 |
-
polys_phrase0 = img_ref_data['gt_Polygons'][j]
|
223 |
-
phrase = img_ref_data['phrases'][j]
|
224 |
-
phrase = self.phrase_form.format(phrase)
|
225 |
-
|
226 |
-
masks = []
|
227 |
-
for polys in polys_phrase0:
|
228 |
-
for poly in polys:
|
229 |
-
poly = [p[::-1] for p in poly] # swap x,y
|
230 |
-
masks += [polygon2mask((img_ref_data['height'], img_ref_data['width']), poly)]
|
231 |
-
|
232 |
-
seg = np.stack(masks).max(0)
|
233 |
-
img = np.array(Image.open(join(self.base_path, str(img_ref_data['image_id']) + '.jpg')))
|
234 |
-
|
235 |
-
min_shape = min(img.shape[:2])
|
236 |
-
|
237 |
-
if self.aug_crop:
|
238 |
-
sly, slx, exceed = find_crop(seg, (min_shape, min_shape), iterations=50, min_frac=0.05)
|
239 |
-
else:
|
240 |
-
sly, slx = slice(0, None), slice(0, None)
|
241 |
-
|
242 |
-
seg = seg[sly, slx]
|
243 |
-
img = img[sly, slx]
|
244 |
-
|
245 |
-
seg = seg.astype('uint8')
|
246 |
-
seg = torch.from_numpy(seg).view(1, 1, *seg.shape)
|
247 |
-
|
248 |
-
if img.ndim == 2:
|
249 |
-
img = np.dstack([img] * 3)
|
250 |
-
|
251 |
-
img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).float()
|
252 |
-
|
253 |
-
seg = nnf.interpolate(seg, (self.image_size, self.image_size), mode='nearest')[0,0]
|
254 |
-
img = nnf.interpolate(img, (self.image_size, self.image_size), mode='bilinear', align_corners=True)[0]
|
255 |
-
|
256 |
-
# img = img.permute([2,0, 1])
|
257 |
-
img = img / 255.0
|
258 |
-
|
259 |
-
if self.aug_color is not None:
|
260 |
-
img = self.aug_color(img)
|
261 |
-
|
262 |
-
img = self.normalize(img)
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
return img, seg, phrase
|
267 |
-
|
268 |
-
def __getitem__(self, i):
|
269 |
-
|
270 |
-
sample_i, j = self.sample_ids[i]
|
271 |
-
|
272 |
-
img, seg, phrase = self.load_sample(sample_i, j)
|
273 |
-
|
274 |
-
if self.negative_prob > 0:
|
275 |
-
if torch.rand((1,)).item() < self.negative_prob:
|
276 |
-
|
277 |
-
new_phrase = None
|
278 |
-
while new_phrase is None or new_phrase == phrase:
|
279 |
-
idx = torch.randint(0, len(self.all_phrases), (1,)).item()
|
280 |
-
new_phrase = self.all_phrases[idx]
|
281 |
-
phrase = new_phrase
|
282 |
-
seg = torch.zeros_like(seg)
|
283 |
-
|
284 |
-
if self.with_visual:
|
285 |
-
# find a corresponding visual image
|
286 |
-
if phrase in self.samples_by_phrase and len(self.samples_by_phrase[phrase]) > 1:
|
287 |
-
idx = torch.randint(0, len(self.samples_by_phrase[phrase]), (1,)).item()
|
288 |
-
other_sample = self.samples_by_phrase[phrase][idx]
|
289 |
-
#print(other_sample)
|
290 |
-
img_s, seg_s, _ = self.load_sample(*other_sample)
|
291 |
-
|
292 |
-
from datasets.utils import blend_image_segmentation
|
293 |
-
|
294 |
-
if self.mask in {'separate', 'text_and_separate'}:
|
295 |
-
# assert img.shape[1:] == img_s.shape[1:] == seg_s.shape == seg.shape[1:]
|
296 |
-
add_phrase = [phrase] if self.mask == 'text_and_separate' else []
|
297 |
-
vis_s = add_phrase + [img_s, seg_s, True]
|
298 |
-
else:
|
299 |
-
if self.mask.startswith('text_and_'):
|
300 |
-
mask_mode = self.mask[9:]
|
301 |
-
label_add = [phrase]
|
302 |
-
else:
|
303 |
-
mask_mode = self.mask
|
304 |
-
label_add = []
|
305 |
-
|
306 |
-
masked_img_s = torch.from_numpy(blend_image_segmentation(img_s, seg_s, mode=mask_mode, image_size=self.image_size)[0])
|
307 |
-
vis_s = label_add + [masked_img_s, True]
|
308 |
-
|
309 |
-
else:
|
310 |
-
# phrase is unique
|
311 |
-
vis_s = torch.zeros_like(img)
|
312 |
-
|
313 |
-
if self.mask in {'separate', 'text_and_separate'}:
|
314 |
-
add_phrase = [phrase] if self.mask == 'text_and_separate' else []
|
315 |
-
vis_s = add_phrase + [vis_s, torch.zeros(*vis_s.shape[1:], dtype=torch.uint8), False]
|
316 |
-
elif self.mask.startswith('text_and_'):
|
317 |
-
vis_s = [phrase, vis_s, False]
|
318 |
-
else:
|
319 |
-
vis_s = [vis_s, False]
|
320 |
-
else:
|
321 |
-
assert self.mask == 'text'
|
322 |
-
vis_s = [phrase]
|
323 |
-
|
324 |
-
seg = seg.unsqueeze(0).float()
|
325 |
-
|
326 |
-
data_x = (img,) + tuple(vis_s)
|
327 |
-
|
328 |
-
return data_x, (seg, torch.zeros(0), i)
|
329 |
-
|
330 |
-
|
331 |
-
class PhraseCutPlus(PhraseCut):
|
332 |
-
|
333 |
-
def __init__(self, split, image_size=400, aug=None, aug_color=False, aug_crop=True, min_size=0, remove_classes=None, only_visual=False, mask=None):
|
334 |
-
super().__init__(split, image_size=image_size, negative_prob=0.2, aug=aug, aug_color=aug_color, aug_crop=aug_crop, min_size=min_size,
|
335 |
-
remove_classes=remove_classes, with_visual=True, only_visual=only_visual, mask=mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clipseg/datasets/utils.py
DELETED
@@ -1,68 +0,0 @@
|
|
1 |
-
|
2 |
-
import numpy as np
|
3 |
-
import torch
|
4 |
-
|
5 |
-
|
6 |
-
def blend_image_segmentation(img, seg, mode, image_size=224):
|
7 |
-
|
8 |
-
|
9 |
-
if mode in {'blur_highlight', 'blur3_highlight', 'blur3_highlight01', 'blur_highlight_random', 'crop'}:
|
10 |
-
if isinstance(img, np.ndarray):
|
11 |
-
img = torch.from_numpy(img)
|
12 |
-
|
13 |
-
if isinstance(seg, np.ndarray):
|
14 |
-
seg = torch.from_numpy(seg)
|
15 |
-
|
16 |
-
if mode == 'overlay':
|
17 |
-
out = img * seg
|
18 |
-
out = [out.astype('float32')]
|
19 |
-
elif mode == 'highlight':
|
20 |
-
out = img * seg[None, :, :] * 0.85 + 0.15 * img
|
21 |
-
out = [out.astype('float32')]
|
22 |
-
elif mode == 'highlight2':
|
23 |
-
img = img / 2
|
24 |
-
out = (img+0.1) * seg[None, :, :] + 0.3 * img
|
25 |
-
out = [out.astype('float32')]
|
26 |
-
elif mode == 'blur_highlight':
|
27 |
-
from evaluation_utils import img_preprocess
|
28 |
-
out = [img_preprocess((None, [img], [seg]), blur=1, bg_fac=0.5).numpy()[0] - 0.01]
|
29 |
-
elif mode == 'blur3_highlight':
|
30 |
-
from evaluation_utils import img_preprocess
|
31 |
-
out = [img_preprocess((None, [img], [seg]), blur=3, bg_fac=0.5).numpy()[0] - 0.01]
|
32 |
-
elif mode == 'blur3_highlight01':
|
33 |
-
from evaluation_utils import img_preprocess
|
34 |
-
out = [img_preprocess((None, [img], [seg]), blur=3, bg_fac=0.1).numpy()[0] - 0.01]
|
35 |
-
elif mode == 'blur_highlight_random':
|
36 |
-
from evaluation_utils import img_preprocess
|
37 |
-
out = [img_preprocess((None, [img], [seg]), blur=0 + torch.randint(0, 3, (1,)).item(), bg_fac=0.1 + 0.8*torch.rand(1).item()).numpy()[0] - 0.01]
|
38 |
-
elif mode == 'crop':
|
39 |
-
from evaluation_utils import img_preprocess
|
40 |
-
out = [img_preprocess((None, [img], [seg]), blur=1, center_context=0.1, image_size=image_size)[0].numpy()]
|
41 |
-
elif mode == 'crop_blur_highlight':
|
42 |
-
from evaluation_utils import img_preprocess
|
43 |
-
out = [img_preprocess((None, [img], [seg]), blur=3, center_context=0.1, bg_fac=0.1, image_size=image_size)[0].numpy()]
|
44 |
-
elif mode == 'crop_blur_highlight352':
|
45 |
-
from evaluation_utils import img_preprocess
|
46 |
-
out = [img_preprocess((None, [img], [seg]), blur=3, center_context=0.1, bg_fac=0.1, image_size=352)[0].numpy()]
|
47 |
-
elif mode == 'shape':
|
48 |
-
out = [np.stack([seg[:, :]]*3).astype('float32')]
|
49 |
-
elif mode == 'concat':
|
50 |
-
out = [np.concatenate([img, seg[None, :, :]]).astype('float32')]
|
51 |
-
elif mode == 'image_only':
|
52 |
-
out = [img.astype('float32')]
|
53 |
-
elif mode == 'image_black':
|
54 |
-
out = [img.astype('float32')*0]
|
55 |
-
elif mode is None:
|
56 |
-
out = [img.astype('float32')]
|
57 |
-
elif mode == 'separate':
|
58 |
-
out = [img.astype('float32'), seg.astype('int64')]
|
59 |
-
elif mode == 'separate_img_black':
|
60 |
-
out = [img.astype('float32')*0, seg.astype('int64')]
|
61 |
-
elif mode == 'separate_seg_ones':
|
62 |
-
out = [img.astype('float32'), np.ones_like(seg).astype('int64')]
|
63 |
-
elif mode == 'separate_both_black':
|
64 |
-
out = [img.astype('float32')*0, seg.astype('int64')*0]
|
65 |
-
else:
|
66 |
-
raise ValueError(f'invalid mode: {mode}')
|
67 |
-
|
68 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clipseg/environment.yml
DELETED
@@ -1,15 +0,0 @@
|
|
1 |
-
name: clipseg-environment
|
2 |
-
channels:
|
3 |
-
- conda-forge
|
4 |
-
- pytorch
|
5 |
-
dependencies:
|
6 |
-
- numpy
|
7 |
-
- scipy
|
8 |
-
- matplotlib-base
|
9 |
-
- pip
|
10 |
-
- pip:
|
11 |
-
- --find-links https://download.pytorch.org/whl/torch_stable.html
|
12 |
-
- torch==1.10.0+cpu
|
13 |
-
- torchvision==0.11.1+cpu
|
14 |
-
- opencv-python
|
15 |
-
- git+https://github.com/openai/CLIP.git
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clipseg/evaluation_utils.py
DELETED
@@ -1,292 +0,0 @@
|
|
1 |
-
from torch.functional import Tensor
|
2 |
-
from general_utils import load_model
|
3 |
-
from torch.utils.data import DataLoader
|
4 |
-
import torch
|
5 |
-
import numpy as np
|
6 |
-
|
7 |
-
def denorm(img):
|
8 |
-
|
9 |
-
np_input = False
|
10 |
-
if isinstance(img, np.ndarray):
|
11 |
-
img = torch.from_numpy(img)
|
12 |
-
np_input = True
|
13 |
-
|
14 |
-
mean = torch.Tensor([0.485, 0.456, 0.406])
|
15 |
-
std = torch.Tensor([0.229, 0.224, 0.225])
|
16 |
-
|
17 |
-
img_denorm = (img*std[:,None,None]) + mean[:,None,None]
|
18 |
-
|
19 |
-
if np_input:
|
20 |
-
img_denorm = np.clip(img_denorm.numpy(), 0, 1)
|
21 |
-
else:
|
22 |
-
img_denorm = torch.clamp(img_denorm, 0, 1)
|
23 |
-
|
24 |
-
return img_denorm
|
25 |
-
|
26 |
-
|
27 |
-
def norm(img):
|
28 |
-
mean = torch.Tensor([0.485, 0.456, 0.406])
|
29 |
-
std = torch.Tensor([0.229, 0.224, 0.225])
|
30 |
-
return (img - mean[:,None,None]) / std[:,None,None]
|
31 |
-
|
32 |
-
|
33 |
-
def fast_iou_curve(p, g):
|
34 |
-
|
35 |
-
g = g[p.sort().indices]
|
36 |
-
p = torch.sigmoid(p.sort().values)
|
37 |
-
|
38 |
-
scores = []
|
39 |
-
vals = np.linspace(0, 1, 50)
|
40 |
-
|
41 |
-
for q in vals:
|
42 |
-
|
43 |
-
n = int(len(g) * q)
|
44 |
-
|
45 |
-
valid = torch.where(p > q)[0]
|
46 |
-
if len(valid) > 0:
|
47 |
-
n = int(valid[0])
|
48 |
-
else:
|
49 |
-
n = len(g)
|
50 |
-
|
51 |
-
fn = g[:n].sum()
|
52 |
-
tn = n - fn
|
53 |
-
tp = g[n:].sum()
|
54 |
-
fp = len(g) - n - tp
|
55 |
-
|
56 |
-
iou = tp / (tp + fn + fp)
|
57 |
-
|
58 |
-
precision = tp / (tp + fp)
|
59 |
-
recall = tp / (tp + fn)
|
60 |
-
|
61 |
-
scores += [iou]
|
62 |
-
|
63 |
-
return vals, scores
|
64 |
-
|
65 |
-
|
66 |
-
def fast_rp_curve(p, g):
|
67 |
-
|
68 |
-
g = g[p.sort().indices]
|
69 |
-
p = torch.sigmoid(p.sort().values)
|
70 |
-
|
71 |
-
precisions, recalls = [], []
|
72 |
-
vals = np.linspace(p.min(), p.max(), 250)
|
73 |
-
|
74 |
-
for q in p[::100000]:
|
75 |
-
|
76 |
-
n = int(len(g) * q)
|
77 |
-
|
78 |
-
valid = torch.where(p > q)[0]
|
79 |
-
if len(valid) > 0:
|
80 |
-
n = int(valid[0])
|
81 |
-
else:
|
82 |
-
n = len(g)
|
83 |
-
|
84 |
-
fn = g[:n].sum()
|
85 |
-
tn = n - fn
|
86 |
-
tp = g[n:].sum()
|
87 |
-
fp = len(g) - n - tp
|
88 |
-
|
89 |
-
iou = tp / (tp + fn + fp)
|
90 |
-
|
91 |
-
precision = tp / (tp + fp)
|
92 |
-
recall = tp / (tp + fn)
|
93 |
-
|
94 |
-
precisions += [precision]
|
95 |
-
recalls += [recall]
|
96 |
-
|
97 |
-
return recalls, precisions
|
98 |
-
|
99 |
-
|
100 |
-
# Image processing
|
101 |
-
|
102 |
-
def img_preprocess(batch, blur=0, grayscale=False, center_context=None, rect=False, rect_color=(255,0,0), rect_width=2,
|
103 |
-
brightness=1.0, bg_fac=1, colorize=False, outline=False, image_size=224):
|
104 |
-
import cv2
|
105 |
-
|
106 |
-
rw = rect_width
|
107 |
-
|
108 |
-
out = []
|
109 |
-
for img, mask in zip(batch[1], batch[2]):
|
110 |
-
|
111 |
-
img = img.cpu() if isinstance(img, torch.Tensor) else torch.from_numpy(img)
|
112 |
-
mask = mask.cpu() if isinstance(mask, torch.Tensor) else torch.from_numpy(mask)
|
113 |
-
|
114 |
-
img *= brightness
|
115 |
-
img_bl = img
|
116 |
-
if blur > 0: # best 5
|
117 |
-
img_bl = torch.from_numpy(cv2.GaussianBlur(img.permute(1,2,0).numpy(), (15, 15), blur)).permute(2,0,1)
|
118 |
-
|
119 |
-
if grayscale:
|
120 |
-
img_bl = img_bl[1][None]
|
121 |
-
|
122 |
-
#img_inp = img_ratio*img*mask + (1-img_ratio)*img_bl
|
123 |
-
# img_inp = img_ratio*img*mask + (1-img_ratio)*img_bl * (1-mask)
|
124 |
-
img_inp = img*mask + (bg_fac) * img_bl * (1-mask)
|
125 |
-
|
126 |
-
if rect:
|
127 |
-
_, bbox = crop_mask(img, mask, context=0.1)
|
128 |
-
img_inp[:, bbox[2]: bbox[3], max(0, bbox[0]-rw):bbox[0]+rw] = torch.tensor(rect_color)[:,None,None]
|
129 |
-
img_inp[:, bbox[2]: bbox[3], max(0, bbox[1]-rw):bbox[1]+rw] = torch.tensor(rect_color)[:,None,None]
|
130 |
-
img_inp[:, max(0, bbox[2]-1): bbox[2]+rw, bbox[0]:bbox[1]] = torch.tensor(rect_color)[:,None,None]
|
131 |
-
img_inp[:, max(0, bbox[3]-1): bbox[3]+rw, bbox[0]:bbox[1]] = torch.tensor(rect_color)[:,None,None]
|
132 |
-
|
133 |
-
|
134 |
-
if center_context is not None:
|
135 |
-
img_inp = object_crop(img_inp, mask, context=center_context, image_size=image_size)
|
136 |
-
|
137 |
-
if colorize:
|
138 |
-
img_gray = denorm(img)
|
139 |
-
img_gray = cv2.cvtColor(img_gray.permute(1,2,0).numpy(), cv2.COLOR_RGB2GRAY)
|
140 |
-
img_gray = torch.stack([torch.from_numpy(img_gray)]*3)
|
141 |
-
img_inp = torch.tensor([1,0.2,0.2])[:,None,None] * img_gray * mask + bg_fac * img_gray * (1-mask)
|
142 |
-
img_inp = norm(img_inp)
|
143 |
-
|
144 |
-
if outline:
|
145 |
-
cont = cv2.findContours(mask.byte().numpy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
146 |
-
outline_img = np.zeros(mask.shape, dtype=np.uint8)
|
147 |
-
cv2.drawContours(outline_img, cont[0], -1, thickness=5, color=(255, 255, 255))
|
148 |
-
outline_img = torch.stack([torch.from_numpy(outline_img)]*3).float() / 255.
|
149 |
-
img_inp = torch.tensor([1,0,0])[:,None,None] * outline_img + denorm(img_inp) * (1- outline_img)
|
150 |
-
img_inp = norm(img_inp)
|
151 |
-
|
152 |
-
out += [img_inp]
|
153 |
-
|
154 |
-
return torch.stack(out)
|
155 |
-
|
156 |
-
|
157 |
-
def object_crop(img, mask, context=0.0, square=False, image_size=224):
|
158 |
-
img_crop, bbox = crop_mask(img, mask, context=context, square=square)
|
159 |
-
img_crop = pad_to_square(img_crop, channel_dim=0)
|
160 |
-
img_crop = torch.nn.functional.interpolate(img_crop.unsqueeze(0), (image_size, image_size)).squeeze(0)
|
161 |
-
return img_crop
|
162 |
-
|
163 |
-
|
164 |
-
def crop_mask(img, mask, context=0.0, square=False):
|
165 |
-
|
166 |
-
assert img.shape[1:] == mask.shape
|
167 |
-
|
168 |
-
bbox = [mask.max(0).values.argmax(), mask.size(0) - mask.max(0).values.flip(0).argmax()]
|
169 |
-
bbox += [mask.max(1).values.argmax(), mask.size(1) - mask.max(1).values.flip(0).argmax()]
|
170 |
-
bbox = [int(x) for x in bbox]
|
171 |
-
|
172 |
-
width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0])
|
173 |
-
|
174 |
-
# square mask
|
175 |
-
if square:
|
176 |
-
bbox[0] = int(max(0, bbox[0] - context * height))
|
177 |
-
bbox[1] = int(min(mask.size(0), bbox[1] + context * height))
|
178 |
-
bbox[2] = int(max(0, bbox[2] - context * width))
|
179 |
-
bbox[3] = int(min(mask.size(1), bbox[3] + context * width))
|
180 |
-
|
181 |
-
width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0])
|
182 |
-
if height > width:
|
183 |
-
bbox[2] = int(max(0, (bbox[2] - 0.5*height)))
|
184 |
-
bbox[3] = bbox[2] + height
|
185 |
-
else:
|
186 |
-
bbox[0] = int(max(0, (bbox[0] - 0.5*width)))
|
187 |
-
bbox[1] = bbox[0] + width
|
188 |
-
else:
|
189 |
-
bbox[0] = int(max(0, bbox[0] - context * height))
|
190 |
-
bbox[1] = int(min(mask.size(0), bbox[1] + context * height))
|
191 |
-
bbox[2] = int(max(0, bbox[2] - context * width))
|
192 |
-
bbox[3] = int(min(mask.size(1), bbox[3] + context * width))
|
193 |
-
|
194 |
-
width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0])
|
195 |
-
img_crop = img[:, bbox[2]: bbox[3], bbox[0]: bbox[1]]
|
196 |
-
return img_crop, bbox
|
197 |
-
|
198 |
-
|
199 |
-
def pad_to_square(img, channel_dim=2, fill=0):
|
200 |
-
"""
|
201 |
-
|
202 |
-
|
203 |
-
add padding such that a squared image is returned """
|
204 |
-
|
205 |
-
from torchvision.transforms.functional import pad
|
206 |
-
|
207 |
-
if channel_dim == 2:
|
208 |
-
img = img.permute(2, 0, 1)
|
209 |
-
elif channel_dim == 0:
|
210 |
-
pass
|
211 |
-
else:
|
212 |
-
raise ValueError('invalid channel_dim')
|
213 |
-
|
214 |
-
h, w = img.shape[1:]
|
215 |
-
pady1 = pady2 = padx1 = padx2 = 0
|
216 |
-
|
217 |
-
if h > w:
|
218 |
-
padx1 = (h - w) // 2
|
219 |
-
padx2 = h - w - padx1
|
220 |
-
elif w > h:
|
221 |
-
pady1 = (w - h) // 2
|
222 |
-
pady2 = w - h - pady1
|
223 |
-
|
224 |
-
img_padded = pad(img, padding=(padx1, pady1, padx2, pady2), padding_mode='constant')
|
225 |
-
|
226 |
-
if channel_dim == 2:
|
227 |
-
img_padded = img_padded.permute(1, 2, 0)
|
228 |
-
|
229 |
-
return img_padded
|
230 |
-
|
231 |
-
|
232 |
-
# qualitative
|
233 |
-
|
234 |
-
def split_sentence(inp, limit=9):
|
235 |
-
t_new, current_len = [], 0
|
236 |
-
for k, t in enumerate(inp.split(' ')):
|
237 |
-
current_len += len(t) + 1
|
238 |
-
t_new += [t+' ']
|
239 |
-
# not last
|
240 |
-
if current_len > limit and k != len(inp.split(' ')) - 1:
|
241 |
-
current_len = 0
|
242 |
-
t_new += ['\n']
|
243 |
-
|
244 |
-
t_new = ''.join(t_new)
|
245 |
-
return t_new
|
246 |
-
|
247 |
-
|
248 |
-
from matplotlib import pyplot as plt
|
249 |
-
|
250 |
-
|
251 |
-
def plot(imgs, *preds, labels=None, scale=1, cmap=plt.cm.magma, aps=None, gt_labels=None, vmax=None):
|
252 |
-
|
253 |
-
row_off = 0 if labels is None else 1
|
254 |
-
_, ax = plt.subplots(len(imgs) + row_off, 1 + len(preds), figsize=(scale * float(1 + 2*len(preds)), scale * float(len(imgs)*2)))
|
255 |
-
[a.axis('off') for a in ax.flatten()]
|
256 |
-
|
257 |
-
if labels is not None:
|
258 |
-
for j in range(len(labels)):
|
259 |
-
t_new = split_sentence(labels[j], limit=6)
|
260 |
-
ax[0, 1+ j].text(0.5, 0.1, t_new, ha='center', fontsize=3+ 10*scale)
|
261 |
-
|
262 |
-
|
263 |
-
for i in range(len(imgs)):
|
264 |
-
ax[i + row_off,0].imshow(imgs[i])
|
265 |
-
for j in range(len(preds)):
|
266 |
-
img = preds[j][i][0].detach().cpu().numpy()
|
267 |
-
|
268 |
-
if gt_labels is not None and labels[j] == gt_labels[i]:
|
269 |
-
print(j, labels[j], gt_labels[i])
|
270 |
-
edgecolor = 'red'
|
271 |
-
if aps is not None:
|
272 |
-
ax[i + row_off, 1 + j].text(30, 70, f'AP: {aps[i]:.3f}', color='red', fontsize=8)
|
273 |
-
else:
|
274 |
-
edgecolor = 'k'
|
275 |
-
|
276 |
-
rect = plt.Rectangle([0,0], img.shape[0], img.shape[1], facecolor="none",
|
277 |
-
edgecolor=edgecolor, linewidth=3)
|
278 |
-
ax[i + row_off,1 + j].add_patch(rect)
|
279 |
-
|
280 |
-
if vmax is None:
|
281 |
-
this_vmax = 1
|
282 |
-
elif vmax == 'per_prompt':
|
283 |
-
this_vmax = max([preds[j][_i][0].max() for _i in range(len(imgs))])
|
284 |
-
elif vmax == 'per_image':
|
285 |
-
this_vmax = max([preds[_j][i][0].max() for _j in range(len(preds))])
|
286 |
-
|
287 |
-
ax[i + row_off,1 + j].imshow(img, vmin=0, vmax=this_vmax, cmap=cmap)
|
288 |
-
|
289 |
-
|
290 |
-
# ax[i,1 + j].imshow(preds[j][i][0].detach().cpu().numpy(), vmin=preds[j].min(), vmax=preds[j].max())
|
291 |
-
plt.tight_layout()
|
292 |
-
plt.subplots_adjust(wspace=0.05, hspace=0.05)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clipseg/example_image.jpg
DELETED
Binary file (91.5 kB)
|
|
clipseg/experiments/ablation.yaml
DELETED
@@ -1,84 +0,0 @@
|
|
1 |
-
configuration:
|
2 |
-
batch_size: 64
|
3 |
-
optimizer: torch.optim.AdamW
|
4 |
-
|
5 |
-
lr: 0.001
|
6 |
-
|
7 |
-
trainer: experiment_setup.train_loop
|
8 |
-
scorer: experiment_setup.score
|
9 |
-
model: models.clipseg.CLIPDensePredT
|
10 |
-
|
11 |
-
lr_scheduler: cosine
|
12 |
-
T_max: 20000
|
13 |
-
eta_min: 0.0001
|
14 |
-
|
15 |
-
max_iterations: 20000 # <-##########################################
|
16 |
-
val_interval: null
|
17 |
-
|
18 |
-
# dataset
|
19 |
-
dataset: datasets.phrasecut.PhraseCut # <-----------------
|
20 |
-
split_mode: pascal_test
|
21 |
-
split: train
|
22 |
-
mask: text_and_crop_blur_highlight352
|
23 |
-
image_size: 352
|
24 |
-
negative_prob: 0.2
|
25 |
-
mix_text_max: 0.5
|
26 |
-
|
27 |
-
# general
|
28 |
-
mix: True # <-----------------
|
29 |
-
prompt: shuffle+
|
30 |
-
norm_cond: True
|
31 |
-
mix_text_min: 0.0
|
32 |
-
with_visual: True
|
33 |
-
|
34 |
-
# model
|
35 |
-
version: 'ViT-B/16'
|
36 |
-
extract_layers: [3, 7, 9]
|
37 |
-
reduce_dim: 64
|
38 |
-
depth: 3
|
39 |
-
fix_shift: False # <-##########################################
|
40 |
-
|
41 |
-
loss: torch.nn.functional.binary_cross_entropy_with_logits
|
42 |
-
amp: True
|
43 |
-
|
44 |
-
test_configuration_common:
|
45 |
-
normalize: True
|
46 |
-
image_size: 352
|
47 |
-
batch_size: 32
|
48 |
-
sigmoid: True
|
49 |
-
split: test
|
50 |
-
label_support: True
|
51 |
-
|
52 |
-
test_configuration:
|
53 |
-
|
54 |
-
-
|
55 |
-
name: pc
|
56 |
-
metric: metrics.FixedIntervalMetrics
|
57 |
-
test_dataset: phrasecut
|
58 |
-
mask: text
|
59 |
-
|
60 |
-
-
|
61 |
-
name: pc-vis
|
62 |
-
metric: metrics.FixedIntervalMetrics
|
63 |
-
test_dataset: phrasecut
|
64 |
-
mask: crop_blur_highlight352
|
65 |
-
with_visual: True
|
66 |
-
visual_only: True
|
67 |
-
|
68 |
-
|
69 |
-
columns: [name,
|
70 |
-
pc_fgiou_best, pc_miou_best, pc_fgiou_0.5,
|
71 |
-
pc-vis_fgiou_best, pc-vis_miou_best, pc-vis_fgiou_0.5,
|
72 |
-
duration]
|
73 |
-
|
74 |
-
|
75 |
-
individual_configurations:
|
76 |
-
|
77 |
-
- {name: rd64-uni}
|
78 |
-
- {name: rd64-no-pretrain, not_pretrained: True, lr: 0.0003}
|
79 |
-
- {name: rd64-no-negatives, negative_prob: 0.0}
|
80 |
-
- {name: rd64-neg0.5, negative_prob: 0.5}
|
81 |
-
- {name: rd64-no-visual, with_visual: False, mix: False}
|
82 |
-
- {name: rd16-uni, reduce_dim: 16}
|
83 |
-
- {name: rd64-layer3, extract_layers: [3], depth: 1}
|
84 |
-
- {name: rd64-blur-highlight, mask: text_and_blur_highlight, test_configuration: {mask: blur_highlight}}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clipseg/experiments/coco.yaml
DELETED
@@ -1,101 +0,0 @@
|
|
1 |
-
configuration:
|
2 |
-
batch_size: 64
|
3 |
-
optimizer: torch.optim.AdamW
|
4 |
-
|
5 |
-
lr: 0.001
|
6 |
-
|
7 |
-
trainer: experiment_setup.train_loop
|
8 |
-
scorer: experiment_setup.score
|
9 |
-
model: models.clipseg.CLIPDensePredT
|
10 |
-
|
11 |
-
lr_scheduler: cosine
|
12 |
-
T_max: 20000
|
13 |
-
eta_min: 0.0001
|
14 |
-
|
15 |
-
max_iterations: 20000
|
16 |
-
val_interval: null
|
17 |
-
|
18 |
-
# dataset
|
19 |
-
dataset: datasets.coco_wrapper.COCOWrapper
|
20 |
-
# split_mode: pascal_test
|
21 |
-
split: train
|
22 |
-
mask: text_and_blur3_highlight01
|
23 |
-
image_size: 352
|
24 |
-
normalize: True
|
25 |
-
pre_crop_image_size: [sample, 1, 1.5]
|
26 |
-
aug: 1new
|
27 |
-
|
28 |
-
# general
|
29 |
-
mix: True
|
30 |
-
prompt: shuffle+
|
31 |
-
norm_cond: True
|
32 |
-
mix_text_min: 0.0
|
33 |
-
|
34 |
-
# model
|
35 |
-
out: 1
|
36 |
-
extract_layers: [3, 7, 9]
|
37 |
-
reduce_dim: 64
|
38 |
-
depth: 3
|
39 |
-
fix_shift: False
|
40 |
-
|
41 |
-
loss: torch.nn.functional.binary_cross_entropy_with_logits
|
42 |
-
amp: True
|
43 |
-
|
44 |
-
test_configuration_common:
|
45 |
-
normalize: True
|
46 |
-
image_size: 352
|
47 |
-
# max_iterations: 10
|
48 |
-
batch_size: 8
|
49 |
-
sigmoid: True
|
50 |
-
test_dataset: coco
|
51 |
-
metric: metrics.FixedIntervalMetrics
|
52 |
-
|
53 |
-
test_configuration:
|
54 |
-
|
55 |
-
-
|
56 |
-
name: coco_t
|
57 |
-
mask: text
|
58 |
-
|
59 |
-
-
|
60 |
-
name: coco_h
|
61 |
-
mask: blur3_highlight01
|
62 |
-
|
63 |
-
-
|
64 |
-
name: coco_h2
|
65 |
-
mask: crop_blur_highlight352
|
66 |
-
|
67 |
-
|
68 |
-
columns: [i, name,
|
69 |
-
coco_t_fgiou_best, coco_t_miou_best, coco_t_fgiou_0.5,
|
70 |
-
coco_h_fgiou_best, coco_h_miou_best, coco_h_fgiou_0.5,
|
71 |
-
coco_h2_fgiou_best, coco_h2_miou_best, coco_h2_fgiou_0.5, coco_h2_fgiou_best_t,
|
72 |
-
train_loss, duration, date
|
73 |
-
]
|
74 |
-
|
75 |
-
individual_configurations:
|
76 |
-
|
77 |
-
|
78 |
-
- {name: rd64-7K-vit16-cbh-coco-0, version: 'ViT-B/16', fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
|
79 |
-
- {name: rd64-7K-vit16-cbh-coco-1, version: 'ViT-B/16', fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
|
80 |
-
- {name: rd64-7K-vit16-cbh-coco-2, version: 'ViT-B/16', fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
|
81 |
-
- {name: rd64-7K-vit16-cbh-coco-3, version: 'ViT-B/16', fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
|
82 |
-
|
83 |
-
|
84 |
-
- {name: rd64-7K-vit16-cbh-neg0.2-coco-0, version: 'ViT-B/16', negative_prob: 0.2, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
|
85 |
-
- {name: rd64-7K-vit16-cbh-neg0.2-coco-1, version: 'ViT-B/16', negative_prob: 0.2, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
|
86 |
-
- {name: rd64-7K-vit16-cbh-neg0.2-coco-2, version: 'ViT-B/16', negative_prob: 0.2, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
|
87 |
-
- {name: rd64-7K-vit16-cbh-neg0.2-coco-3, version: 'ViT-B/16', negative_prob: 0.2, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
|
88 |
-
|
89 |
-
|
90 |
-
# ViT
|
91 |
-
- {name: vit64-7K-vit16-cbh-coco-0, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001}
|
92 |
-
- {name: vit64-7K-vit16-cbh-coco-1, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001}
|
93 |
-
- {name: vit64-7K-vit16-cbh-coco-2, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001}
|
94 |
-
- {name: vit64-7K-vit16-cbh-coco-3, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001}
|
95 |
-
|
96 |
-
|
97 |
-
# BASELINE
|
98 |
-
- {name: bl64-7K-vit16-cbh-neg0.2-coco-0, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
|
99 |
-
- {name: bl64-7K-vit16-cbh-neg0.2-coco-1, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
|
100 |
-
- {name: bl64-7K-vit16-cbh-neg0.2-coco-2, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
|
101 |
-
- {name: bl64-7K-vit16-cbh-neg0.2-coco-3, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clipseg/experiments/pascal_1shot.yaml
DELETED
@@ -1,101 +0,0 @@
|
|
1 |
-
configuration:
|
2 |
-
batch_size: 64
|
3 |
-
optimizer: torch.optim.AdamW
|
4 |
-
|
5 |
-
lr: 0.001
|
6 |
-
|
7 |
-
trainer: experiment_setup.train_loop
|
8 |
-
scorer: experiment_setup.score
|
9 |
-
model: models.clipseg.CLIPDensePredT
|
10 |
-
|
11 |
-
lr_scheduler: cosine
|
12 |
-
T_max: 20000
|
13 |
-
eta_min: 0.0001
|
14 |
-
|
15 |
-
max_iterations: 20000 # <-##########################################
|
16 |
-
val_interval: null
|
17 |
-
|
18 |
-
# dataset
|
19 |
-
dataset: datasets.phrasecut.PhraseCut
|
20 |
-
split_mode: pascal_test
|
21 |
-
mode: train
|
22 |
-
mask: text_and_crop_blur_highlight352
|
23 |
-
image_size: 352
|
24 |
-
normalize: True
|
25 |
-
pre_crop_image_size: [sample, 1, 1.5]
|
26 |
-
aug: 1new
|
27 |
-
with_visual: True
|
28 |
-
split: train
|
29 |
-
|
30 |
-
# general
|
31 |
-
mix: True
|
32 |
-
prompt: shuffle+
|
33 |
-
norm_cond: True
|
34 |
-
mix_text_min: 0.0
|
35 |
-
|
36 |
-
# model
|
37 |
-
out: 1
|
38 |
-
version: 'ViT-B/16'
|
39 |
-
extract_layers: [3, 7, 9]
|
40 |
-
reduce_dim: 64
|
41 |
-
depth: 3
|
42 |
-
|
43 |
-
loss: torch.nn.functional.binary_cross_entropy_with_logits
|
44 |
-
amp: True
|
45 |
-
|
46 |
-
test_configuration_common:
|
47 |
-
normalize: True
|
48 |
-
image_size: 352
|
49 |
-
metric: metrics.FixedIntervalMetrics
|
50 |
-
batch_size: 1
|
51 |
-
test_dataset: pascal
|
52 |
-
sigmoid: True
|
53 |
-
# max_iterations: 250
|
54 |
-
|
55 |
-
test_configuration:
|
56 |
-
|
57 |
-
-
|
58 |
-
name: pas_t
|
59 |
-
mask: text
|
60 |
-
|
61 |
-
-
|
62 |
-
name: pas_h
|
63 |
-
mask: blur3_highlight01
|
64 |
-
|
65 |
-
-
|
66 |
-
name: pas_h2
|
67 |
-
mask: crop_blur_highlight352
|
68 |
-
|
69 |
-
|
70 |
-
columns: [name,
|
71 |
-
pas_t_fgiou_best, pas_t_miou_best, pas_t_fgiou_ct,
|
72 |
-
pas_h_fgiou_best, pas_h_miou_best, pas_h_fgiou_ct,
|
73 |
-
pas_h2_fgiou_best, pas_h2_miou_best, pas_h2_fgiou_ct, pas_h2_fgiou_best_t,
|
74 |
-
train_loss, duration, date
|
75 |
-
]
|
76 |
-
|
77 |
-
individual_configurations:
|
78 |
-
|
79 |
-
- {name: rd64-uni-phrasepas5i-0, remove_classes: [pas5i, 0], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [0], custom_threshold: 0.24}}
|
80 |
-
- {name: rd64-uni-phrasepas5i-1, remove_classes: [pas5i, 1], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [1], custom_threshold: 0.24}}
|
81 |
-
- {name: rd64-uni-phrasepas5i-2, remove_classes: [pas5i, 2], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [2], custom_threshold: 0.24}}
|
82 |
-
- {name: rd64-uni-phrasepas5i-3, remove_classes: [pas5i, 3], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [3], custom_threshold: 0.24}}
|
83 |
-
|
84 |
-
|
85 |
-
- {name: rd64-phrasepas5i-0, remove_classes: [pas5i, 0], negative_prob: 0.0, test_configuration: {splits: [0], custom_threshold: 0.28}}
|
86 |
-
- {name: rd64-phrasepas5i-1, remove_classes: [pas5i, 1], negative_prob: 0.0, test_configuration: {splits: [1], custom_threshold: 0.28}}
|
87 |
-
- {name: rd64-phrasepas5i-2, remove_classes: [pas5i, 2], negative_prob: 0.0, test_configuration: {splits: [2], custom_threshold: 0.28}}
|
88 |
-
- {name: rd64-phrasepas5i-3, remove_classes: [pas5i, 3], negative_prob: 0.0, test_configuration: {splits: [3], custom_threshold: 0.28}}
|
89 |
-
|
90 |
-
|
91 |
-
# baseline
|
92 |
-
- {name: bl64-phrasepas5i-0, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 0], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [0], custom_threshold: 0.24}}
|
93 |
-
- {name: bl64-phrasepas5i-1, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 1], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [1], custom_threshold: 0.24}}
|
94 |
-
- {name: bl64-phrasepas5i-2, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 2], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [2], custom_threshold: 0.24}}
|
95 |
-
- {name: bl64-phrasepas5i-3, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 3], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [3], custom_threshold: 0.24}}
|
96 |
-
|
97 |
-
# ViT
|
98 |
-
- {name: vit64-uni-phrasepas5i-0, remove_classes: [pas5i, 0], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [0], custom_threshold: 0.02}}
|
99 |
-
- {name: vit64-uni-phrasepas5i-1, remove_classes: [pas5i, 1], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [1], custom_threshold: 0.02}}
|
100 |
-
- {name: vit64-uni-phrasepas5i-2, remove_classes: [pas5i, 2], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [2], custom_threshold: 0.02}}
|
101 |
-
- {name: vit64-uni-phrasepas5i-3, remove_classes: [pas5i, 3], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [3], custom_threshold: 0.02}}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clipseg/experiments/phrasecut.yaml
DELETED
@@ -1,80 +0,0 @@
|
|
1 |
-
configuration:
|
2 |
-
batch_size: 64
|
3 |
-
optimizer: torch.optim.AdamW
|
4 |
-
|
5 |
-
lr: 0.001
|
6 |
-
|
7 |
-
trainer: experiment_setup.train_loop
|
8 |
-
scorer: experiment_setup.score
|
9 |
-
model: models.clipseg.CLIPDensePredT
|
10 |
-
|
11 |
-
lr_scheduler: cosine
|
12 |
-
T_max: 20000
|
13 |
-
eta_min: 0.0001
|
14 |
-
|
15 |
-
max_iterations: 20000
|
16 |
-
val_interval: null
|
17 |
-
|
18 |
-
# dataset
|
19 |
-
dataset: datasets.phrasecut.PhraseCut # <-----------------
|
20 |
-
split_mode: pascal_test
|
21 |
-
split: train
|
22 |
-
mask: text_and_crop_blur_highlight352
|
23 |
-
image_size: 352
|
24 |
-
normalize: True
|
25 |
-
pre_crop_image_size: [sample, 1, 1.5]
|
26 |
-
aug: 1new
|
27 |
-
|
28 |
-
# general
|
29 |
-
mix: False # <-----------------
|
30 |
-
prompt: shuffle+
|
31 |
-
norm_cond: True
|
32 |
-
mix_text_min: 0.0
|
33 |
-
|
34 |
-
# model
|
35 |
-
out: 1
|
36 |
-
extract_layers: [3, 7, 9]
|
37 |
-
reduce_dim: 64
|
38 |
-
depth: 3
|
39 |
-
fix_shift: False
|
40 |
-
|
41 |
-
loss: torch.nn.functional.binary_cross_entropy_with_logits
|
42 |
-
amp: True
|
43 |
-
|
44 |
-
test_configuration_common:
|
45 |
-
normalize: True
|
46 |
-
image_size: 352
|
47 |
-
batch_size: 32
|
48 |
-
# max_iterations: 5
|
49 |
-
# max_iterations: 150
|
50 |
-
|
51 |
-
test_configuration:
|
52 |
-
|
53 |
-
-
|
54 |
-
name: pc # old: phrasecut
|
55 |
-
metric: metrics.FixedIntervalMetrics
|
56 |
-
test_dataset: phrasecut
|
57 |
-
split: test
|
58 |
-
mask: text
|
59 |
-
label_support: True
|
60 |
-
sigmoid: True
|
61 |
-
|
62 |
-
|
63 |
-
columns: [i, name, pc_miou_0.3, pc_fgiou_0.3, pc_fgiou_0.5, pc_ap, duration, date]
|
64 |
-
|
65 |
-
|
66 |
-
individual_configurations:
|
67 |
-
|
68 |
-
# important ones
|
69 |
-
|
70 |
-
|
71 |
-
- {name: rd64-uni, version: 'ViT-B/16', reduce_dim: 64, with_visual: True, negative_prob: 0.2, mix: True, mix_text_max: 0.5}
|
72 |
-
|
73 |
-
# this was accedentally trained using old mask
|
74 |
-
- {name: rd128-vit16-phrasecut, version: 'ViT-B/16', reduce_dim: 128, mask: text_and_blur3_highlight01}
|
75 |
-
- {name: rd64-uni-novis, version: 'ViT-B/16', reduce_dim: 64, with_visual: False, negative_prob: 0.2, mix: False}
|
76 |
-
# this was accedentally trained using old mask
|
77 |
-
- {name: baseline3-vit16-phrasecut, model: models.clipseg.CLIPDenseBaseline, version: 'ViT-B/16', reduce_dim: 64, reduce2_dim: 64, mask: text_and_blur3_highlight01}
|
78 |
-
|
79 |
-
- {name: vit64-uni, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, reduce_dim: 64, with_visual: True, only_visual: True, negative_prob: 0.2, mask: crop_blur_highlight352, lr: 0.0003}
|
80 |
-
- {name: vit64-uni-novis, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, with_visual: False, reduce_dim: 64, lr: 0.0001}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clipseg/general_utils.py
DELETED
@@ -1,272 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import inspect
|
3 |
-
import torch
|
4 |
-
import os
|
5 |
-
import sys
|
6 |
-
import yaml
|
7 |
-
from shutil import copy, copytree
|
8 |
-
from os.path import join, dirname, realpath, expanduser, isfile, isdir, basename
|
9 |
-
|
10 |
-
|
11 |
-
class Logger(object):
|
12 |
-
|
13 |
-
def __getattr__(self, k):
|
14 |
-
return print
|
15 |
-
|
16 |
-
log = Logger()
|
17 |
-
|
18 |
-
def training_config_from_cli_args():
|
19 |
-
experiment_name = sys.argv[1]
|
20 |
-
experiment_id = int(sys.argv[2])
|
21 |
-
|
22 |
-
yaml_config = yaml.load(open(f'experiments/{experiment_name}'), Loader=yaml.SafeLoader)
|
23 |
-
|
24 |
-
config = yaml_config['configuration']
|
25 |
-
config = {**config, **yaml_config['individual_configurations'][experiment_id]}
|
26 |
-
config = AttributeDict(config)
|
27 |
-
return config
|
28 |
-
|
29 |
-
|
30 |
-
def score_config_from_cli_args():
|
31 |
-
experiment_name = sys.argv[1]
|
32 |
-
experiment_id = int(sys.argv[2])
|
33 |
-
|
34 |
-
|
35 |
-
yaml_config = yaml.load(open(f'experiments/{experiment_name}'), Loader=yaml.SafeLoader)
|
36 |
-
|
37 |
-
config = yaml_config['test_configuration_common']
|
38 |
-
|
39 |
-
if type(yaml_config['test_configuration']) == list:
|
40 |
-
test_id = int(sys.argv[3])
|
41 |
-
config = {**config, **yaml_config['test_configuration'][test_id]}
|
42 |
-
else:
|
43 |
-
config = {**config, **yaml_config['test_configuration']}
|
44 |
-
|
45 |
-
if 'test_configuration' in yaml_config['individual_configurations'][experiment_id]:
|
46 |
-
config = {**config, **yaml_config['individual_configurations'][experiment_id]['test_configuration']}
|
47 |
-
|
48 |
-
train_checkpoint_id = yaml_config['individual_configurations'][experiment_id]['name']
|
49 |
-
|
50 |
-
config = AttributeDict(config)
|
51 |
-
return config, train_checkpoint_id
|
52 |
-
|
53 |
-
|
54 |
-
def get_from_repository(local_name, repo_files, integrity_check=None, repo_dir='~/dataset_repository',
|
55 |
-
local_dir='~/datasets'):
|
56 |
-
""" copies files from repository to local folder.
|
57 |
-
|
58 |
-
repo_files: list of filenames or list of tuples [filename, target path]
|
59 |
-
|
60 |
-
e.g. get_from_repository('MyDataset', [['data/dataset1.tar', 'other/path/ds03.tar'])
|
61 |
-
will create a folder 'MyDataset' in local_dir, and extract the content of
|
62 |
-
'<repo_dir>/data/dataset1.tar' to <local_dir>/MyDataset/other/path.
|
63 |
-
"""
|
64 |
-
|
65 |
-
local_dir = realpath(join(expanduser(local_dir), local_name))
|
66 |
-
|
67 |
-
dataset_exists = True
|
68 |
-
|
69 |
-
# check if folder is available
|
70 |
-
if not isdir(local_dir):
|
71 |
-
dataset_exists = False
|
72 |
-
|
73 |
-
if integrity_check is not None:
|
74 |
-
try:
|
75 |
-
integrity_ok = integrity_check(local_dir)
|
76 |
-
except BaseException:
|
77 |
-
integrity_ok = False
|
78 |
-
|
79 |
-
if integrity_ok:
|
80 |
-
log.hint('Passed custom integrity check')
|
81 |
-
else:
|
82 |
-
log.hint('Custom integrity check failed')
|
83 |
-
|
84 |
-
dataset_exists = dataset_exists and integrity_ok
|
85 |
-
|
86 |
-
if not dataset_exists:
|
87 |
-
|
88 |
-
repo_dir = realpath(expanduser(repo_dir))
|
89 |
-
|
90 |
-
for i, filename in enumerate(repo_files):
|
91 |
-
|
92 |
-
if type(filename) == str:
|
93 |
-
origin, target = filename, filename
|
94 |
-
archive_target = join(local_dir, basename(origin))
|
95 |
-
extract_target = join(local_dir)
|
96 |
-
else:
|
97 |
-
origin, target = filename
|
98 |
-
archive_target = join(local_dir, dirname(target), basename(origin))
|
99 |
-
extract_target = join(local_dir, dirname(target))
|
100 |
-
|
101 |
-
archive_origin = join(repo_dir, origin)
|
102 |
-
|
103 |
-
log.hint(f'copy: {archive_origin} to {archive_target}')
|
104 |
-
|
105 |
-
# make sure the path exists
|
106 |
-
os.makedirs(dirname(archive_target), exist_ok=True)
|
107 |
-
|
108 |
-
if os.path.isfile(archive_target):
|
109 |
-
# only copy if size differs
|
110 |
-
if os.path.getsize(archive_target) != os.path.getsize(archive_origin):
|
111 |
-
log.hint(f'file exists but filesize differs: target {os.path.getsize(archive_target)} vs. origin {os.path.getsize(archive_origin)}')
|
112 |
-
copy(archive_origin, archive_target)
|
113 |
-
else:
|
114 |
-
copy(archive_origin, archive_target)
|
115 |
-
|
116 |
-
extract_archive(archive_target, extract_target, noarchive_ok=True)
|
117 |
-
|
118 |
-
# concurrent processes might have deleted the file
|
119 |
-
if os.path.isfile(archive_target):
|
120 |
-
os.remove(archive_target)
|
121 |
-
|
122 |
-
|
123 |
-
def extract_archive(filename, target_folder=None, noarchive_ok=False):
|
124 |
-
from subprocess import run, PIPE
|
125 |
-
|
126 |
-
if filename.endswith('.tgz') or filename.endswith('.tar'):
|
127 |
-
command = f'tar -xf {filename}'
|
128 |
-
command += f' -C {target_folder}' if target_folder is not None else ''
|
129 |
-
elif filename.endswith('.tar.gz'):
|
130 |
-
command = f'tar -xzf {filename}'
|
131 |
-
command += f' -C {target_folder}' if target_folder is not None else ''
|
132 |
-
elif filename.endswith('zip'):
|
133 |
-
command = f'unzip {filename}'
|
134 |
-
command += f' -d {target_folder}' if target_folder is not None else ''
|
135 |
-
else:
|
136 |
-
if noarchive_ok:
|
137 |
-
return
|
138 |
-
else:
|
139 |
-
raise ValueError(f'unsuppored file ending of {filename}')
|
140 |
-
|
141 |
-
log.hint(command)
|
142 |
-
result = run(command.split(), stdout=PIPE, stderr=PIPE)
|
143 |
-
if result.returncode != 0:
|
144 |
-
print(result.stdout, result.stderr)
|
145 |
-
|
146 |
-
|
147 |
-
class AttributeDict(dict):
|
148 |
-
"""
|
149 |
-
An extended dictionary that allows access to elements as atttributes and counts
|
150 |
-
these accesses. This way, we know if some attributes were never used.
|
151 |
-
"""
|
152 |
-
|
153 |
-
def __init__(self, *args, **kwargs):
|
154 |
-
from collections import Counter
|
155 |
-
super().__init__(*args, **kwargs)
|
156 |
-
self.__dict__['counter'] = Counter()
|
157 |
-
|
158 |
-
def __getitem__(self, k):
|
159 |
-
self.__dict__['counter'][k] += 1
|
160 |
-
return super().__getitem__(k)
|
161 |
-
|
162 |
-
def __getattr__(self, k):
|
163 |
-
self.__dict__['counter'][k] += 1
|
164 |
-
return super().get(k)
|
165 |
-
|
166 |
-
def __setattr__(self, k, v):
|
167 |
-
return super().__setitem__(k, v)
|
168 |
-
|
169 |
-
def __delattr__(self, k, v):
|
170 |
-
return super().__delitem__(k, v)
|
171 |
-
|
172 |
-
def unused_keys(self, exceptions=()):
|
173 |
-
return [k for k in super().keys() if self.__dict__['counter'][k] == 0 and k not in exceptions]
|
174 |
-
|
175 |
-
def assume_no_unused_keys(self, exceptions=()):
|
176 |
-
if len(self.unused_keys(exceptions=exceptions)) > 0:
|
177 |
-
log.warning('Unused keys:', self.unused_keys(exceptions=exceptions))
|
178 |
-
|
179 |
-
|
180 |
-
def get_attribute(name):
|
181 |
-
import importlib
|
182 |
-
|
183 |
-
if name is None:
|
184 |
-
raise ValueError('The provided attribute is None')
|
185 |
-
|
186 |
-
name_split = name.split('.')
|
187 |
-
mod = importlib.import_module('.'.join(name_split[:-1]))
|
188 |
-
return getattr(mod, name_split[-1])
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
def filter_args(input_args, default_args):
|
193 |
-
|
194 |
-
updated_args = {k: input_args[k] if k in input_args else v for k, v in default_args.items()}
|
195 |
-
used_args = {k: v for k, v in input_args.items() if k in default_args}
|
196 |
-
unused_args = {k: v for k, v in input_args.items() if k not in default_args}
|
197 |
-
|
198 |
-
return AttributeDict(updated_args), AttributeDict(used_args), AttributeDict(unused_args)
|
199 |
-
|
200 |
-
|
201 |
-
def load_model(checkpoint_id, weights_file=None, strict=True, model_args='from_config', with_config=False):
|
202 |
-
|
203 |
-
config = json.load(open(join('logs', checkpoint_id, 'config.json')))
|
204 |
-
|
205 |
-
if model_args != 'from_config' and type(model_args) != dict:
|
206 |
-
raise ValueError('model_args must either be "from_config" or a dictionary of values')
|
207 |
-
|
208 |
-
model_cls = get_attribute(config['model'])
|
209 |
-
|
210 |
-
# load model
|
211 |
-
if model_args == 'from_config':
|
212 |
-
_, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters)
|
213 |
-
|
214 |
-
model = model_cls(**model_args)
|
215 |
-
|
216 |
-
if weights_file is None:
|
217 |
-
weights_file = realpath(join('logs', checkpoint_id, 'weights.pth'))
|
218 |
-
else:
|
219 |
-
weights_file = realpath(join('logs', checkpoint_id, weights_file))
|
220 |
-
|
221 |
-
if isfile(weights_file):
|
222 |
-
weights = torch.load(weights_file)
|
223 |
-
for _, w in weights.items():
|
224 |
-
assert not torch.any(torch.isnan(w)), 'weights contain NaNs'
|
225 |
-
model.load_state_dict(weights, strict=strict)
|
226 |
-
else:
|
227 |
-
raise FileNotFoundError(f'model checkpoint {weights_file} was not found')
|
228 |
-
|
229 |
-
if with_config:
|
230 |
-
return model, config
|
231 |
-
|
232 |
-
return model
|
233 |
-
|
234 |
-
|
235 |
-
class TrainingLogger(object):
|
236 |
-
|
237 |
-
def __init__(self, model, log_dir, config=None, *args):
|
238 |
-
super().__init__()
|
239 |
-
self.model = model
|
240 |
-
self.base_path = join(f'logs/{log_dir}') if log_dir is not None else None
|
241 |
-
|
242 |
-
os.makedirs('logs/', exist_ok=True)
|
243 |
-
os.makedirs(self.base_path, exist_ok=True)
|
244 |
-
|
245 |
-
if config is not None:
|
246 |
-
json.dump(config, open(join(self.base_path, 'config.json'), 'w'))
|
247 |
-
|
248 |
-
def iter(self, i, **kwargs):
|
249 |
-
if i % 100 == 0 and 'loss' in kwargs:
|
250 |
-
loss = kwargs['loss']
|
251 |
-
print(f'iteration {i}: loss {loss:.4f}')
|
252 |
-
|
253 |
-
def save_weights(self, only_trainable=False, weight_file='weights.pth'):
|
254 |
-
if self.model is None:
|
255 |
-
raise AttributeError('You need to provide a model reference when initializing TrainingTracker to save weights.')
|
256 |
-
|
257 |
-
weights_path = join(self.base_path, weight_file)
|
258 |
-
|
259 |
-
weight_dict = self.model.state_dict()
|
260 |
-
|
261 |
-
if only_trainable:
|
262 |
-
weight_dict = {n: weight_dict[n] for n, p in self.model.named_parameters() if p.requires_grad}
|
263 |
-
|
264 |
-
torch.save(weight_dict, weights_path)
|
265 |
-
log.info(f'Saved weights to {weights_path}')
|
266 |
-
|
267 |
-
def __enter__(self):
|
268 |
-
return self
|
269 |
-
|
270 |
-
def __exit__(self, type, value, traceback):
|
271 |
-
""" automatically stop processes if used in a context manager """
|
272 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clipseg/metrics.py
DELETED
@@ -1,271 +0,0 @@
|
|
1 |
-
from torch.functional import Tensor
|
2 |
-
from general_utils import log
|
3 |
-
from collections import defaultdict
|
4 |
-
import numpy as np
|
5 |
-
|
6 |
-
import torch
|
7 |
-
from torch.nn import functional as nnf
|
8 |
-
|
9 |
-
|
10 |
-
class BaseMetric(object):
|
11 |
-
|
12 |
-
def __init__(self, metric_names, pred_range=None, gt_index=0, pred_index=0, eval_intermediate=True,
|
13 |
-
eval_validation=True):
|
14 |
-
self._names = tuple(metric_names)
|
15 |
-
self._eval_intermediate = eval_intermediate
|
16 |
-
self._eval_validation = eval_validation
|
17 |
-
|
18 |
-
self._pred_range = pred_range
|
19 |
-
self._pred_index = pred_index
|
20 |
-
self._gt_index = gt_index
|
21 |
-
|
22 |
-
self.predictions = []
|
23 |
-
self.ground_truths = []
|
24 |
-
|
25 |
-
def eval_intermediate(self):
|
26 |
-
return self._eval_intermediate
|
27 |
-
|
28 |
-
def eval_validation(self):
|
29 |
-
return self._eval_validation
|
30 |
-
|
31 |
-
def names(self):
|
32 |
-
return self._names
|
33 |
-
|
34 |
-
def add(self, predictions, ground_truth):
|
35 |
-
raise NotImplementedError
|
36 |
-
|
37 |
-
def value(self):
|
38 |
-
raise NotImplementedError
|
39 |
-
|
40 |
-
def scores(self):
|
41 |
-
# similar to value but returns dict
|
42 |
-
value = self.value()
|
43 |
-
if type(value) == dict:
|
44 |
-
return value
|
45 |
-
else:
|
46 |
-
assert type(value) in {list, tuple}
|
47 |
-
return list(zip(self.names(), self.value()))
|
48 |
-
|
49 |
-
def _get_pred_gt(self, predictions, ground_truth):
|
50 |
-
pred = predictions[self._pred_index]
|
51 |
-
gt = ground_truth[self._gt_index]
|
52 |
-
|
53 |
-
if self._pred_range is not None:
|
54 |
-
pred = pred[:, self._pred_range[0]: self._pred_range[1]]
|
55 |
-
|
56 |
-
return pred, gt
|
57 |
-
|
58 |
-
|
59 |
-
class FixedIntervalMetrics(BaseMetric):
|
60 |
-
|
61 |
-
def __init__(self, sigmoid=False, ignore_mask=False, resize_to=None,
|
62 |
-
resize_pred=None, n_values=51, custom_threshold=None):
|
63 |
-
|
64 |
-
|
65 |
-
super().__init__(('ap', 'best_fgiou', 'best_miou', 'fgiou0.5', 'fgiou0.1', 'mean_iou_0p5', 'mean_iou_0p1', 'best_biniou', 'biniou_0.5', 'fgiou_thresh'))
|
66 |
-
self.intersections = []
|
67 |
-
self.unions = []
|
68 |
-
# self.threshold = threshold
|
69 |
-
self.sigmoid = sigmoid
|
70 |
-
self.resize_to = resize_to
|
71 |
-
self.resize_pred = resize_pred # resize prediction to match ground truth
|
72 |
-
self.class_count = defaultdict(lambda: 0)
|
73 |
-
self.per_class = defaultdict(lambda : [0,0])
|
74 |
-
self.ignore_mask = ignore_mask
|
75 |
-
self.custom_threshold = custom_threshold
|
76 |
-
|
77 |
-
self.scores_ap = []
|
78 |
-
self.scores_iou = []
|
79 |
-
self.gts, self.preds = [], []
|
80 |
-
self.classes = []
|
81 |
-
|
82 |
-
# [1:-1] ignores 0 and 1
|
83 |
-
self.threshold_values = np.linspace(0, 1, n_values)[1:-1]
|
84 |
-
|
85 |
-
self.metrics = dict(tp=[], fp=[], fn=[], tn=[])
|
86 |
-
|
87 |
-
def add(self, pred, gt):
|
88 |
-
|
89 |
-
pred_batch = pred[0].cpu()
|
90 |
-
|
91 |
-
if self.sigmoid:
|
92 |
-
pred_batch = torch.sigmoid(pred_batch)
|
93 |
-
|
94 |
-
gt_batch = gt[0].cpu()
|
95 |
-
mask_batch = gt[1] if len(gt) > 1 and not self.ignore_mask and gt[1].numel() > 0 else ([None] * len(pred_batch))
|
96 |
-
cls_batch = gt[2] if len(gt) > 2 else [None] * len(pred_batch)
|
97 |
-
|
98 |
-
if self.resize_to is not None:
|
99 |
-
gt_batch = nnf.interpolate(gt_batch, self.resize_to, mode='nearest')
|
100 |
-
pred_batch = nnf.interpolate(pred_batch, self.resize_to, mode='bilinear', align_corners=False)
|
101 |
-
|
102 |
-
if isinstance(cls_batch, torch.Tensor):
|
103 |
-
cls_batch = cls_batch.cpu().numpy().tolist()
|
104 |
-
|
105 |
-
assert len(gt_batch) == len(pred_batch) == len(cls_batch), f'{len(gt_batch)} {len(pred_batch)} {len(cls_batch)}'
|
106 |
-
|
107 |
-
for predictions, ground_truth, mask, cls in zip(pred_batch, gt_batch, mask_batch, cls_batch):
|
108 |
-
|
109 |
-
if self.resize_pred:
|
110 |
-
predictions = nnf.interpolate(predictions.unsqueeze(0).float(), size=ground_truth.size()[-2:], mode='bilinear', align_corners=True)
|
111 |
-
|
112 |
-
p = predictions.flatten()
|
113 |
-
g = ground_truth.flatten()
|
114 |
-
|
115 |
-
assert len(p) == len(g)
|
116 |
-
|
117 |
-
if mask is not None:
|
118 |
-
m = mask.flatten().bool()
|
119 |
-
p = p[m]
|
120 |
-
g = g[m]
|
121 |
-
|
122 |
-
p_sorted = p.sort()
|
123 |
-
p = p_sorted.values
|
124 |
-
g = g[p_sorted.indices]
|
125 |
-
|
126 |
-
tps, fps, fns, tns = [], [], [], []
|
127 |
-
for thresh in self.threshold_values:
|
128 |
-
|
129 |
-
valid = torch.where(p > thresh)[0]
|
130 |
-
if len(valid) > 0:
|
131 |
-
n = int(valid[0])
|
132 |
-
else:
|
133 |
-
n = len(g)
|
134 |
-
|
135 |
-
fn = int(g[:n].sum())
|
136 |
-
tp = int(g[n:].sum())
|
137 |
-
fns += [fn]
|
138 |
-
tns += [n - fn]
|
139 |
-
tps += [tp]
|
140 |
-
fps += [len(g) - n - tp]
|
141 |
-
|
142 |
-
self.metrics['tp'] += [tps]
|
143 |
-
self.metrics['fp'] += [fps]
|
144 |
-
self.metrics['fn'] += [fns]
|
145 |
-
self.metrics['tn'] += [tns]
|
146 |
-
|
147 |
-
self.classes += [cls.item() if isinstance(cls, torch.Tensor) else cls]
|
148 |
-
|
149 |
-
def value(self):
|
150 |
-
|
151 |
-
import time
|
152 |
-
t_start = time.time()
|
153 |
-
|
154 |
-
if set(self.classes) == set([None]):
|
155 |
-
all_classes = None
|
156 |
-
log.warning('classes were not provided, cannot compute mIoU')
|
157 |
-
else:
|
158 |
-
all_classes = set(int(c) for c in self.classes)
|
159 |
-
# log.info(f'compute metrics for {len(all_classes)} classes')
|
160 |
-
|
161 |
-
summed = {k: [sum([self.metrics[k][i][j]
|
162 |
-
for i in range(len(self.metrics[k]))])
|
163 |
-
for j in range(len(self.threshold_values))]
|
164 |
-
for k in self.metrics.keys()}
|
165 |
-
|
166 |
-
if all_classes is not None:
|
167 |
-
|
168 |
-
assert len(self.classes) == len(self.metrics['tp']) == len(self.metrics['fn'])
|
169 |
-
# group by class
|
170 |
-
metrics_by_class = {c: {k: [] for k in self.metrics.keys()} for c in all_classes}
|
171 |
-
for i in range(len(self.metrics['tp'])):
|
172 |
-
for k in self.metrics.keys():
|
173 |
-
metrics_by_class[self.classes[i]][k] += [self.metrics[k][i]]
|
174 |
-
|
175 |
-
# sum over all instances within the classes
|
176 |
-
summed_by_cls = {k: {c: np.array(metrics_by_class[c][k]).sum(0).tolist() for c in all_classes} for k in self.metrics.keys()}
|
177 |
-
|
178 |
-
|
179 |
-
# Compute average precision
|
180 |
-
|
181 |
-
assert (np.array(summed['fp']) + np.array(summed['tp']) ).sum(), 'no predictions is made'
|
182 |
-
|
183 |
-
# only consider values where a prediction is made
|
184 |
-
precisions = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j]) for j in range(len(self.threshold_values))
|
185 |
-
if summed['tp'][j] + summed['fp'][j] > 0]
|
186 |
-
recalls = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fn'][j]) for j in range(len(self.threshold_values))
|
187 |
-
if summed['tp'][j] + summed['fp'][j] > 0]
|
188 |
-
|
189 |
-
# remove duplicate recall-precision-pairs (and sort by recall value)
|
190 |
-
recalls, precisions = zip(*sorted(list(set(zip(recalls, precisions))), key=lambda x: x[0]))
|
191 |
-
|
192 |
-
from scipy.integrate import simps
|
193 |
-
ap = simps(precisions, recalls)
|
194 |
-
|
195 |
-
# Compute best IoU
|
196 |
-
fgiou_scores = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j] + summed['fn'][j]) for j in range(len(self.threshold_values))]
|
197 |
-
|
198 |
-
biniou_scores = [
|
199 |
-
0.5*(summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j] + summed['fn'][j])) +
|
200 |
-
0.5*(summed['tn'][j] / (1 + summed['tn'][j] + summed['fn'][j] + summed['fp'][j]))
|
201 |
-
for j in range(len(self.threshold_values))
|
202 |
-
]
|
203 |
-
|
204 |
-
index_0p5 = self.threshold_values.tolist().index(0.5)
|
205 |
-
index_0p1 = self.threshold_values.tolist().index(0.1)
|
206 |
-
index_0p2 = self.threshold_values.tolist().index(0.2)
|
207 |
-
index_0p3 = self.threshold_values.tolist().index(0.3)
|
208 |
-
|
209 |
-
if self.custom_threshold is not None:
|
210 |
-
index_ct = self.threshold_values.tolist().index(self.custom_threshold)
|
211 |
-
|
212 |
-
if all_classes is not None:
|
213 |
-
# mean IoU
|
214 |
-
mean_ious = [np.mean([summed_by_cls['tp'][c][j] / (1 + summed_by_cls['tp'][c][j] + summed_by_cls['fp'][c][j] + summed_by_cls['fn'][c][j])
|
215 |
-
for c in all_classes])
|
216 |
-
for j in range(len(self.threshold_values))]
|
217 |
-
|
218 |
-
mean_iou_dict = {
|
219 |
-
'miou_best': max(mean_ious) if all_classes is not None else None,
|
220 |
-
'miou_0.5': mean_ious[index_0p5] if all_classes is not None else None,
|
221 |
-
'miou_0.1': mean_ious[index_0p1] if all_classes is not None else None,
|
222 |
-
'miou_0.2': mean_ious[index_0p2] if all_classes is not None else None,
|
223 |
-
'miou_0.3': mean_ious[index_0p3] if all_classes is not None else None,
|
224 |
-
'miou_best_t': self.threshold_values[np.argmax(mean_ious)],
|
225 |
-
'mean_iou_ct': mean_ious[index_ct] if all_classes is not None and self.custom_threshold is not None else None,
|
226 |
-
'mean_iou_scores': mean_ious,
|
227 |
-
}
|
228 |
-
|
229 |
-
print(f'metric computation on {(len(all_classes) if all_classes is not None else "no")} classes took {time.time() - t_start:.1f}s')
|
230 |
-
|
231 |
-
return {
|
232 |
-
'ap': ap,
|
233 |
-
|
234 |
-
# fgiou
|
235 |
-
'fgiou_best': max(fgiou_scores),
|
236 |
-
'fgiou_0.5': fgiou_scores[index_0p5],
|
237 |
-
'fgiou_0.1': fgiou_scores[index_0p1],
|
238 |
-
'fgiou_0.2': fgiou_scores[index_0p2],
|
239 |
-
'fgiou_0.3': fgiou_scores[index_0p3],
|
240 |
-
'fgiou_best_t': self.threshold_values[np.argmax(fgiou_scores)],
|
241 |
-
|
242 |
-
# mean iou
|
243 |
-
|
244 |
-
|
245 |
-
# biniou
|
246 |
-
'biniou_best': max(biniou_scores),
|
247 |
-
'biniou_0.5': biniou_scores[index_0p5],
|
248 |
-
'biniou_0.1': biniou_scores[index_0p1],
|
249 |
-
'biniou_0.2': biniou_scores[index_0p2],
|
250 |
-
'biniou_0.3': biniou_scores[index_0p3],
|
251 |
-
'biniou_best_t': self.threshold_values[np.argmax(biniou_scores)],
|
252 |
-
|
253 |
-
# custom threshold
|
254 |
-
'fgiou_ct': fgiou_scores[index_ct] if self.custom_threshold is not None else None,
|
255 |
-
'biniou_ct': biniou_scores[index_ct] if self.custom_threshold is not None else None,
|
256 |
-
'ct': self.custom_threshold,
|
257 |
-
|
258 |
-
# statistics
|
259 |
-
'fgiou_scores': fgiou_scores,
|
260 |
-
'biniou_scores': biniou_scores,
|
261 |
-
'precision_recall_curve': sorted(list(set(zip(recalls, precisions)))),
|
262 |
-
'summed_statistics': summed,
|
263 |
-
'summed_by_cls_statistics': summed_by_cls,
|
264 |
-
|
265 |
-
**mean_iou_dict
|
266 |
-
}
|
267 |
-
|
268 |
-
# ('ap', 'best_fgiou', 'best_miou', 'fgiou0.5', 'fgiou0.1', 'mean_iou_0p5', 'mean_iou_0p1', 'best_biniou', 'biniou_0.5', 'fgiou_thresh'
|
269 |
-
|
270 |
-
# return ap, best_fgiou, best_mean_iou, iou_0p5, iou_0p1, mean_iou_0p5, mean_iou_0p1, best_biniou, biniou0p5, best_fgiou_thresh, {'summed': summed, 'summed_by_cls': summed_by_cls}
|
271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clipseg/models/clipseg.py
DELETED
@@ -1,552 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
from os.path import basename, dirname, join, isfile
|
3 |
-
import torch
|
4 |
-
from torch import nn
|
5 |
-
from torch.nn import functional as nnf
|
6 |
-
from torch.nn.modules.activation import ReLU
|
7 |
-
|
8 |
-
|
9 |
-
def precompute_clip_vectors():
|
10 |
-
|
11 |
-
from trails.initialization import init_dataset
|
12 |
-
lvis = init_dataset('LVIS_OneShot3', split='train', mask='text_label', image_size=224, aug=1, normalize=True,
|
13 |
-
reduce_factor=None, add_bar=False, negative_prob=0.5)
|
14 |
-
|
15 |
-
all_names = list(lvis.category_names.values())
|
16 |
-
|
17 |
-
import clip
|
18 |
-
from models.clip_prompts import imagenet_templates
|
19 |
-
clip_model = clip.load("ViT-B/32", device='cuda', jit=False)[0]
|
20 |
-
prompt_vectors = {}
|
21 |
-
for name in all_names[:100]:
|
22 |
-
with torch.no_grad():
|
23 |
-
conditionals = [t.format(name).replace('_', ' ') for t in imagenet_templates]
|
24 |
-
text_tokens = clip.tokenize(conditionals).cuda()
|
25 |
-
cond = clip_model.encode_text(text_tokens).cpu()
|
26 |
-
|
27 |
-
for cond, vec in zip(conditionals, cond):
|
28 |
-
prompt_vectors[cond] = vec.cpu()
|
29 |
-
|
30 |
-
import pickle
|
31 |
-
|
32 |
-
pickle.dump(prompt_vectors, open('precomputed_prompt_vectors.pickle', 'wb'))
|
33 |
-
|
34 |
-
|
35 |
-
def get_prompt_list(prompt):
|
36 |
-
if prompt == 'plain':
|
37 |
-
return ['{}']
|
38 |
-
elif prompt == 'fixed':
|
39 |
-
return ['a photo of a {}.']
|
40 |
-
elif prompt == 'shuffle':
|
41 |
-
return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.']
|
42 |
-
elif prompt == 'shuffle+':
|
43 |
-
return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.',
|
44 |
-
'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.',
|
45 |
-
'a bad photo of a {}.', 'a photo of the {}.']
|
46 |
-
elif prompt == 'shuffle_clip':
|
47 |
-
from models.clip_prompts import imagenet_templates
|
48 |
-
return imagenet_templates
|
49 |
-
else:
|
50 |
-
raise ValueError('Invalid value for prompt')
|
51 |
-
|
52 |
-
|
53 |
-
def forward_multihead_attention(x, b, with_aff=False, attn_mask=None):
|
54 |
-
"""
|
55 |
-
Simplified version of multihead attention (taken from torch source code but without tons of if clauses).
|
56 |
-
The mlp and layer norm come from CLIP.
|
57 |
-
x: input.
|
58 |
-
b: multihead attention module.
|
59 |
-
"""
|
60 |
-
|
61 |
-
x_ = b.ln_1(x)
|
62 |
-
q, k, v = nnf.linear(x_, b.attn.in_proj_weight, b.attn.in_proj_bias).chunk(3, dim=-1)
|
63 |
-
tgt_len, bsz, embed_dim = q.size()
|
64 |
-
|
65 |
-
head_dim = embed_dim // b.attn.num_heads
|
66 |
-
scaling = float(head_dim) ** -0.5
|
67 |
-
|
68 |
-
q = q.contiguous().view(tgt_len, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1)
|
69 |
-
k = k.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1)
|
70 |
-
v = v.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1)
|
71 |
-
|
72 |
-
q = q * scaling
|
73 |
-
|
74 |
-
attn_output_weights = torch.bmm(q, k.transpose(1, 2)) # n_heads * batch_size, tokens^2, tokens^2
|
75 |
-
if attn_mask is not None:
|
76 |
-
|
77 |
-
|
78 |
-
attn_mask_type, attn_mask = attn_mask
|
79 |
-
n_heads = attn_output_weights.size(0) // attn_mask.size(0)
|
80 |
-
attn_mask = attn_mask.repeat(n_heads, 1)
|
81 |
-
|
82 |
-
if attn_mask_type == 'cls_token':
|
83 |
-
# the mask only affects similarities compared to the readout-token.
|
84 |
-
attn_output_weights[:, 0, 1:] = attn_output_weights[:, 0, 1:] * attn_mask[None,...]
|
85 |
-
# attn_output_weights[:, 0, 0] = 0*attn_output_weights[:, 0, 0]
|
86 |
-
|
87 |
-
if attn_mask_type == 'all':
|
88 |
-
# print(attn_output_weights.shape, attn_mask[:, None].shape)
|
89 |
-
attn_output_weights[:, 1:, 1:] = attn_output_weights[:, 1:, 1:] * attn_mask[:, None]
|
90 |
-
|
91 |
-
|
92 |
-
attn_output_weights = torch.softmax(attn_output_weights, dim=-1)
|
93 |
-
|
94 |
-
attn_output = torch.bmm(attn_output_weights, v)
|
95 |
-
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
96 |
-
attn_output = b.attn.out_proj(attn_output)
|
97 |
-
|
98 |
-
x = x + attn_output
|
99 |
-
x = x + b.mlp(b.ln_2(x))
|
100 |
-
|
101 |
-
if with_aff:
|
102 |
-
return x, attn_output_weights
|
103 |
-
else:
|
104 |
-
return x
|
105 |
-
|
106 |
-
|
107 |
-
class CLIPDenseBase(nn.Module):
|
108 |
-
|
109 |
-
def __init__(self, version, reduce_cond, reduce_dim, prompt, n_tokens):
|
110 |
-
super().__init__()
|
111 |
-
|
112 |
-
import clip
|
113 |
-
|
114 |
-
# prec = torch.FloatTensor
|
115 |
-
self.clip_model, _ = clip.load(version, device='cpu', jit=False)
|
116 |
-
self.model = self.clip_model.visual
|
117 |
-
|
118 |
-
# if not None, scale conv weights such that we obtain n_tokens.
|
119 |
-
self.n_tokens = n_tokens
|
120 |
-
|
121 |
-
for p in self.clip_model.parameters():
|
122 |
-
p.requires_grad_(False)
|
123 |
-
|
124 |
-
# conditional
|
125 |
-
if reduce_cond is not None:
|
126 |
-
self.reduce_cond = nn.Linear(512, reduce_cond)
|
127 |
-
for p in self.reduce_cond.parameters():
|
128 |
-
p.requires_grad_(False)
|
129 |
-
else:
|
130 |
-
self.reduce_cond = None
|
131 |
-
|
132 |
-
self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
|
133 |
-
self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
|
134 |
-
|
135 |
-
self.reduce = nn.Linear(768, reduce_dim)
|
136 |
-
|
137 |
-
self.prompt_list = get_prompt_list(prompt)
|
138 |
-
|
139 |
-
# precomputed prompts
|
140 |
-
import pickle
|
141 |
-
if isfile('precomputed_prompt_vectors.pickle'):
|
142 |
-
precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb'))
|
143 |
-
self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()}
|
144 |
-
else:
|
145 |
-
self.precomputed_prompts = dict()
|
146 |
-
|
147 |
-
def rescaled_pos_emb(self, new_size):
|
148 |
-
assert len(new_size) == 2
|
149 |
-
|
150 |
-
a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape)
|
151 |
-
b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T
|
152 |
-
return torch.cat([self.model.positional_embedding[:1], b])
|
153 |
-
|
154 |
-
def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None):
|
155 |
-
|
156 |
-
|
157 |
-
with torch.no_grad():
|
158 |
-
|
159 |
-
inp_size = x_inp.shape[2:]
|
160 |
-
|
161 |
-
if self.n_tokens is not None:
|
162 |
-
stride2 = x_inp.shape[2] // self.n_tokens
|
163 |
-
conv_weight2 = nnf.interpolate(self.model.conv1.weight, (stride2, stride2), mode='bilinear', align_corners=True)
|
164 |
-
x = nnf.conv2d(x_inp, conv_weight2, bias=self.model.conv1.bias, stride=stride2, dilation=self.model.conv1.dilation)
|
165 |
-
else:
|
166 |
-
x = self.model.conv1(x_inp) # shape = [*, width, grid, grid]
|
167 |
-
|
168 |
-
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
169 |
-
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
170 |
-
|
171 |
-
x = torch.cat([self.model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
172 |
-
|
173 |
-
standard_n_tokens = 50 if self.model.conv1.kernel_size[0] == 32 else 197
|
174 |
-
|
175 |
-
if x.shape[1] != standard_n_tokens:
|
176 |
-
new_shape = int(math.sqrt(x.shape[1]-1))
|
177 |
-
x = x + self.rescaled_pos_emb((new_shape, new_shape)).to(x.dtype)[None,:,:]
|
178 |
-
else:
|
179 |
-
x = x + self.model.positional_embedding.to(x.dtype)
|
180 |
-
|
181 |
-
x = self.model.ln_pre(x)
|
182 |
-
|
183 |
-
x = x.permute(1, 0, 2) # NLD -> LND
|
184 |
-
|
185 |
-
activations, affinities = [], []
|
186 |
-
for i, res_block in enumerate(self.model.transformer.resblocks):
|
187 |
-
|
188 |
-
if mask is not None:
|
189 |
-
mask_layer, mask_type, mask_tensor = mask
|
190 |
-
if mask_layer == i or mask_layer == 'all':
|
191 |
-
# import ipdb; ipdb.set_trace()
|
192 |
-
size = int(math.sqrt(x.shape[0] - 1))
|
193 |
-
|
194 |
-
attn_mask = (mask_type, nnf.interpolate(mask_tensor.unsqueeze(1).float(), (size, size)).view(mask_tensor.shape[0], size * size))
|
195 |
-
|
196 |
-
else:
|
197 |
-
attn_mask = None
|
198 |
-
else:
|
199 |
-
attn_mask = None
|
200 |
-
|
201 |
-
x, aff_per_head = forward_multihead_attention(x, res_block, with_aff=True, attn_mask=attn_mask)
|
202 |
-
|
203 |
-
if i in extract_layers:
|
204 |
-
affinities += [aff_per_head]
|
205 |
-
|
206 |
-
#if self.n_tokens is not None:
|
207 |
-
# activations += [nnf.interpolate(x, inp_size, mode='bilinear', align_corners=True)]
|
208 |
-
#else:
|
209 |
-
activations += [x]
|
210 |
-
|
211 |
-
if len(extract_layers) > 0 and i == max(extract_layers) and skip:
|
212 |
-
print('early skip')
|
213 |
-
break
|
214 |
-
|
215 |
-
x = x.permute(1, 0, 2) # LND -> NLD
|
216 |
-
x = self.model.ln_post(x[:, 0, :])
|
217 |
-
|
218 |
-
if self.model.proj is not None:
|
219 |
-
x = x @ self.model.proj
|
220 |
-
|
221 |
-
return x, activations, affinities
|
222 |
-
|
223 |
-
def sample_prompts(self, words, prompt_list=None):
|
224 |
-
|
225 |
-
prompt_list = prompt_list if prompt_list is not None else self.prompt_list
|
226 |
-
|
227 |
-
prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
|
228 |
-
prompts = [prompt_list[i] for i in prompt_indices]
|
229 |
-
return [promt.format(w) for promt, w in zip(prompts, words)]
|
230 |
-
|
231 |
-
def get_cond_vec(self, conditional, batch_size):
|
232 |
-
# compute conditional from a single string
|
233 |
-
if conditional is not None and type(conditional) == str:
|
234 |
-
cond = self.compute_conditional(conditional)
|
235 |
-
cond = cond.repeat(batch_size, 1)
|
236 |
-
|
237 |
-
# compute conditional from string list/tuple
|
238 |
-
elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str:
|
239 |
-
assert len(conditional) == batch_size
|
240 |
-
cond = self.compute_conditional(conditional)
|
241 |
-
|
242 |
-
# use conditional directly
|
243 |
-
elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2:
|
244 |
-
cond = conditional
|
245 |
-
|
246 |
-
# compute conditional from image
|
247 |
-
elif conditional is not None and type(conditional) == torch.Tensor:
|
248 |
-
with torch.no_grad():
|
249 |
-
cond, _, _ = self.visual_forward(conditional)
|
250 |
-
else:
|
251 |
-
raise ValueError('invalid conditional')
|
252 |
-
return cond
|
253 |
-
|
254 |
-
def compute_conditional(self, conditional):
|
255 |
-
import clip
|
256 |
-
|
257 |
-
dev = next(self.parameters()).device
|
258 |
-
|
259 |
-
if type(conditional) in {list, tuple}:
|
260 |
-
text_tokens = clip.tokenize(conditional).to(dev)
|
261 |
-
cond = self.clip_model.encode_text(text_tokens)
|
262 |
-
else:
|
263 |
-
if conditional in self.precomputed_prompts:
|
264 |
-
cond = self.precomputed_prompts[conditional].float().to(dev)
|
265 |
-
else:
|
266 |
-
text_tokens = clip.tokenize([conditional]).to(dev)
|
267 |
-
cond = self.clip_model.encode_text(text_tokens)[0]
|
268 |
-
|
269 |
-
if self.shift_vector is not None:
|
270 |
-
return cond + self.shift_vector
|
271 |
-
else:
|
272 |
-
return cond
|
273 |
-
|
274 |
-
|
275 |
-
def clip_load_untrained(version):
|
276 |
-
assert version == 'ViT-B/16'
|
277 |
-
from clip.model import CLIP
|
278 |
-
from clip.clip import _MODELS, _download
|
279 |
-
model = torch.jit.load(_download(_MODELS['ViT-B/16'])).eval()
|
280 |
-
state_dict = model.state_dict()
|
281 |
-
|
282 |
-
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
283 |
-
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
284 |
-
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
285 |
-
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
286 |
-
image_resolution = vision_patch_size * grid_size
|
287 |
-
embed_dim = state_dict["text_projection"].shape[1]
|
288 |
-
context_length = state_dict["positional_embedding"].shape[0]
|
289 |
-
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
290 |
-
transformer_width = state_dict["ln_final.weight"].shape[0]
|
291 |
-
transformer_heads = transformer_width // 64
|
292 |
-
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
293 |
-
|
294 |
-
return CLIP(embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size,
|
295 |
-
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers)
|
296 |
-
|
297 |
-
|
298 |
-
class CLIPDensePredT(CLIPDenseBase):
|
299 |
-
|
300 |
-
def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed',
|
301 |
-
extra_blocks=0, reduce_cond=None, fix_shift=False,
|
302 |
-
learn_trans_conv_only=False, limit_to_clip_only=False, upsample=False,
|
303 |
-
add_calibration=False, rev_activations=False, trans_conv=None, n_tokens=None):
|
304 |
-
|
305 |
-
super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens)
|
306 |
-
# device = 'cpu'
|
307 |
-
|
308 |
-
self.extract_layers = extract_layers
|
309 |
-
self.cond_layer = cond_layer
|
310 |
-
self.limit_to_clip_only = limit_to_clip_only
|
311 |
-
self.process_cond = None
|
312 |
-
self.rev_activations = rev_activations
|
313 |
-
|
314 |
-
depth = len(extract_layers)
|
315 |
-
|
316 |
-
if add_calibration:
|
317 |
-
self.calibration_conds = 1
|
318 |
-
|
319 |
-
self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None
|
320 |
-
|
321 |
-
self.add_activation1 = True
|
322 |
-
|
323 |
-
self.version = version
|
324 |
-
|
325 |
-
self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version]
|
326 |
-
|
327 |
-
if fix_shift:
|
328 |
-
# self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'clip_text_shift_vector.pth')), requires_grad=False)
|
329 |
-
self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'shift_text_to_vis.pth')), requires_grad=False)
|
330 |
-
# self.shift_vector = nn.Parameter(-1*torch.load(join(dirname(basename(__file__)), 'shift2.pth')), requires_grad=False)
|
331 |
-
else:
|
332 |
-
self.shift_vector = None
|
333 |
-
|
334 |
-
if trans_conv is None:
|
335 |
-
trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version]
|
336 |
-
else:
|
337 |
-
# explicitly define transposed conv kernel size
|
338 |
-
trans_conv_ks = (trans_conv, trans_conv)
|
339 |
-
|
340 |
-
self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
|
341 |
-
|
342 |
-
assert len(self.extract_layers) == depth
|
343 |
-
|
344 |
-
self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)])
|
345 |
-
self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))])
|
346 |
-
self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)])
|
347 |
-
|
348 |
-
# refinement and trans conv
|
349 |
-
|
350 |
-
if learn_trans_conv_only:
|
351 |
-
for p in self.parameters():
|
352 |
-
p.requires_grad_(False)
|
353 |
-
|
354 |
-
for p in self.trans_conv.parameters():
|
355 |
-
p.requires_grad_(True)
|
356 |
-
|
357 |
-
self.prompt_list = get_prompt_list(prompt)
|
358 |
-
|
359 |
-
|
360 |
-
def forward(self, inp_image, conditional=None, return_features=False, mask=None):
|
361 |
-
|
362 |
-
assert type(return_features) == bool
|
363 |
-
|
364 |
-
inp_image = inp_image.to(self.model.positional_embedding.device)
|
365 |
-
|
366 |
-
if mask is not None:
|
367 |
-
raise ValueError('mask not supported')
|
368 |
-
|
369 |
-
# x_inp = normalize(inp_image)
|
370 |
-
x_inp = inp_image
|
371 |
-
|
372 |
-
bs, dev = inp_image.shape[0], x_inp.device
|
373 |
-
|
374 |
-
cond = self.get_cond_vec(conditional, bs)
|
375 |
-
|
376 |
-
visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers))
|
377 |
-
|
378 |
-
activation1 = activations[0]
|
379 |
-
activations = activations[1:]
|
380 |
-
|
381 |
-
_activations = activations[::-1] if not self.rev_activations else activations
|
382 |
-
|
383 |
-
a = None
|
384 |
-
for i, (activation, block, reduce) in enumerate(zip(_activations, self.blocks, self.reduces)):
|
385 |
-
|
386 |
-
if a is not None:
|
387 |
-
a = reduce(activation) + a
|
388 |
-
else:
|
389 |
-
a = reduce(activation)
|
390 |
-
|
391 |
-
if i == self.cond_layer:
|
392 |
-
if self.reduce_cond is not None:
|
393 |
-
cond = self.reduce_cond(cond)
|
394 |
-
|
395 |
-
a = self.film_mul(cond) * a + self.film_add(cond)
|
396 |
-
|
397 |
-
a = block(a)
|
398 |
-
|
399 |
-
for block in self.extra_blocks:
|
400 |
-
a = a + block(a)
|
401 |
-
|
402 |
-
a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
|
403 |
-
|
404 |
-
size = int(math.sqrt(a.shape[2]))
|
405 |
-
|
406 |
-
a = a.view(bs, a.shape[1], size, size)
|
407 |
-
|
408 |
-
a = self.trans_conv(a)
|
409 |
-
|
410 |
-
if self.n_tokens is not None:
|
411 |
-
a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear', align_corners=True)
|
412 |
-
|
413 |
-
if self.upsample_proj is not None:
|
414 |
-
a = self.upsample_proj(a)
|
415 |
-
a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear')
|
416 |
-
|
417 |
-
if return_features:
|
418 |
-
return a, visual_q, cond, [activation1] + activations
|
419 |
-
else:
|
420 |
-
return a,
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
class CLIPDensePredTMasked(CLIPDensePredT):
|
425 |
-
|
426 |
-
def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4,
|
427 |
-
prompt='fixed', extra_blocks=0, reduce_cond=None, fix_shift=False, learn_trans_conv_only=False,
|
428 |
-
refine=None, limit_to_clip_only=False, upsample=False, add_calibration=False, n_tokens=None):
|
429 |
-
|
430 |
-
super().__init__(version=version, extract_layers=extract_layers, cond_layer=cond_layer, reduce_dim=reduce_dim,
|
431 |
-
n_heads=n_heads, prompt=prompt, extra_blocks=extra_blocks, reduce_cond=reduce_cond,
|
432 |
-
fix_shift=fix_shift, learn_trans_conv_only=learn_trans_conv_only,
|
433 |
-
limit_to_clip_only=limit_to_clip_only, upsample=upsample, add_calibration=add_calibration,
|
434 |
-
n_tokens=n_tokens)
|
435 |
-
|
436 |
-
def visual_forward_masked(self, img_s, seg_s):
|
437 |
-
return super().visual_forward(img_s, mask=('all', 'cls_token', seg_s))
|
438 |
-
|
439 |
-
def forward(self, img_q, cond_or_img_s, seg_s=None, return_features=False):
|
440 |
-
|
441 |
-
if seg_s is None:
|
442 |
-
cond = cond_or_img_s
|
443 |
-
else:
|
444 |
-
img_s = cond_or_img_s
|
445 |
-
|
446 |
-
with torch.no_grad():
|
447 |
-
cond, _, _ = self.visual_forward_masked(img_s, seg_s)
|
448 |
-
|
449 |
-
return super().forward(img_q, cond, return_features=return_features)
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
class CLIPDenseBaseline(CLIPDenseBase):
|
454 |
-
|
455 |
-
def __init__(self, version='ViT-B/32', cond_layer=0,
|
456 |
-
extract_layer=9, reduce_dim=128, reduce2_dim=None, prompt='fixed',
|
457 |
-
reduce_cond=None, limit_to_clip_only=False, n_tokens=None):
|
458 |
-
|
459 |
-
super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens)
|
460 |
-
device = 'cpu'
|
461 |
-
|
462 |
-
# self.cond_layer = cond_layer
|
463 |
-
self.extract_layer = extract_layer
|
464 |
-
self.limit_to_clip_only = limit_to_clip_only
|
465 |
-
self.shift_vector = None
|
466 |
-
|
467 |
-
self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version]
|
468 |
-
|
469 |
-
assert reduce2_dim is not None
|
470 |
-
|
471 |
-
self.reduce2 = nn.Sequential(
|
472 |
-
nn.Linear(reduce_dim, reduce2_dim),
|
473 |
-
nn.ReLU(),
|
474 |
-
nn.Linear(reduce2_dim, reduce_dim)
|
475 |
-
)
|
476 |
-
|
477 |
-
trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version]
|
478 |
-
self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
|
479 |
-
|
480 |
-
|
481 |
-
def forward(self, inp_image, conditional=None, return_features=False):
|
482 |
-
|
483 |
-
inp_image = inp_image.to(self.model.positional_embedding.device)
|
484 |
-
|
485 |
-
# x_inp = normalize(inp_image)
|
486 |
-
x_inp = inp_image
|
487 |
-
|
488 |
-
bs, dev = inp_image.shape[0], x_inp.device
|
489 |
-
|
490 |
-
cond = self.get_cond_vec(conditional, bs)
|
491 |
-
|
492 |
-
visual_q, activations, affinities = self.visual_forward(x_inp, extract_layers=[self.extract_layer])
|
493 |
-
|
494 |
-
a = activations[0]
|
495 |
-
a = self.reduce(a)
|
496 |
-
a = self.film_mul(cond) * a + self.film_add(cond)
|
497 |
-
|
498 |
-
if self.reduce2 is not None:
|
499 |
-
a = self.reduce2(a)
|
500 |
-
|
501 |
-
# the original model would execute a transformer block here
|
502 |
-
|
503 |
-
a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
|
504 |
-
|
505 |
-
size = int(math.sqrt(a.shape[2]))
|
506 |
-
|
507 |
-
a = a.view(bs, a.shape[1], size, size)
|
508 |
-
a = self.trans_conv(a)
|
509 |
-
|
510 |
-
if return_features:
|
511 |
-
return a, visual_q, cond, activations
|
512 |
-
else:
|
513 |
-
return a,
|
514 |
-
|
515 |
-
|
516 |
-
class CLIPSegMultiLabel(nn.Module):
|
517 |
-
|
518 |
-
def __init__(self, model) -> None:
|
519 |
-
super().__init__()
|
520 |
-
|
521 |
-
from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC
|
522 |
-
|
523 |
-
self.pascal_classes = VOC
|
524 |
-
|
525 |
-
from models.clipseg import CLIPDensePredT
|
526 |
-
from general_utils import load_model
|
527 |
-
# self.clipseg = load_model('rd64-vit16-neg0.2-phrasecut', strict=False)
|
528 |
-
self.clipseg = load_model(model, strict=False)
|
529 |
-
|
530 |
-
self.clipseg.eval()
|
531 |
-
|
532 |
-
def forward(self, x):
|
533 |
-
|
534 |
-
bs = x.shape[0]
|
535 |
-
out = torch.ones(21, bs, 352, 352).to(x.device) * -10
|
536 |
-
|
537 |
-
for class_id, class_name in enumerate(self.pascal_classes):
|
538 |
-
|
539 |
-
fac = 3 if class_name == 'background' else 1
|
540 |
-
|
541 |
-
with torch.no_grad():
|
542 |
-
pred = torch.sigmoid(self.clipseg(x, class_name)[0][:,0]) * fac
|
543 |
-
|
544 |
-
out[class_id] += pred
|
545 |
-
|
546 |
-
|
547 |
-
out = out.permute(1, 0, 2, 3)
|
548 |
-
|
549 |
-
return out
|
550 |
-
|
551 |
-
# construct output tensor
|
552 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clipseg/models/vitseg.py
DELETED
@@ -1,286 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
from posixpath import basename, dirname, join
|
3 |
-
# import clip
|
4 |
-
from clip.model import convert_weights
|
5 |
-
import torch
|
6 |
-
import json
|
7 |
-
from torch import nn
|
8 |
-
from torch.nn import functional as nnf
|
9 |
-
from torch.nn.modules import activation
|
10 |
-
from torch.nn.modules.activation import ReLU
|
11 |
-
from torchvision import transforms
|
12 |
-
|
13 |
-
normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
|
14 |
-
|
15 |
-
from torchvision.models import ResNet
|
16 |
-
|
17 |
-
|
18 |
-
def process_prompts(conditional, prompt_list, conditional_map):
|
19 |
-
# DEPRECATED
|
20 |
-
|
21 |
-
# randomly sample a synonym
|
22 |
-
words = [conditional_map[int(i)] for i in conditional]
|
23 |
-
words = [syns[torch.multinomial(torch.ones(len(syns)), 1, replacement=True).item()] for syns in words]
|
24 |
-
words = [w.replace('_', ' ') for w in words]
|
25 |
-
|
26 |
-
if prompt_list is not None:
|
27 |
-
prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
|
28 |
-
prompts = [prompt_list[i] for i in prompt_indices]
|
29 |
-
else:
|
30 |
-
prompts = ['a photo of {}'] * (len(words))
|
31 |
-
|
32 |
-
return [promt.format(w) for promt, w in zip(prompts, words)]
|
33 |
-
|
34 |
-
|
35 |
-
class VITDenseBase(nn.Module):
|
36 |
-
|
37 |
-
def rescaled_pos_emb(self, new_size):
|
38 |
-
assert len(new_size) == 2
|
39 |
-
|
40 |
-
a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape)
|
41 |
-
b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T
|
42 |
-
return torch.cat([self.model.positional_embedding[:1], b])
|
43 |
-
|
44 |
-
def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None):
|
45 |
-
|
46 |
-
with torch.no_grad():
|
47 |
-
|
48 |
-
x_inp = nnf.interpolate(x_inp, (384, 384))
|
49 |
-
|
50 |
-
x = self.model.patch_embed(x_inp)
|
51 |
-
cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
52 |
-
if self.model.dist_token is None:
|
53 |
-
x = torch.cat((cls_token, x), dim=1)
|
54 |
-
else:
|
55 |
-
x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
|
56 |
-
x = self.model.pos_drop(x + self.model.pos_embed)
|
57 |
-
|
58 |
-
activations = []
|
59 |
-
for i, block in enumerate(self.model.blocks):
|
60 |
-
x = block(x)
|
61 |
-
|
62 |
-
if i in extract_layers:
|
63 |
-
# permute to be compatible with CLIP
|
64 |
-
activations += [x.permute(1,0,2)]
|
65 |
-
|
66 |
-
x = self.model.norm(x)
|
67 |
-
x = self.model.head(self.model.pre_logits(x[:, 0]))
|
68 |
-
|
69 |
-
# again for CLIP compatibility
|
70 |
-
# x = x.permute(1, 0, 2)
|
71 |
-
|
72 |
-
return x, activations, None
|
73 |
-
|
74 |
-
def sample_prompts(self, words, prompt_list=None):
|
75 |
-
|
76 |
-
prompt_list = prompt_list if prompt_list is not None else self.prompt_list
|
77 |
-
|
78 |
-
prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
|
79 |
-
prompts = [prompt_list[i] for i in prompt_indices]
|
80 |
-
return [promt.format(w) for promt, w in zip(prompts, words)]
|
81 |
-
|
82 |
-
def get_cond_vec(self, conditional, batch_size):
|
83 |
-
# compute conditional from a single string
|
84 |
-
if conditional is not None and type(conditional) == str:
|
85 |
-
cond = self.compute_conditional(conditional)
|
86 |
-
cond = cond.repeat(batch_size, 1)
|
87 |
-
|
88 |
-
# compute conditional from string list/tuple
|
89 |
-
elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str:
|
90 |
-
assert len(conditional) == batch_size
|
91 |
-
cond = self.compute_conditional(conditional)
|
92 |
-
|
93 |
-
# use conditional directly
|
94 |
-
elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2:
|
95 |
-
cond = conditional
|
96 |
-
|
97 |
-
# compute conditional from image
|
98 |
-
elif conditional is not None and type(conditional) == torch.Tensor:
|
99 |
-
with torch.no_grad():
|
100 |
-
cond, _, _ = self.visual_forward(conditional)
|
101 |
-
else:
|
102 |
-
raise ValueError('invalid conditional')
|
103 |
-
return cond
|
104 |
-
|
105 |
-
def compute_conditional(self, conditional):
|
106 |
-
import clip
|
107 |
-
|
108 |
-
dev = next(self.parameters()).device
|
109 |
-
|
110 |
-
if type(conditional) in {list, tuple}:
|
111 |
-
text_tokens = clip.tokenize(conditional).to(dev)
|
112 |
-
cond = self.clip_model.encode_text(text_tokens)
|
113 |
-
else:
|
114 |
-
if conditional in self.precomputed_prompts:
|
115 |
-
cond = self.precomputed_prompts[conditional].float().to(dev)
|
116 |
-
else:
|
117 |
-
text_tokens = clip.tokenize([conditional]).to(dev)
|
118 |
-
cond = self.clip_model.encode_text(text_tokens)[0]
|
119 |
-
|
120 |
-
return cond
|
121 |
-
|
122 |
-
|
123 |
-
class VITDensePredT(VITDenseBase):
|
124 |
-
|
125 |
-
def __init__(self, extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed',
|
126 |
-
depth=3, extra_blocks=0, reduce_cond=None, fix_shift=False,
|
127 |
-
learn_trans_conv_only=False, refine=None, limit_to_clip_only=False, upsample=False,
|
128 |
-
add_calibration=False, process_cond=None, not_pretrained=False):
|
129 |
-
super().__init__()
|
130 |
-
# device = 'cpu'
|
131 |
-
|
132 |
-
self.extract_layers = extract_layers
|
133 |
-
self.cond_layer = cond_layer
|
134 |
-
self.limit_to_clip_only = limit_to_clip_only
|
135 |
-
self.process_cond = None
|
136 |
-
|
137 |
-
if add_calibration:
|
138 |
-
self.calibration_conds = 1
|
139 |
-
|
140 |
-
self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None
|
141 |
-
|
142 |
-
self.add_activation1 = True
|
143 |
-
|
144 |
-
import timm
|
145 |
-
self.model = timm.create_model('vit_base_patch16_384', pretrained=True)
|
146 |
-
self.model.head = nn.Linear(768, 512 if reduce_cond is None else reduce_cond)
|
147 |
-
|
148 |
-
for p in self.model.parameters():
|
149 |
-
p.requires_grad_(False)
|
150 |
-
|
151 |
-
import clip
|
152 |
-
self.clip_model, _ = clip.load('ViT-B/16', device='cpu', jit=False)
|
153 |
-
# del self.clip_model.visual
|
154 |
-
|
155 |
-
|
156 |
-
self.token_shape = (14, 14)
|
157 |
-
|
158 |
-
# conditional
|
159 |
-
if reduce_cond is not None:
|
160 |
-
self.reduce_cond = nn.Linear(512, reduce_cond)
|
161 |
-
for p in self.reduce_cond.parameters():
|
162 |
-
p.requires_grad_(False)
|
163 |
-
else:
|
164 |
-
self.reduce_cond = None
|
165 |
-
|
166 |
-
# self.film = AVAILABLE_BLOCKS['film'](512, 128)
|
167 |
-
self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
|
168 |
-
self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
|
169 |
-
|
170 |
-
# DEPRECATED
|
171 |
-
# self.conditional_map = {c['id']: c['synonyms'] for c in json.load(open(cond_map))}
|
172 |
-
|
173 |
-
assert len(self.extract_layers) == depth
|
174 |
-
|
175 |
-
self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)])
|
176 |
-
self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))])
|
177 |
-
self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)])
|
178 |
-
|
179 |
-
trans_conv_ks = (16, 16)
|
180 |
-
self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
|
181 |
-
|
182 |
-
# refinement and trans conv
|
183 |
-
|
184 |
-
if learn_trans_conv_only:
|
185 |
-
for p in self.parameters():
|
186 |
-
p.requires_grad_(False)
|
187 |
-
|
188 |
-
for p in self.trans_conv.parameters():
|
189 |
-
p.requires_grad_(True)
|
190 |
-
|
191 |
-
if prompt == 'fixed':
|
192 |
-
self.prompt_list = ['a photo of a {}.']
|
193 |
-
elif prompt == 'shuffle':
|
194 |
-
self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.']
|
195 |
-
elif prompt == 'shuffle+':
|
196 |
-
self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.',
|
197 |
-
'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.',
|
198 |
-
'a bad photo of a {}.', 'a photo of the {}.']
|
199 |
-
elif prompt == 'shuffle_clip':
|
200 |
-
from models.clip_prompts import imagenet_templates
|
201 |
-
self.prompt_list = imagenet_templates
|
202 |
-
|
203 |
-
if process_cond is not None:
|
204 |
-
if process_cond == 'clamp' or process_cond[0] == 'clamp':
|
205 |
-
|
206 |
-
val = process_cond[1] if type(process_cond) in {list, tuple} else 0.2
|
207 |
-
|
208 |
-
def clamp_vec(x):
|
209 |
-
return torch.clamp(x, -val, val)
|
210 |
-
|
211 |
-
self.process_cond = clamp_vec
|
212 |
-
|
213 |
-
elif process_cond.endswith('.pth'):
|
214 |
-
|
215 |
-
shift = torch.load(process_cond)
|
216 |
-
def add_shift(x):
|
217 |
-
return x + shift.to(x.device)
|
218 |
-
|
219 |
-
self.process_cond = add_shift
|
220 |
-
|
221 |
-
import pickle
|
222 |
-
precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb'))
|
223 |
-
self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()}
|
224 |
-
|
225 |
-
|
226 |
-
def forward(self, inp_image, conditional=None, return_features=False, mask=None):
|
227 |
-
|
228 |
-
assert type(return_features) == bool
|
229 |
-
|
230 |
-
# inp_image = inp_image.to(self.model.positional_embedding.device)
|
231 |
-
|
232 |
-
if mask is not None:
|
233 |
-
raise ValueError('mask not supported')
|
234 |
-
|
235 |
-
# x_inp = normalize(inp_image)
|
236 |
-
x_inp = inp_image
|
237 |
-
|
238 |
-
bs, dev = inp_image.shape[0], x_inp.device
|
239 |
-
|
240 |
-
inp_image_size = inp_image.shape[2:]
|
241 |
-
|
242 |
-
cond = self.get_cond_vec(conditional, bs)
|
243 |
-
|
244 |
-
visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers))
|
245 |
-
|
246 |
-
activation1 = activations[0]
|
247 |
-
activations = activations[1:]
|
248 |
-
|
249 |
-
a = None
|
250 |
-
for i, (activation, block, reduce) in enumerate(zip(activations[::-1], self.blocks, self.reduces)):
|
251 |
-
|
252 |
-
if a is not None:
|
253 |
-
a = reduce(activation) + a
|
254 |
-
else:
|
255 |
-
a = reduce(activation)
|
256 |
-
|
257 |
-
if i == self.cond_layer:
|
258 |
-
if self.reduce_cond is not None:
|
259 |
-
cond = self.reduce_cond(cond)
|
260 |
-
|
261 |
-
a = self.film_mul(cond) * a + self.film_add(cond)
|
262 |
-
|
263 |
-
a = block(a)
|
264 |
-
|
265 |
-
for block in self.extra_blocks:
|
266 |
-
a = a + block(a)
|
267 |
-
|
268 |
-
a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
|
269 |
-
|
270 |
-
size = int(math.sqrt(a.shape[2]))
|
271 |
-
|
272 |
-
a = a.view(bs, a.shape[1], size, size)
|
273 |
-
|
274 |
-
if self.trans_conv is not None:
|
275 |
-
a = self.trans_conv(a)
|
276 |
-
|
277 |
-
if self.upsample_proj is not None:
|
278 |
-
a = self.upsample_proj(a)
|
279 |
-
a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear')
|
280 |
-
|
281 |
-
a = nnf.interpolate(a, inp_image_size)
|
282 |
-
|
283 |
-
if return_features:
|
284 |
-
return a, visual_q, cond, [activation1] + activations
|
285 |
-
else:
|
286 |
-
return a,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clipseg/overview.png
DELETED
Binary file (54 kB)
|
|
clipseg/score.py
DELETED
@@ -1,453 +0,0 @@
|
|
1 |
-
from torch.functional import Tensor
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import inspect
|
5 |
-
import json
|
6 |
-
import yaml
|
7 |
-
import time
|
8 |
-
import sys
|
9 |
-
|
10 |
-
from general_utils import log
|
11 |
-
|
12 |
-
import numpy as np
|
13 |
-
from os.path import expanduser, join, isfile, realpath
|
14 |
-
|
15 |
-
from torch.utils.data import DataLoader
|
16 |
-
|
17 |
-
from metrics import FixedIntervalMetrics
|
18 |
-
|
19 |
-
from general_utils import load_model, log, score_config_from_cli_args, AttributeDict, get_attribute, filter_args
|
20 |
-
|
21 |
-
|
22 |
-
DATASET_CACHE = dict()
|
23 |
-
|
24 |
-
def load_model(checkpoint_id, weights_file=None, strict=True, model_args='from_config', with_config=False, ignore_weights=False):
|
25 |
-
|
26 |
-
config = json.load(open(join('logs', checkpoint_id, 'config.json')))
|
27 |
-
|
28 |
-
if model_args != 'from_config' and type(model_args) != dict:
|
29 |
-
raise ValueError('model_args must either be "from_config" or a dictionary of values')
|
30 |
-
|
31 |
-
model_cls = get_attribute(config['model'])
|
32 |
-
|
33 |
-
# load model
|
34 |
-
if model_args == 'from_config':
|
35 |
-
_, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters)
|
36 |
-
|
37 |
-
model = model_cls(**model_args)
|
38 |
-
|
39 |
-
if weights_file is None:
|
40 |
-
weights_file = realpath(join('logs', checkpoint_id, 'weights.pth'))
|
41 |
-
else:
|
42 |
-
weights_file = realpath(join('logs', checkpoint_id, weights_file))
|
43 |
-
|
44 |
-
if isfile(weights_file) and not ignore_weights:
|
45 |
-
weights = torch.load(weights_file)
|
46 |
-
for _, w in weights.items():
|
47 |
-
assert not torch.any(torch.isnan(w)), 'weights contain NaNs'
|
48 |
-
model.load_state_dict(weights, strict=strict)
|
49 |
-
else:
|
50 |
-
if not ignore_weights:
|
51 |
-
raise FileNotFoundError(f'model checkpoint {weights_file} was not found')
|
52 |
-
|
53 |
-
if with_config:
|
54 |
-
return model, config
|
55 |
-
|
56 |
-
return model
|
57 |
-
|
58 |
-
|
59 |
-
def compute_shift2(model, datasets, seed=123, repetitions=1):
|
60 |
-
""" computes shift """
|
61 |
-
|
62 |
-
model.eval()
|
63 |
-
model.cuda()
|
64 |
-
|
65 |
-
import random
|
66 |
-
random.seed(seed)
|
67 |
-
|
68 |
-
preds, gts = [], []
|
69 |
-
for i_dataset, dataset in enumerate(datasets):
|
70 |
-
|
71 |
-
loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False, drop_last=False)
|
72 |
-
|
73 |
-
max_iterations = int(repetitions * len(dataset.dataset.data_list))
|
74 |
-
|
75 |
-
with torch.no_grad():
|
76 |
-
|
77 |
-
i, losses = 0, []
|
78 |
-
for i_all, (data_x, data_y) in enumerate(loader):
|
79 |
-
|
80 |
-
data_x = [v.cuda(non_blocking=True) if v is not None else v for v in data_x]
|
81 |
-
data_y = [v.cuda(non_blocking=True) if v is not None else v for v in data_y]
|
82 |
-
|
83 |
-
pred, = model(data_x[0], data_x[1], data_x[2])
|
84 |
-
preds += [pred.detach()]
|
85 |
-
gts += [data_y]
|
86 |
-
|
87 |
-
i += 1
|
88 |
-
if max_iterations and i >= max_iterations:
|
89 |
-
break
|
90 |
-
|
91 |
-
from metrics import FixedIntervalMetrics
|
92 |
-
n_values = 51
|
93 |
-
thresholds = np.linspace(0, 1, n_values)[1:-1]
|
94 |
-
metric = FixedIntervalMetrics(resize_pred=True, sigmoid=True, n_values=n_values)
|
95 |
-
|
96 |
-
for p, y in zip(preds, gts):
|
97 |
-
metric.add(p.unsqueeze(1), y)
|
98 |
-
|
99 |
-
best_idx = np.argmax(metric.value()['fgiou_scores'])
|
100 |
-
best_thresh = thresholds[best_idx]
|
101 |
-
|
102 |
-
return best_thresh
|
103 |
-
|
104 |
-
|
105 |
-
def get_cached_pascal_pfe(split, config):
|
106 |
-
from datasets.pfe_dataset import PFEPascalWrapper
|
107 |
-
try:
|
108 |
-
dataset = DATASET_CACHE[(split, config.image_size, config.label_support, config.mask)]
|
109 |
-
except KeyError:
|
110 |
-
dataset = PFEPascalWrapper(mode='val', split=split, mask=config.mask, image_size=config.image_size, label_support=config.label_support)
|
111 |
-
DATASET_CACHE[(split, config.image_size, config.label_support, config.mask)] = dataset
|
112 |
-
return dataset
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
def main():
|
118 |
-
config, train_checkpoint_id = score_config_from_cli_args()
|
119 |
-
|
120 |
-
metrics = score(config, train_checkpoint_id, None)
|
121 |
-
|
122 |
-
for dataset in metrics.keys():
|
123 |
-
for k in metrics[dataset]:
|
124 |
-
if type(metrics[dataset][k]) in {float, int}:
|
125 |
-
print(dataset, f'{k:<16} {metrics[dataset][k]:.3f}')
|
126 |
-
|
127 |
-
|
128 |
-
def score(config, train_checkpoint_id, train_config):
|
129 |
-
|
130 |
-
config = AttributeDict(config)
|
131 |
-
|
132 |
-
print(config)
|
133 |
-
|
134 |
-
# use training dataset and loss
|
135 |
-
train_config = AttributeDict(json.load(open(f'logs/{train_checkpoint_id}/config.json')))
|
136 |
-
|
137 |
-
cp_str = f'_{config.iteration_cp}' if config.iteration_cp is not None else ''
|
138 |
-
|
139 |
-
|
140 |
-
model_cls = get_attribute(train_config['model'])
|
141 |
-
|
142 |
-
_, model_args, _ = filter_args(train_config, inspect.signature(model_cls).parameters)
|
143 |
-
|
144 |
-
model_args = {**model_args, **{k: config[k] for k in ['process_cond', 'fix_shift'] if k in config}}
|
145 |
-
|
146 |
-
strict_models = {'ConditionBase4', 'PFENetWrapper'}
|
147 |
-
model = load_model(train_checkpoint_id, strict=model_cls.__name__ in strict_models, model_args=model_args,
|
148 |
-
weights_file=f'weights{cp_str}.pth', )
|
149 |
-
|
150 |
-
|
151 |
-
model.eval()
|
152 |
-
model.cuda()
|
153 |
-
|
154 |
-
metric_args = dict()
|
155 |
-
|
156 |
-
if 'threshold' in config:
|
157 |
-
if config.metric.split('.')[-1] == 'SkLearnMetrics':
|
158 |
-
metric_args['threshold'] = config.threshold
|
159 |
-
|
160 |
-
if 'resize_to' in config:
|
161 |
-
metric_args['resize_to'] = config.resize_to
|
162 |
-
|
163 |
-
if 'sigmoid' in config:
|
164 |
-
metric_args['sigmoid'] = config.sigmoid
|
165 |
-
|
166 |
-
if 'custom_threshold' in config:
|
167 |
-
metric_args['custom_threshold'] = config.custom_threshold
|
168 |
-
|
169 |
-
if config.test_dataset == 'pascal':
|
170 |
-
|
171 |
-
loss_fn = get_attribute(train_config.loss)
|
172 |
-
# assume that if no split is specified in train_config, test on all splits,
|
173 |
-
|
174 |
-
if 'splits' in config:
|
175 |
-
splits = config.splits
|
176 |
-
else:
|
177 |
-
if 'split' in train_config and type(train_config.split) == int:
|
178 |
-
# unless train_config has a split set, in that case assume train mode in training
|
179 |
-
splits = [train_config.split]
|
180 |
-
assert train_config.mode == 'train'
|
181 |
-
else:
|
182 |
-
splits = [0,1,2,3]
|
183 |
-
|
184 |
-
log.info('Test on these splits', splits)
|
185 |
-
|
186 |
-
scores = dict()
|
187 |
-
for split in splits:
|
188 |
-
|
189 |
-
shift = config.shift if 'shift' in config else 0
|
190 |
-
|
191 |
-
# automatic shift
|
192 |
-
if shift == 'auto':
|
193 |
-
shift_compute_t = time.time()
|
194 |
-
shift = compute_shift2(model, [get_cached_pascal_pfe(s, config) for s in range(4) if s != split], repetitions=config.compute_shift_fac)
|
195 |
-
log.info(f'Best threshold is {shift}, computed on splits: {[s for s in range(4) if s != split]}, took {time.time() - shift_compute_t:.1f}s')
|
196 |
-
|
197 |
-
dataset = get_cached_pascal_pfe(split, config)
|
198 |
-
|
199 |
-
eval_start_t = time.time()
|
200 |
-
|
201 |
-
loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False, drop_last=False)
|
202 |
-
|
203 |
-
assert config.batch_size is None or config.batch_size == 1, 'When PFE Dataset is used, batch size must be 1'
|
204 |
-
|
205 |
-
metric = FixedIntervalMetrics(resize_pred=True, sigmoid=True, custom_threshold=shift, **metric_args)
|
206 |
-
|
207 |
-
with torch.no_grad():
|
208 |
-
|
209 |
-
i, losses = 0, []
|
210 |
-
for i_all, (data_x, data_y) in enumerate(loader):
|
211 |
-
|
212 |
-
data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x]
|
213 |
-
data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y]
|
214 |
-
|
215 |
-
if config.mask == 'separate': # for old CondBase model
|
216 |
-
pred, = model(data_x[0], data_x[1], data_x[2])
|
217 |
-
else:
|
218 |
-
# assert config.mask in {'text', 'highlight'}
|
219 |
-
pred, _, _, _ = model(data_x[0], data_x[1], return_features=True)
|
220 |
-
|
221 |
-
# loss = loss_fn(pred, data_y[0])
|
222 |
-
metric.add(pred.unsqueeze(1) + shift, data_y)
|
223 |
-
|
224 |
-
# losses += [float(loss)]
|
225 |
-
|
226 |
-
i += 1
|
227 |
-
if config.max_iterations and i >= config.max_iterations:
|
228 |
-
break
|
229 |
-
|
230 |
-
#scores[split] = {m: s for m, s in zip(metric.names(), metric.value())}
|
231 |
-
|
232 |
-
log.info(f'Dataset length: {len(dataset)}, took {time.time() - eval_start_t:.1f}s to evaluate.')
|
233 |
-
|
234 |
-
print(metric.value()['mean_iou_scores'])
|
235 |
-
|
236 |
-
scores[split] = metric.scores()
|
237 |
-
|
238 |
-
log.info(f'Completed split {split}')
|
239 |
-
|
240 |
-
key_prefix = config['name'] if 'name' in config else 'pas'
|
241 |
-
|
242 |
-
all_keys = set.intersection(*[set(v.keys()) for v in scores.values()])
|
243 |
-
|
244 |
-
valid_keys = [k for k in all_keys if all(v[k] is not None and isinstance(v[k], (int, float, np.float)) for v in scores.values())]
|
245 |
-
|
246 |
-
return {key_prefix: {k: np.mean([s[k] for s in scores.values()]) for k in valid_keys}}
|
247 |
-
|
248 |
-
|
249 |
-
if config.test_dataset == 'coco':
|
250 |
-
from datasets.coco_wrapper import COCOWrapper
|
251 |
-
|
252 |
-
coco_dataset = COCOWrapper('test', fold=train_config.fold, image_size=train_config.image_size, mask=config.mask,
|
253 |
-
with_class_label=True)
|
254 |
-
|
255 |
-
log.info('Dataset length', len(coco_dataset))
|
256 |
-
loader = DataLoader(coco_dataset, batch_size=config.batch_size, num_workers=2, shuffle=False, drop_last=False)
|
257 |
-
|
258 |
-
metric = get_attribute(config.metric)(resize_pred=True, **metric_args)
|
259 |
-
|
260 |
-
shift = config.shift if 'shift' in config else 0
|
261 |
-
|
262 |
-
with torch.no_grad():
|
263 |
-
|
264 |
-
i, losses = 0, []
|
265 |
-
for i_all, (data_x, data_y) in enumerate(loader):
|
266 |
-
data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x]
|
267 |
-
data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y]
|
268 |
-
|
269 |
-
if config.mask == 'separate': # for old CondBase model
|
270 |
-
pred, = model(data_x[0], data_x[1], data_x[2])
|
271 |
-
else:
|
272 |
-
# assert config.mask in {'text', 'highlight'}
|
273 |
-
pred, _, _, _ = model(data_x[0], data_x[1], return_features=True)
|
274 |
-
|
275 |
-
metric.add([pred + shift], data_y)
|
276 |
-
|
277 |
-
i += 1
|
278 |
-
if config.max_iterations and i >= config.max_iterations:
|
279 |
-
break
|
280 |
-
|
281 |
-
key_prefix = config['name'] if 'name' in config else 'coco'
|
282 |
-
return {key_prefix: metric.scores()}
|
283 |
-
#return {key_prefix: {k: v for k, v in zip(metric.names(), metric.value())}}
|
284 |
-
|
285 |
-
|
286 |
-
if config.test_dataset == 'phrasecut':
|
287 |
-
from datasets.phrasecut import PhraseCut
|
288 |
-
|
289 |
-
only_visual = config.only_visual is not None and config.only_visual
|
290 |
-
with_visual = config.with_visual is not None and config.with_visual
|
291 |
-
|
292 |
-
dataset = PhraseCut('test',
|
293 |
-
image_size=train_config.image_size,
|
294 |
-
mask=config.mask,
|
295 |
-
with_visual=with_visual, only_visual=only_visual, aug_crop=False,
|
296 |
-
aug_color=False)
|
297 |
-
|
298 |
-
loader = DataLoader(dataset, batch_size=config.batch_size, num_workers=2, shuffle=False, drop_last=False)
|
299 |
-
metric = get_attribute(config.metric)(resize_pred=True, **metric_args)
|
300 |
-
|
301 |
-
shift = config.shift if 'shift' in config else 0
|
302 |
-
|
303 |
-
|
304 |
-
with torch.no_grad():
|
305 |
-
|
306 |
-
i, losses = 0, []
|
307 |
-
for i_all, (data_x, data_y) in enumerate(loader):
|
308 |
-
data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x]
|
309 |
-
data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y]
|
310 |
-
|
311 |
-
pred, _, _, _ = model(data_x[0], data_x[1], return_features=True)
|
312 |
-
metric.add([pred + shift], data_y)
|
313 |
-
|
314 |
-
i += 1
|
315 |
-
if config.max_iterations and i >= config.max_iterations:
|
316 |
-
break
|
317 |
-
|
318 |
-
key_prefix = config['name'] if 'name' in config else 'phrasecut'
|
319 |
-
return {key_prefix: metric.scores()}
|
320 |
-
#return {key_prefix: {k: v for k, v in zip(metric.names(), metric.value())}}
|
321 |
-
|
322 |
-
if config.test_dataset == 'pascal_zs':
|
323 |
-
from third_party.JoEm.model.metric import Evaluator
|
324 |
-
from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC
|
325 |
-
from datasets.pascal_zeroshot import PascalZeroShot, PASCAL_VOC_CLASSES_ZS
|
326 |
-
|
327 |
-
from models.clipseg import CLIPSegMultiLabel
|
328 |
-
|
329 |
-
n_unseen = train_config.remove_classes[1]
|
330 |
-
|
331 |
-
pz = PascalZeroShot('val', n_unseen, image_size=352)
|
332 |
-
m = CLIPSegMultiLabel(model=train_config.name).cuda()
|
333 |
-
m.eval();
|
334 |
-
|
335 |
-
print(len(pz), n_unseen)
|
336 |
-
print('training removed', [c for class_set in PASCAL_VOC_CLASSES_ZS[:n_unseen // 2] for c in class_set])
|
337 |
-
|
338 |
-
print('unseen', [VOC[i] for i in get_unseen_idx(n_unseen)])
|
339 |
-
print('seen', [VOC[i] for i in get_seen_idx(n_unseen)])
|
340 |
-
|
341 |
-
loader = DataLoader(pz, batch_size=8)
|
342 |
-
evaluator = Evaluator(21, get_unseen_idx(n_unseen), get_seen_idx(n_unseen))
|
343 |
-
|
344 |
-
for i, (data_x, data_y) in enumerate(loader):
|
345 |
-
pred = m(data_x[0].cuda())
|
346 |
-
evaluator.add_batch(data_y[0].numpy(), pred.argmax(1).cpu().detach().numpy())
|
347 |
-
|
348 |
-
if config.max_iter is not None and i > config.max_iter:
|
349 |
-
break
|
350 |
-
|
351 |
-
scores = evaluator.Mean_Intersection_over_Union()
|
352 |
-
key_prefix = config['name'] if 'name' in config else 'pas_zs'
|
353 |
-
|
354 |
-
return {key_prefix: {k: scores[k] for k in ['seen', 'unseen', 'harmonic', 'overall']}}
|
355 |
-
|
356 |
-
elif config.test_dataset in {'same_as_training', 'affordance'}:
|
357 |
-
loss_fn = get_attribute(train_config.loss)
|
358 |
-
|
359 |
-
metric_cls = get_attribute(config.metric)
|
360 |
-
metric = metric_cls(**metric_args)
|
361 |
-
|
362 |
-
if config.test_dataset == 'same_as_training':
|
363 |
-
dataset_cls = get_attribute(train_config.dataset)
|
364 |
-
elif config.test_dataset == 'affordance':
|
365 |
-
dataset_cls = get_attribute('datasets.lvis_oneshot3.LVIS_Affordance')
|
366 |
-
dataset_name = 'aff'
|
367 |
-
else:
|
368 |
-
dataset_cls = get_attribute('datasets.lvis_oneshot3.LVIS_OneShot')
|
369 |
-
dataset_name = 'lvis'
|
370 |
-
|
371 |
-
_, dataset_args, _ = filter_args(config, inspect.signature(dataset_cls).parameters)
|
372 |
-
|
373 |
-
dataset_args['image_size'] = train_config.image_size # explicitly use training image size for evaluation
|
374 |
-
|
375 |
-
if model.__class__.__name__ == 'PFENetWrapper':
|
376 |
-
dataset_args['image_size'] = config.image_size
|
377 |
-
|
378 |
-
log.info('init dataset', str(dataset_cls))
|
379 |
-
dataset = dataset_cls(**dataset_args)
|
380 |
-
|
381 |
-
log.info(f'Score on {model.__class__.__name__} on {dataset_cls.__name__}')
|
382 |
-
|
383 |
-
data_loader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size, shuffle=config.shuffle)
|
384 |
-
|
385 |
-
# explicitly set prompts
|
386 |
-
if config.prompt == 'plain':
|
387 |
-
model.prompt_list = ['{}']
|
388 |
-
elif config.prompt == 'fixed':
|
389 |
-
model.prompt_list = ['a photo of a {}.']
|
390 |
-
elif config.prompt == 'shuffle':
|
391 |
-
model.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.']
|
392 |
-
elif config.prompt == 'shuffle_clip':
|
393 |
-
from models.clip_prompts import imagenet_templates
|
394 |
-
model.prompt_list = imagenet_templates
|
395 |
-
|
396 |
-
config.assume_no_unused_keys(exceptions=['max_iterations'])
|
397 |
-
|
398 |
-
t_start = time.time()
|
399 |
-
|
400 |
-
with torch.no_grad(): # TODO: switch to inference_mode (torch 1.9)
|
401 |
-
i, losses = 0, []
|
402 |
-
for data_x, data_y in data_loader:
|
403 |
-
|
404 |
-
data_x = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_x]
|
405 |
-
data_y = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_y]
|
406 |
-
|
407 |
-
if model.__class__.__name__ in {'ConditionBase4', 'PFENetWrapper'}:
|
408 |
-
pred, = model(data_x[0], data_x[1], data_x[2])
|
409 |
-
visual_q = None
|
410 |
-
else:
|
411 |
-
pred, visual_q, _, _ = model(data_x[0], data_x[1], return_features=True)
|
412 |
-
|
413 |
-
loss = loss_fn(pred, data_y[0])
|
414 |
-
|
415 |
-
metric.add([pred], data_y)
|
416 |
-
|
417 |
-
losses += [float(loss)]
|
418 |
-
|
419 |
-
i += 1
|
420 |
-
if config.max_iterations and i >= config.max_iterations:
|
421 |
-
break
|
422 |
-
|
423 |
-
# scores = {m: s for m, s in zip(metric.names(), metric.value())}
|
424 |
-
scores = metric.scores()
|
425 |
-
|
426 |
-
keys = set(scores.keys())
|
427 |
-
if dataset.negative_prob > 0 and 'mIoU' in keys:
|
428 |
-
keys.remove('mIoU')
|
429 |
-
|
430 |
-
name_mask = dataset.mask.replace('text_label', 'txt')[:3]
|
431 |
-
name_neg = '' if dataset.negative_prob == 0 else '_' + str(dataset.negative_prob)
|
432 |
-
|
433 |
-
score_name = config.name if 'name' in config else f'{dataset_name}_{name_mask}{name_neg}'
|
434 |
-
|
435 |
-
scores = {score_name: {k: v for k,v in scores.items() if k in keys}}
|
436 |
-
scores[score_name].update({'test_loss': np.mean(losses)})
|
437 |
-
|
438 |
-
log.info(f'Evaluation took {time.time() - t_start:.1f}s')
|
439 |
-
|
440 |
-
return scores
|
441 |
-
else:
|
442 |
-
raise ValueError('invalid test dataset')
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
if __name__ == '__main__':
|
453 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clipseg/setup.py
DELETED
@@ -1,30 +0,0 @@
|
|
1 |
-
from setuptools import setup
|
2 |
-
|
3 |
-
with open("README.md", "r", encoding="utf-8") as readme_file:
|
4 |
-
readme = readme_file.read()
|
5 |
-
|
6 |
-
requirements = [
|
7 |
-
"numpy",
|
8 |
-
"scipy",
|
9 |
-
"matplotlib",
|
10 |
-
"torch",
|
11 |
-
"torchvision",
|
12 |
-
"opencv-python",
|
13 |
-
"CLIP @ git+https://github.com/openai/CLIP.git"
|
14 |
-
]
|
15 |
-
|
16 |
-
setup(
|
17 |
-
name='clipseg',
|
18 |
-
packages=['clipseg'],
|
19 |
-
package_dir={'clipseg': 'models'},
|
20 |
-
package_data={'clipseg': [
|
21 |
-
"../weights/*.pth",
|
22 |
-
]},
|
23 |
-
version='0.0.1',
|
24 |
-
url='https://github.com/timojl/clipseg',
|
25 |
-
python_requires='>=3.9',
|
26 |
-
install_requires=requirements,
|
27 |
-
description='This repository contains the code used in the paper "Image Segmentation Using Text and Image Prompts".',
|
28 |
-
long_description=readme,
|
29 |
-
long_description_content_type="text/markdown",
|
30 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clipseg/training.py
DELETED
@@ -1,266 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import inspect
|
3 |
-
import json
|
4 |
-
import yaml
|
5 |
-
import math
|
6 |
-
import os
|
7 |
-
import sys
|
8 |
-
|
9 |
-
from general_utils import log
|
10 |
-
|
11 |
-
import numpy as np
|
12 |
-
from functools import partial
|
13 |
-
from os.path import expanduser, join, isfile, basename
|
14 |
-
|
15 |
-
from torch.cuda.amp import autocast, GradScaler
|
16 |
-
from torch.optim.lr_scheduler import LambdaLR
|
17 |
-
from contextlib import nullcontext
|
18 |
-
from torch.utils.data import DataLoader
|
19 |
-
|
20 |
-
from general_utils import TrainingLogger, get_attribute, filter_args, log, training_config_from_cli_args
|
21 |
-
|
22 |
-
|
23 |
-
def cosine_warmup_lr(i, warmup=10, max_iter=90):
|
24 |
-
""" Cosine LR with Warmup """
|
25 |
-
if i < warmup:
|
26 |
-
return (i+1)/(warmup+1)
|
27 |
-
else:
|
28 |
-
return 0.5 + 0.5*math.cos(math.pi*(((i-warmup)/(max_iter- warmup))))
|
29 |
-
|
30 |
-
|
31 |
-
def validate(model, dataset, config):
|
32 |
-
data_loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False)
|
33 |
-
|
34 |
-
metric_class, use_metric = config.val_metric_class, config.use_val_metric
|
35 |
-
loss_fn = get_attribute(config.loss)
|
36 |
-
|
37 |
-
model.eval()
|
38 |
-
model.cuda()
|
39 |
-
|
40 |
-
if metric_class is not None:
|
41 |
-
metric = get_attribute(metric_class)()
|
42 |
-
|
43 |
-
with torch.no_grad():
|
44 |
-
|
45 |
-
i, losses = 0, []
|
46 |
-
for data_x, data_y in data_loader:
|
47 |
-
|
48 |
-
data_x = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_x]
|
49 |
-
data_y = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_y]
|
50 |
-
|
51 |
-
prompts = model.sample_prompts(data_x[1], prompt_list=('a photo of a {}',))
|
52 |
-
pred, visual_q, _, _ = model(data_x[0], prompts, return_features=True)
|
53 |
-
|
54 |
-
if metric_class is not None:
|
55 |
-
metric.add([pred], data_y)
|
56 |
-
|
57 |
-
# pred = model(data_x[0], prompts)
|
58 |
-
# loss = loss_fn(pred[0], data_y[0])
|
59 |
-
loss = loss_fn(pred, data_y[0])
|
60 |
-
losses += [float(loss)]
|
61 |
-
|
62 |
-
i += 1
|
63 |
-
|
64 |
-
if config.val_max_iterations is not None and i > config.val_max_iterations:
|
65 |
-
break
|
66 |
-
|
67 |
-
if use_metric is None:
|
68 |
-
return np.mean(losses), {}, False
|
69 |
-
else:
|
70 |
-
metric_scores = {m: s for m, s in zip(metric.names(), metric.value())} if metric is not None else {}
|
71 |
-
return np.mean(losses), metric_scores, True
|
72 |
-
|
73 |
-
|
74 |
-
def main():
|
75 |
-
|
76 |
-
config = training_config_from_cli_args()
|
77 |
-
|
78 |
-
val_interval, best_val_loss, best_val_score = config.val_interval, float('inf'), float('-inf')
|
79 |
-
|
80 |
-
model_cls = get_attribute(config.model)
|
81 |
-
_, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters)
|
82 |
-
model = model_cls(**model_args).cuda()
|
83 |
-
|
84 |
-
dataset_cls = get_attribute(config.dataset)
|
85 |
-
_, dataset_args, _ = filter_args(config, inspect.signature(dataset_cls).parameters)
|
86 |
-
|
87 |
-
dataset = dataset_cls(**dataset_args)
|
88 |
-
|
89 |
-
log.info(f'Train dataset {dataset.__class__.__name__} (length: {len(dataset)})')
|
90 |
-
|
91 |
-
if val_interval is not None:
|
92 |
-
dataset_val_args = {k[4:]: v for k,v in config.items() if k.startswith('val_') and k != 'val_interval'}
|
93 |
-
_, dataset_val_args, _ = filter_args(dataset_val_args, inspect.signature(dataset_cls).parameters)
|
94 |
-
print('val args', {**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args})
|
95 |
-
|
96 |
-
dataset_val = dataset_cls(**{**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args})
|
97 |
-
|
98 |
-
# optimizer
|
99 |
-
opt_cls = get_attribute(config.optimizer)
|
100 |
-
if config.optimize == 'torch.optim.SGD':
|
101 |
-
opt_args = {'momentum': config.momentum if 'momentum' in config else 0}
|
102 |
-
else:
|
103 |
-
opt_args = {}
|
104 |
-
opt = opt_cls(model.parameters(), lr=config.lr, **opt_args)
|
105 |
-
|
106 |
-
if config.lr_scheduler == 'cosine':
|
107 |
-
assert config.T_max is not None and config.eta_min is not None
|
108 |
-
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, config.T_max, config.eta_min)
|
109 |
-
elif config.lr_scheduler == 'warmup_cosine':
|
110 |
-
lr_scheduler = LambdaLR(opt, partial(cosine_warmup_lr, max_iter=(config.max_iterations), warmup=config.warmup))
|
111 |
-
else:
|
112 |
-
lr_scheduler = None
|
113 |
-
|
114 |
-
batch_size, max_iterations = config.batch_size, config.max_iterations
|
115 |
-
|
116 |
-
loss_fn = get_attribute(config.loss)
|
117 |
-
|
118 |
-
if config.amp:
|
119 |
-
log.info('Using AMP')
|
120 |
-
autocast_fn = autocast
|
121 |
-
scaler = GradScaler()
|
122 |
-
else:
|
123 |
-
autocast_fn, scaler = nullcontext, None
|
124 |
-
|
125 |
-
|
126 |
-
save_only_trainable = True
|
127 |
-
data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=4)
|
128 |
-
|
129 |
-
# disable config when hyperparam. opt. to avoid writing logs.
|
130 |
-
tracker_config = config if not config.hyperparameter_optimization else None
|
131 |
-
|
132 |
-
with TrainingLogger(log_dir=config.name, model=model, config=tracker_config) as logger:
|
133 |
-
|
134 |
-
i = 0
|
135 |
-
while True:
|
136 |
-
for data_x, data_y in data_loader:
|
137 |
-
|
138 |
-
# between caption and output feature.
|
139 |
-
# 1. Sample random captions
|
140 |
-
# 2. Check alignment with CLIP
|
141 |
-
|
142 |
-
# randomly mix text and visual support conditionals
|
143 |
-
if config.mix:
|
144 |
-
|
145 |
-
assert config.mask.startswith('text_and')
|
146 |
-
|
147 |
-
with autocast_fn():
|
148 |
-
# data_x[1] = text label
|
149 |
-
prompts = model.sample_prompts(data_x[1])
|
150 |
-
|
151 |
-
# model.clip_model()
|
152 |
-
|
153 |
-
text_cond = model.compute_conditional(prompts)
|
154 |
-
if model.__class__.__name__ == 'CLIPDensePredTMasked':
|
155 |
-
# when mask=='separate'
|
156 |
-
visual_s_cond, _, _ = model.visual_forward_masked(data_x[2].cuda(), data_x[3].cuda())
|
157 |
-
else:
|
158 |
-
# data_x[2] = visual prompt
|
159 |
-
visual_s_cond, _, _ = model.visual_forward(data_x[2].cuda())
|
160 |
-
|
161 |
-
max_txt = config.mix_text_max if config.mix_text_max is not None else 1
|
162 |
-
batch_size = text_cond.shape[0]
|
163 |
-
|
164 |
-
# sample weights for each element in batch
|
165 |
-
text_weights = torch.distributions.Uniform(config.mix_text_min, max_txt).sample((batch_size,))[:, None]
|
166 |
-
text_weights = text_weights.cuda()
|
167 |
-
|
168 |
-
if dataset.__class__.__name__ == 'PhraseCut':
|
169 |
-
# give full weight to text where support_image is invalid
|
170 |
-
visual_is_valid = data_x[4] if model.__class__.__name__ == 'CLIPDensePredTMasked' else data_x[3]
|
171 |
-
text_weights = torch.max(text_weights[:,0], 1 - visual_is_valid.float().cuda()).unsqueeze(1)
|
172 |
-
|
173 |
-
cond = text_cond * text_weights + visual_s_cond * (1 - text_weights)
|
174 |
-
|
175 |
-
else:
|
176 |
-
# no mix
|
177 |
-
|
178 |
-
if model.__class__.__name__ == 'CLIPDensePredTMasked':
|
179 |
-
# compute conditional vector using CLIP masking
|
180 |
-
with autocast_fn():
|
181 |
-
assert config.mask == 'separate'
|
182 |
-
cond, _, _ = model.visual_forward_masked(data_x[1].cuda(), data_x[2].cuda())
|
183 |
-
else:
|
184 |
-
cond = data_x[1]
|
185 |
-
if isinstance(cond, torch.Tensor):
|
186 |
-
cond = cond.cuda()
|
187 |
-
|
188 |
-
with autocast_fn():
|
189 |
-
visual_q = None
|
190 |
-
|
191 |
-
pred, visual_q, _, _ = model(data_x[0].cuda(), cond, return_features=True)
|
192 |
-
|
193 |
-
loss = loss_fn(pred, data_y[0].cuda())
|
194 |
-
|
195 |
-
if torch.isnan(loss) or torch.isinf(loss):
|
196 |
-
# skip if loss is nan
|
197 |
-
log.warning('Training stopped due to inf/nan loss.')
|
198 |
-
sys.exit(-1)
|
199 |
-
|
200 |
-
extra_loss = 0
|
201 |
-
loss += extra_loss
|
202 |
-
|
203 |
-
opt.zero_grad()
|
204 |
-
|
205 |
-
if scaler is None:
|
206 |
-
loss.backward()
|
207 |
-
opt.step()
|
208 |
-
else:
|
209 |
-
scaler.scale(loss).backward()
|
210 |
-
scaler.step(opt)
|
211 |
-
scaler.update()
|
212 |
-
|
213 |
-
if lr_scheduler is not None:
|
214 |
-
lr_scheduler.step()
|
215 |
-
if i % 2000 == 0:
|
216 |
-
current_lr = [g['lr'] for g in opt.param_groups][0]
|
217 |
-
log.info(f'current lr: {current_lr:.5f} ({len(opt.param_groups)} parameter groups)')
|
218 |
-
|
219 |
-
logger.iter(i=i, loss=loss)
|
220 |
-
i += 1
|
221 |
-
|
222 |
-
if i >= max_iterations:
|
223 |
-
|
224 |
-
if not isfile(join(logger.base_path, 'weights.pth')):
|
225 |
-
# only write if no weights were already written
|
226 |
-
logger.save_weights(only_trainable=save_only_trainable)
|
227 |
-
|
228 |
-
sys.exit(0)
|
229 |
-
|
230 |
-
|
231 |
-
if config.checkpoint_iterations is not None and i in config.checkpoint_iterations:
|
232 |
-
logger.save_weights(only_trainable=save_only_trainable, weight_file=f'weights_{i}.pth')
|
233 |
-
|
234 |
-
|
235 |
-
if val_interval is not None and i % val_interval == val_interval - 1:
|
236 |
-
|
237 |
-
val_loss, val_scores, maximize = validate(model, dataset_val, config)
|
238 |
-
|
239 |
-
if len(val_scores) > 0:
|
240 |
-
|
241 |
-
score_str = f', scores: ' + ', '.join(f'{k}: {v}' for k, v in val_scores.items())
|
242 |
-
|
243 |
-
if maximize and val_scores[config.use_val_metric] > best_val_score:
|
244 |
-
logger.save_weights(only_trainable=save_only_trainable)
|
245 |
-
best_val_score = val_scores[config.use_val_metric]
|
246 |
-
|
247 |
-
elif not maximize and val_scores[config.use_val_metric] < best_val_score:
|
248 |
-
logger.save_weights(only_trainable=save_only_trainable)
|
249 |
-
best_val_score = val_scores[config.use_val_metric]
|
250 |
-
|
251 |
-
else:
|
252 |
-
score_str = ''
|
253 |
-
# if no score is used, fall back to loss
|
254 |
-
if val_loss < best_val_loss:
|
255 |
-
logger.save_weights(only_trainable=save_only_trainable)
|
256 |
-
best_val_loss = val_loss
|
257 |
-
|
258 |
-
log.info(f'Validation loss: {val_loss}' + score_str)
|
259 |
-
logger.iter(i=i, val_loss=val_loss, extra_loss=float(extra_loss), **val_scores)
|
260 |
-
model.train()
|
261 |
-
|
262 |
-
print('epoch complete')
|
263 |
-
|
264 |
-
|
265 |
-
if __name__ == '__main__':
|
266 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
init_image.png
DELETED
Binary file (407 kB)
|
|
inpainting.py
DELETED
@@ -1,194 +0,0 @@
|
|
1 |
-
import inspect
|
2 |
-
from typing import List, Optional, Union
|
3 |
-
|
4 |
-
import numpy as np
|
5 |
-
import torch
|
6 |
-
|
7 |
-
import PIL
|
8 |
-
from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, PNDMScheduler, UNet2DConditionModel
|
9 |
-
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
10 |
-
from tqdm.auto import tqdm
|
11 |
-
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
12 |
-
|
13 |
-
|
14 |
-
def preprocess_image(image):
|
15 |
-
w, h = image.size
|
16 |
-
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
17 |
-
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
18 |
-
image = np.array(image).astype(np.float32) / 255.0
|
19 |
-
image = image[None].transpose(0, 3, 1, 2)
|
20 |
-
image = torch.from_numpy(image)
|
21 |
-
return 2.0 * image - 1.0
|
22 |
-
|
23 |
-
|
24 |
-
def preprocess_mask(mask):
|
25 |
-
mask = mask.convert("L")
|
26 |
-
w, h = mask.size
|
27 |
-
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
28 |
-
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
|
29 |
-
mask = np.array(mask).astype(np.float32) / 255.0
|
30 |
-
mask = np.tile(mask, (4, 1, 1))
|
31 |
-
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
32 |
-
mask = 1 - mask # repaint white, keep black
|
33 |
-
mask = torch.from_numpy(mask)
|
34 |
-
return mask
|
35 |
-
|
36 |
-
class StableDiffusionInpaintingPipeline(DiffusionPipeline):
|
37 |
-
def __init__(
|
38 |
-
self,
|
39 |
-
vae: AutoencoderKL,
|
40 |
-
text_encoder: CLIPTextModel,
|
41 |
-
tokenizer: CLIPTokenizer,
|
42 |
-
unet: UNet2DConditionModel,
|
43 |
-
scheduler: Union[DDIMScheduler, PNDMScheduler],
|
44 |
-
safety_checker: StableDiffusionSafetyChecker,
|
45 |
-
feature_extractor: CLIPFeatureExtractor,
|
46 |
-
):
|
47 |
-
super().__init__()
|
48 |
-
scheduler = scheduler.set_format("pt")
|
49 |
-
self.register_modules(
|
50 |
-
vae=vae,
|
51 |
-
text_encoder=text_encoder,
|
52 |
-
tokenizer=tokenizer,
|
53 |
-
unet=unet,
|
54 |
-
scheduler=scheduler,
|
55 |
-
safety_checker=safety_checker,
|
56 |
-
feature_extractor=feature_extractor,
|
57 |
-
)
|
58 |
-
|
59 |
-
@torch.no_grad()
|
60 |
-
def __call__(
|
61 |
-
self,
|
62 |
-
prompt: Union[str, List[str]],
|
63 |
-
init_image: torch.FloatTensor,
|
64 |
-
mask_image: torch.FloatTensor,
|
65 |
-
strength: float = 0.8,
|
66 |
-
num_inference_steps: Optional[int] = 50,
|
67 |
-
guidance_scale: Optional[float] = 7.5,
|
68 |
-
eta: Optional[float] = 0.0,
|
69 |
-
generator: Optional[torch.Generator] = None,
|
70 |
-
output_type: Optional[str] = "pil",
|
71 |
-
):
|
72 |
-
|
73 |
-
if isinstance(prompt, str):
|
74 |
-
batch_size = 1
|
75 |
-
elif isinstance(prompt, list):
|
76 |
-
batch_size = len(prompt)
|
77 |
-
else:
|
78 |
-
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
79 |
-
|
80 |
-
if strength < 0 or strength > 1:
|
81 |
-
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
82 |
-
|
83 |
-
# set timesteps
|
84 |
-
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
|
85 |
-
extra_set_kwargs = {}
|
86 |
-
offset = 0
|
87 |
-
if accepts_offset:
|
88 |
-
offset = 1
|
89 |
-
extra_set_kwargs["offset"] = 1
|
90 |
-
|
91 |
-
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
92 |
-
|
93 |
-
# preprocess image
|
94 |
-
init_image = preprocess_image(init_image).to(self.device)
|
95 |
-
|
96 |
-
# encode the init image into latents and scale the latents
|
97 |
-
init_latent_dist = self.vae.encode(init_image).latent_dist
|
98 |
-
init_latents = init_latent_dist.sample(generator=generator)
|
99 |
-
init_latents = 0.18215 * init_latents
|
100 |
-
|
101 |
-
# prepare init_latents noise to latents
|
102 |
-
init_latents = torch.cat([init_latents] * batch_size)
|
103 |
-
init_latents_orig = init_latents
|
104 |
-
|
105 |
-
# preprocess mask
|
106 |
-
mask = preprocess_mask(mask_image).to(self.device)
|
107 |
-
mask = torch.cat([mask] * batch_size)
|
108 |
-
|
109 |
-
# check sizes
|
110 |
-
if not mask.shape == init_latents.shape:
|
111 |
-
raise ValueError(f"The mask and init_image should be the same size!")
|
112 |
-
|
113 |
-
# get the original timestep using init_timestep
|
114 |
-
init_timestep = int(num_inference_steps * strength) + offset
|
115 |
-
init_timestep = min(init_timestep, num_inference_steps)
|
116 |
-
timesteps = self.scheduler.timesteps[-init_timestep]
|
117 |
-
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
|
118 |
-
|
119 |
-
# add noise to latents using the timesteps
|
120 |
-
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
|
121 |
-
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
122 |
-
|
123 |
-
# get prompt text embeddings
|
124 |
-
text_input = self.tokenizer(
|
125 |
-
prompt,
|
126 |
-
padding="max_length",
|
127 |
-
max_length=self.tokenizer.model_max_length,
|
128 |
-
truncation=True,
|
129 |
-
return_tensors="pt",
|
130 |
-
)
|
131 |
-
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
132 |
-
|
133 |
-
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
134 |
-
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
135 |
-
# corresponds to doing no classifier free guidance.
|
136 |
-
do_classifier_free_guidance = guidance_scale > 1.0
|
137 |
-
# get unconditional embeddings for classifier free guidance
|
138 |
-
if do_classifier_free_guidance:
|
139 |
-
max_length = text_input.input_ids.shape[-1]
|
140 |
-
uncond_input = self.tokenizer(
|
141 |
-
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
|
142 |
-
)
|
143 |
-
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
144 |
-
|
145 |
-
# For classifier free guidance, we need to do two forward passes.
|
146 |
-
# Here we concatenate the unconditional and text embeddings into a single batch
|
147 |
-
# to avoid doing two forward passes
|
148 |
-
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
149 |
-
|
150 |
-
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
151 |
-
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
152 |
-
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
153 |
-
# and should be between [0, 1]
|
154 |
-
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
155 |
-
extra_step_kwargs = {}
|
156 |
-
if accepts_eta:
|
157 |
-
extra_step_kwargs["eta"] = eta
|
158 |
-
|
159 |
-
latents = init_latents
|
160 |
-
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
161 |
-
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
|
162 |
-
# expand the latents if we are doing classifier free guidance
|
163 |
-
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
164 |
-
|
165 |
-
# predict the noise residual
|
166 |
-
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
|
167 |
-
|
168 |
-
# perform guidance
|
169 |
-
if do_classifier_free_guidance:
|
170 |
-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
171 |
-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
172 |
-
|
173 |
-
# compute the previous noisy sample x_t -> x_t-1
|
174 |
-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
|
175 |
-
|
176 |
-
# masking
|
177 |
-
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
|
178 |
-
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
179 |
-
|
180 |
-
# scale and decode the image latents with vae
|
181 |
-
latents = 1 / 0.18215 * latents
|
182 |
-
image = self.vae.decode(latents).sample
|
183 |
-
|
184 |
-
image = (image / 2 + 0.5).clamp(0, 1)
|
185 |
-
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
186 |
-
|
187 |
-
# run safety checker
|
188 |
-
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
189 |
-
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
|
190 |
-
|
191 |
-
if output_type == "pil":
|
192 |
-
image = self.numpy_to_pil(image)
|
193 |
-
|
194 |
-
return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask_image.png
DELETED
Binary file (11.9 kB)
|
|
requirements.txt
CHANGED
@@ -1,11 +1,8 @@
|
|
1 |
-
--extra-index-url https://download.pytorch.org/whl/cu113
|
2 |
-
torch
|
3 |
torchvision
|
4 |
-
diffusers
|
5 |
-
transformers
|
|
|
6 |
ftfy
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
opencv-python
|
11 |
-
git+https://github.com/openai/CLIP.git
|
|
|
|
|
|
|
1 |
torchvision
|
2 |
+
diffusers
|
3 |
+
transformers
|
4 |
+
accelerate
|
5 |
ftfy
|
6 |
+
scipy
|
7 |
+
imageio
|
8 |
+
invisible_watermark
|
|
|
|