akhaliq HF staff commited on
Commit
b7133f4
Β·
1 Parent(s): 6b27ec4
app.py CHANGED
@@ -1,54 +1,174 @@
1
- from diffusers import StableDiffusionInpaintPipeline
2
  import gradio as gr
3
- import numpy as np
4
- import imageio
5
- from PIL import Image
6
  from io import BytesIO
 
 
 
 
7
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- MY_SECRET_TOKEN=os.environ.get('HF_TOKEN_SD')
10
-
11
-
12
- print("hello sylvain")
13
-
14
- YOUR_TOKEN=MY_SECRET_TOKEN
15
-
16
- device="cpu"
17
-
18
- pipe = StableDiffusionInpaintPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=YOUR_TOKEN)
19
- pipe.to(device)
20
-
21
- source_img = gr.Image(source="upload", type="numpy", tool="sketch", elem_id="source_container");
22
- gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(grid=[2], height="auto")
23
-
24
- def resize(height,img):
25
- baseheight = height
26
- img = Image.open(img)
27
- hpercent = (baseheight/float(img.size[1]))
28
- wsize = int((float(img.size[0])*float(hpercent)))
29
- img = img.resize((wsize,baseheight), Image.Resampling.LANCZOS)
30
- return img
31
-
32
- def predict(source_img, prompt):
33
- imageio.imwrite("data.png", source_img["image"])
34
- imageio.imwrite("data_mask.png", source_img["mask"])
35
-
36
- src = resize(512, "data.png")
37
- src.save("src.png")
38
- mask = resize(512, "data_mask.png")
39
- mask.save("mask.png")
40
-
41
- images_list = pipe([prompt] * 2, init_image=src, mask_image=mask, strength=0.75)
42
- images = []
43
- safe_image = Image.open(r"unsafe.png")
44
- for i, image in enumerate(images_list["sample"]):
45
- if(images_list["nsfw_content_detected"][i]):
46
- images.append(safe_image)
47
- else:
48
- images.append(image)
49
- return images
50
-
51
- custom_css="style.css"
52
- title="InPainting Stable Diffusion CPU"
53
- description="Inpainting Stable Diffusion example using CPU and HF token. <br />Warning: Slow process... ~5/10 min inference time. <b>NSFW filter enabled.</b><br />Please use 512*512 square image as input to avoid memory error !"
54
- gr.Interface(fn=predict, inputs=[source_img, "text"], outputs=gallery, css=custom_css, title=title, description=description, allow_flagging="manual").launch(enable_queue=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+
 
 
3
  from io import BytesIO
4
+ import requests
5
+ import PIL
6
+ from PIL import Image
7
+ import numpy as np
8
  import os
9
+ import uuid
10
+ import torch
11
+ from torch import autocast
12
+ import cv2
13
+ from matplotlib import pyplot as plt
14
+ from inpainting import StableDiffusionInpaintingPipeline
15
+ from torchvision import transforms
16
+ from clipseg.models.clipseg import CLIPDensePredT
17
+
18
+ auth_token = os.environ.get("API_TOKEN") or True
19
+
20
+ def download_image(url):
21
+ response = requests.get(url)
22
+ return PIL.Image.open(BytesIO(response.content)).convert("RGB")
23
+
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ pipe = StableDiffusionInpaintingPipeline.from_pretrained(
26
+ "CompVis/stable-diffusion-v1-4",
27
+ revision="fp16",
28
+ torch_dtype=torch.float16,
29
+ use_auth_token=auth_token,
30
+ ).to(device)
31
+
32
+ model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)
33
+ model.eval()
34
+ model.load_state_dict(torch.load('./clipseg/weights/rd64-uni.pth', map_location=torch.device('cuda')), strict=False)
35
+
36
+ transform = transforms.Compose([
37
+ transforms.ToTensor(),
38
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
39
+ transforms.Resize((512, 512)),
40
+ ])
41
+
42
+ def predict(radio, dict, word_mask, prompt=""):
43
+ if(radio == "draw a mask above"):
44
+ with autocast("cuda"):
45
+ init_image = dict["image"].convert("RGB").resize((512, 512))
46
+ mask = dict["mask"].convert("RGB").resize((512, 512))
47
+ else:
48
+ img = transform(dict["image"]).unsqueeze(0)
49
+ word_masks = [word_mask]
50
+ with torch.no_grad():
51
+ preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
52
+ init_image = dict['image'].convert('RGB').resize((512, 512))
53
+ filename = f"{uuid.uuid4()}.png"
54
+ plt.imsave(filename,torch.sigmoid(preds[0][0]))
55
+ img2 = cv2.imread(filename)
56
+ gray_image = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
57
+ (thresh, bw_image) = cv2.threshold(gray_image, 100, 255, cv2.THRESH_BINARY)
58
+ cv2.cvtColor(bw_image, cv2.COLOR_BGR2RGB)
59
+ mask = Image.fromarray(np.uint8(bw_image)).convert('RGB')
60
+ os.remove(filename)
61
+ with autocast("cuda"):
62
+ images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"]
63
+ return images[0]
64
+
65
+ # examples = [[dict(image="init_image.png", mask="mask_image.png"), "A panda sitting on a bench"]]
66
+ css = '''
67
+ .container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
68
+ #image_upload{min-height:400px}
69
+ #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 400px}
70
+ #mask_radio .gr-form{background:transparent; border: none}
71
+ #word_mask{margin-top: .75em !important}
72
+ #word_mask textarea:disabled{opacity: 0.3}
73
+ .footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
74
+ .footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
75
+ .dark .footer {border-color: #303030}
76
+ .dark .footer>p {background: #0b0f19}
77
+ .acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
78
+ #image_upload .touch-none{display: flex}
79
+ '''
80
+ def swap_word_mask(radio_option):
81
+ if(radio_option == "type what to mask below"):
82
+ return gr.update(interactive=True, placeholder="A cat")
83
+ else:
84
+ return gr.update(interactive=False, placeholder="Disabled")
85
 
