hpc-yekin commited on
Commit
92e0882
1 Parent(s): 4aa0b3a

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. AlphaCLIP/.gitignore +12 -0
  2. AlphaCLIP/LICENSE +201 -0
  3. AlphaCLIP/MANIFEST.in +1 -0
  4. AlphaCLIP/alpha_clip/__init__.py +1 -0
  5. AlphaCLIP/alpha_clip/alpha_clip.py +250 -0
  6. AlphaCLIP/alpha_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
  7. AlphaCLIP/alpha_clip/model.py +598 -0
  8. AlphaCLIP/alpha_clip/simple_tokenizer.py +132 -0
  9. AlphaCLIP/eval/README.md +6 -0
  10. AlphaCLIP/eval/imagenet_s_zs_test/.gitignore +2 -0
  11. AlphaCLIP/eval/imagenet_s_zs_test/README.md +21 -0
  12. AlphaCLIP/eval/imagenet_s_zs_test/imagenet_s.py +149 -0
  13. AlphaCLIP/eval/imagenet_s_zs_test/imagenet_s_zs_test.py +66 -0
  14. AlphaCLIP/eval/rec_zs_test/LICENSE.md +201 -0
  15. AlphaCLIP/eval/rec_zs_test/README.md +74 -0
  16. AlphaCLIP/eval/rec_zs_test/cache/.gitkeep +0 -0
  17. AlphaCLIP/eval/rec_zs_test/cal_acc.py +21 -0
  18. AlphaCLIP/eval/rec_zs_test/ckpt/.gitkeep +0 -0
  19. AlphaCLIP/eval/rec_zs_test/data/.gitkeep +0 -0
  20. AlphaCLIP/eval/rec_zs_test/entity_extraction.py +142 -0
  21. AlphaCLIP/eval/rec_zs_test/executor.py +401 -0
  22. AlphaCLIP/eval/rec_zs_test/generic_clip_pairs.py +107 -0
  23. AlphaCLIP/eval/rec_zs_test/heuristics.py +68 -0
  24. AlphaCLIP/eval/rec_zs_test/interpreter.py +212 -0
  25. AlphaCLIP/eval/rec_zs_test/lattice.py +70 -0
  26. AlphaCLIP/eval/rec_zs_test/main.py +200 -0
  27. AlphaCLIP/eval/rec_zs_test/methods/__init__.py +3 -0
  28. AlphaCLIP/eval/rec_zs_test/methods/baseline.py +57 -0
  29. AlphaCLIP/eval/rec_zs_test/methods/parse.py +239 -0
  30. AlphaCLIP/eval/rec_zs_test/methods/random_method.py +30 -0
  31. AlphaCLIP/eval/rec_zs_test/methods/ref_method.py +13 -0
  32. AlphaCLIP/eval/rec_zs_test/output/.gitkeep +0 -0
  33. AlphaCLIP/eval/rec_zs_test/requirements.txt +53 -0
  34. AlphaCLIP/eval/rec_zs_test/run.sh +1 -0
  35. AlphaCLIP/eval/rec_zs_test/run_multi_gpus.sh +15 -0
  36. AlphaCLIP/hubconf.py +42 -0
  37. AlphaCLIP/requirements.txt +5 -0
  38. AlphaCLIP/setup.py +21 -0
  39. README.md +1 -1
  40. app.py +113 -0
  41. clip_l14_grit+mim_fultune_6xe.pth +3 -0
  42. config/inference_config.yaml +16 -0
  43. image_encoder/config.json +23 -0
  44. image_encoder/pytorch_model.bin +3 -0
  45. ip-adapter_sd15.bin +3 -0
  46. model.safetensors +3 -0
  47. model/__init__.py +5 -0
  48. model/attention_processor.py +189 -0
  49. model/clip_away.py +280 -0
  50. model/resampler.py +158 -0