86
+ image_blocks = gr.Blocks(css=css)
87
+ with image_blocks as demo:
88
+ gr.HTML(
89
+ """
90
+ <div style="text-align: center; max-width: 650px; margin: 0 auto;">
91
+ <div
92
+ style="
93
+ display: inline-flex;
94
+ align-items: center;
95
+ gap: 0.8rem;
96
+ font-size: 1.75rem;
97
+ "
98
+ >
99
+ <svg
100
+ width="0.65em"
101
+ height="0.65em"
102
+ viewBox="0 0 115 115"
103
+ fill="none"
104
+ xmlns="http://www.w3.org/2000/svg"
105
+ >
106
+ <rect width="23" height="23" fill="white"></rect>
107
+ <rect y="69" width="23" height="23" fill="white"></rect>
108
+ <rect x="23" width="23" height="23" fill="#AEAEAE"></rect>
109
+ <rect x="23" y="69" width="23" height="23" fill="#AEAEAE"></rect>
110
+ <rect x="46" width="23" height="23" fill="white"></rect>
111
+ <rect x="46" y="69" width="23" height="23" fill="white"></rect>
112
+ <rect x="69" width="23" height="23" fill="black"></rect>
113
+ <rect x="69" y="69" width="23" height="23" fill="black"></rect>
114
+ <rect x="92" width="23" height="23" fill="#D9D9D9"></rect>
115
+ <rect x="92" y="69" width="23" height="23" fill="#AEAEAE"></rect>
116
+ <rect x="115" y="46" width="23" height="23" fill="white"></rect>
117
+ <rect x="115" y="115" width="23" height="23" fill="white"></rect>
118
+ <rect x="115" y="69" width="23" height="23" fill="#D9D9D9"></rect>
119
+ <rect x="92" y="46" width="23" height="23" fill="#AEAEAE"></rect>
120
+ <rect x="92" y="115" width="23" height="23" fill="#AEAEAE"></rect>
121
+ <rect x="92" y="69" width="23" height="23" fill="white"></rect>
122
+ <rect x="69" y="46" width="23" height="23" fill="white"></rect>
123
+ <rect x="69" y="115" width="23" height="23" fill="white"></rect>
124
+ <rect x="69" y="69" width="23" height="23" fill="#D9D9D9"></rect>
125
+ <rect x="46" y="46" width="23" height="23" fill="black"></rect>
126
+ <rect x="46" y="115" width="23" height="23" fill="black"></rect>
127
+ <rect x="46" y="69" width="23" height="23" fill="black"></rect>
128
+ <rect x="23" y="46" width="23" height="23" fill="#D9D9D9"></rect>
129
+ <rect x="23" y="115" width="23" height="23" fill="#AEAEAE"></rect>
130
+ <rect x="23" y="69" width="23" height="23" fill="black"></rect>
131
+ </svg>
132
+ <h1 style="font-weight: 900; margin-bottom: 7px;">
133
+ Stable Diffusion Multi Inpainting
134
+ </h1>
135
+ </div>
136
+ <p style="margin-bottom: 10px; font-size: 94%">
137
+ Inpaint Stable Diffusion by either drawing a mask or typing what to replace
138
+ </p>
139
+ </div>
140
+ """
141
+ )
142
+ with gr.Row():
143
+ with gr.Column():
144
+ image = gr.Image(source='upload', tool='sketch', elem_id="image_upload", type="pil", label="Upload").style(height=400)
145
+ with gr.Box(elem_id="mask_radio").style(border=False):
146
+ radio = gr.Radio(["draw a mask above", "type what to mask below"], value="draw a mask above", show_label=False, interactive=True).style(container=False)
147
+ word_mask = gr.Textbox(label = "What to find in your image", interactive=False, elem_id="word_mask", placeholder="Disabled").style(container=False)
148
+ prompt = gr.Textbox(label = 'Your prompt (what you want to add in place of what you are removing)')
149
+ radio.change(fn=swap_word_mask, inputs=radio, outputs=word_mask,show_progress=False)
150
+ radio.change(None, inputs=[], outputs=image_blocks, _js = """
151
+ () => {
152
+ css_style = document.styleSheets[document.styleSheets.length - 1]
153
+ last_item = css_style.cssRules[css_style.cssRules.length - 1]
154
+ last_item.style.display = ["flex", ""].includes(last_item.style.display) ? "none" : "flex";
155
+ }""")
156
+ btn = gr.Button("Run")
157
+ with gr.Column():
158
+ result = gr.Image(label="Result")
159
+ btn.click(fn=predict, inputs=[radio, image, word_mask, prompt], outputs=result)
160
+ gr.HTML(
161
+ """
162
+ <div class="footer">
163
+ <p>Model by <a href="https://huggingface.co/CompVis" style="text-decoration: underline;" target="_blank">CompVis</a> and <a href="https://huggingface.co/stabilityai" style="text-decoration: underline;" target="_blank">Stability AI</a> - Inpainting by <a href="https://github.com/nagolinc" style="text-decoration: underline;" target="_blank">nagolinc</a> and <a href="https://github.com/patil-suraj" style="text-decoration: underline;">patil-suraj</a>, inpainting with words by <a href="https://twitter.com/yvrjsharma/" style="text-decoration: underline;" target="_blank">@yvrjsharma</a> and <a href="https://twitter.com/1littlecoder" style="text-decoration: underline;">@1littlecoder</a> - Gradio Demo by πŸ€— Hugging Face
164
+ </p>
165
+ </div>
166
+ <div class="acknowledgments">
167
+ <p><h4>LICENSE</h4>
168
+ The model is licensed with a <a href="https://huggingface.co/spaces/CompVis/stable-diffusion-license" style="text-decoration: underline;" target="_blank">CreativeML Open RAIL-M</a> license. The authors claim no rights on the outputs you generate, you are free to use them and are accountable for their use which must not go against the provisions set in this license. The license forbids you from sharing any content that violates any laws, produce any harm to a person, disseminate any personal information that would be meant for harm, spread misinformation and target vulnerable groups. For the full list of restrictions please <a href="https://huggingface.co/spaces/CompVis/stable-diffusion-license" target="_blank" style="text-decoration: underline;" target="_blank">read the license</a></p>
169
+ <p><h4>Biases and content acknowledgment</h4>
170
+ Despite how impressive being able to turn text into image is, beware to the fact that this model may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography and violence. The model was trained on the <a href="https://laion.ai/blog/laion-5b/" style="text-decoration: underline;" target="_blank">LAION-5B dataset</a>, which scraped non-curated image-text-pairs from the internet (the exception being the removal of illegal content) and is meant for research purposes. You can read more in the <a href="https://huggingface.co/CompVis/stable-diffusion-v1-4" style="text-decoration: underline;" target="_blank">model card</a></p>
171
+ </div>
172
+ """
173
+ )
174
+ demo.launch()
clipseg/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ This license does not apply to the model weights.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
clipseg/Quickstart.ipynb ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "import requests\n",
11
+ "\n",
12
+ "! wget https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -O weights.zip\n",
13
+ "! unzip -d weights -j weights.zip\n",
14
+ "from models.clipseg import CLIPDensePredT\n",
15
+ "from PIL import Image\n",
16
+ "from torchvision import transforms\n",
17
+ "from matplotlib import pyplot as plt\n",
18
+ "\n",
19
+ "# load model\n",
20
+ "model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)\n",
21
+ "model.eval();\n",
22
+ "\n",
23
+ "# non-strict, because we only stored decoder weights (not CLIP weights)\n",
24
+ "model.load_state_dict(torch.load('weights/rd64-uni.pth', map_location=torch.device('cpu')), strict=False);"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "markdown",
29
+ "metadata": {},
30
+ "source": [
31
+ "Load and normalize `example_image.jpg`. You can also load through an URL."
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "# load and normalize image\n",
41
+ "input_image = Image.open('example_image.jpg')\n",
42
+ "\n",
43
+ "# or load from URL...\n",
44
+ "# image_url = 'https://farm5.staticflickr.com/4141/4856248695_03475782dc_z.jpg'\n",
45
+ "# input_image = Image.open(requests.get(image_url, stream=True).raw)\n",
46
+ "\n",
47
+ "transform = transforms.Compose([\n",
48
+ " transforms.ToTensor(),\n",
49
+ " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
50
+ " transforms.Resize((352, 352)),\n",
51
+ "])\n",
52
+ "img = transform(input_image).unsqueeze(0)"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "markdown",
57
+ "metadata": {},
58
+ "source": [
59
+ "Predict and visualize (this might take a few seconds if running without GPU support)"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": null,
65
+ "metadata": {},
66
+ "outputs": [],
67
+ "source": [
68
+ "prompts = ['a glass', 'something to fill', 'wood', 'a jar']\n",
69
+ "\n",
70
+ "# predict\n",
71
+ "with torch.no_grad():\n",
72
+ " preds = model(img.repeat(4,1,1,1), prompts)[0]\n",
73
+ "\n",
74
+ "# visualize prediction\n",
75
+ "_, ax = plt.subplots(1, 5, figsize=(15, 4))\n",
76
+ "[a.axis('off') for a in ax.flatten()]\n",
77
+ "ax[0].imshow(input_image)\n",
78
+ "[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(4)];\n",
79
+ "[ax[i+1].text(0, -15, prompts[i]) for i in range(4)];"
80
+ ]
81
+ }
82
+ ],
83
+ "metadata": {
84
+ "interpreter": {
85
+ "hash": "800ed241f7db2bd3aa6942aa3be6809cdb30ee6b0a9e773dfecfa9fef1f4c586"
86
+ },
87
+ "kernelspec": {
88
+ "display_name": "Python 3",
89
+ "language": "python",
90
+ "name": "python3"
91
+ },
92
+ "language_info": {
93
+ "codemirror_mode": {
94
+ "name": "ipython",
95
+ "version": 3
96
+ },
97
+ "file_extension": ".py",
98
+ "mimetype": "text/x-python",
99
+ "name": "python",
100
+ "nbconvert_exporter": "python",
101
+ "pygments_lexer": "ipython3",
102
+ "version": "3.8.10"
103
+ }
104
+ },
105
+ "nbformat": 4,
106
+ "nbformat_minor": 4
107
+ }
clipseg/Readme.md ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Image Segmentation Using Text and Image Prompts
2
+ This repository contains the code used in the paper ["Image Segmentation Using Text and Image Prompts"](https://arxiv.org/abs/2112.10003).
3
+
4
+ **The Paper has been accepted to CVPR 2022!**
5
+
6
+ <img src="overview.png" alt="drawing" height="200em"/>
7
+
8
+ The systems allows to create segmentation models without training based on:
9
+ - An arbitrary text query
10
+ - Or an image with a mask highlighting stuff or an object.
11
+
12
+ ### Quick Start
13
+
14
+ In the `Quickstart.ipynb` notebook we provide the code for using a pre-trained CLIPSeg model. If you run the notebook locally, make sure you downloaded the `rd64-uni.pth` weights, either manually or via git lfs extension.
15
+ It can also be used interactively using [MyBinder](https://mybinder.org/v2/gh/timojl/clipseg/HEAD?labpath=Quickstart.ipynb)
16
+ (please note that the VM does not use a GPU, thus inference takes a few seconds).
17
+
18
+
19
+ ### Dependencies
20
+ This code base depends on pytorch, torchvision and clip (`pip install git+https://github.com/openai/CLIP.git`).
21
+ Additional dependencies are hidden for double blind review.
22
+
23
+
24
+ ### Datasets
25
+
26
+ * `PhraseCut` and `PhraseCutPlus`: Referring expression dataset
27
+ * `PFEPascalWrapper`: Wrapper class for PFENet's Pascal-5i implementation
28
+ * `PascalZeroShot`: Wrapper class for PascalZeroShot
29
+ * `COCOWrapper`: Wrapper class for COCO.
30
+
31
+ ### Models
32
+
33
+ * `CLIPDensePredT`: CLIPSeg model with transformer-based decoder.
34
+ * `ViTDensePredT`: CLIPSeg model with transformer-based decoder.
35
+
36
+ ### Third Party Dependencies
37
+ For some of the datasets third party dependencies are required. Run the following commands in the `third_party` folder.
38
+ ```bash
39
+ git clone https://github.com/cvlab-yonsei/JoEm
40
+ git clone https://github.com/Jia-Research-Lab/PFENet.git
41
+ git clone https://github.com/ChenyunWu/PhraseCutDataset.git
42
+ git clone https://github.com/juhongm999/hsnet.git
43
+ ```
44
+
45
+ ### Weights
46
+
47
+ The MIT license does not apply to these weights.
48
+
49
+ We provide two model weights, for D=64 (4.1MB) and D=16 (1.1MB).
50
+ ```
51
+ wget https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -O weights.zip
52
+ unzip -d weights -j weights.zip
53
+ ```
54
+
55
+
56
+ ### Training and Evaluation
57
+
58
+ To train use the `training.py` script with experiment file and experiment id parameters. E.g. `python training.py phrasecut.yaml 0` will train the first phrasecut experiment which is defined by the `configuration` and first `individual_configurations` parameters. Model weights will be written in `logs/`.
59
+
60
+ For evaluation use `score.py`. E.g. `python score.py phrasecut.yaml 0 0` will train the first phrasecut experiment of `test_configuration` and the first configuration in `individual_configurations`.
61
+
62
+
63
+ ### Usage of PFENet Wrappers
64
+
65
+ In order to use the dataset and model wrappers for PFENet, the PFENet repository needs to be cloned to the root folder.
66
+ `git clone https://github.com/Jia-Research-Lab/PFENet.git `
67
+
68
+
69
+ ### License
70
+
71
+ The source code files in this repository (excluding model weights) are released under MIT license.
72
+
73
+ ### Citation
74
+ ```
75
+ @InProceedings{lueddecke22_cvpr,
76
+ author = {L\"uddecke, Timo and Ecker, Alexander},
77
+ title = {Image Segmentation Using Text and Image Prompts},
78
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
79
+ month = {June},
80
+ year = {2022},
81
+ pages = {7086-7096}
82
+ }
83
+
84
+ ```
clipseg/Tables.ipynb ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "%load_ext autoreload\n",
10
+ "%autoreload 2\n",
11
+ "\n",
12
+ "import clip\n",
13
+ "from evaluation_utils import norm, denorm\n",
14
+ "from general_utils import *\n",
15
+ "from datasets.lvis_oneshot3 import LVIS_OneShot3, LVIS_OneShot"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "markdown",
20
+ "metadata": {},
21
+ "source": [
22
+ "# PhraseCut"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "pc = experiment('experiments/phrasecut.yaml', nums=':6').dataframe()"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "tab1 = pc[['name', 'pc_miou_best', 'pc_fgiou_best', 'pc_ap']]"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "cols = ['pc_miou_0.3', 'pc_fgiou_0.3', 'pc_ap']\n",
50
+ "tab1 = pc[['name'] + cols]\n",
51
+ "for k in cols:\n",
52
+ " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n",
53
+ "tab1.loc[:, 'name'] = ['CLIPSeg (PC+)', 'CLIPSeg (PC, $D=128$)', 'CLIPSeg (PC)', 'CLIP-Deconv', 'ViTSeg (PC+)', 'ViTSeg (PC)']\n",
54
+ "tab1.insert(1, 't', [0.3]*tab1.shape[0])\n",
55
+ "print(tab1.to_latex(header=False, index=False))"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "markdown",
60
+ "metadata": {},
61
+ "source": [
62
+ "For 0.1 threshold"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": [
71
+ "cols = ['pc_miou_0.1', 'pc_fgiou_0.1', 'pc_ap']\n",
72
+ "tab1 = pc[['name'] + cols]\n",
73
+ "for k in cols:\n",
74
+ " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n",
75
+ "tab1.loc[:, 'name'] = ['CLIPSeg (PC+)', 'CLIPSeg (PC, $D=128$)', 'CLIPSeg (PC)', 'CLIP-Deconv', 'ViTSeg (PC+)', 'ViTSeg (PC)']\n",
76
+ "tab1.insert(1, 't', [0.1]*tab1.shape[0])\n",
77
+ "print(tab1.to_latex(header=False, index=False))"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "markdown",
82
+ "metadata": {},
83
+ "source": [
84
+ "# One-shot"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "markdown",
89
+ "metadata": {},
90
+ "source": [
91
+ "### Pascal"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": null,
97
+ "metadata": {},
98
+ "outputs": [],
99
+ "source": [
100
+ "pas = experiment('experiments/pascal_1shot.yaml', nums=':19').dataframe()"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": null,
106
+ "metadata": {},
107
+ "outputs": [],
108
+ "source": [
109
+ "pas[['name', 'pas_h2_miou_0.3', 'pas_h2_biniou_0.3', 'pas_h2_ap', 'pas_h2_fgiou_ct']]"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "execution_count": null,
115
+ "metadata": {},
116
+ "outputs": [],
117
+ "source": [
118
+ "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n",
119
+ "tab1 = pas[['pas_h2_miou_0.3', 'pas_h2_biniou_0.3', 'pas_h2_ap']]\n",
120
+ "print('CLIPSeg (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
121
+ "print('CLIPSeg (PC) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
122
+ "\n",
123
+ "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n",
124
+ "tab1 = pas[['pas_h2_miou_0.2', 'pas_h2_biniou_0.2', 'pas_h2_ap']]\n",
125
+ "print('CLIP-Deconv (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
126
+ "\n",
127
+ "pas = experiment('experiments/pascal_1shot.yaml', nums='16:20').dataframe()\n",
128
+ "tab1 = pas[['pas_t_miou_0.2', 'pas_t_biniou_0.2', 'pas_t_ap']]\n",
129
+ "print('ViTSeg (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "markdown",
134
+ "metadata": {},
135
+ "source": [
136
+ "#### Pascal Zero-shot (in one-shot setting)\n",
137
+ "\n",
138
+ "Using the same setting as one-shot (hence different from the other zero-shot benchmark)"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": null,
144
+ "metadata": {},
145
+ "outputs": [],
146
+ "source": [
147
+ "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n",
148
+ "tab1 = pas[['pas_t_miou_0.3', 'pas_t_biniou_0.3', 'pas_t_ap']]\n",
149
+ "print('CLIPSeg (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
150
+ "print('CLIPSeg (PC) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
151
+ "\n",
152
+ "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n",
153
+ "tab1 = pas[['pas_t_miou_0.3', 'pas_t_biniou_0.3', 'pas_t_ap']]\n",
154
+ "print('CLIP-Deconv (PC+) & 0.3 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
155
+ "\n",
156
+ "pas = experiment('experiments/pascal_1shot.yaml', nums='16:20').dataframe()\n",
157
+ "tab1 = pas[['pas_t_miou_0.2', 'pas_t_biniou_0.2', 'pas_t_ap']]\n",
158
+ "print('ViTSeg (PC+) & 0.2 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": null,
164
+ "metadata": {},
165
+ "outputs": [],
166
+ "source": [
167
+ "# without fixed thresholds...\n",
168
+ "\n",
169
+ "pas = experiment('experiments/pascal_1shot.yaml', nums=':8').dataframe()\n",
170
+ "tab1 = pas[['pas_t_best_miou', 'pas_t_best_biniou', 'pas_t_ap']]\n",
171
+ "print('CLIPSeg (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')\n",
172
+ "print('CLIPSeg (PC) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
173
+ "\n",
174
+ "pas = experiment('experiments/pascal_1shot.yaml', nums='12:16').dataframe()\n",
175
+ "tab1 = pas[['pas_t_best_miou', 'pas_t_best_biniou', 'pas_t_ap']]\n",
176
+ "print('CLIP-Deconv (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[0:4].mean(0).values), '\\\\\\\\')"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "markdown",
181
+ "metadata": {},
182
+ "source": [
183
+ "### COCO"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "code",
188
+ "execution_count": null,
189
+ "metadata": {},
190
+ "outputs": [],
191
+ "source": [
192
+ "coco = experiment('experiments/coco.yaml', nums=':29').dataframe()"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "execution_count": null,
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "tab1 = coco[['coco_h2_miou_0.1', 'coco_h2_biniou_0.1', 'coco_h2_ap']]\n",
202
+ "tab2 = coco[['coco_h2_miou_0.2', 'coco_h2_biniou_0.2', 'coco_h2_ap']]\n",
203
+ "tab3 = coco[['coco_h2_miou_best', 'coco_h2_biniou_best', 'coco_h2_ap']]\n",
204
+ "print('CLIPSeg (COCO) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[:4].mean(0).values), '\\\\\\\\')\n",
205
+ "print('CLIPSeg (COCO+N) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:8].mean(0).values), '\\\\\\\\')\n",
206
+ "print('CLIP-Deconv (COCO+N) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[12:16].mean(0).values), '\\\\\\\\')\n",
207
+ "print('ViTSeg (COCO) & 0.1 & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[8:12].mean(0).values), '\\\\\\\\')"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "markdown",
212
+ "metadata": {},
213
+ "source": [
214
+ "# Zero-shot"
215
+ ]
216
+ },
217
+ {
218
+ "cell_type": "code",
219
+ "execution_count": null,
220
+ "metadata": {},
221
+ "outputs": [],
222
+ "source": [
223
+ "zs = experiment('experiments/pascal_0shot.yaml', nums=':11').dataframe()"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "code",
228
+ "execution_count": null,
229
+ "metadata": {},
230
+ "outputs": [],
231
+ "source": [
232
+ "\n",
233
+ "tab1 = zs[['pas_zs_seen', 'pas_zs_unseen']]\n",
234
+ "print('CLIPSeg (PC+) & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[8:9].values[0].tolist() + tab1[10:11].values[0].tolist()), '\\\\\\\\')\n",
235
+ "print('CLIP-Deconv & CLIP & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[2:3].values[0].tolist() + tab1[3:4].values[0].tolist()), '\\\\\\\\')\n",
236
+ "print('ViTSeg & ImageNet-1K & ' + ' & '.join(f'{x*100:.1f}' for x in tab1[4:5].values[0].tolist() + tab1[5:6].values[0].tolist()), '\\\\\\\\')"
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "markdown",
241
+ "metadata": {},
242
+ "source": [
243
+ "# Ablation"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "code",
248
+ "execution_count": null,
249
+ "metadata": {},
250
+ "outputs": [],
251
+ "source": [
252
+ "ablation = experiment('experiments/ablation.yaml', nums=':8').dataframe()"
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "code",
257
+ "execution_count": null,
258
+ "metadata": {},
259
+ "outputs": [],
260
+ "source": [
261
+ "tab1 = ablation[['name', 'pc_miou_best', 'pc_ap', 'pc-vis_miou_best', 'pc-vis_ap']]\n",
262
+ "for k in ['pc_miou_best', 'pc_ap', 'pc-vis_miou_best', 'pc-vis_ap']:\n",
263
+ " tab1.loc[:, k] = (100 * tab1.loc[:, k]).round(1)\n",
264
+ "tab1.loc[:, 'name'] = ['CLIPSeg', 'no CLIP pre-training', 'no-negatives', '50% negatives', 'no visual', '$D=16$', 'only layer 3', 'highlight mask']"
265
+ ]
266
+ },
267
+ {
268
+ "cell_type": "code",
269
+ "execution_count": null,
270
+ "metadata": {},
271
+ "outputs": [],
272
+ "source": [
273
+ "print(tab1.loc[[0,1,4,5,6,7],:].to_latex(header=False, index=False))"
274
+ ]
275
+ },
276
+ {
277
+ "cell_type": "code",
278
+ "execution_count": null,
279
+ "metadata": {},
280
+ "outputs": [],
281
+ "source": [
282
+ "print(tab1.loc[[0,1,4,5,6,7],:].to_latex(header=False, index=False))"
283
+ ]
284
+ },
285
+ {
286
+ "cell_type": "markdown",
287
+ "metadata": {},
288
+ "source": [
289
+ "# Generalization"
290
+ ]
291
+ },
292
+ {
293
+ "cell_type": "code",
294
+ "execution_count": null,
295
+ "metadata": {},
296
+ "outputs": [],
297
+ "source": [
298
+ "generalization = experiment('experiments/generalize.yaml').dataframe()"
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "execution_count": null,
304
+ "metadata": {},
305
+ "outputs": [],
306
+ "source": [
307
+ "gen = generalization[['aff_best_fgiou', 'aff_ap', 'ability_best_fgiou', 'ability_ap', 'part_best_fgiou', 'part_ap']].values"
308
+ ]
309
+ },
310
+ {
311
+ "cell_type": "code",
312
+ "execution_count": null,
313
+ "metadata": {},
314
+ "outputs": [],
315
+ "source": [
316
+ "print(\n",
317
+ " 'CLIPSeg (PC+) & ' + ' & '.join(f'{x*100:.1f}' for x in gen[1]) + ' \\\\\\\\ \\n' + \\\n",
318
+ " 'CLIPSeg (LVIS) & ' + ' & '.join(f'{x*100:.1f}' for x in gen[0]) + ' \\\\\\\\ \\n' + \\\n",
319
+ " 'CLIP-Deconv & ' + ' & '.join(f'{x*100:.1f}' for x in gen[2]) + ' \\\\\\\\ \\n' + \\\n",
320
+ " 'VITSeg & ' + ' & '.join(f'{x*100:.1f}' for x in gen[3]) + ' \\\\\\\\'\n",
321
+ ")"
322
+ ]
323
+ }
324
+ ],
325
+ "metadata": {
326
+ "interpreter": {
327
+ "hash": "800ed241f7db2bd3aa6942aa3be6809cdb30ee6b0a9e773dfecfa9fef1f4c586"
328
+ },
329
+ "kernelspec": {
330
+ "display_name": "env2",
331
+ "language": "python",
332
+ "name": "env2"
333
+ },
334
+ "language_info": {
335
+ "codemirror_mode": {
336
+ "name": "ipython",
337
+ "version": 3
338
+ },
339
+ "file_extension": ".py",
340
+ "mimetype": "text/x-python",
341
+ "name": "python",
342
+ "nbconvert_exporter": "python",
343
+ "pygments_lexer": "ipython3",
344
+ "version": "3.8.8"
345
+ }
346
+ },
347
+ "nbformat": 4,
348
+ "nbformat_minor": 4
349
+ }
clipseg/Visual_Feature_Engineering.ipynb ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Systematic"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "%load_ext autoreload\n",
17
+ "%autoreload 2\n",
18
+ "\n",
19
+ "import clip\n",
20
+ "from evaluation_utils import norm, denorm\n",
21
+ "from general_utils import *\n",
22
+ "from datasets.lvis_oneshot3 import LVIS_OneShot3\n",
23
+ "\n",
24
+ "clip_device = 'cuda'\n",
25
+ "clip_model, preprocess = clip.load(\"ViT-B/16\", device=clip_device)\n",
26
+ "clip_model.eval();\n",
27
+ "\n",
28
+ "from models.clipseg import CLIPDensePredTMasked\n",
29
+ "\n",
30
+ "clip_mask_model = CLIPDensePredTMasked(version='ViT-B/16').to(clip_device)\n",
31
+ "clip_mask_model.eval();"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "lvis = LVIS_OneShot3('train_fixed', mask='separate', normalize=True, with_class_label=True, add_bar=False, \n",
41
+ " text_class_labels=True, image_size=352, min_area=0.1,\n",
42
+ " min_frac_s=0.05, min_frac_q=0.05, fix_find_crop=True)"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "plot_data(lvis)"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "from collections import defaultdict\n",
61
+ "import json\n",
62
+ "\n",
63
+ "lvis_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_train.json')))\n",
64
+ "lvis_val_raw = json.load(open(expanduser('~/datasets/LVIS/lvis_v1_val.json')))\n",
65
+ "\n",
66
+ "objects_per_image = defaultdict(lambda : set())\n",
67
+ "for ann in lvis_raw['annotations']:\n",
68
+ " objects_per_image[ann['image_id']].add(ann['category_id'])\n",
69
+ " \n",
70
+ "for ann in lvis_val_raw['annotations']:\n",
71
+ " objects_per_image[ann['image_id']].add(ann['category_id']) \n",
72
+ " \n",
73
+ "objects_per_image = {o: [lvis.category_names[o] for o in v] for o, v in objects_per_image.items()}\n",
74
+ "\n",
75
+ "del lvis_raw, lvis_val_raw"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": null,
81
+ "metadata": {},
82
+ "outputs": [],
83
+ "source": [
84
+ "#bs = 32\n",
85
+ "#batches = [get_batch(lvis, i*bs, (i+1)*bs, cuda=True) for i in range(10)]"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": null,
91
+ "metadata": {},
92
+ "outputs": [],
93
+ "source": [
94
+ "from general_utils import get_batch\n",
95
+ "from functools import partial\n",
96
+ "from evaluation_utils import img_preprocess\n",
97
+ "import torch\n",
98
+ "\n",
99
+ "def get_similarities(batches_or_dataset, process, mask=lambda x: None, clipmask=False):\n",
100
+ "\n",
101
+ " # base_words = [f'a photo of {x}' for x in ['a person', 'an animal', 'a knife', 'a cup']]\n",
102
+ "\n",
103
+ " all_prompts = []\n",
104
+ " \n",
105
+ " with torch.no_grad():\n",
106
+ " valid_sims = []\n",
107
+ " torch.manual_seed(571)\n",
108
+ " \n",
109
+ " if type(batches_or_dataset) == list:\n",
110
+ " loader = batches_or_dataset # already loaded\n",
111
+ " max_iter = float('inf')\n",
112
+ " else:\n",
113
+ " loader = DataLoader(batches_or_dataset, shuffle=False, batch_size=32)\n",
114
+ " max_iter = 50\n",
115
+ " \n",
116
+ " global batch\n",
117
+ " for i_batch, (batch, batch_y) in enumerate(loader):\n",
118
+ " \n",
119
+ " if i_batch >= max_iter: break\n",
120
+ " \n",
121
+ " processed_batch = process(batch)\n",
122
+ " if type(processed_batch) == dict:\n",
123
+ " \n",
124
+ " # processed_batch = {k: v.to(clip_device) for k, v in processed_batch.items()}\n",
125
+ " image_features = clip_mask_model.visual_forward(**processed_batch)[0].to(clip_device).half()\n",
126
+ " else:\n",
127
+ " processed_batch = process(batch).to(clip_device)\n",
128
+ " processed_batch = nnf.interpolate(processed_batch, (224, 224), mode='bilinear')\n",
129
+ " #image_features = clip_model.encode_image(processed_batch.to(clip_device)) \n",
130
+ " image_features = clip_mask_model.visual_forward(processed_batch)[0].to(clip_device).half()\n",
131
+ " \n",
132
+ " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n",
133
+ " bs = len(batch[0])\n",
134
+ " for j in range(bs):\n",
135
+ " \n",
136
+ " c, _, sid, qid = lvis.sample_ids[bs * i_batch + j]\n",
137
+ " support_image = basename(lvis.samples[c][sid])\n",
138
+ " \n",
139
+ " img_objs = [o for o in objects_per_image[int(support_image)]]\n",
140
+ " img_objs = [o.replace('_', ' ') for o in img_objs]\n",
141
+ " \n",
142
+ " other_words = [f'a photo of a {o.replace(\"_\", \" \")}' for o in img_objs \n",
143
+ " if o != batch_y[2][j]]\n",
144
+ " \n",
145
+ " prompts = [f'a photo of a {batch_y[2][j]}'] + other_words\n",
146
+ " all_prompts += [prompts]\n",
147
+ " \n",
148
+ " text_cond = clip_model.encode_text(clip.tokenize(prompts).to(clip_device))\n",
149
+ " text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True) \n",
150
+ "\n",
151
+ " global logits\n",
152
+ " logits = clip_model.logit_scale.exp() * image_features[j] @ text_cond.T\n",
153
+ "\n",
154
+ " global sim\n",
155
+ " sim = torch.softmax(logits, dim=-1)\n",
156
+ " \n",
157
+ " valid_sims += [sim]\n",
158
+ " \n",
159
+ " #valid_sims = torch.stack(valid_sims)\n",
160
+ " return valid_sims, all_prompts\n",
161
+ " \n",
162
+ "\n",
163
+ "def new_img_preprocess(x):\n",
164
+ " return {'x_inp': x[1], 'mask': (11, 'cls_token', x[2])}\n",
165
+ " \n",
166
+ "#get_similarities(lvis, partial(img_preprocess, center_context=0.5));\n",
167
+ "get_similarities(lvis, lambda x: x[1]);"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": null,
173
+ "metadata": {},
174
+ "outputs": [],
175
+ "source": [
176
+ "preprocessing_functions = [\n",
177
+ "# ['clip mask CLS L11', lambda x: {'x_inp': x[1].cuda(), 'mask': (11, 'cls_token', x[2].cuda())}],\n",
178
+ "# ['clip mask CLS all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'cls_token', x[2].cuda())}],\n",
179
+ "# ['clip mask all all', lambda x: {'x_inp': x[1].cuda(), 'mask': ('all', 'all', x[2].cuda())}],\n",
180
+ "# ['colorize object red', partial(img_preprocess, colorize=True)],\n",
181
+ "# ['add red outline', partial(img_preprocess, outline=True)],\n",
182
+ " \n",
183
+ "# ['BG brightness 50%', partial(img_preprocess, bg_fac=0.5)],\n",
184
+ "# ['BG brightness 10%', partial(img_preprocess, bg_fac=0.1)],\n",
185
+ "# ['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)],\n",
186
+ "# ['BG blur', partial(img_preprocess, blur=3)],\n",
187
+ "# ['BG blur & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n",
188
+ " \n",
189
+ "# ['crop large context', partial(img_preprocess, center_context=0.5)],\n",
190
+ "# ['crop small context', partial(img_preprocess, center_context=0.1)],\n",
191
+ " ['crop & background blur', partial(img_preprocess, blur=3, center_context=0.5)],\n",
192
+ " ['crop & intensity 10%', partial(img_preprocess, blur=3, bg_fac=0.1)],\n",
193
+ "# ['crop & background blur & intensity 10%', partial(img_preprocess, blur=3, center_context=0.1, bg_fac=0.1)],\n",
194
+ "]\n",
195
+ "\n",
196
+ "preprocessing_functions = preprocessing_functions\n",
197
+ "\n",
198
+ "base, base_p = get_similarities(lvis, lambda x: x[1])\n",
199
+ "outs = [get_similarities(lvis, fun) for _, fun in preprocessing_functions]"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": null,
205
+ "metadata": {},
206
+ "outputs": [],
207
+ "source": [
208
+ "outs2 = [get_similarities(lvis, fun) for _, fun in [['BG brightness 0%', partial(img_preprocess, bg_fac=0.0)]]]"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": null,
214
+ "metadata": {},
215
+ "outputs": [],
216
+ "source": [
217
+ "for j in range(1):\n",
218
+ " print(np.mean([outs2[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3]))"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": null,
224
+ "metadata": {},
225
+ "outputs": [],
226
+ "source": [
227
+ "from pandas import DataFrame\n",
228
+ "tab = dict()\n",
229
+ "for j, (name, _) in enumerate(preprocessing_functions):\n",
230
+ " tab[name] = np.mean([outs[j][0][i][0].cpu() - base[i][0].cpu() for i in range(len(base)) if len(base_p[i]) >= 3])\n",
231
+ " \n",
232
+ " \n",
233
+ "print('\\n'.join(f'{k} & {v*100:.2f} \\\\\\\\' for k,v in tab.items())) "
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "markdown",
238
+ "metadata": {},
239
+ "source": [
240
+ "# Visual"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "code",
245
+ "execution_count": null,
246
+ "metadata": {},
247
+ "outputs": [],
248
+ "source": [
249
+ "from evaluation_utils import denorm, norm"
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "code",
254
+ "execution_count": null,
255
+ "metadata": {},
256
+ "outputs": [],
257
+ "source": [
258
+ "def load_sample(filename, filename2):\n",
259
+ " from os.path import join\n",
260
+ " bp = expanduser('~/cloud/resources/sample_images')\n",
261
+ " tf = transforms.Compose([\n",
262
+ " transforms.ToTensor(),\n",
263
+ " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
264
+ " transforms.Resize(224),\n",
265
+ " transforms.CenterCrop(224)\n",
266
+ " ])\n",
267
+ " tf2 = transforms.Compose([\n",
268
+ " transforms.ToTensor(),\n",
269
+ " transforms.Resize(224),\n",
270
+ " transforms.CenterCrop(224)\n",
271
+ " ])\n",
272
+ " inp1 = [None, tf(Image.open(join(bp, filename))), tf2(Image.open(join(bp, filename2)))]\n",
273
+ " inp1[1] = inp1[1].unsqueeze(0)\n",
274
+ " inp1[2] = inp1[2][:1] \n",
275
+ " return inp1\n",
276
+ "\n",
277
+ "def all_preprocessing(inp1):\n",
278
+ " return [\n",
279
+ " img_preprocess(inp1),\n",
280
+ " img_preprocess(inp1, colorize=True),\n",
281
+ " img_preprocess(inp1, outline=True), \n",
282
+ " img_preprocess(inp1, blur=3),\n",
283
+ " img_preprocess(inp1, bg_fac=0.1),\n",
284
+ " #img_preprocess(inp1, bg_fac=0.5),\n",
285
+ " #img_preprocess(inp1, blur=3, bg_fac=0.5), \n",
286
+ " img_preprocess(inp1, blur=3, bg_fac=0.5, center_context=0.5),\n",
287
+ " ]\n",
288
+ "\n"
289
+ ]
290
+ },
291
+ {
292
+ "cell_type": "code",
293
+ "execution_count": null,
294
+ "metadata": {},
295
+ "outputs": [],
296
+ "source": [
297
+ "from torchvision import transforms\n",
298
+ "from PIL import Image\n",
299
+ "from matplotlib import pyplot as plt\n",
300
+ "from evaluation_utils import img_preprocess\n",
301
+ "import clip\n",
302
+ "\n",
303
+ "images_queries = [\n",
304
+ " [load_sample('things1.jpg', 'things1_jar.png'), ['jug', 'knife', 'car', 'animal', 'sieve', 'nothing']],\n",
305
+ " [load_sample('own_photos/IMG_2017s_square.jpg', 'own_photos/IMG_2017s_square_trash_can.png'), ['trash bin', 'house', 'car', 'bike', 'window', 'nothing']],\n",
306
+ "]\n",
307
+ "\n",
308
+ "\n",
309
+ "_, ax = plt.subplots(2 * len(images_queries), 6, figsize=(14, 4.5 * len(images_queries)))\n",
310
+ "\n",
311
+ "for j, (images, objects) in enumerate(images_queries):\n",
312
+ " \n",
313
+ " joint_image = all_preprocessing(images)\n",
314
+ " \n",
315
+ " joint_image = torch.stack(joint_image)[:,0]\n",
316
+ " clip_model, preprocess = clip.load(\"ViT-B/16\", device='cpu')\n",
317
+ " image_features = clip_model.encode_image(joint_image)\n",
318
+ " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n",
319
+ " \n",
320
+ " prompts = [f'a photo of a {obj}'for obj in objects]\n",
321
+ " text_cond = clip_model.encode_text(clip.tokenize(prompts))\n",
322
+ " text_cond = text_cond / text_cond.norm(dim=-1, keepdim=True)\n",
323
+ " logits = clip_model.logit_scale.exp() * image_features @ text_cond.T\n",
324
+ " sim = torch.softmax(logits, dim=-1).detach().cpu()\n",
325
+ "\n",
326
+ " for i, img in enumerate(joint_image):\n",
327
+ " ax[2*j, i].axis('off')\n",
328
+ " \n",
329
+ " ax[2*j, i].imshow(torch.clamp(denorm(joint_image[i]).permute(1,2,0), 0, 1))\n",
330
+ " ax[2*j+ 1, i].grid(True)\n",
331
+ " \n",
332
+ " ax[2*j + 1, i].set_ylim(0,1)\n",
333
+ " ax[2*j + 1, i].set_yticklabels([])\n",
334
+ " ax[2*j + 1, i].set_xticks([]) # set_xticks(range(len(prompts)))\n",
335
+ "# ax[1, i].set_xticklabels(objects, rotation=90)\n",
336
+ " for k in range(len(sim[i])):\n",
337
+ " ax[2*j + 1, i].bar(k, sim[i][k], color=plt.cm.tab20(1) if k!=0 else plt.cm.tab20(3))\n",
338
+ " ax[2*j + 1, i].text(k, 0.07, objects[k], rotation=90, ha='center', fontsize=15)\n",
339
+ "\n",
340
+ "plt.tight_layout()\n",
341
+ "plt.savefig('figures/prompt_engineering.pdf', bbox_inches='tight')"
342
+ ]
343
+ }
344
+ ],
345
+ "metadata": {
346
+ "kernelspec": {
347
+ "display_name": "env2",
348
+ "language": "python",
349
+ "name": "env2"
350
+ },
351
+ "language_info": {
352
+ "codemirror_mode": {
353
+ "name": "ipython",
354
+ "version": 3
355
+ },
356
+ "file_extension": ".py",
357
+ "mimetype": "text/x-python",
358
+ "name": "python",
359
+ "nbconvert_exporter": "python",
360
+ "pygments_lexer": "ipython3",
361
+ "version": "3.8.8"
362
+ }
363
+ },
364
+ "nbformat": 4,
365
+ "nbformat_minor": 4
366
+ }
clipseg/datasets/coco_wrapper.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from types import new_class
3
+ import torch
4
+ import numpy as np
5
+ import os
6
+ import json
7
+
8
+ from os.path import join, dirname, isdir, isfile, expanduser, realpath, basename
9
+ from random import shuffle, seed as set_seed
10
+ from PIL import Image
11
+
12
+ from itertools import combinations
13
+ from torchvision import transforms
14
+ from torchvision.transforms.transforms import Resize
15
+
16
+ from datasets.utils import blend_image_segmentation
17
+ from general_utils import get_from_repository
18
+
19
+ COCO_CLASSES = {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'}
20
+
21
+ class COCOWrapper(object):
22
+
23
+ def __init__(self, split, fold=0, image_size=400, aug=None, mask='separate', negative_prob=0,
24
+ with_class_label=False):
25
+ super().__init__()
26
+
27
+ self.mask = mask
28
+ self.with_class_label = with_class_label
29
+ self.negative_prob = negative_prob
30
+
31
+ from third_party.hsnet.data.coco import DatasetCOCO
32
+
33
+ get_from_repository('COCO-20i', ['COCO-20i.tar'])
34
+
35
+ foldpath = join(dirname(__file__), '../third_party/hsnet/data/splits/coco/%s/fold%d.pkl')
36
+
37
+ def build_img_metadata_classwise(self):
38
+ with open(foldpath % (self.split, self.fold), 'rb') as f:
39
+ img_metadata_classwise = pickle.load(f)
40
+ return img_metadata_classwise
41
+
42
+
43
+ DatasetCOCO.build_img_metadata_classwise = build_img_metadata_classwise
44
+ # DatasetCOCO.read_mask = read_mask
45
+
46
+ mean = [0.485, 0.456, 0.406]
47
+ std = [0.229, 0.224, 0.225]
48
+ transform = transforms.Compose([
49
+ transforms.Resize((image_size, image_size)),
50
+ transforms.ToTensor(),
51
+ transforms.Normalize(mean, std)
52
+ ])
53
+
54
+ self.coco = DatasetCOCO(expanduser('~/datasets/COCO-20i/'), fold, transform, split, 1, False)
55
+
56
+ self.all_classes = [self.coco.class_ids]
57
+ self.coco.base_path = join(expanduser('~/datasets/COCO-20i'))
58
+
59
+ def __len__(self):
60
+ return len(self.coco)
61
+
62
+ def __getitem__(self, i):
63
+ sample = self.coco[i]
64
+
65
+ label_name = COCO_CLASSES[int(sample['class_id'])]
66
+
67
+ img_s, seg_s = sample['support_imgs'][0], sample['support_masks'][0]
68
+
69
+ if self.negative_prob > 0 and torch.rand(1).item() < self.negative_prob:
70
+ new_class_id = sample['class_id']
71
+ while new_class_id == sample['class_id']:
72
+ sample2 = self.coco[torch.randint(0, len(self), (1,)).item()]
73
+ new_class_id = sample2['class_id']
74
+ img_s = sample2['support_imgs'][0]
75
+ seg_s = torch.zeros_like(seg_s)
76
+
77
+ mask = self.mask
78
+ if mask == 'separate':
79
+ supp = (img_s, seg_s)
80
+ elif mask == 'text_label':
81
+ # DEPRECATED
82
+ supp = [int(sample['class_id'])]
83
+ elif mask == 'text':
84
+ supp = [label_name]
85
+ else:
86
+ if mask.startswith('text_and_'):
87
+ mask = mask[9:]
88
+ label_add = [label_name]
89
+ else:
90
+ label_add = []
91
+
92
+ supp = label_add + blend_image_segmentation(img_s, seg_s, mode=mask)
93
+
94
+ if self.with_class_label:
95
+ label = (torch.zeros(0), sample['class_id'],)
96
+ else:
97
+ label = (torch.zeros(0), )
98
+
99
+ return (sample['query_img'],) + tuple(supp), (sample['query_mask'].unsqueeze(0),) + label
clipseg/datasets/pascal_classes.json ADDED
@@ -0,0 +1 @@
 
 
1
+ [{"id": 1, "synonyms": ["aeroplane"]}, {"id": 2, "synonyms": ["bicycle"]}, {"id": 3, "synonyms": ["bird"]}, {"id": 4, "synonyms": ["boat"]}, {"id": 5, "synonyms": ["bottle"]}, {"id": 6, "synonyms": ["bus"]}, {"id": 7, "synonyms": ["car"]}, {"id": 8, "synonyms": ["cat"]}, {"id": 9, "synonyms": ["chair"]}, {"id": 10, "synonyms": ["cow"]}, {"id": 11, "synonyms": ["diningtable"]}, {"id": 12, "synonyms": ["dog"]}, {"id": 13, "synonyms": ["horse"]}, {"id": 14, "synonyms": ["motorbike"]}, {"id": 15, "synonyms": ["person"]}, {"id": 16, "synonyms": ["pottedplant"]}, {"id": 17, "synonyms": ["sheep"]}, {"id": 18, "synonyms": ["sofa"]}, {"id": 19, "synonyms": ["train"]}, {"id": 20, "synonyms": ["tvmonitor"]}]
clipseg/datasets/pascal_zeroshot.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import expanduser
2
+ import torch
3
+ import json
4
+ import torchvision
5
+ from general_utils import get_from_repository
6
+ from general_utils import log
7
+ from torchvision import transforms
8
+
9
+ PASCAL_VOC_CLASSES_ZS = [['cattle.n.01', 'motorcycle.n.01'], ['aeroplane.n.01', 'sofa.n.01'],
10
+ ['cat.n.01', 'television.n.03'], ['train.n.01', 'bottle.n.01'],
11
+ ['chair.n.01', 'pot_plant.n.01']]
12
+
13
+
14
+ class PascalZeroShot(object):
15
+
16
+ def __init__(self, split, n_unseen, image_size=224) -> None:
17
+ super().__init__()
18
+
19
+ import sys
20
+ sys.path.append('third_party/JoEm')
21
+ from third_party.JoEm.data_loader.dataset import VOCSegmentation
22
+ from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC
23
+
24
+ self.pascal_classes = VOC
25
+ self.image_size = image_size
26
+
27
+ self.transform = transforms.Compose([
28
+ transforms.Resize((image_size, image_size)),
29
+ ])
30
+
31
+ if split == 'train':
32
+ self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen),
33
+ split=split, transform=True, transform_args=dict(base_size=312, crop_size=312),
34
+ ignore_bg=False, ignore_unseen=False, remv_unseen_img=True)
35
+ elif split == 'val':
36
+ self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen),
37
+ split=split, transform=False,
38
+ ignore_bg=False, ignore_unseen=False)
39
+
40
+ self.unseen_idx = get_unseen_idx(n_unseen)
41
+
42
+ def __len__(self):
43
+ return len(self.voc)
44
+
45
+ def __getitem__(self, i):
46
+
47
+ sample = self.voc[i]
48
+ label = sample['label'].long()
49
+ all_labels = [l for l in torch.where(torch.bincount(label.flatten())>0)[0].numpy().tolist() if l != 255]
50
+ class_indices = [l for l in all_labels]
51
+ class_names = [self.pascal_classes[l] for l in all_labels]
52
+
53
+ image = self.transform(sample['image'])
54
+
55
+ label = transforms.Resize((self.image_size, self.image_size),
56
+ interpolation=torchvision.transforms.InterpolationMode.NEAREST)(label.unsqueeze(0))[0]
57
+
58
+ return (image,), (label, )
59
+
60
+
clipseg/datasets/pfe_dataset.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import expanduser
2
+ import torch
3
+ import json
4
+ from general_utils import get_from_repository
5
+ from datasets.lvis_oneshot3 import blend_image_segmentation
6
+ from general_utils import log
7
+
8
+ PASCAL_CLASSES = {a['id']: a['synonyms'] for a in json.load(open('datasets/pascal_classes.json'))}
9
+
10
+
11
+ class PFEPascalWrapper(object):
12
+
13
+ def __init__(self, mode, split, mask='separate', image_size=473, label_support=None, size=None, p_negative=0, aug=None):
14
+ import sys
15
+ # sys.path.append(expanduser('~/projects/new_one_shot'))
16
+ from third_party.PFENet.util.dataset import SemData
17
+
18
+ get_from_repository('PascalVOC2012', ['Pascal5i.tar'])
19
+
20
+ self.p_negative = p_negative
21
+ self.size = size
22
+ self.mode = mode
23
+ self.image_size = image_size
24
+
25
+ if label_support in {True, False}:
26
+ log.warning('label_support argument is deprecated. Use mask instead.')
27
+ #raise ValueError()
28
+
29
+ self.mask = mask
30
+
31
+ value_scale = 255
32
+ mean = [0.485, 0.456, 0.406]
33
+ mean = [item * value_scale for item in mean]
34
+ std = [0.229, 0.224, 0.225]
35
+ std = [item * value_scale for item in std]
36
+
37
+ import third_party.PFENet.util.transform as transform
38
+
39
+ if mode == 'val':
40
+ data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/val.txt')
41
+
42
+ data_transform = [transform.test_Resize(size=image_size)] if image_size != 'original' else []
43
+ data_transform += [
44
+ transform.ToTensor(),
45
+ transform.Normalize(mean=mean, std=std)
46
+ ]
47
+
48
+
49
+ elif mode == 'train':
50
+ data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/voc_sbd_merge_noduplicate.txt')
51
+
52
+ assert image_size != 'original'
53
+
54
+ data_transform = [
55
+ transform.RandScale([0.9, 1.1]),
56
+ transform.RandRotate([-10, 10], padding=mean, ignore_label=255),
57
+ transform.RandomGaussianBlur(),
58
+ transform.RandomHorizontalFlip(),
59
+ transform.Crop((image_size, image_size), crop_type='rand', padding=mean, ignore_label=255),
60
+ transform.ToTensor(),
61
+ transform.Normalize(mean=mean, std=std)
62
+ ]
63
+
64
+ data_transform = transform.Compose(data_transform)
65
+
66
+ self.dataset = SemData(split=split, mode=mode, data_root=expanduser('~/datasets/PascalVOC2012/VOC2012'),
67
+ data_list=data_list, shot=1, transform=data_transform, use_coco=False, use_split_coco=False)
68
+
69
+ self.class_list = self.dataset.sub_val_list if mode == 'val' else self.dataset.sub_list
70
+
71
+ # verify that subcls_list always has length 1
72
+ # assert len(set([len(d[4]) for d in self.dataset])) == 1
73
+
74
+ print('actual length', len(self.dataset.data_list))
75
+
76
+ def __len__(self):
77
+ if self.mode == 'val':
78
+ return len(self.dataset.data_list)
79
+ else:
80
+ return len(self.dataset.data_list)
81
+
82
+ def __getitem__(self, index):
83
+ if self.dataset.mode == 'train':
84
+ image, label, s_x, s_y, subcls_list = self.dataset[index % len(self.dataset.data_list)]
85
+ elif self.dataset.mode == 'val':
86
+ image, label, s_x, s_y, subcls_list, ori_label = self.dataset[index % len(self.dataset.data_list)]
87
+ ori_label = torch.from_numpy(ori_label).unsqueeze(0)
88
+
89
+ if self.image_size != 'original':
90
+ longerside = max(ori_label.size(1), ori_label.size(2))
91
+ backmask = torch.ones(ori_label.size(0), longerside, longerside).cuda()*255
92
+ backmask[0, :ori_label.size(1), :ori_label.size(2)] = ori_label
93
+ label = backmask.clone().long()
94
+ else:
95
+ label = label.unsqueeze(0)
96
+
97
+ # assert label.shape == (473, 473)
98
+
99
+ if self.p_negative > 0:
100
+ if torch.rand(1).item() < self.p_negative:
101
+ while True:
102
+ idx = torch.randint(0, len(self.dataset.data_list), (1,)).item()
103
+ _, _, s_x, s_y, subcls_list_tmp, _ = self.dataset[idx]
104
+ if subcls_list[0] != subcls_list_tmp[0]:
105
+ break
106
+
107
+ s_x = s_x[0]
108
+ s_y = (s_y == 1)[0]
109
+ label_fg = (label == 1).float()
110
+ val_mask = (label != 255).float()
111
+
112
+ class_id = self.class_list[subcls_list[0]]
113
+
114
+ label_name = PASCAL_CLASSES[class_id][0]
115
+ label_add = ()
116
+ mask = self.mask
117
+
118
+ if mask == 'text':
119
+ support = ('a photo of a ' + label_name + '.',)
120
+ elif mask == 'separate':
121
+ support = (s_x, s_y)
122
+ else:
123
+ if mask.startswith('text_and_'):
124
+ label_add = (label_name,)
125
+ mask = mask[9:]
126
+
127
+ support = (blend_image_segmentation(s_x, s_y.float(), mask)[0],)
128
+
129
+ return (image,) + label_add + support, (label_fg.unsqueeze(0), val_mask.unsqueeze(0), subcls_list[0])
clipseg/datasets/phrasecut.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import numpy as np
4
+ import os
5
+
6
+ from os.path import join, isdir, isfile, expanduser
7
+ from PIL import Image
8
+
9
+ from torchvision import transforms
10
+ from torchvision.transforms.transforms import Resize
11
+
12
+ from torch.nn import functional as nnf
13
+ from general_utils import get_from_repository
14
+
15
+ from skimage.draw import polygon2mask
16
+
17
+
18
+
19
+ def random_crop_slices(origin_size, target_size):
20
+ """Gets slices of a random crop. """
21
+ assert origin_size[0] >= target_size[0] and origin_size[1] >= target_size[1], f'actual size: {origin_size}, target size: {target_size}'
22
+
23
+ offset_y = torch.randint(0, origin_size[0] - target_size[0] + 1, (1,)).item() # range: 0 <= value < high
24
+ offset_x = torch.randint(0, origin_size[1] - target_size[1] + 1, (1,)).item()
25
+
26
+ return slice(offset_y, offset_y + target_size[0]), slice(offset_x, offset_x + target_size[1])
27
+
28
+
29
+ def find_crop(seg, image_size, iterations=1000, min_frac=None, best_of=None):
30
+
31
+
32
+ best_crops = []
33
+ best_crop_not_ok = float('-inf'), None, None
34
+ min_sum = 0
35
+
36
+ seg = seg.astype('bool')
37
+
38
+ if min_frac is not None:
39
+ #min_sum = seg.sum() * min_frac
40
+ min_sum = seg.shape[0] * seg.shape[1] * min_frac
41
+
42
+ for iteration in range(iterations):
43
+ sl_y, sl_x = random_crop_slices(seg.shape, image_size)
44
+ seg_ = seg[sl_y, sl_x]
45
+ sum_seg_ = seg_.sum()
46
+
47
+ if sum_seg_ > min_sum:
48
+
49
+ if best_of is None:
50
+ return sl_y, sl_x, False
51
+ else:
52
+ best_crops += [(sum_seg_, sl_y, sl_x)]
53
+ if len(best_crops) >= best_of:
54
+ best_crops.sort(key=lambda x:x[0], reverse=True)
55
+ sl_y, sl_x = best_crops[0][1:]
56
+
57
+ return sl_y, sl_x, False
58
+
59
+ else:
60
+ if sum_seg_ > best_crop_not_ok[0]:
61
+ best_crop_not_ok = sum_seg_, sl_y, sl_x
62
+
63
+ else:
64
+ # return best segmentation found
65
+ return best_crop_not_ok[1:] + (best_crop_not_ok[0] <= min_sum,)
66
+
67
+
68
+ class PhraseCut(object):
69
+
70
+ def __init__(self, split, image_size=400, negative_prob=0, aug=None, aug_color=False, aug_crop=True,
71
+ min_size=0, remove_classes=None, with_visual=False, only_visual=False, mask=None):
72
+ super().__init__()
73
+
74
+ self.negative_prob = negative_prob
75
+ self.image_size = image_size
76
+ self.with_visual = with_visual
77
+ self.only_visual = only_visual
78
+ self.phrase_form = '{}'
79
+ self.mask = mask
80
+ self.aug_crop = aug_crop
81
+
82
+ if aug_color:
83
+ self.aug_color = transforms.Compose([
84
+ transforms.ColorJitter(0.5, 0.5, 0.2, 0.05),
85
+ ])
86
+ else:
87
+ self.aug_color = None
88
+
89
+ get_from_repository('PhraseCut', ['PhraseCut.tar'], integrity_check=lambda local_dir: all([
90
+ isdir(join(local_dir, 'VGPhraseCut_v0')),
91
+ isdir(join(local_dir, 'VGPhraseCut_v0', 'images')),
92
+ isfile(join(local_dir, 'VGPhraseCut_v0', 'refer_train.json')),
93
+ len(os.listdir(join(local_dir, 'VGPhraseCut_v0', 'images'))) in {108250, 108249}
94
+ ]))
95
+
96
+ from third_party.PhraseCutDataset.utils.refvg_loader import RefVGLoader
97
+ self.refvg_loader = RefVGLoader(split=split)
98
+
99
+ # img_ids where the size in the annotations does not match actual size
100
+ invalid_img_ids = set([150417, 285665, 498246, 61564, 285743, 498269, 498010, 150516, 150344, 286093, 61530,
101
+ 150333, 286065, 285814, 498187, 285761, 498042])
102
+
103
+ mean = [0.485, 0.456, 0.406]
104
+ std = [0.229, 0.224, 0.225]
105
+ self.normalize = transforms.Normalize(mean, std)
106
+
107
+ self.sample_ids = [(i, j)
108
+ for i in self.refvg_loader.img_ids
109
+ for j in range(len(self.refvg_loader.get_img_ref_data(i)['phrases']))
110
+ if i not in invalid_img_ids]
111
+
112
+
113
+ # self.all_phrases = list(set([p for i in self.refvg_loader.img_ids for p in self.refvg_loader.get_img_ref_data(i)['phrases']]))
114
+
115
+ from nltk.stem import WordNetLemmatizer
116
+ wnl = WordNetLemmatizer()
117
+
118
+ # Filter by class (if remove_classes is set)
119
+ if remove_classes is None:
120
+ pass
121
+ else:
122
+ from datasets.generate_lvis_oneshot import PASCAL_SYNSETS, traverse_lemmas, traverse_lemmas_hypo
123
+ from nltk.corpus import wordnet
124
+
125
+ print('remove pascal classes...')
126
+
127
+ get_data = self.refvg_loader.get_img_ref_data # shortcut
128
+ keep_sids = None
129
+
130
+ if remove_classes[0] == 'pas5i':
131
+ subset_id = remove_classes[1]
132
+ from datasets.generate_lvis_oneshot import PASCAL_5I_SYNSETS_ORDERED, PASCAL_5I_CLASS_IDS
133
+ avoid = [PASCAL_5I_SYNSETS_ORDERED[i] for i in range(20) if i+1 not in PASCAL_5I_CLASS_IDS[subset_id]]
134
+
135
+
136
+ elif remove_classes[0] == 'zs':
137
+ stop = remove_classes[1]
138
+
139
+ from datasets.pascal_zeroshot import PASCAL_VOC_CLASSES_ZS
140
+
141
+ avoid = [c for class_set in PASCAL_VOC_CLASSES_ZS[:stop] for c in class_set]
142
+ print(avoid)
143
+
144
+ elif remove_classes[0] == 'aff':
145
+ # avoid = ['drink.v.01', 'sit.v.01', 'ride.v.02']
146
+ # all_lemmas = set(['drink', 'sit', 'ride'])
147
+ avoid = ['drink', 'drinks', 'drinking', 'sit', 'sits', 'sitting',
148
+ 'ride', 'rides', 'riding',
149
+ 'fly', 'flies', 'flying', 'drive', 'drives', 'driving', 'driven',
150
+ 'swim', 'swims', 'swimming',
151
+ 'wheels', 'wheel', 'legs', 'leg', 'ear', 'ears']
152
+ keep_sids = [(i, j) for i, j in self.sample_ids if
153
+ all(x not in avoid for x in get_data(i)['phrases'][j].split(' '))]
154
+
155
+ print('avoid classes:', avoid)
156
+
157
+
158
+ if keep_sids is None:
159
+ all_lemmas = [s for ps in avoid for s in traverse_lemmas_hypo(wordnet.synset(ps), max_depth=None)]
160
+ all_lemmas = list(set(all_lemmas))
161
+ all_lemmas = [h.replace('_', ' ').lower() for h in all_lemmas]
162
+ all_lemmas = set(all_lemmas)
163
+
164
+ # divide into multi word and single word
165
+ all_lemmas_s = set(l for l in all_lemmas if ' ' not in l)
166
+ all_lemmas_m = set(l for l in all_lemmas if l not in all_lemmas_s)
167
+
168
+ # new3
169
+ phrases = [get_data(i)['phrases'][j] for i, j in self.sample_ids]
170
+ remove_sids = set((i,j) for (i,j), phrase in zip(self.sample_ids, phrases)
171
+ if any(l in phrase for l in all_lemmas_m) or
172
+ len(set(wnl.lemmatize(w) for w in phrase.split(' ')).intersection(all_lemmas_s)) > 0
173
+ )
174
+ keep_sids = [(i, j) for i, j in self.sample_ids if (i,j) not in remove_sids]
175
+
176
+ print(f'Reduced to {len(keep_sids) / len(self.sample_ids):.3f}')
177
+ removed_ids = set(self.sample_ids) - set(keep_sids)
178
+
179
+ print('Examples of removed', len(removed_ids))
180
+ for i, j in list(removed_ids)[:20]:
181
+ print(i, get_data(i)['phrases'][j])
182
+
183
+ self.sample_ids = keep_sids
184
+
185
+ from itertools import groupby
186
+ samples_by_phrase = [(self.refvg_loader.get_img_ref_data(i)['phrases'][j], (i, j))
187
+ for i, j in self.sample_ids]
188
+ samples_by_phrase = sorted(samples_by_phrase)
189
+ samples_by_phrase = groupby(samples_by_phrase, key=lambda x: x[0])
190
+
191
+ self.samples_by_phrase = {prompt: [s[1] for s in prompt_sample_ids] for prompt, prompt_sample_ids in samples_by_phrase}
192
+
193
+ self.all_phrases = list(set(self.samples_by_phrase.keys()))
194
+
195
+
196
+ if self.only_visual:
197
+ assert self.with_visual
198
+ self.sample_ids = [(i, j) for i, j in self.sample_ids
199
+ if len(self.samples_by_phrase[self.refvg_loader.get_img_ref_data(i)['phrases'][j]]) > 1]
200
+
201
+ # Filter by size (if min_size is set)
202
+ sizes = [self.refvg_loader.get_img_ref_data(i)['gt_boxes'][j] for i, j in self.sample_ids]
203
+ image_sizes = [self.refvg_loader.get_img_ref_data(i)['width'] * self.refvg_loader.get_img_ref_data(i)['height'] for i, j in self.sample_ids]
204
+ #self.sizes = [sum([(s[2] - s[0]) * (s[3] - s[1]) for s in size]) for size in sizes]
205
+ self.sizes = [sum([s[2] * s[3] for s in size]) / img_size for size, img_size in zip(sizes, image_sizes)]
206
+
207
+ if min_size:
208
+ print('filter by size')
209
+
210
+ self.sample_ids = [self.sample_ids[i] for i in range(len(self.sample_ids)) if self.sizes[i] > min_size]
211
+
212
+ self.base_path = join(expanduser('~/datasets/PhraseCut/VGPhraseCut_v0/images/'))
213
+
214
+ def __len__(self):
215
+ return len(self.sample_ids)
216
+
217
+
218
+ def load_sample(self, sample_i, j):
219
+
220
+ img_ref_data = self.refvg_loader.get_img_ref_data(sample_i)
221
+
222
+ polys_phrase0 = img_ref_data['gt_Polygons'][j]
223
+ phrase = img_ref_data['phrases'][j]
224
+ phrase = self.phrase_form.format(phrase)
225
+
226
+ masks = []
227
+ for polys in polys_phrase0:
228
+ for poly in polys:
229
+ poly = [p[::-1] for p in poly] # swap x,y
230
+ masks += [polygon2mask((img_ref_data['height'], img_ref_data['width']), poly)]
231
+
232
+ seg = np.stack(masks).max(0)
233
+ img = np.array(Image.open(join(self.base_path, str(img_ref_data['image_id']) + '.jpg')))
234
+
235
+ min_shape = min(img.shape[:2])
236
+
237
+ if self.aug_crop:
238
+ sly, slx, exceed = find_crop(seg, (min_shape, min_shape), iterations=50, min_frac=0.05)
239
+ else:
240
+ sly, slx = slice(0, None), slice(0, None)
241
+
242
+ seg = seg[sly, slx]
243
+ img = img[sly, slx]
244
+
245
+ seg = seg.astype('uint8')
246
+ seg = torch.from_numpy(seg).view(1, 1, *seg.shape)
247
+
248
+ if img.ndim == 2:
249
+ img = np.dstack([img] * 3)
250
+
251
+ img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).float()
252
+
253
+ seg = nnf.interpolate(seg, (self.image_size, self.image_size), mode='nearest')[0,0]
254
+ img = nnf.interpolate(img, (self.image_size, self.image_size), mode='bilinear', align_corners=True)[0]
255
+
256
+ # img = img.permute([2,0, 1])
257
+ img = img / 255.0
258
+
259
+ if self.aug_color is not None:
260
+ img = self.aug_color(img)
261
+
262
+ img = self.normalize(img)
263
+
264
+
265
+
266
+ return img, seg, phrase
267
+
268
+ def __getitem__(self, i):
269
+
270
+ sample_i, j = self.sample_ids[i]
271
+
272
+ img, seg, phrase = self.load_sample(sample_i, j)
273
+
274
+ if self.negative_prob > 0:
275
+ if torch.rand((1,)).item() < self.negative_prob:
276
+
277
+ new_phrase = None
278
+ while new_phrase is None or new_phrase == phrase:
279
+ idx = torch.randint(0, len(self.all_phrases), (1,)).item()
280
+ new_phrase = self.all_phrases[idx]
281
+ phrase = new_phrase
282
+ seg = torch.zeros_like(seg)
283
+
284
+ if self.with_visual:
285
+ # find a corresponding visual image
286
+ if phrase in self.samples_by_phrase and len(self.samples_by_phrase[phrase]) > 1:
287
+ idx = torch.randint(0, len(self.samples_by_phrase[phrase]), (1,)).item()
288
+ other_sample = self.samples_by_phrase[phrase][idx]
289
+ #print(other_sample)
290
+ img_s, seg_s, _ = self.load_sample(*other_sample)
291
+
292
+ from datasets.utils import blend_image_segmentation
293
+
294
+ if self.mask in {'separate', 'text_and_separate'}:
295
+ # assert img.shape[1:] == img_s.shape[1:] == seg_s.shape == seg.shape[1:]
296
+ add_phrase = [phrase] if self.mask == 'text_and_separate' else []
297
+ vis_s = add_phrase + [img_s, seg_s, True]
298
+ else:
299
+ if self.mask.startswith('text_and_'):
300
+ mask_mode = self.mask[9:]
301
+ label_add = [phrase]
302
+ else:
303
+ mask_mode = self.mask
304
+ label_add = []
305
+
306
+ masked_img_s = torch.from_numpy(blend_image_segmentation(img_s, seg_s, mode=mask_mode, image_size=self.image_size)[0])
307
+ vis_s = label_add + [masked_img_s, True]
308
+
309
+ else:
310
+ # phrase is unique
311
+ vis_s = torch.zeros_like(img)
312
+
313
+ if self.mask in {'separate', 'text_and_separate'}:
314
+ add_phrase = [phrase] if self.mask == 'text_and_separate' else []
315
+ vis_s = add_phrase + [vis_s, torch.zeros(*vis_s.shape[1:], dtype=torch.uint8), False]
316
+ elif self.mask.startswith('text_and_'):
317
+ vis_s = [phrase, vis_s, False]
318
+ else:
319
+ vis_s = [vis_s, False]
320
+ else:
321
+ assert self.mask == 'text'
322
+ vis_s = [phrase]
323
+
324
+ seg = seg.unsqueeze(0).float()
325
+
326
+ data_x = (img,) + tuple(vis_s)
327
+
328
+ return data_x, (seg, torch.zeros(0), i)
329
+
330
+
331
+ class PhraseCutPlus(PhraseCut):
332
+
333
+ def __init__(self, split, image_size=400, aug=None, aug_color=False, aug_crop=True, min_size=0, remove_classes=None, only_visual=False, mask=None):
334
+ super().__init__(split, image_size=image_size, negative_prob=0.2, aug=aug, aug_color=aug_color, aug_crop=aug_crop, min_size=min_size,
335
+ remove_classes=remove_classes, with_visual=True, only_visual=only_visual, mask=mask)
clipseg/datasets/utils.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import torch
4
+
5
+
6
+ def blend_image_segmentation(img, seg, mode, image_size=224):
7
+
8
+
9
+ if mode in {'blur_highlight', 'blur3_highlight', 'blur3_highlight01', 'blur_highlight_random', 'crop'}:
10
+ if isinstance(img, np.ndarray):
11
+ img = torch.from_numpy(img)
12
+
13
+ if isinstance(seg, np.ndarray):
14
+ seg = torch.from_numpy(seg)
15
+
16
+ if mode == 'overlay':
17
+ out = img * seg
18
+ out = [out.astype('float32')]
19
+ elif mode == 'highlight':
20
+ out = img * seg[None, :, :] * 0.85 + 0.15 * img
21
+ out = [out.astype('float32')]
22
+ elif mode == 'highlight2':
23
+ img = img / 2
24
+ out = (img+0.1) * seg[None, :, :] + 0.3 * img
25
+ out = [out.astype('float32')]
26
+ elif mode == 'blur_highlight':
27
+ from evaluation_utils import img_preprocess
28
+ out = [img_preprocess((None, [img], [seg]), blur=1, bg_fac=0.5).numpy()[0] - 0.01]
29
+ elif mode == 'blur3_highlight':
30
+ from evaluation_utils import img_preprocess
31
+ out = [img_preprocess((None, [img], [seg]), blur=3, bg_fac=0.5).numpy()[0] - 0.01]
32
+ elif mode == 'blur3_highlight01':
33
+ from evaluation_utils import img_preprocess
34
+ out = [img_preprocess((None, [img], [seg]), blur=3, bg_fac=0.1).numpy()[0] - 0.01]
35
+ elif mode == 'blur_highlight_random':
36
+ from evaluation_utils import img_preprocess
37
+ out = [img_preprocess((None, [img], [seg]), blur=0 + torch.randint(0, 3, (1,)).item(), bg_fac=0.1 + 0.8*torch.rand(1).item()).numpy()[0] - 0.01]
38
+ elif mode == 'crop':
39
+ from evaluation_utils import img_preprocess
40
+ out = [img_preprocess((None, [img], [seg]), blur=1, center_context=0.1, image_size=image_size)[0].numpy()]
41
+ elif mode == 'crop_blur_highlight':
42
+ from evaluation_utils import img_preprocess
43
+ out = [img_preprocess((None, [img], [seg]), blur=3, center_context=0.1, bg_fac=0.1, image_size=image_size)[0].numpy()]
44
+ elif mode == 'crop_blur_highlight352':
45
+ from evaluation_utils import img_preprocess
46
+ out = [img_preprocess((None, [img], [seg]), blur=3, center_context=0.1, bg_fac=0.1, image_size=352)[0].numpy()]
47
+ elif mode == 'shape':
48
+ out = [np.stack([seg[:, :]]*3).astype('float32')]
49
+ elif mode == 'concat':
50
+ out = [np.concatenate([img, seg[None, :, :]]).astype('float32')]
51
+ elif mode == 'image_only':
52
+ out = [img.astype('float32')]
53
+ elif mode == 'image_black':
54
+ out = [img.astype('float32')*0]
55
+ elif mode is None:
56
+ out = [img.astype('float32')]
57
+ elif mode == 'separate':
58
+ out = [img.astype('float32'), seg.astype('int64')]
59
+ elif mode == 'separate_img_black':
60
+ out = [img.astype('float32')*0, seg.astype('int64')]
61
+ elif mode == 'separate_seg_ones':
62
+ out = [img.astype('float32'), np.ones_like(seg).astype('int64')]
63
+ elif mode == 'separate_both_black':
64
+ out = [img.astype('float32')*0, seg.astype('int64')*0]
65
+ else:
66
+ raise ValueError(f'invalid mode: {mode}')
67
+
68
+ return out
clipseg/environment.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: clipseg-environment
2
+ channels:
3
+ - conda-forge
4
+ - pytorch
5
+ dependencies:
6
+ - numpy
7
+ - scipy
8
+ - matplotlib-base
9
+ - pip
10
+ - pip:
11
+ - --find-links https://download.pytorch.org/whl/torch_stable.html
12
+ - torch==1.10.0+cpu
13
+ - torchvision==0.11.1+cpu
14
+ - opencv-python
15
+ - git+https://github.com/openai/CLIP.git
clipseg/evaluation_utils.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.functional import Tensor
2
+ from general_utils import load_model
3
+ from torch.utils.data import DataLoader
4
+ import torch
5
+ import numpy as np
6
+
7
+ def denorm(img):
8
+
9
+ np_input = False
10
+ if isinstance(img, np.ndarray):
11
+ img = torch.from_numpy(img)
12
+ np_input = True
13
+
14
+ mean = torch.Tensor([0.485, 0.456, 0.406])
15
+ std = torch.Tensor([0.229, 0.224, 0.225])
16
+
17
+ img_denorm = (img*std[:,None,None]) + mean[:,None,None]
18
+
19
+ if np_input:
20
+ img_denorm = np.clip(img_denorm.numpy(), 0, 1)
21
+ else:
22
+ img_denorm = torch.clamp(img_denorm, 0, 1)
23
+
24
+ return img_denorm
25
+
26
+
27
+ def norm(img):
28
+ mean = torch.Tensor([0.485, 0.456, 0.406])
29
+ std = torch.Tensor([0.229, 0.224, 0.225])
30
+ return (img - mean[:,None,None]) / std[:,None,None]
31
+
32
+
33
+ def fast_iou_curve(p, g):
34
+
35
+ g = g[p.sort().indices]
36
+ p = torch.sigmoid(p.sort().values)
37
+
38
+ scores = []
39
+ vals = np.linspace(0, 1, 50)
40
+
41
+ for q in vals:
42
+
43
+ n = int(len(g) * q)
44
+
45
+ valid = torch.where(p > q)[0]
46
+ if len(valid) > 0:
47
+ n = int(valid[0])
48
+ else:
49
+ n = len(g)
50
+
51
+ fn = g[:n].sum()
52
+ tn = n - fn
53
+ tp = g[n:].sum()
54
+ fp = len(g) - n - tp
55
+
56
+ iou = tp / (tp + fn + fp)
57
+
58
+ precision = tp / (tp + fp)
59
+ recall = tp / (tp + fn)
60
+
61
+ scores += [iou]
62
+
63
+ return vals, scores
64
+
65
+
66
+ def fast_rp_curve(p, g):
67
+
68
+ g = g[p.sort().indices]
69
+ p = torch.sigmoid(p.sort().values)
70
+
71
+ precisions, recalls = [], []
72
+ vals = np.linspace(p.min(), p.max(), 250)
73
+
74
+ for q in p[::100000]:
75
+
76
+ n = int(len(g) * q)
77
+
78
+ valid = torch.where(p > q)[0]
79
+ if len(valid) > 0:
80
+ n = int(valid[0])
81
+ else:
82
+ n = len(g)
83
+
84
+ fn = g[:n].sum()
85
+ tn = n - fn
86
+ tp = g[n:].sum()
87
+ fp = len(g) - n - tp
88
+
89
+ iou = tp / (tp + fn + fp)
90
+
91
+ precision = tp / (tp + fp)
92
+ recall = tp / (tp + fn)
93
+
94
+ precisions += [precision]
95
+ recalls += [recall]
96
+
97
+ return recalls, precisions
98
+
99
+
100
+ # Image processing
101
+
102
+ def img_preprocess(batch, blur=0, grayscale=False, center_context=None, rect=False, rect_color=(255,0,0), rect_width=2,
103
+ brightness=1.0, bg_fac=1, colorize=False, outline=False, image_size=224):
104
+ import cv2
105
+
106
+ rw = rect_width
107
+
108
+ out = []
109
+ for img, mask in zip(batch[1], batch[2]):
110
+
111
+ img = img.cpu() if isinstance(img, torch.Tensor) else torch.from_numpy(img)
112
+ mask = mask.cpu() if isinstance(mask, torch.Tensor) else torch.from_numpy(mask)
113
+
114
+ img *= brightness
115
+ img_bl = img
116
+ if blur > 0: # best 5
117
+ img_bl = torch.from_numpy(cv2.GaussianBlur(img.permute(1,2,0).numpy(), (15, 15), blur)).permute(2,0,1)
118
+
119
+ if grayscale:
120
+ img_bl = img_bl[1][None]
121
+
122
+ #img_inp = img_ratio*img*mask + (1-img_ratio)*img_bl
123
+ # img_inp = img_ratio*img*mask + (1-img_ratio)*img_bl * (1-mask)
124
+ img_inp = img*mask + (bg_fac) * img_bl * (1-mask)
125
+
126
+ if rect:
127
+ _, bbox = crop_mask(img, mask, context=0.1)
128
+ img_inp[:, bbox[2]: bbox[3], max(0, bbox[0]-rw):bbox[0]+rw] = torch.tensor(rect_color)[:,None,None]
129
+ img_inp[:, bbox[2]: bbox[3], max(0, bbox[1]-rw):bbox[1]+rw] = torch.tensor(rect_color)[:,None,None]
130
+ img_inp[:, max(0, bbox[2]-1): bbox[2]+rw, bbox[0]:bbox[1]] = torch.tensor(rect_color)[:,None,None]
131
+ img_inp[:, max(0, bbox[3]-1): bbox[3]+rw, bbox[0]:bbox[1]] = torch.tensor(rect_color)[:,None,None]
132
+
133
+
134
+ if center_context is not None:
135
+ img_inp = object_crop(img_inp, mask, context=center_context, image_size=image_size)
136
+
137
+ if colorize:
138
+ img_gray = denorm(img)
139
+ img_gray = cv2.cvtColor(img_gray.permute(1,2,0).numpy(), cv2.COLOR_RGB2GRAY)
140
+ img_gray = torch.stack([torch.from_numpy(img_gray)]*3)
141
+ img_inp = torch.tensor([1,0.2,0.2])[:,None,None] * img_gray * mask + bg_fac * img_gray * (1-mask)
142
+ img_inp = norm(img_inp)
143
+
144
+ if outline:
145
+ cont = cv2.findContours(mask.byte().numpy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
146
+ outline_img = np.zeros(mask.shape, dtype=np.uint8)
147
+ cv2.drawContours(outline_img, cont[0], -1, thickness=5, color=(255, 255, 255))
148
+ outline_img = torch.stack([torch.from_numpy(outline_img)]*3).float() / 255.
149
+ img_inp = torch.tensor([1,0,0])[:,None,None] * outline_img + denorm(img_inp) * (1- outline_img)
150
+ img_inp = norm(img_inp)
151
+
152
+ out += [img_inp]
153
+
154
+ return torch.stack(out)
155
+
156
+
157
+ def object_crop(img, mask, context=0.0, square=False, image_size=224):
158
+ img_crop, bbox = crop_mask(img, mask, context=context, square=square)
159
+ img_crop = pad_to_square(img_crop, channel_dim=0)
160
+ img_crop = torch.nn.functional.interpolate(img_crop.unsqueeze(0), (image_size, image_size)).squeeze(0)
161
+ return img_crop
162
+
163
+
164
+ def crop_mask(img, mask, context=0.0, square=False):
165
+
166
+ assert img.shape[1:] == mask.shape
167
+
168
+ bbox = [mask.max(0).values.argmax(), mask.size(0) - mask.max(0).values.flip(0).argmax()]
169
+ bbox += [mask.max(1).values.argmax(), mask.size(1) - mask.max(1).values.flip(0).argmax()]
170
+ bbox = [int(x) for x in bbox]
171
+
172
+ width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0])
173
+
174
+ # square mask
175
+ if square:
176
+ bbox[0] = int(max(0, bbox[0] - context * height))
177
+ bbox[1] = int(min(mask.size(0), bbox[1] + context * height))
178
+ bbox[2] = int(max(0, bbox[2] - context * width))
179
+ bbox[3] = int(min(mask.size(1), bbox[3] + context * width))
180
+
181
+ width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0])
182
+ if height > width:
183
+ bbox[2] = int(max(0, (bbox[2] - 0.5*height)))
184
+ bbox[3] = bbox[2] + height
185
+ else:
186
+ bbox[0] = int(max(0, (bbox[0] - 0.5*width)))
187
+ bbox[1] = bbox[0] + width
188
+ else:
189
+ bbox[0] = int(max(0, bbox[0] - context * height))
190
+ bbox[1] = int(min(mask.size(0), bbox[1] + context * height))
191
+ bbox[2] = int(max(0, bbox[2] - context * width))
192
+ bbox[3] = int(min(mask.size(1), bbox[3] + context * width))
193
+
194
+ width, height = (bbox[3] - bbox[2]), (bbox[1] - bbox[0])
195
+ img_crop = img[:, bbox[2]: bbox[3], bbox[0]: bbox[1]]
196
+ return img_crop, bbox
197
+
198
+
199
+ def pad_to_square(img, channel_dim=2, fill=0):
200
+ """
201
+
202
+
203
+ add padding such that a squared image is returned """
204
+
205
+ from torchvision.transforms.functional import pad
206
+
207
+ if channel_dim == 2:
208
+ img = img.permute(2, 0, 1)
209
+ elif channel_dim == 0:
210
+ pass
211
+ else:
212
+ raise ValueError('invalid channel_dim')
213
+
214
+ h, w = img.shape[1:]
215
+ pady1 = pady2 = padx1 = padx2 = 0
216
+
217
+ if h > w:
218
+ padx1 = (h - w) // 2
219
+ padx2 = h - w - padx1
220
+ elif w > h:
221
+ pady1 = (w - h) // 2
222
+ pady2 = w - h - pady1
223
+
224
+ img_padded = pad(img, padding=(padx1, pady1, padx2, pady2), padding_mode='constant')
225
+
226
+ if channel_dim == 2:
227
+ img_padded = img_padded.permute(1, 2, 0)
228
+
229
+ return img_padded
230
+
231
+
232
+ # qualitative
233
+
234
+ def split_sentence(inp, limit=9):
235
+ t_new, current_len = [], 0
236
+ for k, t in enumerate(inp.split(' ')):
237
+ current_len += len(t) + 1
238
+ t_new += [t+' ']
239
+ # not last
240
+ if current_len > limit and k != len(inp.split(' ')) - 1:
241
+ current_len = 0
242
+ t_new += ['\n']
243
+
244
+ t_new = ''.join(t_new)
245
+ return t_new
246
+
247
+
248
+ from matplotlib import pyplot as plt
249
+
250
+
251
+ def plot(imgs, *preds, labels=None, scale=1, cmap=plt.cm.magma, aps=None, gt_labels=None, vmax=None):
252
+
253
+ row_off = 0 if labels is None else 1
254
+ _, ax = plt.subplots(len(imgs) + row_off, 1 + len(preds), figsize=(scale * float(1 + 2*len(preds)), scale * float(len(imgs)*2)))
255
+ [a.axis('off') for a in ax.flatten()]
256
+
257
+ if labels is not None:
258
+ for j in range(len(labels)):
259
+ t_new = split_sentence(labels[j], limit=6)
260
+ ax[0, 1+ j].text(0.5, 0.1, t_new, ha='center', fontsize=3+ 10*scale)
261
+
262
+
263
+ for i in range(len(imgs)):
264
+ ax[i + row_off,0].imshow(imgs[i])
265
+ for j in range(len(preds)):
266
+ img = preds[j][i][0].detach().cpu().numpy()
267
+
268
+ if gt_labels is not None and labels[j] == gt_labels[i]:
269
+ print(j, labels[j], gt_labels[i])
270
+ edgecolor = 'red'
271
+ if aps is not None:
272
+ ax[i + row_off, 1 + j].text(30, 70, f'AP: {aps[i]:.3f}', color='red', fontsize=8)
273
+ else:
274
+ edgecolor = 'k'
275
+
276
+ rect = plt.Rectangle([0,0], img.shape[0], img.shape[1], facecolor="none",
277
+ edgecolor=edgecolor, linewidth=3)
278
+ ax[i + row_off,1 + j].add_patch(rect)
279
+
280
+ if vmax is None:
281
+ this_vmax = 1
282
+ elif vmax == 'per_prompt':
283
+ this_vmax = max([preds[j][_i][0].max() for _i in range(len(imgs))])
284
+ elif vmax == 'per_image':
285
+ this_vmax = max([preds[_j][i][0].max() for _j in range(len(preds))])
286
+
287
+ ax[i + row_off,1 + j].imshow(img, vmin=0, vmax=this_vmax, cmap=cmap)
288
+
289
+
290
+ # ax[i,1 + j].imshow(preds[j][i][0].detach().cpu().numpy(), vmin=preds[j].min(), vmax=preds[j].max())
291
+ plt.tight_layout()
292
+ plt.subplots_adjust(wspace=0.05, hspace=0.05)
clipseg/example_image.jpg ADDED
clipseg/experiments/ablation.yaml ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ configuration:
2
+ batch_size: 64
3
+ optimizer: torch.optim.AdamW
4
+
5
+ lr: 0.001
6
+
7
+ trainer: experiment_setup.train_loop
8
+ scorer: experiment_setup.score
9
+ model: models.clipseg.CLIPDensePredT
10
+
11
+ lr_scheduler: cosine
12
+ T_max: 20000
13
+ eta_min: 0.0001
14
+
15
+ max_iterations: 20000 # <-##########################################
16
+ val_interval: null
17
+
18
+ # dataset
19
+ dataset: datasets.phrasecut.PhraseCut # <-----------------
20
+ split_mode: pascal_test
21
+ split: train
22
+ mask: text_and_crop_blur_highlight352
23
+ image_size: 352
24
+ negative_prob: 0.2
25
+ mix_text_max: 0.5
26
+
27
+ # general
28
+ mix: True # <-----------------
29
+ prompt: shuffle+
30
+ norm_cond: True
31
+ mix_text_min: 0.0
32
+ with_visual: True
33
+
34
+ # model
35
+ version: 'ViT-B/16'
36
+ extract_layers: [3, 7, 9]
37
+ reduce_dim: 64
38
+ depth: 3
39
+ fix_shift: False # <-##########################################
40
+
41
+ loss: torch.nn.functional.binary_cross_entropy_with_logits
42
+ amp: True
43
+
44
+ test_configuration_common:
45
+ normalize: True
46
+ image_size: 352
47
+ batch_size: 32
48
+ sigmoid: True
49
+ split: test
50
+ label_support: True
51
+
52
+ test_configuration:
53
+
54
+ -
55
+ name: pc
56
+ metric: metrics.FixedIntervalMetrics
57
+ test_dataset: phrasecut
58
+ mask: text
59
+
60
+ -
61
+ name: pc-vis
62
+ metric: metrics.FixedIntervalMetrics
63
+ test_dataset: phrasecut
64
+ mask: crop_blur_highlight352
65
+ with_visual: True
66
+ visual_only: True
67
+
68
+
69
+ columns: [name,
70
+ pc_fgiou_best, pc_miou_best, pc_fgiou_0.5,
71
+ pc-vis_fgiou_best, pc-vis_miou_best, pc-vis_fgiou_0.5,
72
+ duration]
73
+
74
+
75
+ individual_configurations:
76
+
77
+ - {name: rd64-uni}
78
+ - {name: rd64-no-pretrain, not_pretrained: True, lr: 0.0003}
79
+ - {name: rd64-no-negatives, negative_prob: 0.0}
80
+ - {name: rd64-neg0.5, negative_prob: 0.5}
81
+ - {name: rd64-no-visual, with_visual: False, mix: False}
82
+ - {name: rd16-uni, reduce_dim: 16}
83
+ - {name: rd64-layer3, extract_layers: [3], depth: 1}
84
+ - {name: rd64-blur-highlight, mask: text_and_blur_highlight, test_configuration: {mask: blur_highlight}}
clipseg/experiments/coco.yaml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ configuration:
2
+ batch_size: 64
3
+ optimizer: torch.optim.AdamW
4
+
5
+ lr: 0.001
6
+
7
+ trainer: experiment_setup.train_loop
8
+ scorer: experiment_setup.score
9
+ model: models.clipseg.CLIPDensePredT
10
+
11
+ lr_scheduler: cosine
12
+ T_max: 20000
13
+ eta_min: 0.0001
14
+
15
+ max_iterations: 20000
16
+ val_interval: null
17
+
18
+ # dataset
19
+ dataset: datasets.coco_wrapper.COCOWrapper
20
+ # split_mode: pascal_test
21
+ split: train
22
+ mask: text_and_blur3_highlight01
23
+ image_size: 352
24
+ normalize: True
25
+ pre_crop_image_size: [sample, 1, 1.5]
26
+ aug: 1new
27
+
28
+ # general
29
+ mix: True
30
+ prompt: shuffle+
31
+ norm_cond: True
32
+ mix_text_min: 0.0
33
+
34
+ # model
35
+ out: 1
36
+ extract_layers: [3, 7, 9]
37
+ reduce_dim: 64
38
+ depth: 3
39
+ fix_shift: False
40
+
41
+ loss: torch.nn.functional.binary_cross_entropy_with_logits
42
+ amp: True
43
+
44
+ test_configuration_common:
45
+ normalize: True
46
+ image_size: 352
47
+ # max_iterations: 10
48
+ batch_size: 8
49
+ sigmoid: True
50
+ test_dataset: coco
51
+ metric: metrics.FixedIntervalMetrics
52
+
53
+ test_configuration:
54
+
55
+ -
56
+ name: coco_t
57
+ mask: text
58
+
59
+ -
60
+ name: coco_h
61
+ mask: blur3_highlight01
62
+
63
+ -
64
+ name: coco_h2
65
+ mask: crop_blur_highlight352
66
+
67
+
68
+ columns: [i, name,
69
+ coco_t_fgiou_best, coco_t_miou_best, coco_t_fgiou_0.5,
70
+ coco_h_fgiou_best, coco_h_miou_best, coco_h_fgiou_0.5,
71
+ coco_h2_fgiou_best, coco_h2_miou_best, coco_h2_fgiou_0.5, coco_h2_fgiou_best_t,
72
+ train_loss, duration, date
73
+ ]
74
+
75
+ individual_configurations:
76
+
77
+
78
+ - {name: rd64-7K-vit16-cbh-coco-0, version: 'ViT-B/16', fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
79
+ - {name: rd64-7K-vit16-cbh-coco-1, version: 'ViT-B/16', fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
80
+ - {name: rd64-7K-vit16-cbh-coco-2, version: 'ViT-B/16', fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
81
+ - {name: rd64-7K-vit16-cbh-coco-3, version: 'ViT-B/16', fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
82
+
83
+
84
+ - {name: rd64-7K-vit16-cbh-neg0.2-coco-0, version: 'ViT-B/16', negative_prob: 0.2, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
85
+ - {name: rd64-7K-vit16-cbh-neg0.2-coco-1, version: 'ViT-B/16', negative_prob: 0.2, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
86
+ - {name: rd64-7K-vit16-cbh-neg0.2-coco-2, version: 'ViT-B/16', negative_prob: 0.2, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
87
+ - {name: rd64-7K-vit16-cbh-neg0.2-coco-3, version: 'ViT-B/16', negative_prob: 0.2, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
88
+
89
+
90
+ # ViT
91
+ - {name: vit64-7K-vit16-cbh-coco-0, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001}
92
+ - {name: vit64-7K-vit16-cbh-coco-1, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001}
93
+ - {name: vit64-7K-vit16-cbh-coco-2, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001}
94
+ - {name: vit64-7K-vit16-cbh-coco-3, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000, lr: 0.0001}
95
+
96
+
97
+ # BASELINE
98
+ - {name: bl64-7K-vit16-cbh-neg0.2-coco-0, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 0, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
99
+ - {name: bl64-7K-vit16-cbh-neg0.2-coco-1, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 1, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
100
+ - {name: bl64-7K-vit16-cbh-neg0.2-coco-2, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 2, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
101
+ - {name: bl64-7K-vit16-cbh-neg0.2-coco-3, model: models.clipseg.CLIPDenseBaseline, reduce2_dim: 64, version: 'ViT-B/16', negative_prob: 0.2, fold: 3, reduce_dim: 64, mask: text_and_crop_blur_highlight352, T_max: 7000, max_iterations: 7000}
clipseg/experiments/pascal_1shot.yaml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ configuration:
2
+ batch_size: 64
3
+ optimizer: torch.optim.AdamW
4
+
5
+ lr: 0.001
6
+
7
+ trainer: experiment_setup.train_loop
8
+ scorer: experiment_setup.score
9
+ model: models.clipseg.CLIPDensePredT
10
+
11
+ lr_scheduler: cosine
12
+ T_max: 20000
13
+ eta_min: 0.0001
14
+
15
+ max_iterations: 20000 # <-##########################################
16
+ val_interval: null
17
+
18
+ # dataset
19
+ dataset: datasets.phrasecut.PhraseCut
20
+ split_mode: pascal_test
21
+ mode: train
22
+ mask: text_and_crop_blur_highlight352
23
+ image_size: 352
24
+ normalize: True
25
+ pre_crop_image_size: [sample, 1, 1.5]
26
+ aug: 1new
27
+ with_visual: True
28
+ split: train
29
+
30
+ # general
31
+ mix: True
32
+ prompt: shuffle+
33
+ norm_cond: True
34
+ mix_text_min: 0.0
35
+
36
+ # model
37
+ out: 1
38
+ version: 'ViT-B/16'
39
+ extract_layers: [3, 7, 9]
40
+ reduce_dim: 64
41
+ depth: 3
42
+
43
+ loss: torch.nn.functional.binary_cross_entropy_with_logits
44
+ amp: True
45
+
46
+ test_configuration_common:
47
+ normalize: True
48
+ image_size: 352
49
+ metric: metrics.FixedIntervalMetrics
50
+ batch_size: 1
51
+ test_dataset: pascal
52
+ sigmoid: True
53
+ # max_iterations: 250
54
+
55
+ test_configuration:
56
+
57
+ -
58
+ name: pas_t
59
+ mask: text
60
+
61
+ -
62
+ name: pas_h
63
+ mask: blur3_highlight01
64
+
65
+ -
66
+ name: pas_h2
67
+ mask: crop_blur_highlight352
68
+
69
+
70
+ columns: [name,
71
+ pas_t_fgiou_best, pas_t_miou_best, pas_t_fgiou_ct,
72
+ pas_h_fgiou_best, pas_h_miou_best, pas_h_fgiou_ct,
73
+ pas_h2_fgiou_best, pas_h2_miou_best, pas_h2_fgiou_ct, pas_h2_fgiou_best_t,
74
+ train_loss, duration, date
75
+ ]
76
+
77
+ individual_configurations:
78
+
79
+ - {name: rd64-uni-phrasepas5i-0, remove_classes: [pas5i, 0], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [0], custom_threshold: 0.24}}
80
+ - {name: rd64-uni-phrasepas5i-1, remove_classes: [pas5i, 1], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [1], custom_threshold: 0.24}}
81
+ - {name: rd64-uni-phrasepas5i-2, remove_classes: [pas5i, 2], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [2], custom_threshold: 0.24}}
82
+ - {name: rd64-uni-phrasepas5i-3, remove_classes: [pas5i, 3], negative_prob: 0.2, mix_text_max: 0.5, test_configuration: {splits: [3], custom_threshold: 0.24}}
83
+
84
+
85
+ - {name: rd64-phrasepas5i-0, remove_classes: [pas5i, 0], negative_prob: 0.0, test_configuration: {splits: [0], custom_threshold: 0.28}}
86
+ - {name: rd64-phrasepas5i-1, remove_classes: [pas5i, 1], negative_prob: 0.0, test_configuration: {splits: [1], custom_threshold: 0.28}}
87
+ - {name: rd64-phrasepas5i-2, remove_classes: [pas5i, 2], negative_prob: 0.0, test_configuration: {splits: [2], custom_threshold: 0.28}}
88
+ - {name: rd64-phrasepas5i-3, remove_classes: [pas5i, 3], negative_prob: 0.0, test_configuration: {splits: [3], custom_threshold: 0.28}}
89
+
90
+
91
+ # baseline
92
+ - {name: bl64-phrasepas5i-0, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 0], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [0], custom_threshold: 0.24}}
93
+ - {name: bl64-phrasepas5i-1, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 1], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [1], custom_threshold: 0.24}}
94
+ - {name: bl64-phrasepas5i-2, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 2], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [2], custom_threshold: 0.24}}
95
+ - {name: bl64-phrasepas5i-3, model: models.clipseg.CLIPDenseBaseline, remove_classes: [pas5i, 3], reduce2_dim: 64, negative_prob: 0.0, test_configuration: {splits: [3], custom_threshold: 0.24}}
96
+
97
+ # ViT
98
+ - {name: vit64-uni-phrasepas5i-0, remove_classes: [pas5i, 0], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [0], custom_threshold: 0.02}}
99
+ - {name: vit64-uni-phrasepas5i-1, remove_classes: [pas5i, 1], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [1], custom_threshold: 0.02}}
100
+ - {name: vit64-uni-phrasepas5i-2, remove_classes: [pas5i, 2], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [2], custom_threshold: 0.02}}
101
+ - {name: vit64-uni-phrasepas5i-3, remove_classes: [pas5i, 3], model: models.vitseg.VITDensePredT, negative_prob: 0.2, mix_text_max: 0.5, lr: 0.0001, test_configuration: {splits: [3], custom_threshold: 0.02}}
clipseg/experiments/phrasecut.yaml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ configuration:
2
+ batch_size: 64
3
+ optimizer: torch.optim.AdamW
4
+
5
+ lr: 0.001
6
+
7
+ trainer: experiment_setup.train_loop
8
+ scorer: experiment_setup.score
9
+ model: models.clipseg.CLIPDensePredT
10
+
11
+ lr_scheduler: cosine
12
+ T_max: 20000
13
+ eta_min: 0.0001
14
+
15
+ max_iterations: 20000
16
+ val_interval: null
17
+
18
+ # dataset
19
+ dataset: datasets.phrasecut.PhraseCut # <-----------------
20
+ split_mode: pascal_test
21
+ split: train
22
+ mask: text_and_crop_blur_highlight352
23
+ image_size: 352
24
+ normalize: True
25
+ pre_crop_image_size: [sample, 1, 1.5]
26
+ aug: 1new
27
+
28
+ # general
29
+ mix: False # <-----------------
30
+ prompt: shuffle+
31
+ norm_cond: True
32
+ mix_text_min: 0.0
33
+
34
+ # model
35
+ out: 1
36
+ extract_layers: [3, 7, 9]
37
+ reduce_dim: 64
38
+ depth: 3
39
+ fix_shift: False
40
+
41
+ loss: torch.nn.functional.binary_cross_entropy_with_logits
42
+ amp: True
43
+
44
+ test_configuration_common:
45
+ normalize: True
46
+ image_size: 352
47
+ batch_size: 32
48
+ # max_iterations: 5
49
+ # max_iterations: 150
50
+
51
+ test_configuration:
52
+
53
+ -
54
+ name: pc # old: phrasecut
55
+ metric: metrics.FixedIntervalMetrics
56
+ test_dataset: phrasecut
57
+ split: test
58
+ mask: text
59
+ label_support: True
60
+ sigmoid: True
61
+
62
+
63
+ columns: [i, name, pc_miou_0.3, pc_fgiou_0.3, pc_fgiou_0.5, pc_ap, duration, date]
64
+
65
+
66
+ individual_configurations:
67
+
68
+ # important ones
69
+
70
+
71
+ - {name: rd64-uni, version: 'ViT-B/16', reduce_dim: 64, with_visual: True, negative_prob: 0.2, mix: True, mix_text_max: 0.5}
72
+
73
+ # this was accedentally trained using old mask
74
+ - {name: rd128-vit16-phrasecut, version: 'ViT-B/16', reduce_dim: 128, mask: text_and_blur3_highlight01}
75
+ - {name: rd64-uni-novis, version: 'ViT-B/16', reduce_dim: 64, with_visual: False, negative_prob: 0.2, mix: False}
76
+ # this was accedentally trained using old mask
77
+ - {name: baseline3-vit16-phrasecut, model: models.clipseg.CLIPDenseBaseline, version: 'ViT-B/16', reduce_dim: 64, reduce2_dim: 64, mask: text_and_blur3_highlight01}
78
+
79
+ - {name: vit64-uni, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, reduce_dim: 64, with_visual: True, only_visual: True, negative_prob: 0.2, mask: crop_blur_highlight352, lr: 0.0003}
80
+ - {name: vit64-uni-novis, version: 'ViT-B/16', model: models.vitseg.VITDensePredT, with_visual: False, reduce_dim: 64, lr: 0.0001}
clipseg/general_utils.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import inspect
3
+ import torch
4
+ import os
5
+ import sys
6
+ import yaml
7
+ from shutil import copy, copytree
8
+ from os.path import join, dirname, realpath, expanduser, isfile, isdir, basename
9
+
10
+
11
+ class Logger(object):
12
+
13
+ def __getattr__(self, k):
14
+ return print
15
+
16
+ log = Logger()
17
+
18
+ def training_config_from_cli_args():
19
+ experiment_name = sys.argv[1]
20
+ experiment_id = int(sys.argv[2])
21
+
22
+ yaml_config = yaml.load(open(f'experiments/{experiment_name}'), Loader=yaml.SafeLoader)
23
+
24
+ config = yaml_config['configuration']
25
+ config = {**config, **yaml_config['individual_configurations'][experiment_id]}
26
+ config = AttributeDict(config)
27
+ return config
28
+
29
+
30
+ def score_config_from_cli_args():
31
+ experiment_name = sys.argv[1]
32
+ experiment_id = int(sys.argv[2])
33
+
34
+
35
+ yaml_config = yaml.load(open(f'experiments/{experiment_name}'), Loader=yaml.SafeLoader)
36
+
37
+ config = yaml_config['test_configuration_common']
38
+
39
+ if type(yaml_config['test_configuration']) == list:
40
+ test_id = int(sys.argv[3])
41
+ config = {**config, **yaml_config['test_configuration'][test_id]}
42
+ else:
43
+ config = {**config, **yaml_config['test_configuration']}
44
+
45
+ if 'test_configuration' in yaml_config['individual_configurations'][experiment_id]:
46
+ config = {**config, **yaml_config['individual_configurations'][experiment_id]['test_configuration']}
47
+
48
+ train_checkpoint_id = yaml_config['individual_configurations'][experiment_id]['name']
49
+
50
+ config = AttributeDict(config)
51
+ return config, train_checkpoint_id
52
+
53
+
54
+ def get_from_repository(local_name, repo_files, integrity_check=None, repo_dir='~/dataset_repository',
55
+ local_dir='~/datasets'):
56
+ """ copies files from repository to local folder.
57
+
58
+ repo_files: list of filenames or list of tuples [filename, target path]
59
+
60
+ e.g. get_from_repository('MyDataset', [['data/dataset1.tar', 'other/path/ds03.tar'])
61
+ will create a folder 'MyDataset' in local_dir, and extract the content of
62
+ '<repo_dir>/data/dataset1.tar' to <local_dir>/MyDataset/other/path.
63
+ """
64
+
65
+ local_dir = realpath(join(expanduser(local_dir), local_name))
66
+
67
+ dataset_exists = True
68
+
69
+ # check if folder is available
70
+ if not isdir(local_dir):
71
+ dataset_exists = False
72
+
73
+ if integrity_check is not None:
74
+ try:
75
+ integrity_ok = integrity_check(local_dir)
76
+ except BaseException:
77
+ integrity_ok = False
78
+
79
+ if integrity_ok:
80
+ log.hint('Passed custom integrity check')
81
+ else:
82
+ log.hint('Custom integrity check failed')
83
+
84
+ dataset_exists = dataset_exists and integrity_ok
85
+
86
+ if not dataset_exists:
87
+
88
+ repo_dir = realpath(expanduser(repo_dir))
89
+
90
+ for i, filename in enumerate(repo_files):
91
+
92
+ if type(filename) == str:
93
+ origin, target = filename, filename
94
+ archive_target = join(local_dir, basename(origin))
95
+ extract_target = join(local_dir)
96
+ else:
97
+ origin, target = filename
98
+ archive_target = join(local_dir, dirname(target), basename(origin))
99
+ extract_target = join(local_dir, dirname(target))
100
+
101
+ archive_origin = join(repo_dir, origin)
102
+
103
+ log.hint(f'copy: {archive_origin} to {archive_target}')
104
+
105
+ # make sure the path exists
106
+ os.makedirs(dirname(archive_target), exist_ok=True)
107
+
108
+ if os.path.isfile(archive_target):
109
+ # only copy if size differs
110
+ if os.path.getsize(archive_target) != os.path.getsize(archive_origin):
111
+ log.hint(f'file exists but filesize differs: target {os.path.getsize(archive_target)} vs. origin {os.path.getsize(archive_origin)}')
112
+ copy(archive_origin, archive_target)
113
+ else:
114
+ copy(archive_origin, archive_target)
115
+
116
+ extract_archive(archive_target, extract_target, noarchive_ok=True)
117
+
118
+ # concurrent processes might have deleted the file
119
+ if os.path.isfile(archive_target):
120
+ os.remove(archive_target)
121
+
122
+
123
+ def extract_archive(filename, target_folder=None, noarchive_ok=False):
124
+ from subprocess import run, PIPE
125
+
126
+ if filename.endswith('.tgz') or filename.endswith('.tar'):
127
+ command = f'tar -xf {filename}'
128
+ command += f' -C {target_folder}' if target_folder is not None else ''
129
+ elif filename.endswith('.tar.gz'):
130
+ command = f'tar -xzf {filename}'
131
+ command += f' -C {target_folder}' if target_folder is not None else ''
132
+ elif filename.endswith('zip'):
133
+ command = f'unzip {filename}'
134
+ command += f' -d {target_folder}' if target_folder is not None else ''
135
+ else:
136
+ if noarchive_ok:
137
+ return
138
+ else:
139
+ raise ValueError(f'unsuppored file ending of {filename}')
140
+
141
+ log.hint(command)
142
+ result = run(command.split(), stdout=PIPE, stderr=PIPE)
143
+ if result.returncode != 0:
144
+ print(result.stdout, result.stderr)
145
+
146
+
147
+ class AttributeDict(dict):
148
+ """
149
+ An extended dictionary that allows access to elements as atttributes and counts
150
+ these accesses. This way, we know if some attributes were never used.
151
+ """
152
+
153
+ def __init__(self, *args, **kwargs):
154
+ from collections import Counter
155
+ super().__init__(*args, **kwargs)
156
+ self.__dict__['counter'] = Counter()
157
+
158
+ def __getitem__(self, k):
159
+ self.__dict__['counter'][k] += 1
160
+ return super().__getitem__(k)
161
+
162
+ def __getattr__(self, k):
163
+ self.__dict__['counter'][k] += 1
164
+ return super().get(k)
165
+
166
+ def __setattr__(self, k, v):
167
+ return super().__setitem__(k, v)
168
+
169
+ def __delattr__(self, k, v):
170
+ return super().__delitem__(k, v)
171
+
172
+ def unused_keys(self, exceptions=()):
173
+ return [k for k in super().keys() if self.__dict__['counter'][k] == 0 and k not in exceptions]
174
+
175
+ def assume_no_unused_keys(self, exceptions=()):
176
+ if len(self.unused_keys(exceptions=exceptions)) > 0:
177
+ log.warning('Unused keys:', self.unused_keys(exceptions=exceptions))
178
+
179
+
180
+ def get_attribute(name):
181
+ import importlib
182
+
183
+ if name is None:
184
+ raise ValueError('The provided attribute is None')
185
+
186
+ name_split = name.split('.')
187
+ mod = importlib.import_module('.'.join(name_split[:-1]))
188
+ return getattr(mod, name_split[-1])
189
+
190
+
191
+
192
+ def filter_args(input_args, default_args):
193
+
194
+ updated_args = {k: input_args[k] if k in input_args else v for k, v in default_args.items()}
195
+ used_args = {k: v for k, v in input_args.items() if k in default_args}
196
+ unused_args = {k: v for k, v in input_args.items() if k not in default_args}
197
+
198
+ return AttributeDict(updated_args), AttributeDict(used_args), AttributeDict(unused_args)
199
+
200
+
201
+ def load_model(checkpoint_id, weights_file=None, strict=True, model_args='from_config', with_config=False):
202
+
203
+ config = json.load(open(join('logs', checkpoint_id, 'config.json')))
204
+
205
+ if model_args != 'from_config' and type(model_args) != dict:
206
+ raise ValueError('model_args must either be "from_config" or a dictionary of values')
207
+
208
+ model_cls = get_attribute(config['model'])
209
+
210
+ # load model
211
+ if model_args == 'from_config':
212
+ _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters)
213
+
214
+ model = model_cls(**model_args)
215
+
216
+ if weights_file is None:
217
+ weights_file = realpath(join('logs', checkpoint_id, 'weights.pth'))
218
+ else:
219
+ weights_file = realpath(join('logs', checkpoint_id, weights_file))
220
+
221
+ if isfile(weights_file):
222
+ weights = torch.load(weights_file)
223
+ for _, w in weights.items():
224
+ assert not torch.any(torch.isnan(w)), 'weights contain NaNs'
225
+ model.load_state_dict(weights, strict=strict)
226
+ else:
227
+ raise FileNotFoundError(f'model checkpoint {weights_file} was not found')
228
+
229
+ if with_config:
230
+ return model, config
231
+
232
+ return model
233
+
234
+
235
+ class TrainingLogger(object):
236
+
237
+ def __init__(self, model, log_dir, config=None, *args):
238
+ super().__init__()
239
+ self.model = model
240
+ self.base_path = join(f'logs/{log_dir}') if log_dir is not None else None
241
+
242
+ os.makedirs('logs/', exist_ok=True)
243
+ os.makedirs(self.base_path, exist_ok=True)
244
+
245
+ if config is not None:
246
+ json.dump(config, open(join(self.base_path, 'config.json'), 'w'))
247
+
248
+ def iter(self, i, **kwargs):
249
+ if i % 100 == 0 and 'loss' in kwargs:
250
+ loss = kwargs['loss']
251
+ print(f'iteration {i}: loss {loss:.4f}')
252
+
253
+ def save_weights(self, only_trainable=False, weight_file='weights.pth'):
254
+ if self.model is None:
255
+ raise AttributeError('You need to provide a model reference when initializing TrainingTracker to save weights.')
256
+
257
+ weights_path = join(self.base_path, weight_file)
258
+
259
+ weight_dict = self.model.state_dict()
260
+
261
+ if only_trainable:
262
+ weight_dict = {n: weight_dict[n] for n, p in self.model.named_parameters() if p.requires_grad}
263
+
264
+ torch.save(weight_dict, weights_path)
265
+ log.info(f'Saved weights to {weights_path}')
266
+
267
+ def __enter__(self):
268
+ return self
269
+
270
+ def __exit__(self, type, value, traceback):
271
+ """ automatically stop processes if used in a context manager """
272
+ pass
clipseg/metrics.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.functional import Tensor
2
+ from general_utils import log
3
+ from collections import defaultdict
4
+ import numpy as np
5
+
6
+ import torch
7
+ from torch.nn import functional as nnf
8
+
9
+
10
+ class BaseMetric(object):
11
+
12
+ def __init__(self, metric_names, pred_range=None, gt_index=0, pred_index=0, eval_intermediate=True,
13
+ eval_validation=True):
14
+ self._names = tuple(metric_names)
15
+ self._eval_intermediate = eval_intermediate
16
+ self._eval_validation = eval_validation
17
+
18
+ self._pred_range = pred_range
19
+ self._pred_index = pred_index
20
+ self._gt_index = gt_index
21
+
22
+ self.predictions = []
23
+ self.ground_truths = []
24
+
25
+ def eval_intermediate(self):
26
+ return self._eval_intermediate
27
+
28
+ def eval_validation(self):
29
+ return self._eval_validation
30
+
31
+ def names(self):
32
+ return self._names
33
+
34
+ def add(self, predictions, ground_truth):
35
+ raise NotImplementedError
36
+
37
+ def value(self):
38
+ raise NotImplementedError
39
+
40
+ def scores(self):
41
+ # similar to value but returns dict
42
+ value = self.value()
43
+ if type(value) == dict:
44
+ return value
45
+ else:
46
+ assert type(value) in {list, tuple}
47
+ return list(zip(self.names(), self.value()))
48
+
49
+ def _get_pred_gt(self, predictions, ground_truth):
50
+ pred = predictions[self._pred_index]
51
+ gt = ground_truth[self._gt_index]
52
+
53
+ if self._pred_range is not None:
54
+ pred = pred[:, self._pred_range[0]: self._pred_range[1]]
55
+
56
+ return pred, gt
57
+
58
+
59
+ class FixedIntervalMetrics(BaseMetric):
60
+
61
+ def __init__(self, sigmoid=False, ignore_mask=False, resize_to=None,
62
+ resize_pred=None, n_values=51, custom_threshold=None):
63
+
64
+
65
+ super().__init__(('ap', 'best_fgiou', 'best_miou', 'fgiou0.5', 'fgiou0.1', 'mean_iou_0p5', 'mean_iou_0p1', 'best_biniou', 'biniou_0.5', 'fgiou_thresh'))
66
+ self.intersections = []
67
+ self.unions = []
68
+ # self.threshold = threshold
69
+ self.sigmoid = sigmoid
70
+ self.resize_to = resize_to
71
+ self.resize_pred = resize_pred # resize prediction to match ground truth
72
+ self.class_count = defaultdict(lambda: 0)
73
+ self.per_class = defaultdict(lambda : [0,0])
74
+ self.ignore_mask = ignore_mask
75
+ self.custom_threshold = custom_threshold
76
+
77
+ self.scores_ap = []
78
+ self.scores_iou = []
79
+ self.gts, self.preds = [], []
80
+ self.classes = []
81
+
82
+ # [1:-1] ignores 0 and 1
83
+ self.threshold_values = np.linspace(0, 1, n_values)[1:-1]
84
+
85
+ self.metrics = dict(tp=[], fp=[], fn=[], tn=[])
86
+
87
+ def add(self, pred, gt):
88
+
89
+ pred_batch = pred[0].cpu()
90
+
91
+ if self.sigmoid:
92
+ pred_batch = torch.sigmoid(pred_batch)
93
+
94
+ gt_batch = gt[0].cpu()
95
+ mask_batch = gt[1] if len(gt) > 1 and not self.ignore_mask and gt[1].numel() > 0 else ([None] * len(pred_batch))
96
+ cls_batch = gt[2] if len(gt) > 2 else [None] * len(pred_batch)
97
+
98
+ if self.resize_to is not None:
99
+ gt_batch = nnf.interpolate(gt_batch, self.resize_to, mode='nearest')
100
+ pred_batch = nnf.interpolate(pred_batch, self.resize_to, mode='bilinear', align_corners=False)
101
+
102
+ if isinstance(cls_batch, torch.Tensor):
103
+ cls_batch = cls_batch.cpu().numpy().tolist()
104
+
105
+ assert len(gt_batch) == len(pred_batch) == len(cls_batch), f'{len(gt_batch)} {len(pred_batch)} {len(cls_batch)}'
106
+
107
+ for predictions, ground_truth, mask, cls in zip(pred_batch, gt_batch, mask_batch, cls_batch):
108
+
109
+ if self.resize_pred:
110
+ predictions = nnf.interpolate(predictions.unsqueeze(0).float(), size=ground_truth.size()[-2:], mode='bilinear', align_corners=True)
111
+
112
+ p = predictions.flatten()
113
+ g = ground_truth.flatten()
114
+
115
+ assert len(p) == len(g)
116
+
117
+ if mask is not None:
118
+ m = mask.flatten().bool()
119
+ p = p[m]
120
+ g = g[m]
121
+
122
+ p_sorted = p.sort()
123
+ p = p_sorted.values
124
+ g = g[p_sorted.indices]
125
+
126
+ tps, fps, fns, tns = [], [], [], []
127
+ for thresh in self.threshold_values:
128
+
129
+ valid = torch.where(p > thresh)[0]
130
+ if len(valid) > 0:
131
+ n = int(valid[0])
132
+ else:
133
+ n = len(g)
134
+
135
+ fn = int(g[:n].sum())
136
+ tp = int(g[n:].sum())
137
+ fns += [fn]
138
+ tns += [n - fn]
139
+ tps += [tp]
140
+ fps += [len(g) - n - tp]
141
+
142
+ self.metrics['tp'] += [tps]
143
+ self.metrics['fp'] += [fps]
144
+ self.metrics['fn'] += [fns]
145
+ self.metrics['tn'] += [tns]
146
+
147
+ self.classes += [cls.item() if isinstance(cls, torch.Tensor) else cls]
148
+
149
+ def value(self):
150
+
151
+ import time
152
+ t_start = time.time()
153
+
154
+ if set(self.classes) == set([None]):
155
+ all_classes = None
156
+ log.warning('classes were not provided, cannot compute mIoU')
157
+ else:
158
+ all_classes = set(int(c) for c in self.classes)
159
+ # log.info(f'compute metrics for {len(all_classes)} classes')
160
+
161
+ summed = {k: [sum([self.metrics[k][i][j]
162
+ for i in range(len(self.metrics[k]))])
163
+ for j in range(len(self.threshold_values))]
164
+ for k in self.metrics.keys()}
165
+
166
+ if all_classes is not None:
167
+
168
+ assert len(self.classes) == len(self.metrics['tp']) == len(self.metrics['fn'])
169
+ # group by class
170
+ metrics_by_class = {c: {k: [] for k in self.metrics.keys()} for c in all_classes}
171
+ for i in range(len(self.metrics['tp'])):
172
+ for k in self.metrics.keys():
173
+ metrics_by_class[self.classes[i]][k] += [self.metrics[k][i]]
174
+
175
+ # sum over all instances within the classes
176
+ summed_by_cls = {k: {c: np.array(metrics_by_class[c][k]).sum(0).tolist() for c in all_classes} for k in self.metrics.keys()}
177
+
178
+
179
+ # Compute average precision
180
+
181
+ assert (np.array(summed['fp']) + np.array(summed['tp']) ).sum(), 'no predictions is made'
182
+
183
+ # only consider values where a prediction is made
184
+ precisions = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j]) for j in range(len(self.threshold_values))
185
+ if summed['tp'][j] + summed['fp'][j] > 0]
186
+ recalls = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fn'][j]) for j in range(len(self.threshold_values))
187
+ if summed['tp'][j] + summed['fp'][j] > 0]
188
+
189
+ # remove duplicate recall-precision-pairs (and sort by recall value)
190
+ recalls, precisions = zip(*sorted(list(set(zip(recalls, precisions))), key=lambda x: x[0]))
191
+
192
+ from scipy.integrate import simps
193
+ ap = simps(precisions, recalls)
194
+
195
+ # Compute best IoU
196
+ fgiou_scores = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j] + summed['fn'][j]) for j in range(len(self.threshold_values))]
197
+
198
+ biniou_scores = [
199
+ 0.5*(summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j] + summed['fn'][j])) +
200
+ 0.5*(summed['tn'][j] / (1 + summed['tn'][j] + summed['fn'][j] + summed['fp'][j]))
201
+ for j in range(len(self.threshold_values))
202
+ ]
203
+
204
+ index_0p5 = self.threshold_values.tolist().index(0.5)
205
+ index_0p1 = self.threshold_values.tolist().index(0.1)
206
+ index_0p2 = self.threshold_values.tolist().index(0.2)
207
+ index_0p3 = self.threshold_values.tolist().index(0.3)
208
+
209
+ if self.custom_threshold is not None:
210
+ index_ct = self.threshold_values.tolist().index(self.custom_threshold)
211
+
212
+ if all_classes is not None:
213
+ # mean IoU
214
+ mean_ious = [np.mean([summed_by_cls['tp'][c][j] / (1 + summed_by_cls['tp'][c][j] + summed_by_cls['fp'][c][j] + summed_by_cls['fn'][c][j])
215
+ for c in all_classes])
216
+ for j in range(len(self.threshold_values))]
217
+
218
+ mean_iou_dict = {
219
+ 'miou_best': max(mean_ious) if all_classes is not None else None,
220
+ 'miou_0.5': mean_ious[index_0p5] if all_classes is not None else None,
221
+ 'miou_0.1': mean_ious[index_0p1] if all_classes is not None else None,
222
+ 'miou_0.2': mean_ious[index_0p2] if all_classes is not None else None,
223
+ 'miou_0.3': mean_ious[index_0p3] if all_classes is not None else None,
224
+ 'miou_best_t': self.threshold_values[np.argmax(mean_ious)],
225
+ 'mean_iou_ct': mean_ious[index_ct] if all_classes is not None and self.custom_threshold is not None else None,
226
+ 'mean_iou_scores': mean_ious,
227
+ }
228
+
229
+ print(f'metric computation on {(len(all_classes) if all_classes is not None else "no")} classes took {time.time() - t_start:.1f}s')
230
+
231
+ return {
232
+ 'ap': ap,
233
+
234
+ # fgiou
235
+ 'fgiou_best': max(fgiou_scores),
236
+ 'fgiou_0.5': fgiou_scores[index_0p5],
237
+ 'fgiou_0.1': fgiou_scores[index_0p1],
238
+ 'fgiou_0.2': fgiou_scores[index_0p2],
239
+ 'fgiou_0.3': fgiou_scores[index_0p3],
240
+ 'fgiou_best_t': self.threshold_values[np.argmax(fgiou_scores)],
241
+
242
+ # mean iou
243
+
244
+
245
+ # biniou
246
+ 'biniou_best': max(biniou_scores),
247
+ 'biniou_0.5': biniou_scores[index_0p5],
248
+ 'biniou_0.1': biniou_scores[index_0p1],
249
+ 'biniou_0.2': biniou_scores[index_0p2],
250
+ 'biniou_0.3': biniou_scores[index_0p3],
251
+ 'biniou_best_t': self.threshold_values[np.argmax(biniou_scores)],
252
+
253
+ # custom threshold
254
+ 'fgiou_ct': fgiou_scores[index_ct] if self.custom_threshold is not None else None,
255
+ 'biniou_ct': biniou_scores[index_ct] if self.custom_threshold is not None else None,
256
+ 'ct': self.custom_threshold,
257
+
258
+ # statistics
259
+ 'fgiou_scores': fgiou_scores,
260
+ 'biniou_scores': biniou_scores,
261
+ 'precision_recall_curve': sorted(list(set(zip(recalls, precisions)))),
262
+ 'summed_statistics': summed,
263
+ 'summed_by_cls_statistics': summed_by_cls,
264
+
265
+ **mean_iou_dict
266
+ }
267
+
268
+ # ('ap', 'best_fgiou', 'best_miou', 'fgiou0.5', 'fgiou0.1', 'mean_iou_0p5', 'mean_iou_0p1', 'best_biniou', 'biniou_0.5', 'fgiou_thresh'
269
+
270
+ # return ap, best_fgiou, best_mean_iou, iou_0p5, iou_0p1, mean_iou_0p5, mean_iou_0p1, best_biniou, biniou0p5, best_fgiou_thresh, {'summed': summed, 'summed_by_cls': summed_by_cls}
271
+
clipseg/models/clipseg.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from os.path import basename, dirname, join, isfile
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as nnf
6
+ from torch.nn.modules.activation import ReLU
7
+
8
+
9
+ def precompute_clip_vectors():
10
+
11
+ from trails.initialization import init_dataset
12
+ lvis = init_dataset('LVIS_OneShot3', split='train', mask='text_label', image_size=224, aug=1, normalize=True,
13
+ reduce_factor=None, add_bar=False, negative_prob=0.5)
14
+
15
+ all_names = list(lvis.category_names.values())
16
+
17
+ import clip
18
+ from models.clip_prompts import imagenet_templates
19
+ clip_model = clip.load("ViT-B/32", device='cuda', jit=False)[0]
20
+ prompt_vectors = {}
21
+ for name in all_names[:100]:
22
+ with torch.no_grad():
23
+ conditionals = [t.format(name).replace('_', ' ') for t in imagenet_templates]
24
+ text_tokens = clip.tokenize(conditionals).cuda()
25
+ cond = clip_model.encode_text(text_tokens).cpu()
26
+
27
+ for cond, vec in zip(conditionals, cond):
28
+ prompt_vectors[cond] = vec.cpu()
29
+
30
+ import pickle
31
+
32
+ pickle.dump(prompt_vectors, open('precomputed_prompt_vectors.pickle', 'wb'))
33
+
34
+
35
+ def get_prompt_list(prompt):
36
+ if prompt == 'plain':
37
+ return ['{}']
38
+ elif prompt == 'fixed':
39
+ return ['a photo of a {}.']
40
+ elif prompt == 'shuffle':
41
+ return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.']
42
+ elif prompt == 'shuffle+':
43
+ return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.',
44
+ 'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.',
45
+ 'a bad photo of a {}.', 'a photo of the {}.']
46
+ elif prompt == 'shuffle_clip':
47
+ from models.clip_prompts import imagenet_templates
48
+ return imagenet_templates
49
+ else:
50
+ raise ValueError('Invalid value for prompt')
51
+
52
+
53
+ def forward_multihead_attention(x, b, with_aff=False, attn_mask=None):
54
+ """
55
+ Simplified version of multihead attention (taken from torch source code but without tons of if clauses).
56
+ The mlp and layer norm come from CLIP.
57
+ x: input.
58
+ b: multihead attention module.
59
+ """
60
+
61
+ x_ = b.ln_1(x)
62
+ q, k, v = nnf.linear(x_, b.attn.in_proj_weight, b.attn.in_proj_bias).chunk(3, dim=-1)
63
+ tgt_len, bsz, embed_dim = q.size()
64
+
65
+ head_dim = embed_dim // b.attn.num_heads
66
+ scaling = float(head_dim) ** -0.5
67
+
68
+ q = q.contiguous().view(tgt_len, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1)
69
+ k = k.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1)
70
+ v = v.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1)
71
+
72
+ q = q * scaling
73
+
74
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2)) # n_heads * batch_size, tokens^2, tokens^2
75
+ if attn_mask is not None:
76
+
77
+
78
+ attn_mask_type, attn_mask = attn_mask
79
+ n_heads = attn_output_weights.size(0) // attn_mask.size(0)
80
+ attn_mask = attn_mask.repeat(n_heads, 1)
81
+
82
+ if attn_mask_type == 'cls_token':
83
+ # the mask only affects similarities compared to the readout-token.
84
+ attn_output_weights[:, 0, 1:] = attn_output_weights[:, 0, 1:] * attn_mask[None,...]
85
+ # attn_output_weights[:, 0, 0] = 0*attn_output_weights[:, 0, 0]
86
+
87
+ if attn_mask_type == 'all':
88
+ # print(attn_output_weights.shape, attn_mask[:, None].shape)
89
+ attn_output_weights[:, 1:, 1:] = attn_output_weights[:, 1:, 1:] * attn_mask[:, None]
90
+
91
+
92
+ attn_output_weights = torch.softmax(attn_output_weights, dim=-1)
93
+
94
+ attn_output = torch.bmm(attn_output_weights, v)
95
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
96
+ attn_output = b.attn.out_proj(attn_output)
97
+
98
+ x = x + attn_output
99
+ x = x + b.mlp(b.ln_2(x))
100
+
101
+ if with_aff:
102
+ return x, attn_output_weights
103
+ else:
104
+ return x
105
+
106
+
107
+ class CLIPDenseBase(nn.Module):
108
+
109
+ def __init__(self, version, reduce_cond, reduce_dim, prompt, n_tokens):
110
+ super().__init__()
111
+
112
+ import clip
113
+
114
+ # prec = torch.FloatTensor
115
+ self.clip_model, _ = clip.load(version, device='cpu', jit=False)
116
+ self.model = self.clip_model.visual
117
+
118
+ # if not None, scale conv weights such that we obtain n_tokens.
119
+ self.n_tokens = n_tokens
120
+
121
+ for p in self.clip_model.parameters():
122
+ p.requires_grad_(False)
123
+
124
+ # conditional
125
+ if reduce_cond is not None:
126
+ self.reduce_cond = nn.Linear(512, reduce_cond)
127
+ for p in self.reduce_cond.parameters():
128
+ p.requires_grad_(False)
129
+ else:
130
+ self.reduce_cond = None
131
+
132
+ self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
133
+ self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
134
+
135
+ self.reduce = nn.Linear(768, reduce_dim)
136
+
137
+ self.prompt_list = get_prompt_list(prompt)
138
+
139
+ # precomputed prompts
140
+ import pickle
141
+ if isfile('precomputed_prompt_vectors.pickle'):
142
+ precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb'))
143
+ self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()}
144
+ else:
145
+ self.precomputed_prompts = dict()
146
+
147
+ def rescaled_pos_emb(self, new_size):
148
+ assert len(new_size) == 2
149
+
150
+ a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape)
151
+ b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T
152
+ return torch.cat([self.model.positional_embedding[:1], b])
153
+
154
+ def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None):
155
+
156
+
157
+ with torch.no_grad():
158
+
159
+ inp_size = x_inp.shape[2:]
160
+
161
+ if self.n_tokens is not None:
162
+ stride2 = x_inp.shape[2] // self.n_tokens
163
+ conv_weight2 = nnf.interpolate(self.model.conv1.weight, (stride2, stride2), mode='bilinear', align_corners=True)
164
+ x = nnf.conv2d(x_inp, conv_weight2, bias=self.model.conv1.bias, stride=stride2, dilation=self.model.conv1.dilation)
165
+ else:
166
+ x = self.model.conv1(x_inp) # shape = [*, width, grid, grid]
167
+
168
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
169
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
170
+
171
+ x = torch.cat([self.model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
172
+
173
+ standard_n_tokens = 50 if self.model.conv1.kernel_size[0] == 32 else 197
174
+
175
+ if x.shape[1] != standard_n_tokens:
176
+ new_shape = int(math.sqrt(x.shape[1]-1))
177
+ x = x + self.rescaled_pos_emb((new_shape, new_shape)).to(x.dtype)[None,:,:]
178
+ else:
179
+ x = x + self.model.positional_embedding.to(x.dtype)
180
+
181
+ x = self.model.ln_pre(x)
182
+
183
+ x = x.permute(1, 0, 2) # NLD -> LND
184
+
185
+ activations, affinities = [], []
186
+ for i, res_block in enumerate(self.model.transformer.resblocks):
187
+
188
+ if mask is not None:
189
+ mask_layer, mask_type, mask_tensor = mask
190
+ if mask_layer == i or mask_layer == 'all':
191
+ # import ipdb; ipdb.set_trace()
192
+ size = int(math.sqrt(x.shape[0] - 1))
193
+
194
+ attn_mask = (mask_type, nnf.interpolate(mask_tensor.unsqueeze(1).float(), (size, size)).view(mask_tensor.shape[0], size * size))
195
+
196
+ else:
197
+ attn_mask = None
198
+ else:
199
+ attn_mask = None
200
+
201
+ x, aff_per_head = forward_multihead_attention(x, res_block, with_aff=True, attn_mask=attn_mask)
202
+
203
+ if i in extract_layers:
204
+ affinities += [aff_per_head]
205
+
206
+ #if self.n_tokens is not None:
207
+ # activations += [nnf.interpolate(x, inp_size, mode='bilinear', align_corners=True)]
208
+ #else:
209
+ activations += [x]
210
+
211
+ if len(extract_layers) > 0 and i == max(extract_layers) and skip:
212
+ print('early skip')
213
+ break
214
+
215
+ x = x.permute(1, 0, 2) # LND -> NLD
216
+ x = self.model.ln_post(x[:, 0, :])
217
+
218
+ if self.model.proj is not None:
219
+ x = x @ self.model.proj
220
+
221
+ return x, activations, affinities
222
+
223
+ def sample_prompts(self, words, prompt_list=None):
224
+
225
+ prompt_list = prompt_list if prompt_list is not None else self.prompt_list
226
+
227
+ prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
228
+ prompts = [prompt_list[i] for i in prompt_indices]
229
+ return [promt.format(w) for promt, w in zip(prompts, words)]
230
+
231
+ def get_cond_vec(self, conditional, batch_size):
232
+ # compute conditional from a single string
233
+ if conditional is not None and type(conditional) == str:
234
+ cond = self.compute_conditional(conditional)
235
+ cond = cond.repeat(batch_size, 1)
236
+
237
+ # compute conditional from string list/tuple
238
+ elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str:
239
+ assert len(conditional) == batch_size
240
+ cond = self.compute_conditional(conditional)
241
+
242
+ # use conditional directly
243
+ elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2:
244
+ cond = conditional
245
+
246
+ # compute conditional from image
247
+ elif conditional is not None and type(conditional) == torch.Tensor:
248
+ with torch.no_grad():
249
+ cond, _, _ = self.visual_forward(conditional)
250
+ else:
251
+ raise ValueError('invalid conditional')
252
+ return cond
253
+
254
+ def compute_conditional(self, conditional):
255
+ import clip
256
+
257
+ dev = next(self.parameters()).device
258
+
259
+ if type(conditional) in {list, tuple}:
260
+ text_tokens = clip.tokenize(conditional).to(dev)
261
+ cond = self.clip_model.encode_text(text_tokens)
262
+ else:
263
+ if conditional in self.precomputed_prompts:
264
+ cond = self.precomputed_prompts[conditional].float().to(dev)
265
+ else:
266
+ text_tokens = clip.tokenize([conditional]).to(dev)
267
+ cond = self.clip_model.encode_text(text_tokens)[0]
268
+
269
+ if self.shift_vector is not None:
270
+ return cond + self.shift_vector
271
+ else:
272
+ return cond
273
+
274
+
275
+ def clip_load_untrained(version):
276
+ assert version == 'ViT-B/16'
277
+ from clip.model import CLIP
278
+ from clip.clip import _MODELS, _download
279
+ model = torch.jit.load(_download(_MODELS['ViT-B/16'])).eval()
280
+ state_dict = model.state_dict()
281
+
282
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
283
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
284
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
285
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
286
+ image_resolution = vision_patch_size * grid_size
287
+ embed_dim = state_dict["text_projection"].shape[1]
288
+ context_length = state_dict["positional_embedding"].shape[0]
289
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
290
+ transformer_width = state_dict["ln_final.weight"].shape[0]
291
+ transformer_heads = transformer_width // 64
292
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
293
+
294
+ return CLIP(embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size,
295
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers)
296
+
297
+
298
+ class CLIPDensePredT(CLIPDenseBase):
299
+
300
+ def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed',
301
+ extra_blocks=0, reduce_cond=None, fix_shift=False,
302
+ learn_trans_conv_only=False, limit_to_clip_only=False, upsample=False,
303
+ add_calibration=False, rev_activations=False, trans_conv=None, n_tokens=None):
304
+
305
+ super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens)
306
+ # device = 'cpu'
307
+
308
+ self.extract_layers = extract_layers
309
+ self.cond_layer = cond_layer
310
+ self.limit_to_clip_only = limit_to_clip_only
311
+ self.process_cond = None
312
+ self.rev_activations = rev_activations
313
+
314
+ depth = len(extract_layers)
315
+
316
+ if add_calibration:
317
+ self.calibration_conds = 1
318
+
319
+ self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None
320
+
321
+ self.add_activation1 = True
322
+
323
+ self.version = version
324
+
325
+ self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version]
326
+
327
+ if fix_shift:
328
+ # self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'clip_text_shift_vector.pth')), requires_grad=False)
329
+ self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'shift_text_to_vis.pth')), requires_grad=False)
330
+ # self.shift_vector = nn.Parameter(-1*torch.load(join(dirname(basename(__file__)), 'shift2.pth')), requires_grad=False)
331
+ else:
332
+ self.shift_vector = None
333
+
334
+ if trans_conv is None:
335
+ trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version]
336
+ else:
337
+ # explicitly define transposed conv kernel size
338
+ trans_conv_ks = (trans_conv, trans_conv)
339
+
340
+ self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
341
+
342
+ assert len(self.extract_layers) == depth
343
+
344
+ self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)])
345
+ self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))])
346
+ self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)])
347
+
348
+ # refinement and trans conv
349
+
350
+ if learn_trans_conv_only:
351
+ for p in self.parameters():
352
+ p.requires_grad_(False)
353
+
354
+ for p in self.trans_conv.parameters():
355
+ p.requires_grad_(True)
356
+
357
+ self.prompt_list = get_prompt_list(prompt)
358
+
359
+
360
+ def forward(self, inp_image, conditional=None, return_features=False, mask=None):
361
+
362
+ assert type(return_features) == bool
363
+
364
+ inp_image = inp_image.to(self.model.positional_embedding.device)
365
+
366
+ if mask is not None:
367
+ raise ValueError('mask not supported')
368
+
369
+ # x_inp = normalize(inp_image)
370
+ x_inp = inp_image
371
+
372
+ bs, dev = inp_image.shape[0], x_inp.device
373
+
374
+ cond = self.get_cond_vec(conditional, bs)
375
+
376
+ visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers))
377
+
378
+ activation1 = activations[0]
379
+ activations = activations[1:]
380
+
381
+ _activations = activations[::-1] if not self.rev_activations else activations
382
+
383
+ a = None
384
+ for i, (activation, block, reduce) in enumerate(zip(_activations, self.blocks, self.reduces)):
385
+
386
+ if a is not None:
387
+ a = reduce(activation) + a
388
+ else:
389
+ a = reduce(activation)
390
+
391
+ if i == self.cond_layer:
392
+ if self.reduce_cond is not None:
393
+ cond = self.reduce_cond(cond)
394
+
395
+ a = self.film_mul(cond) * a + self.film_add(cond)
396
+
397
+ a = block(a)
398
+
399
+ for block in self.extra_blocks:
400
+ a = a + block(a)
401
+
402
+ a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
403
+
404
+ size = int(math.sqrt(a.shape[2]))
405
+
406
+ a = a.view(bs, a.shape[1], size, size)
407
+
408
+ a = self.trans_conv(a)
409
+
410
+ if self.n_tokens is not None:
411
+ a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear', align_corners=True)
412
+
413
+ if self.upsample_proj is not None:
414
+ a = self.upsample_proj(a)
415
+ a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear')
416
+
417
+ if return_features:
418
+ return a, visual_q, cond, [activation1] + activations
419
+ else:
420
+ return a,
421
+
422
+
423
+
424
+ class CLIPDensePredTMasked(CLIPDensePredT):
425
+
426
+ def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4,
427
+ prompt='fixed', extra_blocks=0, reduce_cond=None, fix_shift=False, learn_trans_conv_only=False,
428
+ refine=None, limit_to_clip_only=False, upsample=False, add_calibration=False, n_tokens=None):
429
+
430
+ super().__init__(version=version, extract_layers=extract_layers, cond_layer=cond_layer, reduce_dim=reduce_dim,
431
+ n_heads=n_heads, prompt=prompt, extra_blocks=extra_blocks, reduce_cond=reduce_cond,
432
+ fix_shift=fix_shift, learn_trans_conv_only=learn_trans_conv_only,
433
+ limit_to_clip_only=limit_to_clip_only, upsample=upsample, add_calibration=add_calibration,
434
+ n_tokens=n_tokens)
435
+
436
+ def visual_forward_masked(self, img_s, seg_s):
437
+ return super().visual_forward(img_s, mask=('all', 'cls_token', seg_s))
438
+
439
+ def forward(self, img_q, cond_or_img_s, seg_s=None, return_features=False):
440
+
441
+ if seg_s is None:
442
+ cond = cond_or_img_s
443
+ else:
444
+ img_s = cond_or_img_s
445
+
446
+ with torch.no_grad():
447
+ cond, _, _ = self.visual_forward_masked(img_s, seg_s)
448
+
449
+ return super().forward(img_q, cond, return_features=return_features)
450
+
451
+
452
+
453
+ class CLIPDenseBaseline(CLIPDenseBase):
454
+
455
+ def __init__(self, version='ViT-B/32', cond_layer=0,
456
+ extract_layer=9, reduce_dim=128, reduce2_dim=None, prompt='fixed',
457
+ reduce_cond=None, limit_to_clip_only=False, n_tokens=None):
458
+
459
+ super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens)
460
+ device = 'cpu'
461
+
462
+ # self.cond_layer = cond_layer
463
+ self.extract_layer = extract_layer
464
+ self.limit_to_clip_only = limit_to_clip_only
465
+ self.shift_vector = None
466
+
467
+ self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version]
468
+
469
+ assert reduce2_dim is not None
470
+
471
+ self.reduce2 = nn.Sequential(
472
+ nn.Linear(reduce_dim, reduce2_dim),
473
+ nn.ReLU(),
474
+ nn.Linear(reduce2_dim, reduce_dim)
475
+ )
476
+
477
+ trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version]
478
+ self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
479
+
480
+
481
+ def forward(self, inp_image, conditional=None, return_features=False):
482
+
483
+ inp_image = inp_image.to(self.model.positional_embedding.device)
484
+
485
+ # x_inp = normalize(inp_image)
486
+ x_inp = inp_image
487
+
488
+ bs, dev = inp_image.shape[0], x_inp.device
489
+
490
+ cond = self.get_cond_vec(conditional, bs)
491
+
492
+ visual_q, activations, affinities = self.visual_forward(x_inp, extract_layers=[self.extract_layer])
493
+
494
+ a = activations[0]
495
+ a = self.reduce(a)
496
+ a = self.film_mul(cond) * a + self.film_add(cond)
497
+
498
+ if self.reduce2 is not None:
499
+ a = self.reduce2(a)
500
+
501
+ # the original model would execute a transformer block here
502
+
503
+ a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
504
+
505
+ size = int(math.sqrt(a.shape[2]))
506
+
507
+ a = a.view(bs, a.shape[1], size, size)
508
+ a = self.trans_conv(a)
509
+
510
+ if return_features:
511
+ return a, visual_q, cond, activations
512
+ else:
513
+ return a,
514
+
515
+
516
+ class CLIPSegMultiLabel(nn.Module):
517
+
518
+ def __init__(self, model) -> None:
519
+ super().__init__()
520
+
521
+ from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC
522
+
523
+ self.pascal_classes = VOC
524
+
525
+ from models.clipseg import CLIPDensePredT
526
+ from general_utils import load_model
527
+ # self.clipseg = load_model('rd64-vit16-neg0.2-phrasecut', strict=False)
528
+ self.clipseg = load_model(model, strict=False)
529
+
530
+ self.clipseg.eval()
531
+
532
+ def forward(self, x):
533
+
534
+ bs = x.shape[0]
535
+ out = torch.ones(21, bs, 352, 352).to(x.device) * -10
536
+
537
+ for class_id, class_name in enumerate(self.pascal_classes):
538
+
539
+ fac = 3 if class_name == 'background' else 1
540
+
541
+ with torch.no_grad():
542
+ pred = torch.sigmoid(self.clipseg(x, class_name)[0][:,0]) * fac
543
+
544
+ out[class_id] += pred
545
+
546
+
547
+ out = out.permute(1, 0, 2, 3)
548
+
549
+ return out
550
+
551
+ # construct output tensor
552
+
clipseg/models/vitseg.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from posixpath import basename, dirname, join
3
+ # import clip
4
+ from clip.model import convert_weights
5
+ import torch
6
+ import json
7
+ from torch import nn
8
+ from torch.nn import functional as nnf
9
+ from torch.nn.modules import activation
10
+ from torch.nn.modules.activation import ReLU
11
+ from torchvision import transforms
12
+
13
+ normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
14
+
15
+ from torchvision.models import ResNet
16
+
17
+
18
+ def process_prompts(conditional, prompt_list, conditional_map):
19
+ # DEPRECATED
20
+
21
+ # randomly sample a synonym
22
+ words = [conditional_map[int(i)] for i in conditional]
23
+ words = [syns[torch.multinomial(torch.ones(len(syns)), 1, replacement=True).item()] for syns in words]
24
+ words = [w.replace('_', ' ') for w in words]
25
+
26
+ if prompt_list is not None:
27
+ prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
28
+ prompts = [prompt_list[i] for i in prompt_indices]
29
+ else:
30
+ prompts = ['a photo of {}'] * (len(words))
31
+
32
+ return [promt.format(w) for promt, w in zip(prompts, words)]
33
+
34
+
35
+ class VITDenseBase(nn.Module):
36
+
37
+ def rescaled_pos_emb(self, new_size):
38
+ assert len(new_size) == 2
39
+
40
+ a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape)
41
+ b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T
42
+ return torch.cat([self.model.positional_embedding[:1], b])
43
+
44
+ def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None):
45
+
46
+ with torch.no_grad():
47
+
48
+ x_inp = nnf.interpolate(x_inp, (384, 384))
49
+
50
+ x = self.model.patch_embed(x_inp)
51
+ cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
52
+ if self.model.dist_token is None:
53
+ x = torch.cat((cls_token, x), dim=1)
54
+ else:
55
+ x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
56
+ x = self.model.pos_drop(x + self.model.pos_embed)
57
+
58
+ activations = []
59
+ for i, block in enumerate(self.model.blocks):
60
+ x = block(x)
61
+
62
+ if i in extract_layers:
63
+ # permute to be compatible with CLIP
64
+ activations += [x.permute(1,0,2)]
65
+
66
+ x = self.model.norm(x)
67
+ x = self.model.head(self.model.pre_logits(x[:, 0]))
68
+
69
+ # again for CLIP compatibility
70
+ # x = x.permute(1, 0, 2)
71
+
72
+ return x, activations, None
73
+
74
+ def sample_prompts(self, words, prompt_list=None):
75
+
76
+ prompt_list = prompt_list if prompt_list is not None else self.prompt_list
77
+
78
+ prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
79
+ prompts = [prompt_list[i] for i in prompt_indices]
80
+ return [promt.format(w) for promt, w in zip(prompts, words)]
81
+
82
+ def get_cond_vec(self, conditional, batch_size):
83
+ # compute conditional from a single string
84
+ if conditional is not None and type(conditional) == str:
85
+ cond = self.compute_conditional(conditional)
86
+ cond = cond.repeat(batch_size, 1)
87
+
88
+ # compute conditional from string list/tuple
89
+ elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str:
90
+ assert len(conditional) == batch_size
91
+ cond = self.compute_conditional(conditional)
92
+
93
+ # use conditional directly
94
+ elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2:
95
+ cond = conditional
96
+
97
+ # compute conditional from image
98
+ elif conditional is not None and type(conditional) == torch.Tensor:
99
+ with torch.no_grad():
100
+ cond, _, _ = self.visual_forward(conditional)
101
+ else:
102
+ raise ValueError('invalid conditional')
103
+ return cond
104
+
105
+ def compute_conditional(self, conditional):
106
+ import clip
107
+
108
+ dev = next(self.parameters()).device
109
+
110
+ if type(conditional) in {list, tuple}:
111
+ text_tokens = clip.tokenize(conditional).to(dev)
112
+ cond = self.clip_model.encode_text(text_tokens)
113
+ else:
114
+ if conditional in self.precomputed_prompts:
115
+ cond = self.precomputed_prompts[conditional].float().to(dev)
116
+ else:
117
+ text_tokens = clip.tokenize([conditional]).to(dev)
118
+ cond = self.clip_model.encode_text(text_tokens)[0]
119
+
120
+ return cond
121
+
122
+
123
+ class VITDensePredT(VITDenseBase):
124
+
125
+ def __init__(self, extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed',
126
+ depth=3, extra_blocks=0, reduce_cond=None, fix_shift=False,
127
+ learn_trans_conv_only=False, refine=None, limit_to_clip_only=False, upsample=False,
128
+ add_calibration=False, process_cond=None, not_pretrained=False):
129
+ super().__init__()
130
+ # device = 'cpu'
131
+
132
+ self.extract_layers = extract_layers
133
+ self.cond_layer = cond_layer
134
+ self.limit_to_clip_only = limit_to_clip_only
135
+ self.process_cond = None
136
+
137
+ if add_calibration:
138
+ self.calibration_conds = 1
139
+
140
+ self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None
141
+
142
+ self.add_activation1 = True
143
+
144
+ import timm
145
+ self.model = timm.create_model('vit_base_patch16_384', pretrained=True)
146
+ self.model.head = nn.Linear(768, 512 if reduce_cond is None else reduce_cond)
147
+
148
+ for p in self.model.parameters():
149
+ p.requires_grad_(False)
150
+
151
+ import clip
152
+ self.clip_model, _ = clip.load('ViT-B/16', device='cpu', jit=False)
153
+ # del self.clip_model.visual
154
+
155
+
156
+ self.token_shape = (14, 14)
157
+
158
+ # conditional
159
+ if reduce_cond is not None:
160
+ self.reduce_cond = nn.Linear(512, reduce_cond)
161
+ for p in self.reduce_cond.parameters():
162
+ p.requires_grad_(False)
163
+ else:
164
+ self.reduce_cond = None
165
+
166
+ # self.film = AVAILABLE_BLOCKS['film'](512, 128)
167
+ self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
168
+ self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
169
+
170
+ # DEPRECATED
171
+ # self.conditional_map = {c['id']: c['synonyms'] for c in json.load(open(cond_map))}
172
+
173
+ assert len(self.extract_layers) == depth
174
+
175
+ self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)])
176
+ self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))])
177
+ self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)])
178
+
179
+ trans_conv_ks = (16, 16)
180
+ self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
181
+
182
+ # refinement and trans conv
183
+
184
+ if learn_trans_conv_only:
185
+ for p in self.parameters():
186
+ p.requires_grad_(False)
187
+
188
+ for p in self.trans_conv.parameters():
189
+ p.requires_grad_(True)
190
+
191
+ if prompt == 'fixed':
192
+ self.prompt_list = ['a photo of a {}.']
193
+ elif prompt == 'shuffle':
194
+ self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.']
195
+ elif prompt == 'shuffle+':
196
+ self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.',
197
+ 'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.',
198
+ 'a bad photo of a {}.', 'a photo of the {}.']
199
+ elif prompt == 'shuffle_clip':
200
+ from models.clip_prompts import imagenet_templates
201
+ self.prompt_list = imagenet_templates
202
+
203
+ if process_cond is not None:
204
+ if process_cond == 'clamp' or process_cond[0] == 'clamp':
205
+
206
+ val = process_cond[1] if type(process_cond) in {list, tuple} else 0.2
207
+
208
+ def clamp_vec(x):
209
+ return torch.clamp(x, -val, val)
210
+
211
+ self.process_cond = clamp_vec
212
+
213
+ elif process_cond.endswith('.pth'):
214
+
215
+ shift = torch.load(process_cond)
216
+ def add_shift(x):
217
+ return x + shift.to(x.device)
218
+
219
+ self.process_cond = add_shift
220
+
221
+ import pickle
222
+ precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb'))
223
+ self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()}
224
+
225
+
226
+ def forward(self, inp_image, conditional=None, return_features=False, mask=None):
227
+
228
+ assert type(return_features) == bool
229
+
230
+ # inp_image = inp_image.to(self.model.positional_embedding.device)
231
+
232
+ if mask is not None:
233
+ raise ValueError('mask not supported')
234
+
235
+ # x_inp = normalize(inp_image)
236
+ x_inp = inp_image
237
+
238
+ bs, dev = inp_image.shape[0], x_inp.device
239
+
240
+ inp_image_size = inp_image.shape[2:]
241
+
242
+ cond = self.get_cond_vec(conditional, bs)
243
+
244
+ visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers))
245
+
246
+ activation1 = activations[0]
247
+ activations = activations[1:]
248
+
249
+ a = None
250
+ for i, (activation, block, reduce) in enumerate(zip(activations[::-1], self.blocks, self.reduces)):
251
+
252
+ if a is not None:
253
+ a = reduce(activation) + a
254
+ else:
255
+ a = reduce(activation)
256
+
257
+ if i == self.cond_layer:
258
+ if self.reduce_cond is not None:
259
+ cond = self.reduce_cond(cond)
260
+
261
+ a = self.film_mul(cond) * a + self.film_add(cond)
262
+
263
+ a = block(a)
264
+
265
+ for block in self.extra_blocks:
266
+ a = a + block(a)
267
+
268
+ a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
269
+
270
+ size = int(math.sqrt(a.shape[2]))
271
+
272
+ a = a.view(bs, a.shape[1], size, size)
273
+
274
+ if self.trans_conv is not None:
275
+ a = self.trans_conv(a)
276
+
277
+ if self.upsample_proj is not None:
278
+ a = self.upsample_proj(a)
279
+ a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear')
280
+
281
+ a = nnf.interpolate(a, inp_image_size)
282
+
283
+ if return_features:
284
+ return a, visual_q, cond, [activation1] + activations
285
+ else:
286
+ return a,
clipseg/overview.png ADDED
clipseg/score.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.functional import Tensor
2
+
3
+ import torch
4
+ import inspect
5
+ import json
6
+ import yaml
7
+ import time
8
+ import sys
9
+
10
+ from general_utils import log
11
+
12
+ import numpy as np
13
+ from os.path import expanduser, join, isfile, realpath
14
+
15
+ from torch.utils.data import DataLoader
16
+
17
+ from metrics import FixedIntervalMetrics
18
+
19
+ from general_utils import load_model, log, score_config_from_cli_args, AttributeDict, get_attribute, filter_args
20
+
21
+
22
+ DATASET_CACHE = dict()
23
+
24
+ def load_model(checkpoint_id, weights_file=None, strict=True, model_args='from_config', with_config=False, ignore_weights=False):
25
+
26
+ config = json.load(open(join('logs', checkpoint_id, 'config.json')))
27
+
28
+ if model_args != 'from_config' and type(model_args) != dict:
29
+ raise ValueError('model_args must either be "from_config" or a dictionary of values')
30
+
31
+ model_cls = get_attribute(config['model'])
32
+
33
+ # load model
34
+ if model_args == 'from_config':
35
+ _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters)
36
+
37
+ model = model_cls(**model_args)
38
+
39
+ if weights_file is None:
40
+ weights_file = realpath(join('logs', checkpoint_id, 'weights.pth'))
41
+ else:
42
+ weights_file = realpath(join('logs', checkpoint_id, weights_file))
43
+
44
+ if isfile(weights_file) and not ignore_weights:
45
+ weights = torch.load(weights_file)
46
+ for _, w in weights.items():
47
+ assert not torch.any(torch.isnan(w)), 'weights contain NaNs'
48
+ model.load_state_dict(weights, strict=strict)
49
+ else:
50
+ if not ignore_weights:
51
+ raise FileNotFoundError(f'model checkpoint {weights_file} was not found')
52
+
53
+ if with_config:
54
+ return model, config
55
+
56
+ return model
57
+
58
+
59
+ def compute_shift2(model, datasets, seed=123, repetitions=1):
60
+ """ computes shift """
61
+
62
+ model.eval()
63
+ model.cuda()
64
+
65
+ import random
66
+ random.seed(seed)
67
+
68
+ preds, gts = [], []
69
+ for i_dataset, dataset in enumerate(datasets):
70
+
71
+ loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False, drop_last=False)
72
+
73
+ max_iterations = int(repetitions * len(dataset.dataset.data_list))
74
+
75
+ with torch.no_grad():
76
+
77
+ i, losses = 0, []
78
+ for i_all, (data_x, data_y) in enumerate(loader):
79
+
80
+ data_x = [v.cuda(non_blocking=True) if v is not None else v for v in data_x]
81
+ data_y = [v.cuda(non_blocking=True) if v is not None else v for v in data_y]
82
+
83
+ pred, = model(data_x[0], data_x[1], data_x[2])
84
+ preds += [pred.detach()]
85
+ gts += [data_y]
86
+
87
+ i += 1
88
+ if max_iterations and i >= max_iterations:
89
+ break
90
+
91
+ from metrics import FixedIntervalMetrics
92
+ n_values = 51
93
+ thresholds = np.linspace(0, 1, n_values)[1:-1]
94
+ metric = FixedIntervalMetrics(resize_pred=True, sigmoid=True, n_values=n_values)
95
+
96
+ for p, y in zip(preds, gts):
97
+ metric.add(p.unsqueeze(1), y)
98
+
99
+ best_idx = np.argmax(metric.value()['fgiou_scores'])
100
+ best_thresh = thresholds[best_idx]
101
+
102
+ return best_thresh
103
+
104
+
105
+ def get_cached_pascal_pfe(split, config):
106
+ from datasets.pfe_dataset import PFEPascalWrapper
107
+ try:
108
+ dataset = DATASET_CACHE[(split, config.image_size, config.label_support, config.mask)]
109
+ except KeyError:
110
+ dataset = PFEPascalWrapper(mode='val', split=split, mask=config.mask, image_size=config.image_size, label_support=config.label_support)
111
+ DATASET_CACHE[(split, config.image_size, config.label_support, config.mask)] = dataset
112
+ return dataset
113
+
114
+
115
+
116
+
117
+ def main():
118
+ config, train_checkpoint_id = score_config_from_cli_args()
119
+
120
+ metrics = score(config, train_checkpoint_id, None)
121
+
122
+ for dataset in metrics.keys():
123
+ for k in metrics[dataset]:
124
+ if type(metrics[dataset][k]) in {float, int}:
125
+ print(dataset, f'{k:<16} {metrics[dataset][k]:.3f}')
126
+
127
+
128
+ def score(config, train_checkpoint_id, train_config):
129
+
130
+ config = AttributeDict(config)
131
+
132
+ print(config)
133
+
134
+ # use training dataset and loss
135
+ train_config = AttributeDict(json.load(open(f'logs/{train_checkpoint_id}/config.json')))
136
+
137
+ cp_str = f'_{config.iteration_cp}' if config.iteration_cp is not None else ''
138
+
139
+
140
+ model_cls = get_attribute(train_config['model'])
141
+
142
+ _, model_args, _ = filter_args(train_config, inspect.signature(model_cls).parameters)
143
+
144
+ model_args = {**model_args, **{k: config[k] for k in ['process_cond', 'fix_shift'] if k in config}}
145
+
146
+ strict_models = {'ConditionBase4', 'PFENetWrapper'}
147
+ model = load_model(train_checkpoint_id, strict=model_cls.__name__ in strict_models, model_args=model_args,
148
+ weights_file=f'weights{cp_str}.pth', )
149
+
150
+
151
+ model.eval()
152
+ model.cuda()
153
+
154
+ metric_args = dict()
155
+
156
+ if 'threshold' in config:
157
+ if config.metric.split('.')[-1] == 'SkLearnMetrics':
158
+ metric_args['threshold'] = config.threshold
159
+
160
+ if 'resize_to' in config:
161
+ metric_args['resize_to'] = config.resize_to
162
+
163
+ if 'sigmoid' in config:
164
+ metric_args['sigmoid'] = config.sigmoid
165
+
166
+ if 'custom_threshold' in config:
167
+ metric_args['custom_threshold'] = config.custom_threshold
168
+
169
+ if config.test_dataset == 'pascal':
170
+
171
+ loss_fn = get_attribute(train_config.loss)
172
+ # assume that if no split is specified in train_config, test on all splits,
173
+
174
+ if 'splits' in config:
175
+ splits = config.splits
176
+ else:
177
+ if 'split' in train_config and type(train_config.split) == int:
178
+ # unless train_config has a split set, in that case assume train mode in training
179
+ splits = [train_config.split]
180
+ assert train_config.mode == 'train'
181
+ else:
182
+ splits = [0,1,2,3]
183
+
184
+ log.info('Test on these splits', splits)
185
+
186
+ scores = dict()
187
+ for split in splits:
188
+
189
+ shift = config.shift if 'shift' in config else 0
190
+
191
+ # automatic shift
192
+ if shift == 'auto':
193
+ shift_compute_t = time.time()
194
+ shift = compute_shift2(model, [get_cached_pascal_pfe(s, config) for s in range(4) if s != split], repetitions=config.compute_shift_fac)
195
+ log.info(f'Best threshold is {shift}, computed on splits: {[s for s in range(4) if s != split]}, took {time.time() - shift_compute_t:.1f}s')
196
+
197
+ dataset = get_cached_pascal_pfe(split, config)
198
+
199
+ eval_start_t = time.time()
200
+
201
+ loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False, drop_last=False)
202
+
203
+ assert config.batch_size is None or config.batch_size == 1, 'When PFE Dataset is used, batch size must be 1'
204
+
205
+ metric = FixedIntervalMetrics(resize_pred=True, sigmoid=True, custom_threshold=shift, **metric_args)
206
+
207
+ with torch.no_grad():
208
+
209
+ i, losses = 0, []
210
+ for i_all, (data_x, data_y) in enumerate(loader):
211
+
212
+ data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x]
213
+ data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y]
214
+
215
+ if config.mask == 'separate': # for old CondBase model
216
+ pred, = model(data_x[0], data_x[1], data_x[2])
217
+ else:
218
+ # assert config.mask in {'text', 'highlight'}
219
+ pred, _, _, _ = model(data_x[0], data_x[1], return_features=True)
220
+
221
+ # loss = loss_fn(pred, data_y[0])
222
+ metric.add(pred.unsqueeze(1) + shift, data_y)
223
+
224
+ # losses += [float(loss)]
225
+
226
+ i += 1
227
+ if config.max_iterations and i >= config.max_iterations:
228
+ break
229
+
230
+ #scores[split] = {m: s for m, s in zip(metric.names(), metric.value())}
231
+
232
+ log.info(f'Dataset length: {len(dataset)}, took {time.time() - eval_start_t:.1f}s to evaluate.')
233
+
234
+ print(metric.value()['mean_iou_scores'])
235
+
236
+ scores[split] = metric.scores()
237
+
238
+ log.info(f'Completed split {split}')
239
+
240
+ key_prefix = config['name'] if 'name' in config else 'pas'
241
+
242
+ all_keys = set.intersection(*[set(v.keys()) for v in scores.values()])
243
+
244
+ valid_keys = [k for k in all_keys if all(v[k] is not None and isinstance(v[k], (int, float, np.float)) for v in scores.values())]
245
+
246
+ return {key_prefix: {k: np.mean([s[k] for s in scores.values()]) for k in valid_keys}}
247
+
248
+
249
+ if config.test_dataset == 'coco':
250
+ from datasets.coco_wrapper import COCOWrapper
251
+
252
+ coco_dataset = COCOWrapper('test', fold=train_config.fold, image_size=train_config.image_size, mask=config.mask,
253
+ with_class_label=True)
254
+
255
+ log.info('Dataset length', len(coco_dataset))
256
+ loader = DataLoader(coco_dataset, batch_size=config.batch_size, num_workers=2, shuffle=False, drop_last=False)
257
+
258
+ metric = get_attribute(config.metric)(resize_pred=True, **metric_args)
259
+
260
+ shift = config.shift if 'shift' in config else 0
261
+
262
+ with torch.no_grad():
263
+
264
+ i, losses = 0, []
265
+ for i_all, (data_x, data_y) in enumerate(loader):
266
+ data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x]
267
+ data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y]
268
+
269
+ if config.mask == 'separate': # for old CondBase model
270
+ pred, = model(data_x[0], data_x[1], data_x[2])
271
+ else:
272
+ # assert config.mask in {'text', 'highlight'}
273
+ pred, _, _, _ = model(data_x[0], data_x[1], return_features=True)
274
+
275
+ metric.add([pred + shift], data_y)
276
+
277
+ i += 1
278
+ if config.max_iterations and i >= config.max_iterations:
279
+ break
280
+
281
+ key_prefix = config['name'] if 'name' in config else 'coco'
282
+ return {key_prefix: metric.scores()}
283
+ #return {key_prefix: {k: v for k, v in zip(metric.names(), metric.value())}}
284
+
285
+
286
+ if config.test_dataset == 'phrasecut':
287
+ from datasets.phrasecut import PhraseCut
288
+
289
+ only_visual = config.only_visual is not None and config.only_visual
290
+ with_visual = config.with_visual is not None and config.with_visual
291
+
292
+ dataset = PhraseCut('test',
293
+ image_size=train_config.image_size,
294
+ mask=config.mask,
295
+ with_visual=with_visual, only_visual=only_visual, aug_crop=False,
296
+ aug_color=False)
297
+
298
+ loader = DataLoader(dataset, batch_size=config.batch_size, num_workers=2, shuffle=False, drop_last=False)
299
+ metric = get_attribute(config.metric)(resize_pred=True, **metric_args)
300
+
301
+ shift = config.shift if 'shift' in config else 0
302
+
303
+
304
+ with torch.no_grad():
305
+
306
+ i, losses = 0, []
307
+ for i_all, (data_x, data_y) in enumerate(loader):
308
+ data_x = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_x]
309
+ data_y = [v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for v in data_y]
310
+
311
+ pred, _, _, _ = model(data_x[0], data_x[1], return_features=True)
312
+ metric.add([pred + shift], data_y)
313
+
314
+ i += 1
315
+ if config.max_iterations and i >= config.max_iterations:
316
+ break
317
+
318
+ key_prefix = config['name'] if 'name' in config else 'phrasecut'
319
+ return {key_prefix: metric.scores()}
320
+ #return {key_prefix: {k: v for k, v in zip(metric.names(), metric.value())}}
321
+
322
+ if config.test_dataset == 'pascal_zs':
323
+ from third_party.JoEm.model.metric import Evaluator
324
+ from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC
325
+ from datasets.pascal_zeroshot import PascalZeroShot, PASCAL_VOC_CLASSES_ZS
326
+
327
+ from models.clipseg import CLIPSegMultiLabel
328
+
329
+ n_unseen = train_config.remove_classes[1]
330
+
331
+ pz = PascalZeroShot('val', n_unseen, image_size=352)
332
+ m = CLIPSegMultiLabel(model=train_config.name).cuda()
333
+ m.eval();
334
+
335
+ print(len(pz), n_unseen)
336
+ print('training removed', [c for class_set in PASCAL_VOC_CLASSES_ZS[:n_unseen // 2] for c in class_set])
337
+
338
+ print('unseen', [VOC[i] for i in get_unseen_idx(n_unseen)])
339
+ print('seen', [VOC[i] for i in get_seen_idx(n_unseen)])
340
+
341
+ loader = DataLoader(pz, batch_size=8)
342
+ evaluator = Evaluator(21, get_unseen_idx(n_unseen), get_seen_idx(n_unseen))
343
+
344
+ for i, (data_x, data_y) in enumerate(loader):
345
+ pred = m(data_x[0].cuda())
346
+ evaluator.add_batch(data_y[0].numpy(), pred.argmax(1).cpu().detach().numpy())
347
+
348
+ if config.max_iter is not None and i > config.max_iter:
349
+ break
350
+
351
+ scores = evaluator.Mean_Intersection_over_Union()
352
+ key_prefix = config['name'] if 'name' in config else 'pas_zs'
353
+
354
+ return {key_prefix: {k: scores[k] for k in ['seen', 'unseen', 'harmonic', 'overall']}}
355
+
356
+ elif config.test_dataset in {'same_as_training', 'affordance'}:
357
+ loss_fn = get_attribute(train_config.loss)
358
+
359
+ metric_cls = get_attribute(config.metric)
360
+ metric = metric_cls(**metric_args)
361
+
362
+ if config.test_dataset == 'same_as_training':
363
+ dataset_cls = get_attribute(train_config.dataset)
364
+ elif config.test_dataset == 'affordance':
365
+ dataset_cls = get_attribute('datasets.lvis_oneshot3.LVIS_Affordance')
366
+ dataset_name = 'aff'
367
+ else:
368
+ dataset_cls = get_attribute('datasets.lvis_oneshot3.LVIS_OneShot')
369
+ dataset_name = 'lvis'
370
+
371
+ _, dataset_args, _ = filter_args(config, inspect.signature(dataset_cls).parameters)
372
+
373
+ dataset_args['image_size'] = train_config.image_size # explicitly use training image size for evaluation
374
+
375
+ if model.__class__.__name__ == 'PFENetWrapper':
376
+ dataset_args['image_size'] = config.image_size
377
+
378
+ log.info('init dataset', str(dataset_cls))
379
+ dataset = dataset_cls(**dataset_args)
380
+
381
+ log.info(f'Score on {model.__class__.__name__} on {dataset_cls.__name__}')
382
+
383
+ data_loader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size, shuffle=config.shuffle)
384
+
385
+ # explicitly set prompts
386
+ if config.prompt == 'plain':
387
+ model.prompt_list = ['{}']
388
+ elif config.prompt == 'fixed':
389
+ model.prompt_list = ['a photo of a {}.']
390
+ elif config.prompt == 'shuffle':
391
+ model.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.']
392
+ elif config.prompt == 'shuffle_clip':
393
+ from models.clip_prompts import imagenet_templates
394
+ model.prompt_list = imagenet_templates
395
+
396
+ config.assume_no_unused_keys(exceptions=['max_iterations'])
397
+
398
+ t_start = time.time()
399
+
400
+ with torch.no_grad(): # TODO: switch to inference_mode (torch 1.9)
401
+ i, losses = 0, []
402
+ for data_x, data_y in data_loader:
403
+
404
+ data_x = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_x]
405
+ data_y = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_y]
406
+
407
+ if model.__class__.__name__ in {'ConditionBase4', 'PFENetWrapper'}:
408
+ pred, = model(data_x[0], data_x[1], data_x[2])
409
+ visual_q = None
410
+ else:
411
+ pred, visual_q, _, _ = model(data_x[0], data_x[1], return_features=True)
412
+
413
+ loss = loss_fn(pred, data_y[0])
414
+
415
+ metric.add([pred], data_y)
416
+
417
+ losses += [float(loss)]
418
+
419
+ i += 1
420
+ if config.max_iterations and i >= config.max_iterations:
421
+ break
422
+
423
+ # scores = {m: s for m, s in zip(metric.names(), metric.value())}
424
+ scores = metric.scores()
425
+
426
+ keys = set(scores.keys())
427
+ if dataset.negative_prob > 0 and 'mIoU' in keys:
428
+ keys.remove('mIoU')
429
+
430
+ name_mask = dataset.mask.replace('text_label', 'txt')[:3]
431
+ name_neg = '' if dataset.negative_prob == 0 else '_' + str(dataset.negative_prob)
432
+
433
+ score_name = config.name if 'name' in config else f'{dataset_name}_{name_mask}{name_neg}'
434
+
435
+ scores = {score_name: {k: v for k,v in scores.items() if k in keys}}
436
+ scores[score_name].update({'test_loss': np.mean(losses)})
437
+
438
+ log.info(f'Evaluation took {time.time() - t_start:.1f}s')
439
+
440
+ return scores
441
+ else:
442
+ raise ValueError('invalid test dataset')
443
+
444
+
445
+
446
+
447
+
448
+
449
+
450
+
451
+
452
+ if __name__ == '__main__':
453
+ main()
clipseg/setup.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup
2
+
3
+ with open("README.md", "r", encoding="utf-8") as readme_file:
4
+ readme = readme_file.read()
5
+
6
+ requirements = [
7
+ "numpy",
8
+ "scipy",
9
+ "matplotlib",
10
+ "torch",
11
+ "torchvision",
12
+ "opencv-python",
13
+ "CLIP @ git+https://github.com/openai/CLIP.git"
14
+ ]
15
+
16
+ setup(
17
+ name='clipseg',
18
+ packages=['clipseg'],
19
+ package_dir={'clipseg': 'models'},
20
+ package_data={'clipseg': [
21
+ "../weights/*.pth",
22
+ ]},
23
+ version='0.0.1',
24
+ url='https://github.com/timojl/clipseg',
25
+ python_requires='>=3.9',
26
+ install_requires=requirements,
27
+ description='This repository contains the code used in the paper "Image Segmentation Using Text and Image Prompts".',
28
+ long_description=readme,
29
+ long_description_content_type="text/markdown",
30
+ )
clipseg/training.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import inspect
3
+ import json
4
+ import yaml
5
+ import math
6
+ import os
7
+ import sys
8
+
9
+ from general_utils import log
10
+
11
+ import numpy as np
12
+ from functools import partial
13
+ from os.path import expanduser, join, isfile, basename
14
+
15
+ from torch.cuda.amp import autocast, GradScaler
16
+ from torch.optim.lr_scheduler import LambdaLR
17
+ from contextlib import nullcontext
18
+ from torch.utils.data import DataLoader
19
+
20
+ from general_utils import TrainingLogger, get_attribute, filter_args, log, training_config_from_cli_args
21
+
22
+
23
+ def cosine_warmup_lr(i, warmup=10, max_iter=90):
24
+ """ Cosine LR with Warmup """
25
+ if i < warmup:
26
+ return (i+1)/(warmup+1)
27
+ else:
28
+ return 0.5 + 0.5*math.cos(math.pi*(((i-warmup)/(max_iter- warmup))))
29
+
30
+
31
+ def validate(model, dataset, config):
32
+ data_loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False)
33
+
34
+ metric_class, use_metric = config.val_metric_class, config.use_val_metric
35
+ loss_fn = get_attribute(config.loss)
36
+
37
+ model.eval()
38
+ model.cuda()
39
+
40
+ if metric_class is not None:
41
+ metric = get_attribute(metric_class)()
42
+
43
+ with torch.no_grad():
44
+
45
+ i, losses = 0, []
46
+ for data_x, data_y in data_loader:
47
+
48
+ data_x = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_x]
49
+ data_y = [x.cuda() if isinstance(x, torch.Tensor) else x for x in data_y]
50
+
51
+ prompts = model.sample_prompts(data_x[1], prompt_list=('a photo of a {}',))
52
+ pred, visual_q, _, _ = model(data_x[0], prompts, return_features=True)
53
+
54
+ if metric_class is not None:
55
+ metric.add([pred], data_y)
56
+
57
+ # pred = model(data_x[0], prompts)
58
+ # loss = loss_fn(pred[0], data_y[0])
59
+ loss = loss_fn(pred, data_y[0])
60
+ losses += [float(loss)]
61
+
62
+ i += 1
63
+
64
+ if config.val_max_iterations is not None and i > config.val_max_iterations:
65
+ break
66
+
67
+ if use_metric is None:
68
+ return np.mean(losses), {}, False
69
+ else:
70
+ metric_scores = {m: s for m, s in zip(metric.names(), metric.value())} if metric is not None else {}
71
+ return np.mean(losses), metric_scores, True
72
+
73
+
74
+ def main():
75
+
76
+ config = training_config_from_cli_args()
77
+
78
+ val_interval, best_val_loss, best_val_score = config.val_interval, float('inf'), float('-inf')
79
+
80
+ model_cls = get_attribute(config.model)
81
+ _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters)
82
+ model = model_cls(**model_args).cuda()
83
+
84
+ dataset_cls = get_attribute(config.dataset)
85
+ _, dataset_args, _ = filter_args(config, inspect.signature(dataset_cls).parameters)
86
+
87
+ dataset = dataset_cls(**dataset_args)
88
+
89
+ log.info(f'Train dataset {dataset.__class__.__name__} (length: {len(dataset)})')
90
+
91
+ if val_interval is not None:
92
+ dataset_val_args = {k[4:]: v for k,v in config.items() if k.startswith('val_') and k != 'val_interval'}
93
+ _, dataset_val_args, _ = filter_args(dataset_val_args, inspect.signature(dataset_cls).parameters)
94
+ print('val args', {**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args})
95
+
96
+ dataset_val = dataset_cls(**{**dataset_args, **{'split': 'val', 'aug': 0}, **dataset_val_args})
97
+
98
+ # optimizer
99
+ opt_cls = get_attribute(config.optimizer)
100
+ if config.optimize == 'torch.optim.SGD':
101
+ opt_args = {'momentum': config.momentum if 'momentum' in config else 0}
102
+ else:
103
+ opt_args = {}
104
+ opt = opt_cls(model.parameters(), lr=config.lr, **opt_args)
105
+
106
+ if config.lr_scheduler == 'cosine':
107
+ assert config.T_max is not None and config.eta_min is not None
108
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, config.T_max, config.eta_min)
109
+ elif config.lr_scheduler == 'warmup_cosine':
110
+ lr_scheduler = LambdaLR(opt, partial(cosine_warmup_lr, max_iter=(config.max_iterations), warmup=config.warmup))
111
+ else:
112
+ lr_scheduler = None
113
+
114
+ batch_size, max_iterations = config.batch_size, config.max_iterations
115
+
116
+ loss_fn = get_attribute(config.loss)
117
+
118
+ if config.amp:
119
+ log.info('Using AMP')
120
+ autocast_fn = autocast
121
+ scaler = GradScaler()
122
+ else:
123
+ autocast_fn, scaler = nullcontext, None
124
+
125
+
126
+ save_only_trainable = True
127
+ data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=4)
128
+
129
+ # disable config when hyperparam. opt. to avoid writing logs.
130
+ tracker_config = config if not config.hyperparameter_optimization else None
131
+
132
+ with TrainingLogger(log_dir=config.name, model=model, config=tracker_config) as logger:
133
+
134
+ i = 0
135
+ while True:
136
+ for data_x, data_y in data_loader:
137
+
138
+ # between caption and output feature.
139
+ # 1. Sample random captions
140
+ # 2. Check alignment with CLIP
141
+
142
+ # randomly mix text and visual support conditionals
143
+ if config.mix:
144
+
145
+ assert config.mask.startswith('text_and')
146
+
147
+ with autocast_fn():
148
+ # data_x[1] = text label
149
+ prompts = model.sample_prompts(data_x[1])
150
+
151
+ # model.clip_model()
152
+
153
+ text_cond = model.compute_conditional(prompts)
154
+ if model.__class__.__name__ == 'CLIPDensePredTMasked':
155
+ # when mask=='separate'
156
+ visual_s_cond, _, _ = model.visual_forward_masked(data_x[2].cuda(), data_x[3].cuda())
157
+ else:
158
+ # data_x[2] = visual prompt
159
+ visual_s_cond, _, _ = model.visual_forward(data_x[2].cuda())
160
+
161
+ max_txt = config.mix_text_max if config.mix_text_max is not None else 1
162
+ batch_size = text_cond.shape[0]
163
+
164
+ # sample weights for each element in batch
165
+ text_weights = torch.distributions.Uniform(config.mix_text_min, max_txt).sample((batch_size,))[:, None]
166
+ text_weights = text_weights.cuda()
167
+
168
+ if dataset.__class__.__name__ == 'PhraseCut':
169
+ # give full weight to text where support_image is invalid
170
+ visual_is_valid = data_x[4] if model.__class__.__name__ == 'CLIPDensePredTMasked' else data_x[3]
171
+ text_weights = torch.max(text_weights[:,0], 1 - visual_is_valid.float().cuda()).unsqueeze(1)
172
+
173
+ cond = text_cond * text_weights + visual_s_cond * (1 - text_weights)
174
+
175
+ else:
176
+ # no mix
177
+
178
+ if model.__class__.__name__ == 'CLIPDensePredTMasked':
179
+ # compute conditional vector using CLIP masking
180
+ with autocast_fn():
181
+ assert config.mask == 'separate'
182
+ cond, _, _ = model.visual_forward_masked(data_x[1].cuda(), data_x[2].cuda())
183
+ else:
184
+ cond = data_x[1]
185
+ if isinstance(cond, torch.Tensor):
186
+ cond = cond.cuda()
187
+
188
+ with autocast_fn():
189
+ visual_q = None
190
+
191
+ pred, visual_q, _, _ = model(data_x[0].cuda(), cond, return_features=True)
192
+
193
+ loss = loss_fn(pred, data_y[0].cuda())
194
+
195
+ if torch.isnan(loss) or torch.isinf(loss):
196
+ # skip if loss is nan
197
+ log.warning('Training stopped due to inf/nan loss.')
198
+ sys.exit(-1)
199
+
200
+ extra_loss = 0
201
+ loss += extra_loss
202
+
203
+ opt.zero_grad()
204
+
205
+ if scaler is None:
206
+ loss.backward()
207
+ opt.step()
208
+ else:
209
+ scaler.scale(loss).backward()
210
+ scaler.step(opt)
211
+ scaler.update()
212
+
213
+ if lr_scheduler is not None:
214
+ lr_scheduler.step()
215
+ if i % 2000 == 0:
216
+ current_lr = [g['lr'] for g in opt.param_groups][0]
217
+ log.info(f'current lr: {current_lr:.5f} ({len(opt.param_groups)} parameter groups)')
218
+
219
+ logger.iter(i=i, loss=loss)
220
+ i += 1
221
+
222
+ if i >= max_iterations:
223
+
224
+ if not isfile(join(logger.base_path, 'weights.pth')):
225
+ # only write if no weights were already written
226
+ logger.save_weights(only_trainable=save_only_trainable)
227
+
228
+ sys.exit(0)
229
+
230
+
231
+ if config.checkpoint_iterations is not None and i in config.checkpoint_iterations:
232
+ logger.save_weights(only_trainable=save_only_trainable, weight_file=f'weights_{i}.pth')
233
+
234
+
235
+ if val_interval is not None and i % val_interval == val_interval - 1:
236
+
237
+ val_loss, val_scores, maximize = validate(model, dataset_val, config)
238
+
239
+ if len(val_scores) > 0:
240
+
241
+ score_str = f', scores: ' + ', '.join(f'{k}: {v}' for k, v in val_scores.items())
242
+
243
+ if maximize and val_scores[config.use_val_metric] > best_val_score:
244
+ logger.save_weights(only_trainable=save_only_trainable)
245
+ best_val_score = val_scores[config.use_val_metric]
246
+
247
+ elif not maximize and val_scores[config.use_val_metric] < best_val_score:
248
+ logger.save_weights(only_trainable=save_only_trainable)
249
+ best_val_score = val_scores[config.use_val_metric]
250
+
251
+ else:
252
+ score_str = ''
253
+ # if no score is used, fall back to loss
254
+ if val_loss < best_val_loss:
255
+ logger.save_weights(only_trainable=save_only_trainable)
256
+ best_val_loss = val_loss
257
+
258
+ log.info(f'Validation loss: {val_loss}' + score_str)
259
+ logger.iter(i=i, val_loss=val_loss, extra_loss=float(extra_loss), **val_scores)
260
+ model.train()
261
+
262
+ print('epoch complete')
263
+
264
+
265
+ if __name__ == '__main__':
266
+ main()
clipseg/weights/rd64-uni.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13845f6cee4d54ca46f62ee19dd354822094a26e0efccc64e606be93d6a7e26f
3
+ size 4306645
init_image.png ADDED
inpainting.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ import PIL
8
+ from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, PNDMScheduler, UNet2DConditionModel
9
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
10
+ from tqdm.auto import tqdm
11
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
12
+
13
+
14
+ def preprocess_image(image):
15
+ w, h = image.size
16
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
17
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
18
+ image = np.array(image).astype(np.float32) / 255.0
19
+ image = image[None].transpose(0, 3, 1, 2)
20
+ image = torch.from_numpy(image)
21
+ return 2.0 * image - 1.0
22
+
23
+
24
+ def preprocess_mask(mask):
25
+ mask = mask.convert("L")
26
+ w, h = mask.size
27
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
28
+ mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
29
+ mask = np.array(mask).astype(np.float32) / 255.0
30
+ mask = np.tile(mask, (4, 1, 1))
31
+ mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
32
+ mask = 1 - mask # repaint white, keep black
33
+ mask = torch.from_numpy(mask)
34
+ return mask
35
+
36
+ class StableDiffusionInpaintingPipeline(DiffusionPipeline):
37
+ def __init__(
38
+ self,
39
+ vae: AutoencoderKL,
40
+ text_encoder: CLIPTextModel,
41
+ tokenizer: CLIPTokenizer,
42
+ unet: UNet2DConditionModel,
43
+ scheduler: Union[DDIMScheduler, PNDMScheduler],
44
+ safety_checker: StableDiffusionSafetyChecker,
45
+ feature_extractor: CLIPFeatureExtractor,
46
+ ):
47
+ super().__init__()
48
+ scheduler = scheduler.set_format("pt")
49
+ self.register_modules(
50
+ vae=vae,
51
+ text_encoder=text_encoder,
52
+ tokenizer=tokenizer,
53
+ unet=unet,
54
+ scheduler=scheduler,
55
+ safety_checker=safety_checker,
56
+ feature_extractor=feature_extractor,
57
+ )
58
+
59
+ @torch.no_grad()
60
+ def __call__(
61
+ self,
62
+ prompt: Union[str, List[str]],
63
+ init_image: torch.FloatTensor,
64
+ mask_image: torch.FloatTensor,
65
+ strength: float = 0.8,
66
+ num_inference_steps: Optional[int] = 50,
67
+ guidance_scale: Optional[float] = 7.5,
68
+ eta: Optional[float] = 0.0,
69
+ generator: Optional[torch.Generator] = None,
70
+ output_type: Optional[str] = "pil",
71
+ ):
72
+
73
+ if isinstance(prompt, str):
74
+ batch_size = 1
75
+ elif isinstance(prompt, list):
76
+ batch_size = len(prompt)
77
+ else:
78
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
79
+
80
+ if strength < 0 or strength > 1:
81
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
82
+
83
+ # set timesteps
84
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
85
+ extra_set_kwargs = {}
86
+ offset = 0
87
+ if accepts_offset:
88
+ offset = 1
89
+ extra_set_kwargs["offset"] = 1
90
+
91
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
92
+
93
+ # preprocess image
94
+ init_image = preprocess_image(init_image).to(self.device)
95
+
96
+ # encode the init image into latents and scale the latents
97
+ init_latent_dist = self.vae.encode(init_image).latent_dist
98
+ init_latents = init_latent_dist.sample(generator=generator)
99
+ init_latents = 0.18215 * init_latents
100
+
101
+ # prepare init_latents noise to latents
102
+ init_latents = torch.cat([init_latents] * batch_size)
103
+ init_latents_orig = init_latents
104
+
105
+ # preprocess mask
106
+ mask = preprocess_mask(mask_image).to(self.device)
107
+ mask = torch.cat([mask] * batch_size)
108
+
109
+ # check sizes
110
+ if not mask.shape == init_latents.shape:
111
+ raise ValueError(f"The mask and init_image should be the same size!")
112
+
113
+ # get the original timestep using init_timestep
114
+ init_timestep = int(num_inference_steps * strength) + offset
115
+ init_timestep = min(init_timestep, num_inference_steps)
116
+ timesteps = self.scheduler.timesteps[-init_timestep]
117
+ timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
118
+
119
+ # add noise to latents using the timesteps
120
+ noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
121
+ init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
122
+
123
+ # get prompt text embeddings
124
+ text_input = self.tokenizer(
125
+ prompt,
126
+ padding="max_length",
127
+ max_length=self.tokenizer.model_max_length,
128
+ truncation=True,
129
+ return_tensors="pt",
130
+ )
131
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
132
+
133
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
134
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
135
+ # corresponds to doing no classifier free guidance.
136
+ do_classifier_free_guidance = guidance_scale > 1.0
137
+ # get unconditional embeddings for classifier free guidance
138
+ if do_classifier_free_guidance:
139
+ max_length = text_input.input_ids.shape[-1]
140
+ uncond_input = self.tokenizer(
141
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
142
+ )
143
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
144
+
145
+ # For classifier free guidance, we need to do two forward passes.
146
+ # Here we concatenate the unconditional and text embeddings into a single batch
147
+ # to avoid doing two forward passes
148
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
149
+
150
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
151
+ # eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers.
152
+ # eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502
153
+ # and should be between [0, 1]
154
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
155
+ extra_step_kwargs = {}
156
+ if accepts_eta:
157
+ extra_step_kwargs["eta"] = eta
158
+
159
+ latents = init_latents
160
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
161
+ for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
162
+ # expand the latents if we are doing classifier free guidance
163
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
164
+
165
+ # predict the noise residual
166
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
167
+
168
+ # perform guidance
169
+ if do_classifier_free_guidance:
170
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
171
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
172
+
173
+ # compute the previous noisy sample x_t -> x_t-1
174
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
175
+
176
+ # masking
177
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
178
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
179
+
180
+ # scale and decode the image latents with vae
181
+ latents = 1 / 0.18215 * latents
182
+ image = self.vae.decode(latents).sample
183
+
184
+ image = (image / 2 + 0.5).clamp(0, 1)
185
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
186
+
187
+ # run safety checker
188
+ safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
189
+ image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
190
+
191
+ if output_type == "pil":
192
+ image = self.numpy_to_pil(image)
193
+
194
+ return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
mask_image.png ADDED