AlphaCLIP/.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ *$py.class
4
+ *.egg-info
5
+ .pytest_cache
6
+ .ipynb_checkpoints
7
+
8
+ thumbs.db
9
+ .DS_Store
10
+ .idea
11
+ checkpoints/*
12
+ *.pth
AlphaCLIP/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [Zeyi Sun] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
AlphaCLIP/MANIFEST.in ADDED
@@ -0,0 +1 @@
 
 
1
+ include alpha_clip/bpe_simple_vocab_16e6.txt.gz
AlphaCLIP/alpha_clip/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .alpha_clip import *
AlphaCLIP/alpha_clip/alpha_clip.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from typing import Any, Union, List
6
+ from pkg_resources import packaging
7
+
8
+ import torch
9
+ from PIL import Image
10
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11
+ from tqdm import tqdm
12
+
13
+ from .model import build_model
14
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15
+
16
+ try:
17
+ from torchvision.transforms import InterpolationMode
18
+ BICUBIC = InterpolationMode.BICUBIC
19
+ except ImportError:
20
+ BICUBIC = Image.BICUBIC
21
+
22
+
23
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24
+ warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25
+
26
+
27
+ __all__ = ["available_models", "load", "tokenize"]
28
+ _tokenizer = _Tokenizer()
29
+
30
+ _MODELS = {
31
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35
+ "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
36
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
37
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
38
+ "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
39
+ "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
40
+ }
41
+
42
+
43
+ def _download(url: str, root: str):
44
+ os.makedirs(root, exist_ok=True)
45
+ filename = os.path.basename(url)
46
+
47
+ expected_sha256 = url.split("/")[-2]
48
+ download_target = os.path.join(root, filename)
49
+
50
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
51
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
52
+
53
+ if os.path.isfile(download_target):
54
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
55
+ return download_target
56
+ else:
57
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
58
+
59
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
60
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
61
+ while True:
62
+ buffer = source.read(8192)
63
+ if not buffer:
64
+ break
65
+
66
+ output.write(buffer)
67
+ loop.update(len(buffer))
68
+
69
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
70
+ raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
71
+
72
+ return download_target
73
+
74
+
75
+ def _convert_image_to_rgb(image):
76
+ return image.convert("RGB")
77
+
78
+
79
+ def _transform(n_px):
80
+ return Compose([
81
+ Resize(n_px, interpolation=BICUBIC),
82
+ CenterCrop(n_px),
83
+ _convert_image_to_rgb,
84
+ ToTensor(),
85
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
86
+ ])
87
+
88
+
89
+ def available_models() -> List[str]:
90
+ """Returns the names of available CLIP models"""
91
+ return list(_MODELS.keys())
92
+
93
+
94
+ def load(name: str, alpha_vision_ckpt_pth="None", device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None, lora_adapt=False, rank=16):
95
+ """Load a CLIP model
96
+
97
+ Parameters
98
+ ----------
99
+ name : str
100
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
101
+
102
+ alpha_vision_ckpt_pth: str
103
+ only changed when inferencing model instead of training
104
+
105
+ device : Union[str, torch.device]
106
+ The device to put the loaded model
107
+
108
+ jit : bool
109
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
110
+
111
+ download_root: str
112
+ path to download the model files; by default, it uses "~/.cache/clip"
113
+
114
+ Returns
115
+ -------
116
+ model : torch.nn.Module
117
+ The CLIP model
118
+
119
+ preprocess : Callable[[PIL.Image], torch.Tensor]
120
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
121
+ """
122
+ if name in _MODELS:
123
+ model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
124
+ elif os.path.isfile(name):
125
+ model_path = name
126
+ else:
127
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
128
+
129
+ with open(model_path, 'rb') as opened_file:
130
+ try:
131
+ # loading JIT archive
132
+ model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
133
+ state_dict = None
134
+ except RuntimeError:
135
+ # loading saved state dict
136
+ if jit:
137
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
138
+ jit = False
139
+ state_dict = torch.load(opened_file, map_location="cpu")
140
+
141
+ if not jit:
142
+ model = build_model(state_dict or model.state_dict(), lora_adapt=lora_adapt, rank=rank).to(device)
143
+ if str(device) == "cpu":
144
+ model.float()
145
+ if alpha_vision_ckpt_pth != "None":
146
+ model.visual.load_state_dict(torch.load(alpha_vision_ckpt_pth))
147
+ model.eval() # merge lora params if exists (for inference only)
148
+ return model, _transform(model.visual.input_resolution)
149
+
150
+ # patch the device names
151
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
152
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
153
+
154
+ def _node_get(node: torch._C.Node, key: str):
155
+ """Gets attributes of a node which is polymorphic over return type.
156
+
157
+ From https://github.com/pytorch/pytorch/pull/82628
158
+ """
159
+ sel = node.kindOf(key)
160
+ return getattr(node, sel)(key)
161
+
162
+ def patch_device(module):
163
+ try:
164
+ graphs = [module.graph] if hasattr(module, "graph") else []
165
+ except RuntimeError:
166
+ graphs = []
167
+
168
+ if hasattr(module, "forward1"):
169
+ graphs.append(module.forward1.graph)
170
+
171
+ for graph in graphs:
172
+ for node in graph.findAllNodes("prim::Constant"):
173
+ if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
174
+ node.copyAttributes(device_node)
175
+
176
+ model.apply(patch_device)
177
+ patch_device(model.encode_image)
178
+ patch_device(model.encode_text)
179
+
180
+ # patch dtype to float32 on CPU
181
+ if str(device) == "cpu":
182
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
183
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
184
+ float_node = float_input.node()
185
+
186
+ def patch_float(module):
187
+ try:
188
+ graphs = [module.graph] if hasattr(module, "graph") else []
189
+ except RuntimeError:
190
+ graphs = []
191
+
192
+ if hasattr(module, "forward1"):
193
+ graphs.append(module.forward1.graph)
194
+
195
+ for graph in graphs:
196
+ for node in graph.findAllNodes("aten::to"):
197
+ inputs = list(node.inputs())
198
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
199
+ if _node_get(inputs[i].node(), "value") == 5:
200
+ inputs[i].node().copyAttributes(float_node)
201
+
202
+ model.apply(patch_float)
203
+ patch_float(model.encode_image)
204
+ patch_float(model.encode_text)
205
+
206
+ model.float()
207
+ return model, _transform(model.input_resolution.item())
208
+
209
+
210
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = True) -> Union[torch.IntTensor, torch.LongTensor]:
211
+ """
212
+ Returns the tokenized representation of given input string(s)
213
+
214
+ Parameters
215
+ ----------
216
+ texts : Union[str, List[str]]
217
+ An input string or a list of input strings to tokenize
218
+
219
+ context_length : int
220
+ The context length to use; all CLIP models use 77 as the context length
221
+
222
+ truncate: bool
223
+ Whether to truncate the text in case its encoding is longer than the context length
224
+
225
+ Returns
226
+ -------
227
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
228
+ We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
229
+ """
230
+ if isinstance(texts, str):
231
+ texts = [texts]
232
+
233
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
234
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
235
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
236
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
237
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
238
+ else:
239
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
240
+
241
+ for i, tokens in enumerate(all_tokens):
242
+ if len(tokens) > context_length:
243
+ if truncate:
244
+ tokens = tokens[:context_length]
245
+ tokens[-1] = eot_token
246
+ else:
247
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
248
+ result[i, :len(tokens)] = torch.tensor(tokens)
249
+
250
+ return result
AlphaCLIP/alpha_clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
AlphaCLIP/alpha_clip/model.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ import loralib as lora
9
+ import math
10
+ import collections
11
+
12
+ class Bottleneck(nn.Module):
13
+ expansion = 4
14
+
15
+ def __init__(self, inplanes, planes, stride=1):
16
+ super().__init__()
17
+
18
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
19
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
20
+ self.bn1 = nn.BatchNorm2d(planes)
21
+ self.relu1 = nn.ReLU(inplace=True)
22
+
23
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
24
+ self.bn2 = nn.BatchNorm2d(planes)
25
+ self.relu2 = nn.ReLU(inplace=True)
26
+
27
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
28
+
29
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
30
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
31
+ self.relu3 = nn.ReLU(inplace=True)
32
+
33
+ self.downsample = None
34
+ self.stride = stride
35
+
36
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
37
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
38
+ self.downsample = nn.Sequential(OrderedDict([
39
+ ("-1", nn.AvgPool2d(stride)),
40
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
41
+ ("1", nn.BatchNorm2d(planes * self.expansion))
42
+ ]))
43
+
44
+ def forward(self, x: torch.Tensor):
45
+ identity = x
46
+
47
+ out = self.relu1(self.bn1(self.conv1(x)))
48
+ out = self.relu2(self.bn2(self.conv2(out)))
49
+ out = self.avgpool(out)
50
+ out = self.bn3(self.conv3(out))
51
+
52
+ if self.downsample is not None:
53
+ identity = self.downsample(x)
54
+
55
+ out += identity
56
+ out = self.relu3(out)
57
+ return out
58
+
59
+
60
+ class AttentionPool2d(nn.Module):
61
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
62
+ super().__init__()
63
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
64
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
65
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
66
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
67
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
68
+ self.num_heads = num_heads
69
+
70
+ def forward(self, x):
71
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
72
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
73
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
74
+ x, _ = F.multi_head_attention_forward(
75
+ query=x[:1], key=x, value=x,
76
+ embed_dim_to_check=x.shape[-1],
77
+ num_heads=self.num_heads,
78
+ q_proj_weight=self.q_proj.weight,
79
+ k_proj_weight=self.k_proj.weight,
80
+ v_proj_weight=self.v_proj.weight,
81
+ in_proj_weight=None,
82
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
83
+ bias_k=None,
84
+ bias_v=None,
85
+ add_zero_attn=False,
86
+ dropout_p=0,
87
+ out_proj_weight=self.c_proj.weight,
88
+ out_proj_bias=self.c_proj.bias,
89
+ use_separate_proj_weight=True,
90
+ training=self.training,
91
+ need_weights=False
92
+ )
93
+ return x.squeeze(0)
94
+
95
+
96
+ class ModifiedResNet(nn.Module):
97
+ """
98
+ A ResNet class that is similar to torchvision's but contains the following changes:
99
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
100
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
101
+ - The final pooling layer is a QKV attention instead of an average pool
102
+ """
103
+
104
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
105
+ super().__init__()
106
+ self.output_dim = output_dim
107
+ self.input_resolution = input_resolution
108
+
109
+ # the 3-layer stem
110
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
111
+ self.conv1_alpha = nn.Conv2d(in_channels=1, out_channels=width // 2, kernel_size=3, stride=2, padding=1, bias=False)
112
+ self.bn1 = nn.BatchNorm2d(width // 2)
113
+ self.relu1 = nn.ReLU(inplace=True)
114
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
115
+ self.bn2 = nn.BatchNorm2d(width // 2)
116
+ self.relu2 = nn.ReLU(inplace=True)
117
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
118
+ self.bn3 = nn.BatchNorm2d(width)
119
+ self.relu3 = nn.ReLU(inplace=True)
120
+ self.avgpool = nn.AvgPool2d(2)
121
+
122
+ # residual layers
123
+ self._inplanes = width # this is a *mutable* variable used during construction
124
+ self.layer1 = self._make_layer(width, layers[0])
125
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
126
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
127
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
128
+
129
+ embed_dim = width * 32 # the ResNet feature dimension
130
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
131
+
132
+ def _make_layer(self, planes, blocks, stride=1):
133
+ layers = [Bottleneck(self._inplanes, planes, stride)]
134
+
135
+ self._inplanes = planes * Bottleneck.expansion
136
+ for _ in range(1, blocks):
137
+ layers.append(Bottleneck(self._inplanes, planes))
138
+
139
+ return nn.Sequential(*layers)
140
+
141
+ def forward(self, x, alpha=None):
142
+ def stem(x):
143
+ x = self.relu1(self.bn1(self.conv1(x) + self.conv1_alpha(alpha)))
144
+ x = self.relu2(self.bn2(self.conv2(x)))
145
+ x = self.relu3(self.bn3(self.conv3(x)))
146
+ x = self.avgpool(x)
147
+ return x
148
+
149
+ x = x.type(self.conv1.weight.dtype)
150
+ x = stem(x)
151
+ x = self.layer1(x)
152
+ x = self.layer2(x)
153
+ x = self.layer3(x)
154
+ x = self.layer4(x)
155
+ x = self.attnpool(x)
156
+
157
+ return x
158
+
159
+
160
+ class LayerNorm(nn.LayerNorm):
161
+ """Subclass torch's LayerNorm to handle fp16."""
162
+
163
+ def forward(self, x: torch.Tensor):
164
+ orig_type = x.dtype
165
+ ret = super().forward(x.type(torch.float32))
166
+ return ret.type(orig_type)
167
+
168
+
169
+ class QuickGELU(nn.Module):
170
+ def forward(self, x: torch.Tensor):
171
+ return x * torch.sigmoid(1.702 * x)
172
+
173
+ class Attention(nn.Module):
174
+ def __init__(
175
+ self,
176
+ dim,
177
+ num_heads=8,
178
+ qkv_bias=True,
179
+ scaled_cosine=False,
180
+ scale_heads=False,
181
+ logit_scale_max=math.log(1. / 0.01),
182
+ attn_drop=0.,
183
+ proj_drop=0.,
184
+ lora_adapt=False,
185
+ rank=16
186
+ ):
187
+ super().__init__()
188
+ self.scaled_cosine = scaled_cosine
189
+ self.scale_heads = scale_heads
190
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
191
+ self.num_heads = num_heads
192
+ self.head_dim = dim // num_heads
193
+ self.scale = self.head_dim ** -0.5
194
+ self.logit_scale_max = logit_scale_max
195
+
196
+ # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
197
+ if lora_adapt:
198
+ print("!!!!!!!!!!using lora for qkv projection!!!!!!!!!!")
199
+ self.in_proj = lora.MergedLinear(dim, 3*dim, r=rank, enable_lora=[True, False, True])
200
+ else:
201
+ self.in_proj = nn.Linear(dim, dim * 3)
202
+ # self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
203
+ # if qkv_bias:
204
+ # self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
205
+ # else:
206
+ # self.in_proj_bias = None
207
+
208
+ if self.scaled_cosine:
209
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
210
+ else:
211
+ self.logit_scale = None
212
+ self.attn_drop = nn.Dropout(attn_drop)
213
+ if self.scale_heads:
214
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
215
+ else:
216
+ self.head_scale = None
217
+ self.out_proj = nn.Linear(dim, dim) if not lora_adapt else lora.Linear(dim, dim, r=rank)
218
+ self.out_drop = nn.Dropout(proj_drop)
219
+
220
+ def forward(self, x, attn_mask = None):
221
+ L, N, C = x.shape
222
+ q, k, v = self.in_proj(x).chunk(3, dim=-1)
223
+ q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
224
+ k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
225
+ v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
226
+
227
+ if self.logit_scale is not None:
228
+ attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
229
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
230
+ attn = attn.view(N, self.num_heads, L, L) * logit_scale
231
+ attn = attn.view(-1, L, L)
232
+ else:
233
+ q = q * self.scale
234
+ attn = torch.bmm(q, k.transpose(-2, -1))
235
+
236
+ if attn_mask is not None:
237
+ if attn_mask.dtype == torch.bool:
238
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
239
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
240
+ attn_mask = new_attn_mask
241
+ attn += attn_mask
242
+
243
+ attn = attn.softmax(dim=-1)
244
+ attn = self.attn_drop(attn)
245
+
246
+ x = torch.bmm(attn, v)
247
+ if self.head_scale is not None:
248
+ x = x.view(N, self.num_heads, L, C) * self.head_scale
249
+ x = x.view(-1, L, C)
250
+ x = x.transpose(0, 1).reshape(L, N, C)
251
+ x = self.out_proj(x)
252
+ x = self.out_drop(x)
253
+ return x, attn
254
+
255
+
256
+ class CustomResidualAttentionBlock(nn.Module):
257
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, lora_adapt=False, rank=16):
258
+ super().__init__()
259
+
260
+ self.attn = Attention(d_model, n_head, lora_adapt=lora_adapt, rank=rank)
261
+ self.ln_1 = LayerNorm(d_model)
262
+ self.mlp = nn.Sequential(OrderedDict([
263
+ ("c_fc", nn.Linear(d_model, d_model * 4) if not lora_adapt else lora.Linear(d_model, d_model*4, r=rank)),
264
+ ("gelu", QuickGELU()),
265
+ ("c_proj", nn.Linear(d_model * 4, d_model) if not lora_adapt else lora.Linear(d_model*4, d_model, r=rank))
266
+ ]))
267
+ self.ln_2 = LayerNorm(d_model)
268
+ self.attn_mask = attn_mask
269
+
270
+ def attention(self, x: torch.Tensor):
271
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
272
+ return self.attn(x, attn_mask=self.attn_mask)
273
+
274
+ def forward(self, x: torch.Tensor, return_attn=False):
275
+ attn_out, attn = self.attention(self.ln_1(x))
276
+ x = x + attn_out
277
+ x = x + self.mlp(self.ln_2(x))
278
+ if return_attn:
279
+ return x, attn
280
+ else:
281
+ return x
282
+
283
+ class ResidualAttentionBlock(nn.Module):
284
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
285
+ super().__init__()
286
+
287
+ self.attn = nn.MultiheadAttention(d_model, n_head)
288
+ self.ln_1 = LayerNorm(d_model)
289
+ self.mlp = nn.Sequential(OrderedDict([
290
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
291
+ ("gelu", QuickGELU()),
292
+ ("c_proj", nn.Linear(d_model * 4, d_model))
293
+ ]))
294
+ self.ln_2 = LayerNorm(d_model)
295
+ self.attn_mask = attn_mask
296
+
297
+ def attention(self, x: torch.Tensor):
298
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
299
+ return self.attn(x, x, x, attn_mask=self.attn_mask)[0]
300
+
301
+ def forward(self, x: torch.Tensor):
302
+ x = x + self.attention(self.ln_1(x))
303
+ x = x + self.mlp(self.ln_2(x))
304
+ return x
305
+
306
+ class Transformer(nn.Module):
307
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
308
+ super().__init__()
309
+ self.width = width
310
+ self.layers = layers
311
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
312
+
313
+ def forward(self, x: torch.Tensor):
314
+ return self.resblocks(x)
315
+
316
+ class CustomTransformer(nn.Module):
317
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, lora_adapt=False, rank=16):
318
+ super().__init__()
319
+ self.width = width
320
+ self.layers = layers
321
+ self.resblocks = nn.Sequential(*[CustomResidualAttentionBlock(width, heads, attn_mask, lora_adapt=lora_adapt, rank=rank) for _ in range(layers)])
322
+
323
+ def forward(self, x: torch.Tensor, return_attn=False):
324
+ if return_attn:
325
+ for i, block in enumerate(self.resblocks):
326
+ if i == len(self.resblocks) - 1:
327
+ return block(x, return_attn=True)
328
+ else:
329
+ x = block(x)
330
+ assert False
331
+ return self.resblocks(x)
332
+
333
+ class VisionTransformer(nn.Module):
334
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, lora_adapt=False, rank=16):
335
+ super().__init__()
336
+ self.input_resolution = input_resolution
337
+ self.output_dim = output_dim
338
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
339
+ self.conv1_alpha = nn.Conv2d(in_channels=1, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
340
+
341
+ scale = width ** -0.5
342
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
343
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
344
+ self.ln_pre = LayerNorm(width)
345
+
346
+ self.transformer = CustomTransformer(width, layers, heads, lora_adapt=lora_adapt, rank=rank)
347
+
348
+ self.ln_post = LayerNorm(width)
349
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
350
+
351
+ def forward(self, x: torch.Tensor, alpha=None, return_attn=False):
352
+ x = self.conv1(x) # shape = [*, width, grid, grid]
353
+ # ASSUME alpha is always not None!
354
+ x = x + self.conv1_alpha(alpha)
355
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
356
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
357
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
358
+ x = x + self.positional_embedding.to(x.dtype)
359
+ x = self.ln_pre(x)
360
+
361
+ x = x.permute(1, 0, 2) # NLD -> LND
362
+ if return_attn:
363
+ x, attn_last = self.transformer(x, return_attn=True)
364
+ else:
365
+ x = self.transformer(x, return_attn=False)
366
+ x = x.permute(1, 0, 2) # LND -> NLD
367
+
368
+ x = self.ln_post(x[:, 0, :])
369
+
370
+ if self.proj is not None:
371
+ x = x @ self.proj
372
+ if return_attn:
373
+ return x, attn_last
374
+ else:
375
+ return x
376
+
377
+
378
+ class CLIP(nn.Module):
379
+ def __init__(self,
380
+ embed_dim: int,
381
+ # vision
382
+ image_resolution: int,
383
+ vision_layers: Union[Tuple[int, int, int, int], int],
384
+ vision_width: int,
385
+ vision_patch_size: int,
386
+ # text
387
+ context_length: int,
388
+ vocab_size: int,
389
+ transformer_width: int,
390
+ transformer_heads: int,
391
+ transformer_layers: int,
392
+ lora_adapt = False,
393
+ rank = 16,
394
+ ):
395
+ super().__init__()
396
+
397
+ self.context_length = context_length
398
+
399
+ if isinstance(vision_layers, (tuple, list)):
400
+ vision_heads = vision_width * 32 // 64
401
+ self.visual = ModifiedResNet(
402
+ layers=vision_layers,
403
+ output_dim=embed_dim,
404
+ heads=vision_heads,
405
+ input_resolution=image_resolution,
406
+ width=vision_width
407
+ )
408
+ else:
409
+ vision_heads = vision_width // 64
410
+ self.visual = VisionTransformer(
411
+ input_resolution=image_resolution,
412
+ patch_size=vision_patch_size,
413
+ width=vision_width,
414
+ layers=vision_layers,
415
+ heads=vision_heads,
416
+ output_dim=embed_dim,
417
+ lora_adapt=lora_adapt,
418
+ rank=rank
419
+ )
420
+
421
+ self.transformer = Transformer(
422
+ width=transformer_width,
423
+ layers=transformer_layers,
424
+ heads=transformer_heads,
425
+ attn_mask=self.build_attention_mask()
426
+ )
427
+
428
+ self.vocab_size = vocab_size
429
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
430
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
431
+ self.ln_final = LayerNorm(transformer_width)
432
+
433
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
434
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
435
+
436
+ self.initialize_parameters()
437
+
438
+ def initialize_parameters(self):
439
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
440
+ nn.init.normal_(self.positional_embedding, std=0.01)
441
+
442
+ if isinstance(self.visual, ModifiedResNet):
443
+ if self.visual.attnpool is not None:
444
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
445
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
446
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
447
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
448
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
449
+
450
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
451
+ for name, param in resnet_block.named_parameters():
452
+ if name.endswith("bn3.weight"):
453
+ nn.init.zeros_(param)
454
+
455
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
456
+ attn_std = self.transformer.width ** -0.5
457
+ fc_std = (2 * self.transformer.width) ** -0.5
458
+ for block in self.transformer.resblocks:
459
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
460
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
461
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
462
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
463
+
464
+ if self.text_projection is not None:
465
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
466
+
467
+ def build_attention_mask(self):
468
+ # lazily create causal attention mask, with full attention between the vision tokens
469
+ # pytorch uses additive attention mask; fill with -inf
470
+ mask = torch.empty(self.context_length, self.context_length)
471
+ mask.fill_(float("-inf"))
472
+ mask.triu_(1) # zero out the lower diagonal
473
+ return mask
474
+
475
+ @property
476
+ def dtype(self):
477
+ if not hasattr(self.visual, "conv1"):
478
+ return self.visual.module.conv1.weight.dtype
479
+ return self.visual.conv1.weight.dtype
480
+
481
+ def encode_image(self, image, alpha):
482
+ assert alpha is not None
483
+ return self.visual(image.type(self.dtype), alpha.type(self.dtype))
484
+
485
+ def encode_text(self, text):
486
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
487
+
488
+ x = x + self.positional_embedding.type(self.dtype)
489
+ x = x.permute(1, 0, 2) # NLD -> LND
490
+ x = self.transformer(x)
491
+ x = x.permute(1, 0, 2) # LND -> NLD
492
+ x = self.ln_final(x).type(self.dtype)
493
+
494
+ # x.shape = [batch_size, n_ctx, transformer.width]
495
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
496
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
497
+
498
+ return x
499
+
500
+ def forward(self, image, text, alpha):
501
+ image_features = self.encode_image(image, alpha)
502
+ text_features = self.encode_text(text)
503
+
504
+ # normalized features
505
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
506
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
507
+
508
+ # cosine similarity as logits
509
+ logit_scale = self.logit_scale.exp()
510
+ logits_per_image = logit_scale * image_features @ text_features.t()
511
+ logits_per_text = logits_per_image.t()
512
+
513
+ # shape = [global_batch_size, global_batch_size]
514
+ return logits_per_image, logits_per_text
515
+
516
+
517
+ def convert_weights(model: nn.Module):
518
+ """Convert applicable model parameters to fp16"""
519
+
520
+ def _convert_weights_to_fp16(l):
521
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
522
+ l.weight.data = l.weight.data.half()
523
+ if l.bias is not None:
524
+ l.bias.data = l.bias.data.half()
525
+
526
+ if isinstance(l, nn.MultiheadAttention):
527
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
528
+ tensor = getattr(l, attr)
529
+ if tensor is not None:
530
+ tensor.data = tensor.data.half()
531
+
532
+ for name in ["text_projection", "proj"]:
533
+ if hasattr(l, name):
534
+ attr = getattr(l, name)
535
+ if attr is not None:
536
+ attr.data = attr.data.half()
537
+
538
+ model.apply(_convert_weights_to_fp16)
539
+
540
+
541
+ def build_model(state_dict: dict, lora_adapt=False, rank=16):
542
+ vit = "visual.proj" in state_dict
543
+
544
+ if vit:
545
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
546
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
547
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
548
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
549
+ image_resolution = vision_patch_size * grid_size
550
+ else:
551
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
552
+ vision_layers = tuple(counts)
553
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
554
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
555
+ vision_patch_size = None
556
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
557
+ image_resolution = output_width * 32
558
+
559
+ embed_dim = state_dict["text_projection"].shape[1]
560
+ context_length = state_dict["positional_embedding"].shape[0]
561
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
562
+ transformer_width = state_dict["ln_final.weight"].shape[0]
563
+ transformer_heads = transformer_width // 64
564
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
565
+
566
+ # always load lora version
567
+ model = CLIP(
568
+ embed_dim,
569
+ image_resolution, vision_layers, vision_width, vision_patch_size,
570
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers,
571
+ lora_adapt=lora_adapt, rank=rank,
572
+ )
573
+
574
+ for key in ["input_resolution", "context_length", "vocab_size"]:
575
+ if key in state_dict:
576
+ del state_dict[key]
577
+ # para_wb to linear
578
+ new_state_dict = collections.OrderedDict()
579
+ for k, v in state_dict.items():
580
+ if 'visual' in k:
581
+ if 'in_proj_weight' in k:
582
+ new_state_dict[k.replace('in_proj_weight', 'in_proj.weight')] = v
583
+ elif 'in_proj_bias' in k:
584
+ new_state_dict[k.replace('in_proj_bias', 'in_proj.bias')] = v
585
+ else:
586
+ new_state_dict[k] = v
587
+ else:
588
+ new_state_dict[k] = v
589
+
590
+ state_dict = new_state_dict
591
+ # add rgba_conv_weight
592
+ if 'visual.conv1_alpha.weight' not in state_dict.keys(): # zero initialization on alpha channel
593
+ rgb_weight = state_dict['visual.conv1.weight'].clone().detach()
594
+ rgba_weigth = torch.zeros_like(rgb_weight)[:, 0:1, :, :]
595
+ state_dict['visual.conv1_alpha.weight'] = rgba_weigth
596
+ convert_weights(model)
597
+ model.load_state_dict(state_dict, strict=False)
598
+ return model.eval()
AlphaCLIP/alpha_clip/simple_tokenizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8+n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r'\s+', ' ', text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe()):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
+ merges = merges[1:49152-256-2+1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v+'</w>' for v in vocab]
71
+ for merge in merges:
72
+ vocab.append(''.join(merge))
73
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
+ self.encoder = dict(zip(vocab, range(len(vocab))))
75
+ self.decoder = {v: k for k, v in self.encoder.items()}
76
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
+
80
+ def bpe(self, token):
81
+ if token in self.cache:
82
+ return self.cache[token]
83
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
+ pairs = get_pairs(word)
85
+
86
+ if not pairs:
87
+ return token+'</w>'
88
+
89
+ while True:
90
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
+ if bigram not in self.bpe_ranks:
92
+ break
93
+ first, second = bigram
94
+ new_word = []
95
+ i = 0
96
+ while i < len(word):
97
+ try:
98
+ j = word.index(first, i)
99
+ new_word.extend(word[i:j])
100
+ i = j
101
+ except:
102
+ new_word.extend(word[i:])
103
+ break
104
+
105
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
+ new_word.append(first+second)
107
+ i += 2
108
+ else:
109
+ new_word.append(word[i])
110
+ i += 1
111
+ new_word = tuple(new_word)
112
+ word = new_word
113
+ if len(word) == 1:
114
+ break
115
+ else:
116
+ pairs = get_pairs(word)
117
+ word = ' '.join(word)
118
+ self.cache[token] = word
119
+ return word
120
+
121
+ def encode(self, text):
122
+ bpe_tokens = []
123
+ text = whitespace_clean(basic_clean(text)).lower()
124
+ for token in re.findall(self.pat, text):
125
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
+ return bpe_tokens
128
+
129
+ def decode(self, tokens):
130
+ text = ''.join([self.decoder[token] for token in tokens])
131
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
+ return text
AlphaCLIP/eval/README.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Alpha-CLIP evaluation
2
+ ## Zero-Shot Classification on ImageNet-S
3
+ checkout [imagenet_s_zs_test](https://github.com/SunzeY/AlphaCLIP/tree/eval-dev/eval/imagenet_s_zs_test)
4
+
5
+ ## Zero-Shot Referring Expression Comprehension on RefCOCO
6
+ checkout [rec_zs_test](https://github.com/SunzeY/AlphaCLIP/tree/eval-dev/eval/rec_zs_test)
AlphaCLIP/eval/imagenet_s_zs_test/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.json
2
+ data/*
AlphaCLIP/eval/imagenet_s_zs_test/README.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Alpha-CLIP evaluation
2
+ ## Zero-Shot Classification on ImageNet-S
3
+
4
+ 1.prepare [imagenet-s](https://github.com/LUSSeg/ImageNet-S) dataset, only `validation` raw image is needed.
5
+
6
+ 2.download [imagenet_919.json](https://download.openxlab.org.cn/models/SunzeY/AlphaCLIP/weight/imagenet_919.json) we provide as data annotation (generated from imagenet-s annotation). The folder should be structured like
7
+
8
+ ```
9
+ ├── imagenet_s_zs_test
10
+ │ ├── data
11
+ │ │ ├── imagenet_919.json
12
+ │ │ └── ImageNetS919
13
+ │ │ └── validation
14
+ ```
15
+
16
+ 3.run test script.
17
+
18
+ ```
19
+ cd eval/imagenet_s_zs_test
20
+ python imagenet_s_zs_test.py
21
+ ```
AlphaCLIP/eval/imagenet_s_zs_test/imagenet_s.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ from tqdm import tqdm
5
+ from torch.utils.data import Dataset
6
+ from pycocotools.coco import COCO
7
+ from pycocotools import mask as maskUtils
8
+ from PIL import Image
9
+ import cv2
10
+ import random
11
+ from torchvision import transforms
12
+ from tqdm import tqdm
13
+
14
+ import pickle
15
+ import torch
16
+ import numpy as np
17
+ import copy
18
+ import sys
19
+ import shutil
20
+ from PIL import Image
21
+ from nltk.corpus import wordnet
22
+
23
+ PIXEL_MEAN = (0.48145466, 0.4578275, 0.40821073)
24
+ MASK_FILL = [int(255 * c) for c in PIXEL_MEAN]
25
+
26
+
27
+ clip_standard_transform = transforms.Compose([
28
+ transforms.ToTensor(),
29
+ transforms.Resize((224, 224), interpolation=Image.BICUBIC),
30
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
31
+ ])
32
+
33
+ hi_clip_standard_transform = transforms.Compose([
34
+ transforms.ToTensor(),
35
+ transforms.Resize((336, 336), interpolation=Image.BICUBIC),
36
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
37
+ ])
38
+
39
+ res_clip_standard_transform = transforms.Compose([
40
+ transforms.ToTensor(),
41
+ transforms.Resize((336, 336), interpolation=Image.BICUBIC),
42
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
43
+ ])
44
+
45
+ mask_transform = transforms.Compose([
46
+ transforms.ToTensor(),
47
+ transforms.Resize((224, 224)),
48
+ transforms.Normalize(0.5, 0.26)
49
+ ])
50
+
51
+ hi_mask_transform = transforms.Compose([
52
+ transforms.ToTensor(),
53
+ transforms.Resize((336, 336)),
54
+ transforms.Normalize(0.5, 0.26)
55
+ ])
56
+
57
+ res_mask_transform = transforms.Compose([
58
+ transforms.ToTensor(),
59
+ transforms.Resize((336, 336)),
60
+ transforms.Normalize(0.5, 0.26)
61
+ ])
62
+
63
+ def crop_center(img, croph, cropw):
64
+ h, w = img.shape[:2]
65
+ starth = h//2 - (croph//2)
66
+ startw = w//2 - (cropw//2)
67
+ return img[starth:starth+croph, startw:startw+cropw, :]
68
+
69
+ class Imagenet_S(Dataset):
70
+ def __init__(self, ann_file='data/imagenet_919.json', hi_res=False, all_one=False):
71
+ self.anns = json.load(open(ann_file, 'r'))
72
+ self.root_pth = 'data/'
73
+ cats = []
74
+ for ann in self.anns:
75
+ if ann['category_word'] not in cats:
76
+ cats.append(ann['category_word'])
77
+ ann['cat_index'] = len(cats) - 1
78
+ self.classes = []
79
+ for cat_word in cats:
80
+ synset = wordnet.synset_from_pos_and_offset('n', int(cat_word[1:]))
81
+ synonyms = [x.name() for x in synset.lemmas()]
82
+ self.classes.append(synonyms[0])
83
+
84
+ self.choice = "center_crop"
85
+ if hi_res:
86
+ self.mask_transform = res_mask_transform
87
+ self.clip_standard_transform = res_clip_standard_transform
88
+ else:
89
+ self.mask_transform = mask_transform
90
+ self.clip_standard_transform = clip_standard_transform
91
+
92
+ self.all_one = all_one
93
+
94
+ def __len__(self):
95
+ return len(self.anns)
96
+
97
+ def __getitem__(self, index):
98
+ ann = self.anns[index]
99
+ image = cv2.imread(os.path.join(self.root_pth, ann['image_pth']))
100
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
101
+
102
+ mask = maskUtils.decode(ann['mask'])
103
+ # image[mask==0] = MASK_FILL
104
+ rgba = np.concatenate((image, np.expand_dims(mask, axis=-1)), axis=-1)
105
+ h, w = rgba.shape[:2]
106
+
107
+ if self.choice == "padding":
108
+ if max(h, w) == w:
109
+ pad = (w - h) // 2
110
+ l, r = pad, w - h - pad
111
+ rgba = np.pad(rgba, ((l, r), (0, 0), (0, 0)), 'constant', constant_values=0)
112
+ else:
113
+ pad = (h - w) // 2
114
+ l, r = pad, h - w - pad
115
+ rgba = np.pad(rgba, ((0, 0), (l, r), (0, 0)), 'constant', constant_values=0)
116
+ else:
117
+ if min(h, w) == h:
118
+ rgba = crop_center(rgba, h, h)
119
+ else:
120
+ rgba = crop_center(rgba, w, w)
121
+ rgb = rgba[:, :, :-1]
122
+ mask = rgba[:, :, -1]
123
+ image_torch = self.clip_standard_transform(rgb)
124
+ # using box: bounding-box compute
125
+ # bi_mask = mask == 1
126
+ # h, w = bi_mask.shape[-2:]
127
+ # in_height = np.max(bi_mask, axis=-1)
128
+ # in_height_coords = np.max(bi_mask, axis=-1) * np.arange(h)
129
+ # b_e = in_height_coords.max()
130
+ # in_height_coords = in_height_coords + h * (~in_height)
131
+ # t_e = in_height_coords.min()
132
+ # in_width = np.max(bi_mask, axis=-2)
133
+ # in_width_coords = np.max(bi_mask, axis=-2) * np.arange(w)
134
+ # r_e = in_width_coords.max()
135
+ # in_width_coords = in_width_coords + w * (~in_width)
136
+ # l_e = in_width_coords.min()
137
+ # box = np.zeros_like(mask)
138
+ # box[t_e: b_e, l_e:r_e] = 1
139
+ # mask = box
140
+ if self.all_one:
141
+ mask_torch = self.mask_transform(np.ones_like(mask) * 255)
142
+ else:
143
+ mask_torch = self.mask_transform(mask * 255)
144
+ return image_torch, mask_torch, ann['cat_index']
145
+
146
+ if __name__ == "__main__":
147
+ data = Imagenet_S()
148
+ for i in tqdm(range(data.__len__())):
149
+ data.__getitem__(i)
AlphaCLIP/eval/imagenet_s_zs_test/imagenet_s_zs_test.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import alpha_clip
3
+ from tqdm import tqdm
4
+ from imagenet_s import Imagenet_S
5
+
6
+ model, preprocess = alpha_clip.load("ViT-L/14@336px", alpha_vision_ckpt_pth="../../clip_l14@336_grit_20m_4xe.pth")
7
+
8
+ def zeroshot_classifier(classnames, templates):
9
+ with torch.no_grad():
10
+ zeroshot_weights = []
11
+ for classname in tqdm(classnames):
12
+ texts = [template.format(classname) for template in templates] #format with class
13
+ texts = alpha_clip.tokenize(texts).cuda() #tokenize
14
+ class_embeddings = model.encode_text(texts) #embed with text encoder
15
+ class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
16
+ class_embedding = class_embeddings.mean(dim=0)
17
+ class_embedding /= class_embedding.norm()
18
+ zeroshot_weights.append(class_embedding)
19
+ zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
20
+ return zeroshot_weights
21
+
22
+ dataset = Imagenet_S(hi_res=True)
23
+ loader = torch.utils.data.DataLoader(dataset, batch_size=64, num_workers=2)
24
+
25
+ imagenet_templates = [
26
+ 'a photo of a {}.'
27
+ ]
28
+
29
+ zeroshot_weights = zeroshot_classifier(dataset.classes, imagenet_templates)
30
+ temp_corr_dict = dict()
31
+
32
+ with torch.no_grad():
33
+ for i, (images, alpha, target) in enumerate(tqdm(loader)):
34
+ images = images.cuda()
35
+ alpha = alpha.cuda()
36
+ target = target.cuda()
37
+ # predict
38
+ image_features = model.encode_image(images, alpha)
39
+ image_features /= image_features.norm(dim=-1, keepdim=True)
40
+ score = 100. * image_features @ zeroshot_weights
41
+
42
+ pred = score.topk(1, dim=1)[1].squeeze(dim=1)
43
+ pred_5 = score.topk(5, dim=1)[1].squeeze(dim=1)
44
+
45
+ for i in range(target.shape[0]):
46
+ if target[i].item() not in temp_corr_dict:
47
+ temp_corr_dict[target[i].item()] = [0, 0, 0]
48
+ temp_corr_dict[target[i].item()][0] += 1
49
+ if target[i].item() == pred[i].item():
50
+ temp_corr_dict[target[i].item()][1] += 1
51
+ if target[i].item() in pred_5[i].tolist():
52
+ temp_corr_dict[target[i].item()][2] += 1
53
+
54
+ acc1 = 0.0
55
+ acc5 = 0.0
56
+ num_class = 0
57
+ for v in temp_corr_dict.values():
58
+ if v[0] == 0: continue
59
+ acc1 += v[1] / v[0]
60
+ acc5 += v[2] / v[0]
61
+ num_class += 1
62
+ acc1 = acc1 / num_class * 100
63
+ acc5 = acc5 / num_class * 100
64
+
65
+ print(f"Top-1 accuracy: {acc1:.2f}")
66
+ print(f"Top-5 accuracy: {acc5:.2f}")
AlphaCLIP/eval/rec_zs_test/LICENSE.md ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
AlphaCLIP/eval/rec_zs_test/README.md ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Zero-Shot Referring Expression Comprehension on RefCOCO
2
+
3
+ **Preparing Data**
4
+
5
+ 1.Download [images for RefCOCO/g/+](http://images.cocodataset.org/zips/train2014.zip). Put downloaded dataset(train2014) to eval/rec_zs_test/data/.
6
+
7
+ 2.Download preprocessed data files via `gsutil cp gs://reclip-sanjays/reclip_data.tar.gz` and `cd rec_zs_test`, and then extract the data using `tar -xvzf reclip_data.tar.gz`.
8
+
9
+ **Preparing model**
10
+
11
+ 3.Download [SAM](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth) (vit-h), [Alpha-CLIP](https://github.com/SunzeY/AlphaCLIP/blob/main/model-zoo.md) model, and put them in ./eval/rec_zs_test/ckpt.
12
+
13
+ ```
14
+ ├── eval
15
+ │ ├── rec_zs_test
16
+ │ │ ├── data
17
+ │ │ └── train2014
18
+ │ │ ├── reclip_data
19
+ │ │ └── refcoco_val.jsonl
20
+ │ │ └── refcoco_dets_dict.json
21
+ │ │ ...
22
+ │ │ ├── ckpt
23
+ │ │ └── sam_vit_h_4b8939.pth
24
+ │ │ └── grit1m
25
+ │ │ └── clip_b16_grit+mim_fultune_4xe.pth
26
+ │ │ └── clip_l14_grit+mim_fultune_6xe.pth
27
+ │ │ ├── methods
28
+ │ │ ├── cache
29
+ │ │ ├── output
30
+ │ │ ├── main.py
31
+ │ │ ├── executor.py
32
+ │ │ ├── run.sh
33
+ │ │ ├── ...
34
+ ```
35
+
36
+ 4.run test script.
37
+
38
+ ```
39
+ cd eval/rec_zs_test
40
+ ```
41
+ ```
42
+ bash run.sh
43
+ ```
44
+ or
45
+
46
+ ```
47
+ python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_representation_method full,blur --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --detector_file reclip_data/refcoco+_dets_dict.json --cache_path ./cache
48
+ ```
49
+ (We recommend using `cache_path` to reduce time to generate mask by SAM for a image repeatedly.`)
50
+
51
+ For multi-gpus testing, try:
52
+
53
+ ```
54
+ bash run_multi_gpus.sh
55
+ python cal_acc.py refcoco_val
56
+ ```
57
+
58
+
59
+ **Acknowledgement**
60
+
61
+ We test our model based on the wonderful work [ReCLIP](https://github.com/allenai/reclip/tree/main). We simply replace CLIP with Alpha-CLIP; and skip the image-cropping operation.
62
+
63
+
64
+
65
+ **Experiment results**
66
+
67
+ | Method | RefCOCO | | | RefCOCO+ | | | RefCOCOg | |
68
+ |----------------|---------|------|------|----------|------|------|----------|------|
69
+ | | Val | TestA| TestB| Val | TestA| TestB| Val | Test |
70
+ | CPT [67] | 32.2 | 36.1 | 30.3 | 31.9 | 35.2 | 28.8 | 36.7 | 36.5 |
71
+ | ReCLIP [54] | 45.8 | 46.1 | 47.1 | 47.9 | 50.1 | 45.1 | 59.3 | 59.0 |
72
+ | Red Circle [52]| 49.8 | 58.6 | 39.9 | 55.3 | 63.9 | 45.4 | 59.4 | 58.9 |
73
+ | Alpha-CLIP | 55.7 | 61.1 | 50.3 | 55.6 | 62.7 | 46.4 | 61.2 | 62.0 |
74
+
AlphaCLIP/eval/rec_zs_test/cache/.gitkeep ADDED
File without changes
AlphaCLIP/eval/rec_zs_test/cal_acc.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+
4
+ parser = argparse.ArgumentParser()
5
+ parser.add_argument('name', type=str, default='refcoco_val')
6
+
7
+ args = parser.parse_args()
8
+
9
+ name = args.name
10
+ print(name)
11
+ count = 0
12
+ all_count = 0
13
+ for i in range(8):
14
+ pth = f'output/{name}_count_{i}.json'
15
+ acc = json.load(open(pth, 'r'))
16
+ a_list = acc.split()
17
+ a, b = a_list[0], a_list[1]
18
+ count += int(a)
19
+ all_count += int(b)
20
+
21
+ print(float(count) / float(all_count))
AlphaCLIP/eval/rec_zs_test/ckpt/.gitkeep ADDED
File without changes
AlphaCLIP/eval/rec_zs_test/data/.gitkeep ADDED
File without changes
AlphaCLIP/eval/rec_zs_test/entity_extraction.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, Callable, List, Tuple, NamedTuple, Text, Optional
2
+ import numpy as np
3
+ from spacy.tokens.token import Token
4
+ from spacy.tokens.span import Span
5
+
6
+ from lattice import Product as L
7
+
8
+ from heuristics import Heuristics
9
+
10
+ Rel = Tuple[List[Token], "Entity"]
11
+ Sup = List[Token]
12
+
13
+ DEFAULT_HEURISTICS = Heuristics()
14
+
15
+
16
+ def find_superlatives(tokens, heuristics) -> List[Sup]:
17
+ """Modify and return a list of superlative tokens."""
18
+ for heuristic in heuristics.superlatives:
19
+ if any(tok.text in heuristic.keywords for tok in tokens):
20
+ tokens.sort(key=lambda tok: tok.i)
21
+ return [tokens]
22
+ return []
23
+
24
+ def expand_chunks(doc, chunks):
25
+ expanded = {}
26
+ for key in chunks:
27
+ chunk = chunks[key]
28
+ start = chunk.start
29
+ end = chunk.end
30
+ for i in range(chunk.start-1, -1, -1):
31
+ if any(doc[j].is_ancestor(doc[i]) for j in range(chunk.start, chunk.end)):
32
+ if not any(any(doc[i].is_ancestor(doc[j]) for j in range(chunks[key2].start, chunks[key2].end)) for key2 in chunks if key != key2):
33
+ start = i
34
+ for i in range(chunk.end, len(doc)):
35
+ if any(doc[j].is_ancestor(doc[i]) for j in range(chunk.start, chunk.end)):
36
+ if not any(any(doc[i].is_ancestor(doc[j]) or i == j for j in range(chunks[key2].start, chunks[key2].end)) for key2 in chunks if key != key2):
37
+ end = i+1
38
+ else:
39
+ break
40
+ expanded[key] = Span(doc=doc, start=start, end=end)
41
+ return expanded
42
+
43
+ class Entity(NamedTuple):
44
+ """Represents an entity with locative constraints extracted from the parse."""
45
+
46
+ head: Span
47
+ relations: List[Rel]
48
+ superlatives: List[Sup]
49
+
50
+ @classmethod
51
+ def extract(cls, head, chunks, heuristics: Optional[Heuristics] = None) -> "Entity":
52
+ """Extract entities from a spacy parse.
53
+
54
+ Jointly recursive with `_get_rel_sups`."""
55
+ if heuristics is None:
56
+ heuristics = DEFAULT_HEURISTICS
57
+
58
+ if head.i not in chunks:
59
+ # Handles predicative cases.
60
+ children = list(head.children)
61
+ if children and children[0].i in chunks:
62
+ head = children[0]
63
+ # TODO: Also extract predicative relations.
64
+ else:
65
+ return None
66
+ hchunk = chunks[head.i]
67
+ rels, sups = cls._get_rel_sups(head, head, [], chunks, heuristics)
68
+ return cls(hchunk, rels, sups)
69
+
70
+ @classmethod
71
+ def _get_rel_sups(cls, token, head, tokens, chunks, heuristics) -> Tuple[List[Rel], List[Sup]]:
72
+ hchunk = chunks[head.i]
73
+ is_keyword = any(token.text in h.keywords for h in heuristics.relations)
74
+ is_keyword |= token.text in heuristics.null_keywords
75
+
76
+ # Found another entity head.
77
+ if token.i in chunks and chunks[token.i] is not hchunk and not is_keyword:
78
+ tchunk = chunks[token.i]
79
+ tokens.sort(key=lambda tok: tok.i)
80
+ subhead = cls.extract(token, chunks, heuristics)
81
+ return [(tokens, subhead)], []
82
+
83
+ # End of a chain of modifiers.
84
+ n_children = len(list(token.children))
85
+ if n_children == 0:
86
+ return [], find_superlatives(tokens + [token], heuristics)
87
+
88
+ relations = []
89
+ superlatives = []
90
+ is_keyword |= any(token.text in h.keywords for h in heuristics.superlatives)
91
+ for child in token.children:
92
+ if token.i in chunks and child.i in chunks and chunks[token.i] is chunks[child.i]:
93
+ if not any(child.text in h.keywords for h in heuristics.superlatives):
94
+ if n_children == 1:
95
+ # Catches "the goat on the left"
96
+ sups = find_superlatives(tokens + [token], heuristics)
97
+ superlatives.extend(sups)
98
+ continue
99
+ new_tokens = tokens + [token] if token.i not in chunks or is_keyword else tokens
100
+ subrel, subsup = cls._get_rel_sups(child, head, new_tokens, chunks, heuristics)
101
+ relations.extend(subrel)
102
+ superlatives.extend(subsup)
103
+ return relations, superlatives
104
+
105
+ def expand(self, span: Span = None):
106
+ tokens = [token for token in self.head]
107
+ if span is None:
108
+ span = [None]
109
+ for target_token in span:
110
+ include = False
111
+ stack = [token for token in self.head]
112
+ while len(stack) > 0:
113
+ token = stack.pop()
114
+ if token == target_token:
115
+ token2 = target_token.head
116
+ while token2.head != token2:
117
+ tokens.append(token2)
118
+ token2 = token2.head
119
+ tokens.append(token2)
120
+ stack = []
121
+ include = True
122
+ if target_token is None or include:
123
+ tokens.append(token)
124
+ for child in token.children:
125
+ stack.append(child)
126
+ tokens = list(set(tokens))
127
+ tokens = sorted(tokens, key=lambda x: x.i)
128
+ return ' '.join([token.text for token in tokens])
129
+
130
+ def __eq__(self, other: "Entity") -> bool:
131
+ if self.text != other.text:
132
+ return False
133
+ if self.relations != other.relations:
134
+ return False
135
+ if self.superlatives != other.superlatives:
136
+ return False
137
+ return True
138
+
139
+ @property
140
+ def text(self) -> Text:
141
+ """Get the text predicate associated with this entity."""
142
+ return self.head.text
AlphaCLIP/eval/rec_zs_test/executor.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Union, Tuple
2
+
3
+ from PIL import Image, ImageDraw, ImageFilter, ImageOps, ImageEnhance
4
+ import spacy
5
+ import hashlib
6
+ import os
7
+
8
+ import torch
9
+ import torchvision
10
+ import torchvision.transforms as transforms
11
+ import clip
12
+ from transformers import BertTokenizer, RobertaTokenizerFast
13
+ import ruamel.yaml as yaml
14
+ import copy
15
+
16
+ from interpreter import Box
17
+
18
+ import pycocotools.mask as mask_utils
19
+ import alpha_clip
20
+ from segment_anything import sam_model_registry, SamPredictor
21
+ import numpy as np
22
+ import cv2
23
+ import matplotlib.pyplot as plt
24
+
25
+ import pickle
26
+
27
+ class Executor:
28
+ def __init__(self, device: str = "cpu", box_representation_method: str = "crop", method_aggregator: str = "max", enlarge_boxes: int = 0, expand_position_embedding: bool = False, square_size: bool = False, blur_std_dev: int = 100, cache_path: str = None, input_file: str = None) -> None:
29
+ IMPLEMENTED_METHODS = ["blur", "full", "gray"]
30
+ if any(m not in IMPLEMENTED_METHODS for m in box_representation_method.split(",")):
31
+ raise NotImplementedError
32
+ IMPLEMENTED_AGGREGATORS = ["max", "sum"]
33
+ if method_aggregator not in IMPLEMENTED_AGGREGATORS:
34
+ raise NotImplementedError
35
+ self.box_representation_method = box_representation_method
36
+ self.method_aggregator = method_aggregator
37
+ self.enlarge_boxes = enlarge_boxes
38
+ self.device = device
39
+ self.expand_position_embedding = expand_position_embedding
40
+ self.square_size = square_size
41
+ self.blur_std_dev = blur_std_dev
42
+ self.cache_path = cache_path
43
+
44
+ def preprocess_image(self, image: Image) -> List[torch.Tensor]:
45
+ return [preprocess(image) for preprocess in self.preprocesses]
46
+
47
+ def preprocess_mask(self, mask: Image) -> List[torch.Tensor]:
48
+ preprocess = self.preprocesses[0]
49
+ return preprocess.transforms[1](preprocess.transforms[0](mask))
50
+
51
+ def preprocess_text(self, text: str) -> torch.Tensor:
52
+ raise NotImplementedError
53
+
54
+ def call_model(self, model: torch.nn.Module, images: torch.Tensor, text: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor:
55
+ raise NotImplementedError
56
+
57
+ def tensorize_inputs(self, caption: str, image: Image, boxes: List[Box], image_name: str = None, image_pth: str = None) -> Tuple[List[torch.Tensor], torch.Tensor]:
58
+ images = []
59
+ for preprocess in self.preprocesses:
60
+ images.append([])
61
+
62
+ if 'aclip' in self.clip_type:
63
+ self.all_masks = []
64
+ read_save = False
65
+ if self.mask_path is not None: # load mask if cached
66
+ file_name = image_pth.split('/')[-1].split('.')[0]+'.pkl'
67
+ if os.path.exists(os.path.join(self.mask_path, file_name)):
68
+ all_rles = pickle.load(open(os.path.join(self.mask_path, file_name),'rb'))
69
+ for rle in all_rles:
70
+ mask = np.array(mask_utils.decode(rle), dtype=bool)
71
+ self.all_masks.append(mask)
72
+ read_save = True
73
+ if not read_save:
74
+ # use SAM to generate masks
75
+ self.predictor.set_image(np.array(image.convert('RGB')))
76
+ all_rles = []
77
+ for i in range(len(boxes)):
78
+ box = [
79
+ max(boxes[i].left-self.enlarge_boxes, 0),
80
+ max(boxes[i].top-self.enlarge_boxes, 0),
81
+ min(boxes[i].right+self.enlarge_boxes, image.width),
82
+ min(boxes[i].bottom+self.enlarge_boxes, image.height)
83
+ ] # box prompt
84
+ input_box = np.array(box)
85
+ masks, _, _ = self.predictor.predict(
86
+ point_coords=None,
87
+ point_labels=None,
88
+ box=input_box[None, :],
89
+ multimask_output=False,
90
+ )
91
+ self.all_masks.append(masks[0])
92
+ rle = mask_utils.encode(np.array(masks[0][:, :, None], order='F', dtype="uint8"))[0]
93
+ rle["counts"] = rle["counts"].decode("utf-8")
94
+ all_rles.append(rle)
95
+ if self.mask_path is not None: # save mask
96
+ os.makedirs(self.mask_path, exist_ok=True)
97
+ pickle.dump(all_rles, open(os.path.join(self.mask_path, file_name),'wb'))
98
+
99
+ if self.cache_path is None or any([not os.path.exists(os.path.join(self.cache_path, "refcoco_val", model_name, "image", image_name, method_name+".pt")) for model_name in self.model_names for method_name in self.box_representation_method.split(',')]):
100
+ if "full" in self.box_representation_method: # original full image with alpha-map
101
+ for i in range(len(boxes)):
102
+ image_i = image.copy()
103
+ preprocessed_images = self.preprocess_image(image_i)
104
+ for j, img in enumerate(preprocessed_images):
105
+ images[j].append(img.to(self.device))
106
+ if "blur" in self.box_representation_method:
107
+ for i in range(len(boxes)):
108
+ image_i = image.copy()
109
+
110
+ mask = Image.new('L', image_i.size, 0)
111
+ draw = ImageDraw.Draw(mask)
112
+ box = (
113
+ max(boxes[i].left-self.enlarge_boxes, 0),
114
+ max(boxes[i].top-self.enlarge_boxes, 0),
115
+ min(boxes[i].right+self.enlarge_boxes, image_i.width),
116
+ min(boxes[i].bottom+self.enlarge_boxes, image_i.height)
117
+ )
118
+ if 'aclip' in self.clip_type:
119
+ width, height = image.size
120
+ for y in range(height):
121
+ for x in range(width):
122
+ if self.all_masks[i][y][x] == 1:
123
+ draw.point((x, y), fill=255)
124
+ else:
125
+ draw.rectangle([box[:2], box[2:]], fill=255)
126
+ blurred = image_i.filter(ImageFilter.GaussianBlur(self.blur_std_dev))
127
+ blurred.paste(image_i, mask=mask)
128
+ preprocessed_images = self.preprocess_image(blurred)
129
+
130
+ for j, img in enumerate(preprocessed_images):
131
+ images[j].append(img.to(self.device))
132
+ if "gray" in self.box_representation_method:
133
+ for i in range(len(boxes)):
134
+ image_i = image.copy()
135
+ mask_i = self.all_masks[i]
136
+ width, height = image.size
137
+
138
+ pixels = image_i.load()
139
+ for y in range(height):
140
+ for x in range(width):
141
+ if mask_i[y][x] == 0:
142
+ pixel_value = pixels[x, y]
143
+ gray_value = int(0.2989 * pixel_value[0] + 0.5870 * pixel_value[1] + 0.1140 * pixel_value[2])
144
+ pixels[x, y] = (gray_value, gray_value, gray_value)
145
+ preprocessed_images = self.preprocess_image(image_i)
146
+ for j, img in enumerate(preprocessed_images):
147
+ images[j].append(img.to(self.device))
148
+
149
+ imgs = [torch.stack(image_list) for image_list in images]
150
+ else:
151
+ imgs = [[] for _ in self.models]
152
+ text_tensor = self.preprocess_text(caption.lower()).to(self.device)
153
+ return imgs, text_tensor
154
+
155
+ @torch.no_grad()
156
+ def __call__(self, caption: str, image: Image, boxes: List[Box], image_name: str = None, image_pth=None) -> torch.Tensor:
157
+ images, text_tensor = self.tensorize_inputs(caption, image, boxes, image_name, image_pth)
158
+ all_logits_per_image = []
159
+ all_logits_per_text = []
160
+ box_representation_methods = self.box_representation_method.split(',')
161
+ caption_hash = hashlib.md5(caption.encode('utf-8')).hexdigest()
162
+ for model, images_t, model_name in zip(self.models, images, self.model_names):
163
+ self.image_feat_path = ""
164
+ if self.cache_path is not None:
165
+ text_cache_path = os.path.join(self.cache_path, "refcoco_val", model_name, "text"+("_shade" if self.box_representation_method == "shade" else ""))
166
+ image_feat_path = os.path.join(self.cache_path, "refcoco_val", model_name, "image", image_name)
167
+ self.image_feat_path = image_feat_path
168
+ image_features = None
169
+ text_features = None
170
+ if self.cache_path is not None and os.path.exists(os.path.join(self.cache_path, "refcoco_val", model_name)):
171
+ if os.path.exists(os.path.join(text_cache_path, caption_hash+".pt")):
172
+ text_features = torch.load(os.path.join(text_cache_path, caption_hash+".pt"), map_location=self.device)
173
+ if os.path.exists(image_feat_path):
174
+ if all([os.path.exists(os.path.join(image_feat_path, method_name+".pt")) for method_name in box_representation_methods]):
175
+ image_features = []
176
+ for method_name in box_representation_methods:
177
+ features = torch.load(os.path.join(image_feat_path, method_name+".pt"), map_location=self.device)
178
+ image_features.append(torch.stack([
179
+ features[(box.x, box.y, box.w, box.h)]
180
+ for box in boxes
181
+ ]))
182
+ image_features = torch.stack(image_features)
183
+ image_features = image_features.view(-1, image_features.shape[-1])
184
+ logits_per_image, logits_per_text, image_features, text_features = self.call_model(model, images_t, text_tensor, image_features=image_features, text_features=text_features, boxes=boxes, image_pth=image_pth)
185
+ all_logits_per_image.append(logits_per_image)
186
+ all_logits_per_text.append(logits_per_text)
187
+ if self.cache_path is not None and image_name is not None and image_features is not None:
188
+ image_features = image_features.view(len(box_representation_methods), len(boxes), image_features.shape[-1])
189
+ if not os.path.exists(image_feat_path):
190
+ os.makedirs(image_feat_path)
191
+ for i in range(image_features.shape[0]):
192
+ method_name = box_representation_methods[i]
193
+ if not os.path.exists(os.path.join(image_feat_path, method_name+".pt")):
194
+ image_features_dict = {(box.x, box.y, box.w, box.h): image_features[i,j,:].cpu() for j, box in enumerate(boxes)}
195
+ torch.save(image_features_dict, os.path.join(image_feat_path, method_name+".pt"))
196
+ if self.cache_path is not None and not os.path.exists(os.path.join(text_cache_path, caption_hash+".pt")) and text_features is not None:
197
+ assert text_features.shape[0] == 1
198
+ if not os.path.exists(text_cache_path):
199
+ os.makedirs(text_cache_path)
200
+ torch.save(text_features.cpu(), os.path.join(text_cache_path, caption_hash+".pt"))
201
+
202
+ all_logits_per_image = torch.stack(all_logits_per_image).sum(0)
203
+ all_logits_per_text = torch.stack(all_logits_per_text).sum(0)
204
+ if self.method_aggregator == "max":
205
+ all_logits_per_text = all_logits_per_text.view(-1, len(boxes)).max(dim=0, keepdim=True)[0]
206
+ elif self.method_aggregator == "sum":
207
+ all_logits_per_text = all_logits_per_text.view(-1, len(boxes)).sum(dim=0, keepdim=True)
208
+ return all_logits_per_text.view(-1)
209
+
210
+ class ClipExecutor(Executor):
211
+ def __init__(self, clip_model: str = "ViT-B/32", device: str = "cpu", box_representation_method: str = "crop", method_aggregator: str = "max", enlarge_boxes: int = 0, expand_position_embedding: bool = False, square_size: bool = False, blur_std_dev: int = 100, cache_path: str = None, input_file: str = None, clip_type: str=None) -> None:
212
+ super().__init__(device, box_representation_method, method_aggregator, enlarge_boxes, expand_position_embedding, square_size, blur_std_dev, cache_path)
213
+ self.clip_models = clip_model.split(",")
214
+ self.model_names = [model_name.replace("/", "_") for model_name in self.clip_models]
215
+ self.models = []
216
+ self.preprocesses = []
217
+ self.data_name = input_file.split('/')[-1].split('.')[0]
218
+ self.mask_path = None
219
+ self.clip_type = clip_type
220
+ if self.cache_path is not None:
221
+ self.mask_path = os.path.join(self.cache_path, "refcoco_val", 'det_masks')
222
+ sam_checkpoint = "./ckpt/sam_vit_h_4b8939.pth"
223
+ model_type = "vit_h"
224
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
225
+ sam.to(device=device)
226
+ self.predictor = SamPredictor(sam)
227
+ for model_name in self.clip_models:
228
+ if 'aclip' in self.clip_type:#using alpha-clip
229
+ self.mask_transform = transforms.Compose([
230
+ transforms.ToTensor(),
231
+ transforms.Resize((224, 224)),
232
+ transforms.Normalize(0.5, 0.26)
233
+ ])
234
+ if model_name == 'ViT-B/16':
235
+ model, preprocess = alpha_clip.load("ViT-B/16", alpha_vision_ckpt_pth="./ckpt/grit1m/clip_b16_grit+mim_fultune_4xe.pth", device=device)
236
+ elif model_name == 'ViT-L/14':
237
+ model, preprocess = alpha_clip.load("ViT-L/14", alpha_vision_ckpt_pth="./ckpt/grit1m/clip_l14_grit+mim_fultune_6xe.pth", device=device)
238
+
239
+ else: model, preprocess = clip.load(model_name, device=device, jit=False)
240
+ self.models.append(model)
241
+ if self.square_size:
242
+ print("Square size!")
243
+ preprocess.transforms[0] = transforms.Resize((model.visual.input_resolution, model.visual.input_resolution), interpolation=transforms.InterpolationMode.BICUBIC)
244
+ self.preprocesses.append(preprocess)
245
+ self.models = torch.nn.ModuleList(self.models)
246
+
247
+ def preprocess_text(self, text: str) -> torch.Tensor:
248
+ if "aclip" in self.box_representation_method:
249
+ return alpha_clip.tokenize([text.lower()])
250
+ if "shade" in self.box_representation_method:
251
+ return clip.tokenize([text.lower()+" is in red color."])
252
+ return clip.tokenize(["a photo of "+text.lower()])
253
+
254
+ def call_model(self, model: torch.nn.Module, images: torch.Tensor, text: torch.Tensor, image_features: torch.Tensor = None, text_features: torch.Tensor = None, boxes=None, image_pth=None) -> torch.Tensor:
255
+ if image_features is None:
256
+ print('computing image features')
257
+ if 'aclip' not in self.clip_type:
258
+ image_features = model.encode_image(images)
259
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
260
+ else:
261
+ image_features = []
262
+ if 'full' in self.box_representation_method:
263
+ aclip_images = images[:len(boxes)]
264
+ alphas = []
265
+
266
+ if os.path.exists(os.path.join(self.image_feat_path, 'full.pt')):
267
+ features = torch.load(os.path.join(self.image_feat_path, 'full.pt'), map_location=self.device)
268
+ aclip_image_features = torch.stack([
269
+ features[(box.x, box.y, box.w, box.h)]
270
+ for box in boxes
271
+ ])
272
+ else:
273
+ for i in range(len(self.all_masks)):
274
+ binary_mask = self.all_masks[i]
275
+ alpha = self.mask_transform((binary_mask * 255).astype(np.uint8))
276
+ alpha = alpha.half().cuda().unsqueeze(dim=0)
277
+ alphas.append(alpha)
278
+
279
+ alphas = torch.cat(alphas, dim=0)
280
+ aclip_images = aclip_images.half()
281
+ aclip_image_features = model.visual(aclip_images, alphas) # using alpha channels
282
+ images = images[len(boxes):]
283
+ image_features.append(aclip_image_features)
284
+
285
+ if 'blur' in self.box_representation_method:
286
+ if os.path.exists(os.path.join(self.image_feat_path, 'blur.pt')):
287
+ features = torch.load(os.path.join(self.image_feat_path, 'blur.pt'), map_location=self.device)
288
+ ablur_images_features = torch.stack([
289
+ features[(box.x, box.y, box.w, box.h)]
290
+ for box in boxes
291
+ ])
292
+ else:
293
+ ablur_images = images[:len(boxes)]
294
+ alphas = []
295
+ for i in range(len(self.all_masks)):
296
+ binary_mask = self.all_masks[i]
297
+ alpha = self.mask_transform((binary_mask * 255).astype(np.uint8))
298
+ alpha = alpha.half().cuda().unsqueeze(dim=0)
299
+ alphas.append(alpha)
300
+ alphas = torch.cat(alphas, dim=0)
301
+ ablur_images = ablur_images.half()
302
+ ablur_images_features = model.visual(ablur_images, alphas)
303
+ images = images[len(boxes):]
304
+ image_features.append(ablur_images_features)
305
+
306
+ if 'gray' in self.box_representation_method:
307
+ if os.path.exists(os.path.join(self.image_feat_path, 'gray.pt')):
308
+ features = torch.load(os.path.join(self.image_feat_path, 'gray.pt'), map_location=self.device)
309
+ gray_images_features = torch.stack([
310
+ features[(box.x, box.y, box.w, box.h)]
311
+ for box in boxes
312
+ ])
313
+ else:
314
+ gray_images = images[:len(boxes)]
315
+ alphas = []
316
+ for i in range(len(self.all_masks)):
317
+ binary_mask = self.all_masks[i]
318
+ alpha = self.mask_transform((binary_mask * 255).astype(np.uint8))
319
+ alpha = alpha.half().cuda().unsqueeze(dim=0)
320
+ alphas.append(alpha)
321
+ alphas = torch.cat(alphas, dim=0)
322
+ gray_images = gray_images.half()
323
+ gray_images_features = model.visual(gray_images, alphas)
324
+ images = images[len(boxes):]
325
+ image_features.append(gray_images_features)
326
+
327
+
328
+ image_features = torch.cat(image_features, dim=0)
329
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
330
+
331
+ if text_features is None:
332
+ print('computing text features')
333
+ text_features = model.encode_text(text)
334
+ # normalized features
335
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
336
+
337
+ # cosine similarity as logits
338
+ logit_scale = model.logit_scale.exp()
339
+ logits_per_image = logit_scale * image_features @ text_features.t()
340
+ logits_per_text = logits_per_image.t()
341
+ return logits_per_image, logits_per_text, image_features, text_features
342
+
343
+ def __call__(self, caption: str, image: Image, boxes: List[Box], image_name: str = None, image_pth=None) -> torch.Tensor:
344
+ if self.expand_position_embedding:
345
+ original_preprocesses = self.preprocesses
346
+ new_preprocesses = []
347
+ original_position_embeddings = []
348
+ for model_name, model, preprocess in zip(self.clip_models, self.models, self.preprocesses):
349
+ if "RN" in model_name:
350
+ model_spatial_dim = int((model.visual.attnpool.positional_embedding.shape[0]-1)**0.5)
351
+ patch_size = model.visual.input_resolution // model_spatial_dim
352
+ original_positional_embedding = model.visual.attnpool.positional_embedding.clone()
353
+ model.visual.attnpool.positional_embedding = torch.nn.Parameter(torch.nn.functional.interpolate(
354
+ model.visual.attnpool.positional_embedding[1:,:].permute(1, 0).view(1, -1, model_spatial_dim, model_spatial_dim),
355
+ size=(image.height // patch_size, image.width // patch_size),
356
+ mode='bicubic',
357
+ align_corners=False
358
+ ).squeeze(0).permute(1, 2, 0).view(-1, original_positional_embedding.shape[-1]))
359
+ model.visual.attnpool.positional_embedding = torch.nn.Parameter(torch.cat((
360
+ original_positional_embedding[:1,:],
361
+ model.visual.attnpool.positional_embedding
362
+ ), dim=0))
363
+ transform = transforms.Compose([
364
+ transforms.Resize(((image.height // patch_size)*patch_size, (image.width // patch_size)*patch_size), interpolation=Image.BICUBIC),
365
+ lambda image: image.convert("RGB"),
366
+ transforms.ToTensor(),
367
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
368
+ ])
369
+ else:
370
+ model_spatial_dim = int((model.visual.positional_embedding.shape[0]-1)**0.5)
371
+ patch_size = model.visual.input_resolution // model_spatial_dim
372
+ original_positional_embedding = model.visual.positional_embedding.clone()
373
+ model.visual.positional_embedding = torch.nn.Parameter(torch.nn.functional.interpolate(
374
+ model.visual.positional_embedding[1:,:].permute(1, 0).view(1, -1, model_spatial_dim, model_spatial_dim),
375
+ size=(image.height // patch_size, image.width // patch_size),
376
+ mode='bicubic',
377
+ align_corners=False
378
+ ).squeeze(0).permute(1, 2, 0).view(-1, original_positional_embedding.shape[-1]))
379
+ model.visual.positional_embedding = torch.nn.Parameter(torch.cat((
380
+ original_positional_embedding[:1,:],
381
+ model.visual.positional_embedding
382
+ ), dim=0))
383
+ transform = transforms.Compose([
384
+ transforms.Resize(((image.height // patch_size)*patch_size, (image.width // patch_size)*patch_size), interpolation=Image.BICUBIC),
385
+ lambda image: image.convert("RGB"),
386
+ transforms.ToTensor(),
387
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
388
+ ])
389
+ new_preprocesses.append(transform)
390
+ original_position_embeddings.append(original_positional_embedding)
391
+ self.preprocesses = new_preprocesses
392
+ result = super().__call__(caption, image, boxes, image_name, image_pth)
393
+ if self.expand_position_embedding:
394
+ self.preprocesses = original_preprocesses
395
+ for model, model_name, pos_embedding in zip(self.models, self.clip_models, original_position_embeddings):
396
+ if "RN" in model_name:
397
+ model.visual.attnpool.positional_embedding = torch.nn.Parameter(pos_embedding)
398
+ else:
399
+ model.visual.positional_embedding = torch.nn.Parameter(pos_embedding)
400
+ return result
401
+
AlphaCLIP/eval/rec_zs_test/generic_clip_pairs.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import clip
3
+ import json
4
+ import argparse
5
+ import ruamel.yaml as yaml
6
+
7
+ from PIL import Image
8
+ import torch
9
+ import torchvision.transforms as transforms
10
+ from tqdm import tqdm
11
+
12
+ from albef.utils import *
13
+ from executor import AlbefExecutor
14
+
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument("--input_path", type=str, help="Path to input JSON file")
17
+ parser.add_argument("--image_root", type=str, help="Path to directory containing images")
18
+ parser.add_argument("--albef_path", type=str, default=None, help="Path to ALBEF model/config/etc. if the goal is to use ALBEF")
19
+ parser.add_argument("--albef_itc", action="store_true", help="Use ITC output of ALBEF")
20
+ parser.add_argument("--clip_model", type=str, help="CLIP model to use")
21
+ parser.add_argument("--gpu", type=int, default=-1, help="Which gpu to use")
22
+ parser.add_argument("--batch_size", type=int, default=32, help="Batch size for running CLIP")
23
+
24
+ args = parser.parse_args()
25
+
26
+ if args.albef_path is not None:
27
+ executor = AlbefExecutor(checkpoint_path = os.path.join(args.albef_path, "checkpoint.pth"), config_path = os.path.join(args.albef_path, "config.yaml"), device = "cpu" if args.gpu < 0 else "cuda:"+str(args.gpu))
28
+ model = executor.models[0]
29
+ preprocess = executor.preprocesses[0]
30
+ model = model.eval()
31
+ else:
32
+ model, preprocess = clip.load(args.clip_model, jit=False, device="cuda:"+str(args.gpu))
33
+ preprocess.transforms[0] == transforms.Resize((model.visual.input_resolution, model.visual.input_resolution), transforms.InterpolationMode.BICUBIC)
34
+ model = model.eval()
35
+ input_file = open(args.input_path)
36
+ data = json.load(input_file)
37
+ input_file.close()
38
+ correct = 0
39
+ for i in tqdm(range(0, len(data), args.batch_size)):
40
+ batch_images = []
41
+ batch_text = []
42
+ for datum in data[i:min(i+args.batch_size, len(data))]:
43
+ img = Image.open(os.path.join(args.image_root, datum["image_filename"])).convert('RGB')
44
+ batch_images.append(preprocess(img))
45
+ if "text2" in datum:
46
+ if args.albef_path is None:
47
+ datum["text1"] = "a photo of "+datum["text1"]
48
+ datum["text2"] = "a photo of "+datum["text2"]
49
+ batch_text.append(datum["text1"])
50
+ batch_text.append(datum["text2"])
51
+ else:
52
+ img2 = Image.open(os.path.join(args.image_root, datum["image_filename2"])).convert('RGB')
53
+ batch_images.append(preprocess(img2))
54
+ batch_text.append(datum["text1"])
55
+ batch_images = torch.stack(batch_images).to("cuda:"+str(args.gpu))
56
+ if args.albef_path is None:
57
+ batch_text = clip.tokenize(batch_text).to("cuda:"+str(args.gpu))
58
+ else:
59
+ modified_text = [pre_caption(txt, executor.max_words) for txt in batch_text]
60
+ batch_text = executor.tokenizer(modified_text, padding='longest', return_tensors="pt")
61
+ for key in batch_text:
62
+ batch_text[key] = batch_text[key].to(batch_images.device)
63
+
64
+ with torch.no_grad():
65
+ if args.albef_path is None:
66
+ logits_per_image, logits_per_text = model(batch_images, batch_text)
67
+ else:
68
+ if not args.albef_itc:
69
+ if batch_images.shape[0]*2 == batch_text.input_ids.shape[0]:
70
+ batch_images = batch_images.unsqueeze(1).repeat(1, 2, 1, 1, 1).view(batch_images.shape[0]*2, batch_images.shape[1], batch_images.shape[2], batch_images.shape[3])
71
+ else:
72
+ assert batch_images.shape[0] ==2*batch_text.input_ids.shape[0]
73
+ batch_text.input_ids = batch_text.input_ids.unsqueeze(1).repeat(1, 2, 1).view(batch_images.shape[0], -1)
74
+ batch_text.attention_mask = batch_text.attention_mask.unsqueeze(1).repeat(1, 2, 1).view(batch_images.shape[0], -1)
75
+ image_embeds = model.visual_encoder(batch_images)
76
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(batch_images.device)
77
+ output = model.text_encoder(
78
+ batch_text.input_ids,
79
+ attention_mask = batch_text.attention_mask,
80
+ encoder_hidden_states = image_embeds,
81
+ encoder_attention_mask = image_atts,
82
+ return_dict = True,
83
+ )
84
+ vl_embeddings = output.last_hidden_state[:,0,:]
85
+ vl_output = model.itm_head(vl_embeddings)
86
+ logits_per_image = vl_output[:,1:2].view(-1, 2)
87
+ else:
88
+ image_embeds = model.visual_encoder(batch_images)
89
+ image_feat = torch.nn.functional.normalize(model.vision_proj(image_embeds[:,0,:]),dim=-1)
90
+ text_output = model.text_encoder(batch_text.input_ids, attention_mask = batch_text.attention_mask,
91
+ return_dict = True, mode = 'text')
92
+ text_embeds = text_output.last_hidden_state
93
+ text_feat = torch.nn.functional.normalize(model.text_proj(text_embeds[:,0,:]),dim=-1)
94
+ sim = image_feat@text_feat.t()/model.temp
95
+ logits_per_image = sim
96
+ if args.albef_path is None or args.albef_itc:
97
+ if logits_per_image.shape[0]*2 == logits_per_image.shape[1]:
98
+ for j in range(logits_per_image.shape[0]):
99
+ correct += 1 if logits_per_image[j,2*j].item() > logits_per_image[j,2*j+1].item() else 0
100
+ else:
101
+ assert logits_per_image.shape[0] == 2*logits_per_image.shape[1]
102
+ for j in range(logits_per_image.shape[1]):
103
+ correct += 1 if logits_per_image[2*j,j].item() > logits_per_image[2*j+1,j].item() else 0
104
+ else:
105
+ correct += (logits_per_image[:,0] > logits_per_image[:,1]).long().sum().item()
106
+
107
+ print("Accuracy:", correct/len(data))
AlphaCLIP/eval/rec_zs_test/heuristics.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Heuristic rules used to extract and execute entity parses."""
2
+
3
+ from typing import Callable, List, NamedTuple
4
+ from argparse import Namespace
5
+ import numpy as np
6
+
7
+
8
+ class RelHeuristic(NamedTuple):
9
+ keywords: List[str]
10
+ callback: Callable[["Environment"], np.ndarray]
11
+
12
+
13
+ class Heuristics:
14
+ """A class defining heuristics that can be enabled/disabled."""
15
+
16
+ RELATIONS = [
17
+ RelHeuristic(["left", "west"], lambda env: env.left_of()),
18
+ RelHeuristic(["right", "east"], lambda env: env.right_of()),
19
+ RelHeuristic(["above", "north", "top", "back", "behind"], lambda env: env.above()),
20
+ RelHeuristic(["below", "south", "under", "front"], lambda env: env.below()),
21
+ RelHeuristic(["bigger", "larger", "closer"], lambda env: env.bigger_than()),
22
+ RelHeuristic(["smaller", "tinier", "further"], lambda env: env.smaller_than()),
23
+ RelHeuristic(["inside", "within", "contained"], lambda env: env.within()),
24
+ ]
25
+
26
+ TERNARY_RELATIONS = [
27
+ RelHeuristic(["between"], lambda env: env.between()),
28
+ ]
29
+
30
+ SUPERLATIVES = [
31
+ RelHeuristic(["left", "west", "leftmost", "western"], lambda env: env.left_of()),
32
+ RelHeuristic(["right", "rightmost", "east", "eastern"], lambda env: env.right_of()),
33
+ RelHeuristic(["above", "north", "top"], lambda env: env.above()),
34
+ RelHeuristic(["below", "south", "underneath", "front"], lambda env: env.below()),
35
+ RelHeuristic(["bigger", "biggest", "larger", "largest", "closer", "closest"], lambda env: env.bigger_than()),
36
+ RelHeuristic(["smaller", "smallest", "tinier", "tiniest", "further", "furthest"], lambda env: env.smaller_than()),
37
+ ]
38
+ OPPOSITES = {0: 1, 1: 0, 2: 3, 3: 2, 4: 5, 5: 4}
39
+
40
+ NULL_KEYWORDS = ["part", "image", "side", "picture", "half", "region", "section"]
41
+
42
+ EMPTY = []
43
+
44
+ def __init__(self, args: Namespace = None):
45
+ self.enable_relations = not args or not args.no_rel
46
+ self.enable_superlatives = not args or not args.no_sup
47
+ self.enable_nulls = not args or not args.no_null
48
+ self.enable_ternary = not args or args.ternary
49
+
50
+ @property
51
+ def relations(self) -> List[RelHeuristic]:
52
+ return self.RELATIONS if self.enable_relations else self.EMPTY
53
+
54
+ @property
55
+ def ternary_relations(self) -> List[RelHeuristic]:
56
+ return self.TERNARY_RELATIONS if self.enable_ternary else self.EMPTY
57
+
58
+ @property
59
+ def superlatives(self) -> List[RelHeuristic]:
60
+ return self.SUPERLATIVES if self.enable_superlatives else self.EMPTY
61
+
62
+ @property
63
+ def opposites(self):
64
+ return self.OPPOSITES
65
+
66
+ @property
67
+ def null_keywords(self) -> List[str]:
68
+ return self.NULL_KEYWORDS if self.enable_nulls else self.EMPTY
AlphaCLIP/eval/rec_zs_test/interpreter.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import NamedTuple, List, Callable
2
+ import sys
3
+ import re
4
+ import numpy as np
5
+ import torch
6
+ from numpy.linalg import norm
7
+ from itertools import product, groupby
8
+ from PIL import Image
9
+
10
+
11
+ # Do two line segments intersect? Copied from
12
+ # https://stackoverflow.com/questions/3838329/how-can-i-check-if-two-segments-intersect
13
+
14
+
15
+ def ccw(A, B, C):
16
+ return (C.y - A.y) * (B.x - A.x) > (B.y - A.y) * (C.x - A.x)
17
+
18
+
19
+ def intersect(A, B, C, D):
20
+ """Do line segments AB and CD intersect?"""
21
+ return ccw(A, C, D) != ccw(B, C, D) and ccw(A, B, C) != ccw(A, B, D)
22
+
23
+
24
+ class Box(NamedTuple):
25
+ x: int
26
+ y: int
27
+ w: int = 0
28
+ h: int = 0
29
+
30
+ @property
31
+ def left(self):
32
+ return self.x
33
+
34
+ @property
35
+ def right(self):
36
+ return self.x + self.w
37
+
38
+ @property
39
+ def top(self):
40
+ return self.y
41
+
42
+ @property
43
+ def bottom(self):
44
+ return self.y + self.h
45
+
46
+ @property
47
+ def center(self):
48
+ return Box(self.x + self.w // 2, self.y + self.h // 2)
49
+
50
+ def corners(self):
51
+ yield Box(self.x, self.y)
52
+ yield Box(self.x + self.w, self.y)
53
+ yield Box(self.x + self.w, self.y + self.h)
54
+ yield Box(self.x, self.y + self.h)
55
+
56
+ @property
57
+ def area(self):
58
+ return self.w * self.h
59
+
60
+ def intersect(self, other: "Box") -> "Box":
61
+ x1 = max(self.x, other.x)
62
+ x2 = max(x1, min(self.x+self.w, other.x+other.w))
63
+ y1 = max(self.y, other.y)
64
+ y2 = max(y1, min(self.y+self.h, other.y+other.h))
65
+ return Box(x=x1, y=y1, w=x2-x1, h=y2-y1)
66
+
67
+ def min_bounding(self, other: "Box") -> "Box":
68
+ corners = list(self.corners())
69
+ corners.extend(other.corners())
70
+ min_x = min_y = float("inf")
71
+ max_x = max_y = -float("inf")
72
+
73
+ for item in corners:
74
+ min_x = min(min_x, item.x)
75
+ min_y = min(min_y, item.y)
76
+ max_x = max(max_x, item.x)
77
+ max_y = max(max_y, item.y)
78
+
79
+ return Box(min_x, min_y, max_x - min_x, max_y - min_y)
80
+
81
+ def expand(self, growth: float = .1) -> "Box":
82
+ factor = 1 + growth
83
+ w = factor * self.w
84
+ h = factor * self.h
85
+ return Box(min_x - (w - self.w) / 2, min_y - (h - self.h) / 2, w, h)
86
+
87
+
88
+ def iou(box1, box2):
89
+ x1 = max(box1.x, box2.x)
90
+ x2 = max(x1, min(box1.x+box1.w, box2.x+box2.w))
91
+ y1 = max(box1.y, box2.y)
92
+ y2 = max(y1, min(box1.y+box1.h, box2.y+box2.h))
93
+ intersection = Box(x=x1, y=y1, w=x2-x1, h=y2-y1)
94
+ intersection_area = intersection.area
95
+ union_area = box1.area+box2.area-intersection_area
96
+ return intersection_area / union_area
97
+
98
+
99
+ def all_equal(iterable):
100
+ """Are all elements the same?"""
101
+ g = groupby(iterable)
102
+ return next(g, True) and not next(g, False)
103
+
104
+
105
+ class spatial:
106
+ """A decorator that converts a predicate over boxes to a function that returns a tensor over all boxes."""
107
+
108
+ def __init__(self, arity: int = 2, enforce_antisymmetry: bool = False):
109
+ self.arity = arity
110
+ self.enforce_antisymmetry = enforce_antisymmetry # Zero out any entries where two boxes are the same.
111
+
112
+ def __call__(self, predicate: Callable[[Box], float]) -> Callable[["Environment"], np.ndarray]:
113
+ def _rel(env):
114
+ n_boxes = len(env.boxes)
115
+ tensor = np.empty([n_boxes for _ in range(self.arity)])
116
+ enum_boxes = list(enumerate(env.boxes))
117
+ for pairs in product(*[enum_boxes for _ in range(self.arity)]):
118
+ indices, boxes = zip(*pairs)
119
+ if self.enforce_antisymmetry and len(set(indices)) < len(indices):
120
+ tensor[indices] = 0.
121
+ else:
122
+ tensor[indices] = predicate(*boxes)
123
+ return tensor
124
+ return _rel
125
+
126
+
127
+ class Environment:
128
+ def __init__(self, image: Image, boxes: List[Box], executor: "Executor" = None, freeform_boxes: bool = False, image_name: str = None, image_pth: str=None):
129
+ self.image = image
130
+ self.boxes = boxes
131
+ self.executor = executor # An object or callback that can query CLIP with captions/images.
132
+ self.freeform_boxes = freeform_boxes
133
+ self.image_name = image_name
134
+ self.image_pth=image_pth
135
+
136
+ def uniform(self) -> np.ndarray:
137
+ n_boxes = len(self.boxes)
138
+ return 1 / n_boxes * np.ones(n_boxes)
139
+
140
+ def filter(self,
141
+ caption: str,
142
+ temperature: float = 1.,
143
+ area_threshold: float = 0.0,
144
+ softmax: bool = False,
145
+ expand: float = None
146
+ ) -> np.ndarray:
147
+ """Return a new distribution reflecting the likelihood that `caption` describes the content of each box."""
148
+ area_filtered_dist = torch.from_numpy(self.filter_area(area_threshold)).to(self.executor.device)
149
+ candidate_indices = [i for i in range(len(self.boxes)) if float(area_filtered_dist[i]) > 0.0]
150
+ boxes = [self.boxes[i] for i in candidate_indices]
151
+ if len(boxes) == 0:
152
+ boxes = self.boxes
153
+ candidate_indices = list(range(len(boxes)))
154
+ if expand is not None:
155
+ boxes = [box.expand(expand) for box in boxes]
156
+ result_partial = self.executor(caption, self.image, boxes, image_name=self.image_name, image_pth=self.image_pth)
157
+ if self.freeform_boxes:
158
+ result_partial, boxes = result_partial
159
+ self.boxes = [Box(x=boxes[i,0].item(), y=boxes[i,1].item(), w=boxes[i,2].item()-boxes[i,0].item(), h=boxes[i,3].item()-boxes[i,1].item()) for i in range(boxes.shape[0])]
160
+ candidate_indices = list(range(len(self.boxes)))
161
+ result_partial = result_partial.float()
162
+ if not softmax:
163
+ result_partial = (result_partial-result_partial.mean()) / (result_partial.std() + 1e-9)
164
+ result_partial = (temperature * result_partial).sigmoid()
165
+ result = torch.zeros((len(self.boxes))).to(result_partial.device)
166
+ result[candidate_indices] = result_partial
167
+ else:
168
+ result = torch.zeros((len(self.boxes))).to(result_partial.device)
169
+ result[candidate_indices] = result_partial.softmax(dim=-1) #softmax结果
170
+ return result.cpu().numpy()
171
+
172
+ def filter_area(self, area_threshold: float) -> np.ndarray:
173
+ """Return a new distribution in which all boxes whose area as a fraction of the image is less than the threshold."""
174
+ image_area = self.image.width*self.image.height
175
+ return np.array([1 if self.boxes[i].area/image_area > area_threshold else 0 for i in range(len(self.boxes))])
176
+
177
+ @spatial()
178
+ def left_of(b1, b2):
179
+ return (b1.right+b1.left) / 2 < (b2.right+b2.left) / 2
180
+
181
+ @spatial()
182
+ def right_of(b1, b2):
183
+ return (b1.right+b1.left) / 2 > (b2.right+b2.left) / 2
184
+
185
+ @spatial()
186
+ def above(b1, b2):
187
+ return (b1.bottom+b1.top) < (b2.bottom+b2.top)
188
+
189
+ @spatial()
190
+ def below(b1, b2):
191
+ return (b1.bottom+b1.top) > (b2.bottom+b2.top)
192
+
193
+ @spatial()
194
+ def bigger_than(b1, b2):
195
+ return b1.area > b2.area
196
+
197
+ @spatial()
198
+ def smaller_than(b1, b2):
199
+ return b1.area < b2.area
200
+
201
+ @spatial(enforce_antisymmetry=False)
202
+ def within(box1, box2):
203
+ """Return percent of box1 inside box2."""
204
+ intersection = box1.intersect(box2)
205
+ return intersection.area / box1.area
206
+
207
+ @spatial(arity=3, enforce_antisymmetry=True)
208
+ def between(box1, box2, box3):
209
+ """How much of box1 lies in min bounding box over box2 and box3?"""
210
+ min_bounding = box2.min_bounding(box3)
211
+ intersect = box1.intersect(min_bounding)
212
+ return intersect.area / box1.area
AlphaCLIP/eval/rec_zs_test/lattice.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implement lattice interface."""
2
+
3
+ from overrides import overrides
4
+ import numpy as np
5
+ from abc import ABCMeta, abstractmethod
6
+
7
+
8
+ class Lattice(metaclass=ABCMeta):
9
+
10
+ """Abstract base class representing a complemented lattice."""
11
+
12
+ @classmethod
13
+ @abstractmethod
14
+ def join(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray:
15
+ return NotImplemented
16
+
17
+ @classmethod
18
+ @abstractmethod
19
+ def meet(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray:
20
+ return NotImplemented
21
+
22
+ @classmethod
23
+ @abstractmethod
24
+ def join_reduce(cls, probs: np.ndarray) -> np.ndarray:
25
+ return NotImplemented
26
+
27
+ @classmethod
28
+ @abstractmethod
29
+ def meet_reduce(cls, probs: np.ndarray) -> np.ndarray:
30
+ return NotImplemented
31
+
32
+
33
+ class Product(Lattice):
34
+ """Lattice where meet=prod and sum is defined accordingly.
35
+
36
+ Equivalent to assuming independence, more or less.
37
+ """
38
+
39
+ eps = 1e-9
40
+
41
+ @classmethod
42
+ @overrides
43
+ def join(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray:
44
+ return probs1 + probs2 - cls.meet(probs1, probs2)
45
+
46
+ @classmethod
47
+ @overrides
48
+ def meet(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray:
49
+ return probs1 * probs2
50
+
51
+ @classmethod
52
+ @overrides
53
+ def join_reduce(cls, probs: np.ndarray) -> np.ndarray:
54
+ """Assumes disjoint events."""
55
+ # return cls.comp(cls.meet_reduce(cls.comp(probs)))
56
+ return np.sum(probs, axis=-1)
57
+
58
+ @classmethod
59
+ @overrides
60
+ def meet_reduce(cls, probs: np.ndarray) -> np.ndarray:
61
+ return np.prod(probs, axis=-1)
62
+
63
+ @classmethod
64
+ def comp(cls, probs):
65
+ return 1 - probs
66
+
67
+ @classmethod
68
+ def normalize(cls, probs):
69
+ """Normalize a distribution by dividing by the total mass."""
70
+ return probs / np.sum(probs + cls.eps, axis=-1)
AlphaCLIP/eval/rec_zs_test/main.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import json
3
+ import argparse
4
+ import os
5
+ import random
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+
11
+ from interpreter import *
12
+ from executor import *
13
+ from methods import *
14
+
15
+ METHODS_MAP = {
16
+ "baseline": Baseline,
17
+ "random": Random,
18
+ "parse": Parse,
19
+ }
20
+
21
+ if __name__ == "__main__":
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("--input_file", type=str, help="input file with expressions and annotations in jsonlines format")
24
+ parser.add_argument("--image_root", type=str, help="path to images (train2014 directory of COCO)")
25
+ parser.add_argument("--clip_model", type=str, default="RN50x16,ViT-B/32", help="which clip model to use (should use RN50x4, ViT-B/32, or both separated by a comma")
26
+ parser.add_argument("--clip_type", type=str, default="aclip", help="which clip model to use (should use RN50x4, ViT-B/32, or both separated by a comma")
27
+ parser.add_argument("--albef_path", type=str, default=None, help="to use ALBEF (instead of CLIP), specify the path to the ALBEF checkpoint")
28
+ parser.add_argument("--method", type=str, default="parse", help="method to solve expressions")
29
+ parser.add_argument("--box_representation_method", type=str, default="crop,blur", help="method of representing boxes as individual images (crop, blur, or both separated by a comma)")
30
+ parser.add_argument("--box_method_aggregator", type=str, default="sum", help="method of combining box representation scores")
31
+ parser.add_argument("--box_area_threshold", type=float, default=0.0, help="minimum area (as a proportion of image area) for a box to be considered as the answer")
32
+ parser.add_argument("--output_file", type=str, default=None, help="(optional) output path to save results")
33
+ parser.add_argument("--detector_file", type=str, default=None, help="(optional) file containing object detections. if not provided, the gold object boxes will be used.")
34
+ parser.add_argument("--mock", action="store_true", help="(optional) mock CLIP execution.")
35
+ parser.add_argument("--device", type=int, default=0, help="CUDA device to use.")
36
+ parser.add_argument("--shuffle_words", action="store_true", help="If true, shuffle words in the sentence")
37
+ parser.add_argument("--gradcam_alpha", type=float, nargs='+', help="alpha value to use for gradcam method")
38
+ parser.add_argument("--enlarge_boxes", type=float, default=0.0, help="(optional) whether to enlarge boxes when passing them to the model")
39
+ parser.add_argument("--part", type=str, default=None, help="(optional) specify how many parts to divide the dataset into and which part to run in the format NUM_PARTS,PART_NUM")
40
+ parser.add_argument("--batch_size", type=int, default=1, help="number of instances to process in one model call (only supported for baseline model)")
41
+ parser.add_argument("--baseline_head", action="store_true", help="For baseline, controls whether model is called on both full expression and head noun chunk of expression")
42
+ parser.add_argument("--mdetr", type=str, default=None, help="to use MDETR as the executor model, specify the name of the MDETR model")
43
+ parser.add_argument("--albef_block_num", type=int, default=8, help="block num for ALBEF gradcam")
44
+ parser.add_argument("--albef_mode", type=str, choices=["itm", "itc"], default="itm")
45
+ parser.add_argument("--expand_position_embedding",action="store_true")
46
+ parser.add_argument("--gradcam_background", action="store_true")
47
+ parser.add_argument("--mdetr_given_bboxes", action="store_true")
48
+ parser.add_argument("--mdetr_use_token_mapping", action="store_true")
49
+ parser.add_argument("--non_square_size", action="store_true")
50
+ parser.add_argument("--blur_std_dev", type=int, default=100, help="standard deviation of Gaussian blur")
51
+ parser.add_argument("--gradcam_ensemble_before", action="store_true", help="Average gradcam maps of different models before summing over the maps")
52
+ parser.add_argument("--cache_path", type=str, default=None, help="cache features")
53
+ # Arguments related to Parse method.
54
+ parser.add_argument("--no_rel", action="store_true", help="Disable relation extraction.")
55
+ parser.add_argument("--no_sup", action="store_true", help="Disable superlative extraction.")
56
+ parser.add_argument("--no_null", action="store_true", help="Disable null keyword heuristics.")
57
+ parser.add_argument("--ternary", action="store_true", help="Disable ternary relation extraction.")
58
+ parser.add_argument("--baseline_threshold", type=float, default=float("inf"), help="(Parse) Threshold to use relations/superlatives.")
59
+ parser.add_argument("--temperature", type=float, default=1., help="(Parse) Sigmoid temperature.")
60
+ parser.add_argument("--superlative_head_only", action="store_true", help="(Parse) Superlatives only quanntify head predicate.")
61
+ parser.add_argument("--sigmoid", action="store_true", help="(Parse) Use sigmoid, not softmax.")
62
+ parser.add_argument("--no_possessive", action="store_true", help="(Parse) Model extraneous relations as possessive relations.")
63
+ parser.add_argument("--expand_chunks", action="store_true", help="(Parse) Expand noun chunks to include descendant tokens that aren't ancestors of tokens in other chunks")
64
+ parser.add_argument("--parse_no_branch", action="store_true", help="(Parse) Only do the parsing procedure if some relation/superlative keyword is in the expression")
65
+ parser.add_argument("--possessive_no_expand", action="store_true", help="(Parse) Expand ent2 in possessive case")
66
+ args = parser.parse_args()
67
+
68
+ with open(args.input_file) as f:
69
+ lines = f.readlines()
70
+ data = [json.loads(line) for line in lines]
71
+
72
+ device = f"cuda:{args.device}" if torch.cuda.is_available() and args.device >= 0 else "cpu"
73
+ gradcam = args.method == "gradcam"
74
+
75
+ executor = ClipExecutor(clip_model=args.clip_model, box_representation_method=args.box_representation_method, method_aggregator=args.box_method_aggregator, device=device, square_size=not args.non_square_size, expand_position_embedding=args.expand_position_embedding, blur_std_dev=args.blur_std_dev, cache_path=args.cache_path, input_file=args.input_file, clip_type=args.clip_type)
76
+
77
+ method = METHODS_MAP[args.method](args)
78
+ correct_count = 0
79
+ total_count = 0
80
+ if args.output_file:
81
+ output_file = open(args.output_file, "w")
82
+ if args.detector_file:
83
+ detector_file = open(args.detector_file)
84
+ detections_list = json.load(detector_file)
85
+ if isinstance(detections_list, dict):
86
+ detections_map = {int(image_id): detections_list[image_id] for image_id in detections_list}
87
+ else:
88
+ detections_map = defaultdict(list)
89
+ for detection in detections_list:
90
+ detections_map[detection["image_id"]].append(detection["box"])
91
+
92
+ part = 0
93
+ if args.part is not None: # for multi-gpu test / part-data test
94
+ num_parts = int(args.part.split(",")[0])
95
+ part = int(args.part.split(",")[1])
96
+ data = data[int(len(data)*part/num_parts):int(len(data)*(part+1)/num_parts)]
97
+
98
+ batch_count = 0
99
+ batch_boxes = []
100
+ batch_gold_boxes = []
101
+ batch_gold_index = []
102
+ batch_file_names = []
103
+ batch_sentences = []
104
+ for datum in tqdm(data):
105
+ if "coco" in datum["file_name"].lower():
106
+ file_name = "_".join(datum["file_name"].split("_")[:-1])+".jpg"
107
+ else:
108
+ file_name = datum["file_name"]
109
+ img_path = os.path.join(args.image_root, file_name)
110
+ img = Image.open(img_path).convert('RGB')
111
+ gold_boxes = [Box(x=ann["bbox"][0], y=ann["bbox"][1], w=ann["bbox"][2], h=ann["bbox"][3]) for ann in datum["anns"]]
112
+ if isinstance(datum["ann_id"], int) or isinstance(datum["ann_id"], str):
113
+ datum["ann_id"] = [datum["ann_id"]]
114
+ assert isinstance(datum["ann_id"], list)
115
+ gold_index = [i for i in range(len(datum["anns"])) if datum["anns"][i]["id"] in datum["ann_id"]]
116
+ if args.detector_file:
117
+ boxes = [Box(x=box[0], y=box[1], w=box[2], h=box[3]) for box in detections_map[int(datum["image_id"])]]
118
+ if len(boxes) == 0:
119
+ boxes = [Box(x=0, y=0, w=img.width, h=img.height)]
120
+ else:
121
+ boxes = gold_boxes
122
+ for sentence in datum["sentences"]:
123
+ env = Environment(img, boxes, executor, (args.mdetr is not None and not args.mdetr_given_bboxes), str(datum["image_id"]), img_path)
124
+ if args.shuffle_words:
125
+ words = sentence["raw"].lower().split()
126
+ random.shuffle(words)
127
+ result = method.execute(" ".join(words), env)
128
+ else:
129
+ result = method.execute(sentence["raw"].lower(), env)
130
+ boxes = env.boxes
131
+ print(sentence["raw"].lower())
132
+ correct = False
133
+ for g_index in gold_index:
134
+ if iou(boxes[result["pred"]], gold_boxes[g_index]) > 0.5:
135
+ correct = True
136
+ break
137
+ if correct:
138
+ result["correct"] = 1
139
+ correct_count += 1
140
+ else:
141
+ result["correct"] = 0
142
+ if args.detector_file:
143
+ argmax_ious = []
144
+ max_ious = []
145
+ for g_index in gold_index:
146
+ ious = [iou(box, gold_boxes[g_index]) for box in boxes]
147
+ argmax_iou = -1
148
+ max_iou = 0
149
+ if max(ious) >= 0.5:
150
+ for index, value in enumerate(ious):
151
+ if value > max_iou:
152
+ max_iou = value
153
+ argmax_iou = index
154
+ argmax_ious.append(argmax_iou)
155
+ max_ious.append(max_iou)
156
+ argmax_iou = -1
157
+ max_iou = 0
158
+ if max(max_ious) >= 0.5:
159
+ for index, value in zip(argmax_ious, max_ious):
160
+ if value > max_iou:
161
+ max_iou = value
162
+ argmax_iou = index
163
+ result["gold_index"] = argmax_iou
164
+ else:
165
+ result["gold_index"] = gold_index
166
+ result["bboxes"] = [[box.left, box.top, box.right, box.bottom] for box in boxes]
167
+ result["file_name"] = file_name
168
+ result["probabilities"] = result["probs"]
169
+ result["text"] = sentence["raw"].lower()
170
+ if args.output_file:
171
+ # Serialize numpy arrays for JSON.
172
+ for key in result:
173
+ if isinstance(result[key], np.ndarray):
174
+ result[key] = result[key].tolist()
175
+ if isinstance(result[key], np.int64):
176
+ result[key] = result[key].item()
177
+ output_file.write(json.dumps(result)+"\n")
178
+ total_count += 1
179
+ print(f"est_acc: {100 * correct_count / total_count:.3f}")
180
+
181
+ if args.output_file:
182
+ output_file.close()
183
+ print(f"acc: {100 * correct_count / total_count:.3f}")
184
+ acc = 100 * correct_count / total_count
185
+
186
+ result = {}
187
+ result['acc'] = acc
188
+ json.dump(acc, open(os.path.join('./output', args.input_file.split('/')[-1].split('.')[0] + '_acc_' + str(part)+'.json'),'w'))
189
+ json.dump(str(correct_count)+' '+str(total_count), open(os.path.join('./output', args.input_file.split('/')[-1].split('.')[0] + '_count_' + str(part)+'.json'),'w'))
190
+ stats = method.get_stats()
191
+ if stats:
192
+ pairs = sorted(list(stats.items()), key=lambda tup: tup[0])
193
+ for key, value in pairs:
194
+ result[key] = value
195
+ if isinstance(value, float):
196
+ print(f"{key}: {value:.5f}")
197
+ else:
198
+ print(f"{key}: {value}")
199
+
200
+ json.dump(result, open(os.path.join('./output', args.input_file.split('/')[-1].split('.')[0] + '_' + str(part)+'.json'),'w'))
AlphaCLIP/eval/rec_zs_test/methods/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .baseline import Baseline
2
+ from .random_method import Random
3
+ from .parse import Parse
AlphaCLIP/eval/rec_zs_test/methods/baseline.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A naive baseline method: just pass the full expression to CLIP."""
2
+
3
+ from overrides import overrides
4
+ from typing import Dict, Any, List
5
+ import numpy as np
6
+ import torch
7
+ import spacy
8
+ from argparse import Namespace
9
+
10
+ from .ref_method import RefMethod
11
+ from lattice import Product as L
12
+
13
+
14
+ class Baseline(RefMethod):
15
+ """CLIP-only baseline where each box is evaluated with the full expression."""
16
+
17
+ nlp = spacy.load('en_core_web_sm')
18
+
19
+ def __init__(self, args: Namespace):
20
+ self.args = args
21
+ self.box_area_threshold = args.box_area_threshold
22
+ self.batch_size = args.batch_size
23
+ self.batch = []
24
+
25
+ @overrides
26
+ def execute(self, caption: str, env: "Environment") -> Dict[str, Any]:
27
+ chunk_texts = self.get_chunk_texts(caption)
28
+ probs = env.filter(caption, area_threshold = self.box_area_threshold, softmax=True)
29
+ if self.args.baseline_head:
30
+ probs2 = env.filter(chunk_texts[0], area_threshold = self.box_area_threshold, softmax=True)
31
+ probs = L.meet(probs, probs2)
32
+ pred = np.argmax(probs)
33
+ return {
34
+ "probs": probs,
35
+ "pred": pred,
36
+ "box": env.boxes[pred],
37
+ }
38
+
39
+ def get_chunk_texts(self, expression: str) -> List:
40
+ doc = self.nlp(expression)
41
+ head = None
42
+ for token in doc:
43
+ if token.head.i == token.i:
44
+ head = token
45
+ break
46
+ head_chunk = None
47
+ chunk_texts = []
48
+ for chunk in doc.noun_chunks:
49
+ if head.i >= chunk.start and head.i < chunk.end:
50
+ head_chunk = chunk.text
51
+ chunk_texts.append(chunk.text)
52
+ if head_chunk is None:
53
+ if len(list(doc.noun_chunks)) > 0:
54
+ head_chunk = list(doc.noun_chunks)[0].text
55
+ else:
56
+ head_chunk = expression
57
+ return [head_chunk] + [txt for txt in chunk_texts if txt != head_chunk]
AlphaCLIP/eval/rec_zs_test/methods/parse.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Use spatial relations extracted from the parses."""
2
+
3
+ from typing import Dict, Any, Callable, List, Tuple, NamedTuple
4
+ from numbers import Number
5
+ from collections import defaultdict
6
+ from overrides import overrides
7
+ import numpy as np
8
+ import spacy
9
+ from spacy.tokens.token import Token
10
+ from spacy.tokens.span import Span
11
+ from argparse import Namespace
12
+
13
+ from .ref_method import RefMethod
14
+ from lattice import Product as L
15
+ from heuristics import Heuristics
16
+ from entity_extraction import Entity, expand_chunks
17
+
18
+
19
+ def get_conjunct(ent, chunks, heuristics: Heuristics) -> Entity:
20
+ """If an entity represents a conjunction of two entities, pull them apart."""
21
+ head = ent.head.root # Not ...root.head. Confusing names here.
22
+ if not any(child.text == "and" for child in head.children):
23
+ return None
24
+ for child in head.children:
25
+ if child.i in chunks and head.i is not child.i:
26
+ return Entity.extract(child, chunks, heuristics)
27
+ return None
28
+
29
+
30
+ class Parse(RefMethod):
31
+ """An REF method that extracts and composes predicates, relations, and superlatives from a dependency parse.
32
+
33
+ The process is as follows:
34
+ 1. Use spacy to parse the document.
35
+ 2. Extract a semantic entity tree from the parse.
36
+ 3. Execute the entity tree to yield a distribution over boxes."""
37
+
38
+ nlp = spacy.load('en_core_web_sm')
39
+
40
+ def __init__(self, args: Namespace = None):
41
+ self.args = args
42
+ self.box_area_threshold = args.box_area_threshold
43
+ self.baseline_threshold = args.baseline_threshold
44
+ self.temperature = args.temperature
45
+ self.superlative_head_only = args.superlative_head_only
46
+ self.expand_chunks = args.expand_chunks
47
+ self.branch = not args.parse_no_branch
48
+ self.possessive_expand = not args.possessive_no_expand
49
+
50
+ # Lists of keyword heuristics to use.
51
+ self.heuristics = Heuristics(args)
52
+
53
+ # Metrics for debugging relation extraction behavor.
54
+ self.counts = defaultdict(int)
55
+
56
+ @overrides
57
+ def execute(self, caption: str, env: "Environment") -> Dict[str, Any]:
58
+ """Construct an `Entity` tree from the parse and execute it to yield a distribution over boxes."""
59
+ # Start by using the full caption, as in Baseline.
60
+ probs = env.filter(caption, area_threshold=self.box_area_threshold, softmax=True)
61
+ ori_probs = probs
62
+
63
+ # Extend the baseline using parse stuff.
64
+ doc = self.nlp(caption)
65
+ head = self.get_head(doc)
66
+ chunks = self.get_chunks(doc)
67
+ if self.expand_chunks:
68
+ chunks = expand_chunks(doc, chunks)
69
+ entity = Entity.extract(head, chunks, self.heuristics)
70
+
71
+ # If no head noun is found, take the first one.
72
+ if entity is None and len(list(doc.noun_chunks)) > 0:
73
+ head = list(doc.noun_chunks)[0]
74
+ entity = Entity.extract(head.root.head, chunks, self.heuristics)
75
+ self.counts["n_0th_noun"] += 1
76
+
77
+ # If we have found some head noun, filter based on it.
78
+ if entity is not None and (any(any(token.text in h.keywords for h in self.heuristics.relations+self.heuristics.superlatives) for token in doc) or not self.branch):
79
+ ent_probs, texts = self.execute_entity(entity, env, chunks)
80
+ probs = L.meet(probs, ent_probs)
81
+ else:
82
+ texts = [caption]
83
+ self.counts["n_full_expr"] += 1
84
+
85
+ if len(ori_probs) == 1:
86
+ probs = ori_probs
87
+
88
+ self.counts["n_total"] += 1
89
+ pred = np.argmax(probs)
90
+ return {
91
+ "probs": probs,
92
+ "pred": pred,
93
+ "box": env.boxes[pred],
94
+ "texts": texts
95
+ }
96
+
97
+ def execute_entity(self,
98
+ ent: Entity,
99
+ env: "Environment",
100
+ chunks: Dict[int, Span],
101
+ root: bool = True,
102
+ ) -> np.ndarray:
103
+ """Execute an `Entity` tree recursively, yielding a distribution over boxes."""
104
+ self.counts["n_rec"] += 1
105
+ probs = [1, 1]
106
+ head_probs = probs
107
+
108
+ # Only use relations if the head baseline isn't certain.
109
+ if len(probs) == 1 or len(env.boxes) == 1:
110
+ return probs, [ent.text]
111
+
112
+ m1, m2 = probs[:2] # probs[(-probs).argsort()[:2]]
113
+ text = ent.text
114
+ rel_probs = []
115
+ if self.baseline_threshold == float("inf") or m1 < self.baseline_threshold * m2:
116
+ self.counts["n_rec_rel"] += 1
117
+ for tokens, ent2 in ent.relations:
118
+ self.counts["n_rel"] += 1
119
+ rel = None
120
+ # Heuristically decide which spatial relation is represented.
121
+ for heuristic in self.heuristics.relations:
122
+ if any(tok.text in heuristic.keywords for tok in tokens):
123
+ rel = heuristic.callback(env)
124
+ self.counts[f"n_rel_{heuristic.keywords[0]}"] += 1
125
+ break
126
+ # Filter and normalize by the spatial relation.
127
+ if rel is not None:
128
+ probs2 = self.execute_entity(ent2, env, chunks, root=False)
129
+ events = L.meet(np.expand_dims(probs2, axis=0), rel)
130
+ new_probs = L.join_reduce(events)
131
+ rel_probs.append((ent2.text, new_probs, probs2))
132
+ continue
133
+
134
+ # This case specifically handles "between", which takes two noun arguments.
135
+ rel = None
136
+ for heuristic in self.heuristics.ternary_relations:
137
+ if any(tok.text in heuristic.keywords for tok in tokens):
138
+ rel = heuristic.callback(env)
139
+ self.counts[f"n_rel_{heuristic.keywords[0]}"] += 1
140
+ break
141
+ if rel is not None:
142
+ ent3 = get_conjunct(ent2, chunks, self.heuristics)
143
+ if ent3 is not None:
144
+ probs2 = self.execute_entity(ent2, env, chunks, root=False)
145
+ probs2 = np.expand_dims(probs2, axis=[0, 2])
146
+ probs3 = self.execute_entity(ent3, env, chunks, root=False)
147
+ probs3 = np.expand_dims(probs3, axis=[0, 1])
148
+ events = L.meet(L.meet(probs2, probs3), rel)
149
+ new_probs = L.join_reduce(L.join_reduce(events))
150
+ probs = L.meet(probs, new_probs)
151
+ continue
152
+ # Otherwise, treat the relation as a possessive relation.
153
+ if not self.args.no_possessive:
154
+ if self.possessive_expand:
155
+ text = ent.expand(ent2.head)
156
+ else:
157
+ text += f' {" ".join(tok.text for tok in tokens)} {ent2.text}'
158
+ #poss_probs = self._filter(text, env, root=root, expand=.3)
159
+ probs = self._filter(text, env, root=root)
160
+ texts = [text]
161
+ return_probs = [(probs.tolist(), probs.tolist())]
162
+ for (ent2_text, new_probs, ent2_only_probs) in rel_probs:
163
+ probs = L.meet(probs, new_probs)
164
+ probs /= probs.sum()
165
+ texts.append(ent2_text)
166
+ return_probs.append((probs.tolist(), ent2_only_probs.tolist()))
167
+
168
+ # Only use superlatives if thresholds work out.
169
+ m1, m2 = probs[(-probs).argsort()[:2]]
170
+ if m1 < self.baseline_threshold * m2:
171
+ self.counts["n_rec_sup"] += 1
172
+ for tokens in ent.superlatives:
173
+ self.counts["n_sup"] += 1
174
+ sup = None
175
+ for heuristic_index, heuristic in enumerate(self.heuristics.superlatives):
176
+ if any(tok.text in heuristic.keywords for tok in tokens):
177
+ texts.append('sup:'+' '.join([tok.text for tok in tokens if tok.text in heuristic.keywords]))
178
+ sup = heuristic.callback(env)
179
+ self.counts[f"n_sup_{heuristic.keywords[0]}"] += 1
180
+ break
181
+ if sup is not None:
182
+ # Could use `probs` or `head_probs` here?
183
+ precond = head_probs if self.superlative_head_only else probs
184
+ probs = L.meet(np.expand_dims(precond, axis=1)*np.expand_dims(precond, axis=0), sup).sum(axis=1)
185
+ probs = probs / probs.sum()
186
+ return_probs.append((probs.tolist(), None))
187
+
188
+ if root:
189
+ assert len(texts) == len(return_probs)
190
+ return probs, (texts, return_probs, tuple(str(chunk) for chunk in chunks.values()))
191
+ return probs
192
+
193
+ def get_head(self, doc) -> Token:
194
+ """Return the token that is the head of the dependency parse. """
195
+ for token in doc:
196
+ if token.head.i == token.i:
197
+ return token
198
+ return None
199
+
200
+ def get_chunks(self, doc) -> Dict[int, Any]:
201
+ """Return a dictionary mapping sentence indices to their noun chunk."""
202
+ chunks = {}
203
+ for chunk in doc.noun_chunks:
204
+ for idx in range(chunk.start, chunk.end):
205
+ chunks[idx] = chunk
206
+ return chunks
207
+
208
+ @overrides
209
+ def get_stats(self) -> Dict[str, Number]:
210
+ """Summary statistics that have been tracked on this object."""
211
+ stats = dict(self.counts)
212
+ n_rel_caught = sum(v for k, v in stats.items() if k.startswith("n_rel_"))
213
+ n_sup_caught = sum(v for k, v in stats.items() if k.startswith("n_sup_"))
214
+ stats.update({
215
+ "p_rel_caught": n_rel_caught / (self.counts["n_rel"] + 1e-9),
216
+ "p_sup_caught": n_sup_caught / (self.counts["n_sup"] + 1e-9),
217
+ "p_rec_rel": self.counts["n_rec_rel"] / (self.counts["n_rec"] + 1e-9),
218
+ "p_rec_sup": self.counts["n_rec_sup"] / (self.counts["n_rec"] + 1e-9),
219
+ "p_0th_noun": self.counts["n_0th_noun"] / (self.counts["n_total"] + 1e-9),
220
+ "p_full_expr": self.counts["n_full_expr"] / (self.counts["n_total"] + 1e-9),
221
+ "avg_rec": self.counts["n_rec"] / self.counts["n_total"],
222
+ })
223
+ return stats
224
+
225
+ def _filter(self,
226
+ caption: str,
227
+ env: "Environment",
228
+ root: bool = False,
229
+ expand: float = None,
230
+ ) -> np.ndarray:
231
+ """Wrap a filter call in a consistent way for all recursions."""
232
+ kwargs = {
233
+ "softmax": not self.args.sigmoid,
234
+ "temperature": self.args.temperature,
235
+ }
236
+ if root:
237
+ return env.filter(caption, area_threshold=self.box_area_threshold, **kwargs)
238
+ else:
239
+ return env.filter(caption, **kwargs)
AlphaCLIP/eval/rec_zs_test/methods/random_method.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A naive baseline method: just pass the full expression to CLIP."""
2
+
3
+ from overrides import overrides
4
+ from typing import Dict, Any
5
+ import random
6
+ from argparse import Namespace
7
+
8
+ import numpy as np
9
+
10
+ from .ref_method import RefMethod
11
+
12
+
13
+ class Random(RefMethod):
14
+ """CLIP-only baseline where each box is evaluated with the full expression."""
15
+
16
+ def __init__(self, args: Namespace):
17
+ self.box_area_threshold = args.box_area_threshold
18
+
19
+ @overrides
20
+ def execute(self, caption: str, env: "Environment") -> Dict[str, Any]:
21
+ probs = env.filter_area(self.box_area_threshold)*env.uniform()
22
+ random_ordering = list(range(len(env.boxes)))
23
+ random.shuffle(random_ordering)
24
+ random_ordering = np.array(random_ordering)
25
+ pred = np.argmax(probs*random_ordering)
26
+ return {
27
+ "probs": probs.tolist(),
28
+ "pred": int(pred),
29
+ "text": caption.lower()
30
+ }
AlphaCLIP/eval/rec_zs_test/methods/ref_method.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base class for a method for doing referring expressions."""
2
+
3
+ from typing import Dict, Any
4
+ from abc import ABCMeta, abstractmethod
5
+
6
+
7
+ class RefMethod(metaclass=ABCMeta):
8
+ @abstractmethod
9
+ def execute(self, caption: str, env: "Environment") -> Dict[str, Any]:
10
+ return NotImplemented
11
+
12
+ def get_stats(self) -> Dict[str, Any]:
13
+ return {}
AlphaCLIP/eval/rec_zs_test/output/.gitkeep ADDED
File without changes
AlphaCLIP/eval/rec_zs_test/requirements.txt ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ attrs==21.2.0
2
+ blis==0.7.4
3
+ catalogue==2.0.4
4
+ certifi==2021.5.30
5
+ chardet==4.0.0
6
+ click==7.1.2
7
+ cymem==2.0.5
8
+ en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0-py3-none-any.whl
9
+ filelock==3.0.12
10
+ ftfy==6.0.3
11
+ huggingface-hub==0.0.12
12
+ idna==2.10
13
+ iniconfig==1.1.1
14
+ itsdangerous==2.0.1
15
+ joblib==1.0.1
16
+ MarkupSafe==2.0.1
17
+ murmurhash==1.0.5
18
+ numpy==1.21.0
19
+ overrides==6.1.0
20
+ packaging==21.0
21
+ pathy==0.6.0
22
+ Pillow==8.2.0
23
+ pluggy==0.13.1
24
+ preshed==3.0.5
25
+ py==1.10.0
26
+ pydantic==1.7.4
27
+ pyparsing==2.4.7
28
+ pytest==6.2.4
29
+ PyYAML==5.4.1
30
+ regex==2021.7.6
31
+ requests==2.25.1
32
+ ruamel.yaml==0.17.10
33
+ ruamel.yaml.clib==0.2.6
34
+ sacremoses==0.0.45
35
+ scipy==1.7.0
36
+ six==1.16.0
37
+ smart-open==5.1.0
38
+ spacy==3.0.6
39
+ spacy-legacy==3.0.7
40
+ srsly==2.4.1
41
+ thinc==8.0.7
42
+ timm==0.4.12
43
+ tokenizers==0.10.3
44
+ toml==0.10.2
45
+ tqdm==4.61.2
46
+ transformers==4.9.0
47
+ typer==0.3.2
48
+ typing-extensions==3.10.0.0
49
+ typing-utils==0.1.0
50
+ urllib3==1.26.6
51
+ wasabi==0.8.2
52
+ wcwidth==0.2.5
53
+ Werkzeug==2.0.1
AlphaCLIP/eval/rec_zs_test/run.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ CUDA_VISIBLE_DEVICES=0 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_representation_method full,blur --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache
AlphaCLIP/eval/rec_zs_test/run_multi_gpus.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES=0 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,0" &
2
+
3
+ CUDA_VISIBLE_DEVICES=1 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,1" &
4
+
5
+ CUDA_VISIBLE_DEVICES=2 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,2" &
6
+
7
+ CUDA_VISIBLE_DEVICES=3 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,3" &
8
+
9
+ CUDA_VISIBLE_DEVICES=4 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,4" &
10
+
11
+ CUDA_VISIBLE_DEVICES=5 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,5" &
12
+
13
+ CUDA_VISIBLE_DEVICES=6 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,6" &
14
+
15
+ CUDA_VISIBLE_DEVICES=7 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,7"
AlphaCLIP/hubconf.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from alpha_clip.alpha_clip import tokenize as _tokenize, load as _load, available_models as _available_models
2
+ import re
3
+ import string
4
+
5
+ dependencies = ["torch", "torchvision", "ftfy", "regex", "tqdm"]
6
+
7
+ # For compatibility (cannot include special characters in function name)
8
+ model_functions = { model: re.sub(f'[{string.punctuation}]', '_', model) for model in _available_models()}
9
+
10
+ def _create_hub_entrypoint(model):
11
+ def entrypoint(**kwargs):
12
+ return _load(model, **kwargs)
13
+
14
+ entrypoint.__doc__ = f"""Loads the {model} CLIP model
15
+
16
+ Parameters
17
+ ----------
18
+ device : Union[str, torch.device]
19
+ The device to put the loaded model
20
+
21
+ jit : bool
22
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
23
+
24
+ download_root: str
25
+ path to download the model files; by default, it uses "~/.cache/clip"
26
+
27
+ Returns
28
+ -------
29
+ model : torch.nn.Module
30
+ The {model} CLIP model
31
+
32
+ preprocess : Callable[[PIL.Image], torch.Tensor]
33
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
34
+ """
35
+ return entrypoint
36
+
37
+ def tokenize():
38
+ return _tokenize
39
+
40
+ _entrypoints = {model_functions[model]: _create_hub_entrypoint(model) for model in _available_models()}
41
+
42
+ globals().update(_entrypoints)
AlphaCLIP/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ ftfy
2
+ regex
3
+ tqdm
4
+ torch
5
+ torchvision
AlphaCLIP/setup.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import pkg_resources
4
+ from setuptools import setup, find_packages
5
+
6
+ setup(
7
+ name="alpha_clip",
8
+ py_modules=["alpha_clip"],
9
+ version="1.0",
10
+ description="",
11
+ author="OpenAI&ZeyiSun",
12
+ packages=find_packages(exclude=["tests*"]),
13
+ install_requires=[
14
+ str(r)
15
+ for r in pkg_resources.parse_requirements(
16
+ open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
17
+ )
18
+ ],
19
+ include_package_data=True,
20
+ extras_require={'dev': ['pytest']},
21
+ )
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🏢
4
  colorFrom: green
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
4
  colorFrom: green
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 3.48.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import sys
3
+ import torch
4
+ from omegaconf import OmegaConf
5
+ from PIL import Image
6
+ from diffusers import StableDiffusionInpaintPipeline
7
+ from model.clip_away import CLIPAway
8
+ import cv2
9
+ import numpy as np
10
+ import argparse
11
+
12
+ # Parse command line arguments
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--config", type=str, default="config/inference_config.yaml", help="Path to the config file")
15
+ parser.add_argument("--share", action="store_true", help="Share the interface if provided")
16
+ args = parser.parse_args()
17
+
18
+ # Load configuration and models
19
+ config = OmegaConf.load(args.config)
20
+ sd_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
21
+ "runwayml/stable-diffusion-inpainting", safety_checker=None, torch_dtype=torch.float32
22
+ )
23
+ clipaway = CLIPAway(
24
+ sd_pipe=sd_pipeline,
25
+ image_encoder_path=config.image_encoder_path,
26
+ ip_ckpt=config.ip_adapter_ckpt_path,
27
+ alpha_clip_path=config.alpha_clip_ckpt_pth,
28
+ config=config,
29
+ alpha_clip_id=config.alpha_clip_id,
30
+ device=config.device,
31
+ num_tokens=4
32
+ )
33
+
34
+ def dilate_mask(mask, kernel_size=5, iterations=5):
35
+ mask = mask.convert("L")
36
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
37
+ mask = cv2.dilate(np.array(mask), kernel, iterations=iterations)
38
+ return Image.fromarray(mask)
39
+
40
+ def combine_masks(uploaded_mask, sketched_mask):
41
+ if uploaded_mask is not None:
42
+ return uploaded_mask
43
+ elif sketched_mask is not None:
44
+ return sketched_mask
45
+ else:
46
+ raise ValueError("Please provide a mask")
47
+
48
+ def remove_obj(image, uploaded_mask, seed):
49
+ image_pil, sketched_mask = image["image"], image["mask"]
50
+ mask = dilate_mask(combine_masks(uploaded_mask, sketched_mask))
51
+ seed = int(seed)
52
+ latents = torch.randn((1, 4, 64, 64), generator=torch.Generator().manual_seed(seed)).to("cuda")
53
+ final_image = clipaway.generate(
54
+ prompt=[""], scale=1, seed=seed,
55
+ pil_image=[image_pil], alpha=[mask], strength=1, latents=latents
56
+ )[0]
57
+ return final_image
58
+
59
+ # Define example data
60
+ examples = [
61
+ ["assets/gradio_examples/images/1.jpg", "assets/gradio_examples/masks/1.png", 42],
62
+ ["assets/gradio_examples/images/2.jpg", "assets/gradio_examples/masks/2.png", 42],
63
+ ["assets/gradio_examples/images/3.jpg", "assets/gradio_examples/masks/3.png", 464],
64
+ ["assets/gradio_examples/images/4.jpg", "assets/gradio_examples/masks/4.png", 2024],
65
+ ]
66
+
67
+ # Define the Gradio interface
68
+ with gr.Blocks() as demo:
69
+ gr.Markdown("<h1 style='text-align:center'>CLIPAway: Harmonizing Focused Embeddings for Removing Objects via Diffusion Models</h1>")
70
+ gr.Markdown("""
71
+ <div style='display:flex; justify-content:center; align-items:center;'>
72
+ <a href='https://arxiv.org/abs/2406.09368' style="margin:10px;">Paper</a> |
73
+ <a href='https://yigitekin.github.io/CLIPAway/' style="margin:10px;">Project Website</a> |
74
+ <a href='https://github.com/YigitEkin/CLIPAway' style="margin:10px;">GitHub</a>
75
+ </div>
76
+ """)
77
+ gr.Markdown("""
78
+ This application allows you to remove objects from images using the CLIPAway method with diffusion models.
79
+ To use this tool:
80
+ 1. Upload an image.
81
+ 2. Either Sketch a mask over the object you want to remove or upload a pre-defined mask if you have one.
82
+ 4. Set the seed for reproducibility (default is 42).
83
+ 5. Click 'Remove Object' to process the image.
84
+ 6. The result will be displayed on the right side.
85
+ Note: The mask should be a binary image where the object to be removed is white and the background is black.
86
+ """)
87
+
88
+ with gr.Row():
89
+ with gr.Column():
90
+ image_input = gr.Image(label="Upload Image and Sketch Mask", type="pil", tool="sketch")
91
+ uploaded_mask = gr.Image(label="Upload Mask (Optional)", type="pil", optional=True)
92
+ seed_input = gr.Number(value=42, label="Seed")
93
+ process_button = gr.Button("Remove Object")
94
+ with gr.Column():
95
+ result_image = gr.Image(label="Result")
96
+
97
+ process_button.click(
98
+ fn=remove_obj,
99
+ inputs=[image_input, uploaded_mask, seed_input],
100
+ outputs=result_image
101
+ )
102
+
103
+ gr.Examples(
104
+ examples=examples,
105
+ inputs=[image_input, uploaded_mask, seed_input],
106
+ outputs=result_image
107
+ )
108
+
109
+ # Launch the interface with caching
110
+ if args.share:
111
+ demo.launch(share=True)
112
+ else:
113
+ demo.launch()
clip_l14_grit+mim_fultune_6xe.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5f3f2e24459e9764d9f4b4c053fb354dc9d508bd8f647b952402d6860bc9c3d
3
+ size 1216760175
config/inference_config.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ device: "cuda"
2
+ root_path: assets/gradio_examples
3
+ image_encoder_path: image_encoder
4
+ alpha_clip_ckpt_pth: clip_l14_grit+mim_fultune_6xe.pth
5
+ alpha_clip_id: ViT-L/14
6
+ ip_adapter_ckpt_path: ip-adapter_sd15.bin
7
+ sd_model_key: "runwayml/stable-diffusion-inpainting"
8
+ number_of_hidden_layers: 6
9
+ alpha_clip_embed_dim: 768
10
+ ip_adapter_embed_dim: 1024
11
+ mlp_projection_layer_ckpt_path: model.safetensors
12
+ save_path_prefix: test/results
13
+ seed: 42
14
+ scale: 1
15
+ strength: 1
16
+ display_focused_embeds: True
image_encoder/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./image_encoder",
3
+ "architectures": [
4
+ "CLIPVisionModelWithProjection"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "dropout": 0.0,
8
+ "hidden_act": "gelu",
9
+ "hidden_size": 1280,
10
+ "image_size": 224,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 5120,
14
+ "layer_norm_eps": 1e-05,
15
+ "model_type": "clip_vision_model",
16
+ "num_attention_heads": 16,
17
+ "num_channels": 3,
18
+ "num_hidden_layers": 32,
19
+ "patch_size": 14,
20
+ "projection_dim": 1024,
21
+ "torch_dtype": "float16",
22
+ "transformers_version": "4.28.0.dev0"
23
+ }
image_encoder/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d3ec1e66737f77a4f3bc2df3c52eacefc69ce7825e2784183b1d4e9877d9193
3
+ size 2528481905
ip-adapter_sd15.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68e1df30d760f280e578c302f1e73b37ea08654eff16a31153588047affe0058
3
+ size 44642825
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ade94c0505170a7698afe8ad4b4fb2307d06f67917b877cf1fd694a43cd6e335
3
+ size 22877152
model/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .clip_away import CLIPAway
2
+
3
+ __all__ = [
4
+ "CLIPAway"
5
+ ]
model/attention_processor.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ taken from https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/attention_processor.py
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange, repeat
8
+
9
+
10
+ class AttnProcessor(nn.Module):
11
+ r"""
12
+ Default processor for performing attention-related computations.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ hidden_size=None,
18
+ cross_attention_dim=None,
19
+ ):
20
+ super().__init__()
21
+
22
+ def __call__(
23
+ self,
24
+ attn,
25
+ hidden_states,
26
+ encoder_hidden_states=None,
27
+ attention_mask=None,
28
+ temb=None,
29
+ ):
30
+ residual = hidden_states
31
+
32
+ if attn.spatial_norm is not None:
33
+ hidden_states = attn.spatial_norm(hidden_states, temb)
34
+
35
+ input_ndim = hidden_states.ndim
36
+
37
+ if input_ndim == 4:
38
+ batch_size, channel, height, width = hidden_states.shape
39
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
40
+
41
+ batch_size, sequence_length, _ = (
42
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
43
+ )
44
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
45
+
46
+ if attn.group_norm is not None:
47
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
48
+
49
+ query = attn.to_q(hidden_states)
50
+
51
+ if encoder_hidden_states is None:
52
+ encoder_hidden_states = hidden_states
53
+ elif attn.norm_cross:
54
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
55
+
56
+ key = attn.to_k(encoder_hidden_states)
57
+ value = attn.to_v(encoder_hidden_states)
58
+
59
+ query = attn.head_to_batch_dim(query)
60
+ key = attn.head_to_batch_dim(key)
61
+ value = attn.head_to_batch_dim(value)
62
+
63
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
64
+ hidden_states = torch.bmm(attention_probs, value)
65
+ hidden_states = attn.batch_to_head_dim(hidden_states)
66
+
67
+ # linear proj
68
+ hidden_states = attn.to_out[0](hidden_states)
69
+ # dropout
70
+ hidden_states = attn.to_out[1](hidden_states)
71
+
72
+ if input_ndim == 4:
73
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
74
+
75
+ if attn.residual_connection:
76
+ hidden_states = hidden_states + residual
77
+
78
+ hidden_states = hidden_states / attn.rescale_output_factor
79
+
80
+ return hidden_states
81
+
82
+
83
+ class IPAttnProcessor(nn.Module):
84
+ r"""
85
+ Attention processor for IP-Adapater.
86
+ Args:
87
+ hidden_size (`int`):
88
+ The hidden size of the attention layer.
89
+ cross_attention_dim (`int`):
90
+ The number of channels in the `encoder_hidden_states`.
91
+ scale (`float`, defaults to 1.0):
92
+ the weight scale of image prompt.
93
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
94
+ The context length of the image features.
95
+ """
96
+
97
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
98
+ super().__init__()
99
+
100
+ self.hidden_size = hidden_size
101
+ self.cross_attention_dim = cross_attention_dim
102
+ self.scale = scale
103
+ self.num_tokens = num_tokens
104
+
105
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
106
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
107
+
108
+ def __call__(
109
+ self,
110
+ attn,
111
+ hidden_states,
112
+ encoder_hidden_states=None,
113
+ attention_mask=None,
114
+ temb=None,
115
+ ):
116
+ residual = hidden_states
117
+
118
+ if attn.spatial_norm is not None:
119
+ hidden_states = attn.spatial_norm(hidden_states, temb)
120
+
121
+ input_ndim = hidden_states.ndim
122
+
123
+ if input_ndim == 4:
124
+ batch_size, channel, height, width = hidden_states.shape
125
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
126
+
127
+ batch_size, sequence_length, _ = (
128
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
129
+ )
130
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
131
+
132
+ if attn.group_norm is not None:
133
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
134
+
135
+ query = attn.to_q(hidden_states)
136
+
137
+ if encoder_hidden_states is None:
138
+ encoder_hidden_states = hidden_states
139
+ else:
140
+ # get encoder_hidden_states, ip_hidden_states
141
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
142
+ encoder_hidden_states, ip_hidden_states = (
143
+ encoder_hidden_states[:, :end_pos, :],
144
+ encoder_hidden_states[:, end_pos:, :],
145
+ )
146
+ if attn.norm_cross:
147
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
148
+
149
+ key = attn.to_k(encoder_hidden_states)
150
+ value = attn.to_v(encoder_hidden_states)
151
+
152
+ query = attn.head_to_batch_dim(query)
153
+ key = attn.head_to_batch_dim(key)
154
+ value = attn.head_to_batch_dim(value)
155
+
156
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
157
+ #!MASK HERE
158
+ hidden_states = torch.bmm(attention_probs, value)
159
+ hidden_states = attn.batch_to_head_dim(hidden_states)
160
+
161
+ # for ip-adapter
162
+ ip_key = self.to_k_ip(ip_hidden_states)
163
+ ip_value = self.to_v_ip(ip_hidden_states)
164
+
165
+ ip_key = attn.head_to_batch_dim(ip_key)
166
+ ip_value = attn.head_to_batch_dim(ip_value)
167
+
168
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
169
+ #!MASK HERE
170
+ self.attn_map = ip_attention_probs
171
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
172
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
173
+
174
+ hidden_states = hidden_states + self.scale * ip_hidden_states
175
+
176
+ # linear proj
177
+ hidden_states = attn.to_out[0](hidden_states)
178
+ # dropout
179
+ hidden_states = attn.to_out[1](hidden_states)
180
+
181
+ if input_ndim == 4:
182
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
183
+
184
+ if attn.residual_connection:
185
+ hidden_states = hidden_states + residual
186
+
187
+ hidden_states = hidden_states / attn.rescale_output_factor
188
+
189
+ return hidden_states
model/clip_away.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ modified from from https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter.py
3
+ """
4
+ import os
5
+ from typing import List
6
+ import torch
7
+ from PIL import Image
8
+ from torchvision import transforms
9
+ from transformers import CLIPVisionModelWithProjection
10
+ import alpha_clip
11
+ from .utils import get_generator
12
+ from .attention_processor import AttnProcessor, IPAttnProcessor
13
+ from safetensors import safe_open
14
+ from safetensors.torch import load_model
15
+ import numpy as np
16
+
17
+ import torch.nn as nn
18
+
19
+
20
+ class ImageProjModel(torch.nn.Module):
21
+ """Projection Model of IP-Adapter"""
22
+
23
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
24
+ super().__init__()
25
+
26
+ self.generator = None
27
+ self.cross_attention_dim = cross_attention_dim
28
+ self.clip_extra_context_tokens = clip_extra_context_tokens
29
+ self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
30
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
31
+
32
+ def forward(self, image_embeds):
33
+ embeds = image_embeds
34
+ clip_extra_context_tokens = self.proj(embeds).reshape(
35
+ -1, self.clip_extra_context_tokens, self.cross_attention_dim
36
+ )
37
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
38
+ return clip_extra_context_tokens
39
+
40
+ class CLIPAway:
41
+ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, alpha_clip_path, config, alpha_clip_id="ViT-L/14", device="cuda", num_tokens=4):
42
+ super().__init__()
43
+ self.device = device
44
+ self.ipadapter_image_encoder_path = image_encoder_path
45
+ self.ipadapter_ckpt = ip_ckpt
46
+ self.num_tokens = num_tokens
47
+
48
+ self.pipe = sd_pipe.to(self.device)
49
+ self.set_ip_adapter()
50
+ alpha_clip_model, alpha_clip_preprocess = alpha_clip.load(alpha_clip_id, alpha_vision_ckpt_pth=alpha_clip_path, device=device)
51
+
52
+ # load image encoder
53
+ self.image_encoder = alpha_clip_model.visual.to(self.device, dtype=torch.float32)
54
+
55
+ self.clip_proj = CLIPVisionModelWithProjection.from_pretrained(self.ipadapter_image_encoder_path).to(
56
+ self.device, dtype=torch.float32
57
+ )
58
+ self.alpha_clip_image_processor = alpha_clip_preprocess
59
+
60
+ # preprocess mask transformation for alpha clip
61
+ if "@336" in alpha_clip_id:
62
+ self.mask_transform = transforms.Compose([
63
+ transforms.ToTensor(),
64
+ transforms.Resize((336, 336)), # change to (336,336) when using ViT-L/14@336px
65
+ transforms.Normalize(0.5, 0.26)
66
+ ])
67
+ else:
68
+ self.mask_transform = transforms.Compose([
69
+ transforms.ToTensor(),
70
+ transforms.Resize((224, 224)), # change to (336,336) when using ViT-L/14@336px
71
+ transforms.Normalize(0.5, 0.26)
72
+ ])
73
+ # image proj model
74
+ self.image_proj_model = self.init_proj()
75
+
76
+ self.load_ip_adapter()
77
+ self.mlp_projection_layer = self.generate_projection_layer(config)
78
+
79
+ print(config.mlp_projection_layer_ckpt_path, type(config.mlp_projection_layer_ckpt_path) )
80
+ if config.mlp_projection_layer_ckpt_path is not None:
81
+ self.load_projection_layer(config.mlp_projection_layer_ckpt_path)
82
+
83
+ def load_projection_layer(self, path):
84
+ load_model(self.mlp_projection_layer, path)
85
+ print("Projection layer loaded from", path)
86
+
87
+ def generate_projection_layer(self, config):
88
+ projection_layer = nn.ModuleList()
89
+
90
+ for i in range(config.number_of_hidden_layers):
91
+ if i < config.number_of_hidden_layers // 2:
92
+ projection_layer.append(nn.Linear(config.alpha_clip_embed_dim, config.alpha_clip_embed_dim))
93
+ projection_layer.append(nn.LayerNorm(config.alpha_clip_embed_dim))
94
+ elif i == config.number_of_hidden_layers // 2:
95
+ projection_layer.append(nn.Linear(config.alpha_clip_embed_dim, config.ip_adapter_embed_dim))
96
+ projection_layer.append(nn.LayerNorm(config.ip_adapter_embed_dim))
97
+ else:
98
+ projection_layer.append(nn.Linear(config.ip_adapter_embed_dim, config.ip_adapter_embed_dim))
99
+ projection_layer.append(nn.LayerNorm(config.ip_adapter_embed_dim))
100
+ projection_layer.append(nn.GELU())
101
+
102
+ projection_layer.append(nn.Linear(config.ip_adapter_embed_dim, config.ip_adapter_embed_dim))
103
+
104
+ return nn.Sequential(*projection_layer).to(self.device).to(torch.float32)
105
+
106
+ def init_proj(self):
107
+ image_proj_model = ImageProjModel(
108
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
109
+ clip_embeddings_dim=self.clip_proj.config.projection_dim,
110
+ clip_extra_context_tokens=self.num_tokens,
111
+ ).to(self.device, dtype=torch.float32)
112
+ return image_proj_model
113
+
114
+ def set_ip_adapter(self):
115
+ unet = self.pipe.unet
116
+ attn_procs = {}
117
+ for name in unet.attn_processors.keys():
118
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
119
+ if name.startswith("mid_block"):
120
+ hidden_size = unet.config.block_out_channels[-1]
121
+ elif name.startswith("up_blocks"):
122
+ block_id = int(name[len("up_blocks.")])
123
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
124
+ elif name.startswith("down_blocks"):
125
+ block_id = int(name[len("down_blocks.")])
126
+ hidden_size = unet.config.block_out_channels[block_id]
127
+ if cross_attention_dim is None:
128
+ attn_procs[name] = AttnProcessor().to(self.device)
129
+ else:
130
+ attn_procs[name] = IPAttnProcessor(
131
+ hidden_size=hidden_size,
132
+ cross_attention_dim=cross_attention_dim,
133
+ scale=1.0,
134
+ num_tokens=self.num_tokens,
135
+ ).to(self.device, dtype=torch.float32)
136
+ unet.set_attn_processor(attn_procs)
137
+
138
+ def get_alpha_clip_embeds(self, pil_image, alpha):
139
+ clip_image = [self.alpha_clip_image_processor(image) for image in pil_image]
140
+ clip_image = torch.stack(clip_image).to(self.device, dtype=torch.float32)
141
+ masks = [self.mask_transform(mask) for mask in alpha]
142
+ masks = torch.stack(masks).to(self.device, dtype=torch.float32)
143
+
144
+ return self.image_encoder(clip_image, masks)
145
+
146
+ def load_ip_adapter(self):
147
+ if os.path.splitext(self.ipadapter_ckpt)[-1] == ".safetensors":
148
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
149
+ with safe_open(self.ipadapter_ckpt, framework="pt", device="cpu") as f:
150
+ for key in f.keys():
151
+ if key.startswith("image_proj."):
152
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
153
+ elif key.startswith("ip_adapter."):
154
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
155
+ else:
156
+ state_dict = torch.load(self.ipadapter_ckpt, map_location="cpu")
157
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
158
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
159
+ ip_layers.load_state_dict(state_dict["ip_adapter"])
160
+
161
+ def get_complement_of_mask(self, mask):
162
+ return Image.fromarray((255 - np.array(mask[0])).astype(np.uint8))
163
+
164
+ def clipaway_projection_block(self, bg_embeds, fg_embeds):
165
+ projected_vector_magnitude = bg_embeds[0].dot(fg_embeds[0]) / fg_embeds[0].norm()
166
+ projected_vector = projected_vector_magnitude * fg_embeds / fg_embeds.norm()
167
+ return bg_embeds - projected_vector
168
+
169
+ def get_focused_embeddings(self, pil_image, alpha, use_projection_block=False):
170
+ # get focused alpha clip embeds
171
+ clip_image_embeds_fg = self.get_alpha_clip_embeds(pil_image, alpha)
172
+ clip_image_embeds_bg = self.get_alpha_clip_embeds(pil_image, [self.get_complement_of_mask(alpha)])
173
+
174
+ # mlp projection
175
+ projected_alpha_clip_embeds_fg = self.mlp_projection_layer(clip_image_embeds_fg)
176
+ projected_alpha_clip_embeds_bg = self.mlp_projection_layer(clip_image_embeds_bg)
177
+
178
+ # ip adapter logic
179
+ image_prompt_embeds_fg = self.image_proj_model(projected_alpha_clip_embeds_fg)
180
+ image_prompt_embeds_bg = self.image_proj_model(projected_alpha_clip_embeds_bg)
181
+ uncond_image_prompt_embeds = self.image_proj_model(self.mlp_projection_layer(torch.zeros_like(clip_image_embeds_fg)))
182
+
183
+ if use_projection_block:
184
+ # clipaway projection block
185
+ projected_alpha_clip_embeds = self.clipaway_projection_block(projected_alpha_clip_embeds_bg, projected_alpha_clip_embeds_fg)
186
+ image_prompt_embeds = self.image_proj_model(projected_alpha_clip_embeds)
187
+ return image_prompt_embeds, image_prompt_embeds_fg, image_prompt_embeds_bg, uncond_image_prompt_embeds
188
+
189
+ return image_prompt_embeds_fg, image_prompt_embeds_bg, uncond_image_prompt_embeds
190
+
191
+
192
+ def get_ipadapter_embeds(self, pil_image=None, alpha=None):
193
+ # get focused alpha clip embeds
194
+ clip_image_embeds_fg = self.get_alpha_clip_embeds(pil_image, alpha)
195
+ clip_image_embeds_bg = self.get_alpha_clip_embeds(pil_image, [self.get_complement_of_mask(alpha)])
196
+
197
+ # mlp projection
198
+ projected_alpha_clip_embeds_fg = self.mlp_projection_layer(clip_image_embeds_fg)
199
+ projected_alpha_clip_embeds_bg = self.mlp_projection_layer(clip_image_embeds_bg)
200
+
201
+ # clipaway projection block
202
+ projected_alpha_clip_embeds = self.clipaway_projection_block(projected_alpha_clip_embeds_bg, projected_alpha_clip_embeds_fg)
203
+
204
+ # ip adapter logic
205
+ image_prompt_embeds = self.image_proj_model(projected_alpha_clip_embeds)
206
+ uncond_image_prompt_embeds = self.image_proj_model(self.mlp_projection_layer(torch.zeros_like(clip_image_embeds_fg)))
207
+
208
+ return image_prompt_embeds, uncond_image_prompt_embeds
209
+
210
+
211
+ def set_scale(self, scale):
212
+ for attn_processor in self.pipe.unet.attn_processors.values():
213
+ if isinstance(attn_processor, IPAttnProcessor):
214
+ attn_processor.scale = scale
215
+
216
+ @torch.inference_mode()
217
+ def generate(
218
+ self,
219
+ pil_image=None,
220
+ alpha=None,
221
+ prompt=None,
222
+ negative_prompt=None,
223
+ image_prompt_embeds=None,
224
+ uncond_image_prompt_embeds=None,
225
+ scale=1.0,
226
+ num_samples=1,
227
+ seed=None,
228
+ guidance_scale=7.5,
229
+ num_inference_steps=50,
230
+ **kwargs,
231
+ ):
232
+ self.set_scale(scale)
233
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
234
+
235
+ if prompt is None:
236
+ prompt = "best quality, high quality"
237
+ if negative_prompt is None:
238
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
239
+
240
+ if not isinstance(prompt, List):
241
+ prompt = [prompt] * num_prompts
242
+ if not isinstance(negative_prompt, List):
243
+ negative_prompt = [negative_prompt] * num_prompts
244
+
245
+ if image_prompt_embeds is None or uncond_image_prompt_embeds is None:
246
+ image_prompt_embeds, uncond_image_prompt_embeds= self.get_ipadapter_embeds(pil_image=pil_image, alpha=alpha)
247
+ else:
248
+ image_prompt_embeds = image_prompt_embeds.to(self.device)
249
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.to(self.device)
250
+
251
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
252
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed, seq_len, -1)
253
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed, seq_len, -1)
254
+
255
+ with torch.inference_mode():
256
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
257
+ prompt,
258
+ device=self.device,
259
+ num_images_per_prompt=num_samples,
260
+ do_classifier_free_guidance=True,
261
+ negative_prompt=negative_prompt,
262
+ )
263
+ prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
264
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
265
+
266
+ generator = get_generator(seed, self.device)
267
+
268
+ images = self.pipe(
269
+ prompt_embeds=prompt_embeds,
270
+ negative_prompt_embeds=negative_prompt_embeds,
271
+ guidance_scale=guidance_scale,
272
+ num_inference_steps=num_inference_steps,
273
+ generator=generator,
274
+ image=pil_image,
275
+ mask_image=alpha,
276
+ **kwargs,
277
+ ).images
278
+
279
+ return images
280
+
model/resampler.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ taken from https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py
3
+ """
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from einops.layers.torch import Rearrange
10
+
11
+
12
+ # FFN
13
+ def FeedForward(dim, mult=4):
14
+ inner_dim = int(dim * mult)
15
+ return nn.Sequential(
16
+ nn.LayerNorm(dim),
17
+ nn.Linear(dim, inner_dim, bias=False),
18
+ nn.GELU(),
19
+ nn.Linear(inner_dim, dim, bias=False),
20
+ )
21
+
22
+
23
+ def reshape_tensor(x, heads):
24
+ bs, length, width = x.shape
25
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
26
+ x = x.view(bs, length, heads, -1)
27
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
28
+ x = x.transpose(1, 2)
29
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
30
+ x = x.reshape(bs, heads, length, -1)
31
+ return x
32
+
33
+
34
+ class PerceiverAttention(nn.Module):
35
+ def __init__(self, *, dim, dim_head=64, heads=8):
36
+ super().__init__()
37
+ self.scale = dim_head**-0.5
38
+ self.dim_head = dim_head
39
+ self.heads = heads
40
+ inner_dim = dim_head * heads
41
+
42
+ self.norm1 = nn.LayerNorm(dim)
43
+ self.norm2 = nn.LayerNorm(dim)
44
+
45
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
46
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
47
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
48
+
49
+ def forward(self, x, latents):
50
+ """
51
+ Args:
52
+ x (torch.Tensor): image features
53
+ shape (b, n1, D)
54
+ latent (torch.Tensor): latent features
55
+ shape (b, n2, D)
56
+ """
57
+ x = self.norm1(x)
58
+ latents = self.norm2(latents)
59
+
60
+ b, l, _ = latents.shape
61
+
62
+ q = self.to_q(latents)
63
+ kv_input = torch.cat((x, latents), dim=-2)
64
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
65
+
66
+ q = reshape_tensor(q, self.heads)
67
+ k = reshape_tensor(k, self.heads)
68
+ v = reshape_tensor(v, self.heads)
69
+
70
+ # attention
71
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
72
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
73
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
74
+ out = weight @ v
75
+
76
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
77
+
78
+ return self.to_out(out)
79
+
80
+
81
+ class Resampler(nn.Module):
82
+ def __init__(
83
+ self,
84
+ dim=1024,
85
+ depth=8,
86
+ dim_head=64,
87
+ heads=16,
88
+ num_queries=8,
89
+ embedding_dim=768,
90
+ output_dim=1024,
91
+ ff_mult=4,
92
+ max_seq_len: int = 257, # CLIP tokens + CLS token
93
+ apply_pos_emb: bool = False,
94
+ num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
95
+ ):
96
+ super().__init__()
97
+ self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
98
+
99
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
100
+
101
+ self.proj_in = nn.Linear(embedding_dim, dim)
102
+
103
+ self.proj_out = nn.Linear(dim, output_dim)
104
+ self.norm_out = nn.LayerNorm(output_dim)
105
+
106
+ self.to_latents_from_mean_pooled_seq = (
107
+ nn.Sequential(
108
+ nn.LayerNorm(dim),
109
+ nn.Linear(dim, dim * num_latents_mean_pooled),
110
+ Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
111
+ )
112
+ if num_latents_mean_pooled > 0
113
+ else None
114
+ )
115
+
116
+ self.layers = nn.ModuleList([])
117
+ for _ in range(depth):
118
+ self.layers.append(
119
+ nn.ModuleList(
120
+ [
121
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
122
+ FeedForward(dim=dim, mult=ff_mult),
123
+ ]
124
+ )
125
+ )
126
+
127
+ def forward(self, x):
128
+ if self.pos_emb is not None:
129
+ n, device = x.shape[1], x.device
130
+ pos_emb = self.pos_emb(torch.arange(n, device=device))
131
+ x = x + pos_emb
132
+
133
+ latents = self.latents.repeat(x.size(0), 1, 1)
134
+
135
+ x = self.proj_in(x)
136
+
137
+ if self.to_latents_from_mean_pooled_seq:
138
+ meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
139
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
140
+ latents = torch.cat((meanpooled_latents, latents), dim=-2)
141
+
142
+ for attn, ff in self.layers:
143
+ latents = attn(x, latents) + latents
144
+ latents = ff(latents) + latents
145
+
146
+ latents = self.proj_out(latents)
147
+ return self.norm_out(latents)
148
+
149
+
150
+ def masked_mean(t, *, dim, mask=None):
151
+ if mask is None:
152
+ return t.mean(dim=dim)
153
+
154
+ denom = mask.sum(dim=dim, keepdim=True)
155
+ mask = rearrange(mask, "b n -> b n 1")
156
+ masked_t = t.masked_fill(~mask, 0.0)
157
+
158
+ return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)