diff --git a/AlphaCLIP/.gitignore b/AlphaCLIP/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..a77f9fe9b0e05a49ebf1903fe98ec5138b020f26 --- /dev/null +++ b/AlphaCLIP/.gitignore @@ -0,0 +1,12 @@ +__pycache__/ +*.py[cod] +*$py.class +*.egg-info +.pytest_cache +.ipynb_checkpoints + +thumbs.db +.DS_Store +.idea +checkpoints/* +*.pth \ No newline at end of file diff --git a/AlphaCLIP/LICENSE b/AlphaCLIP/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..284803df68c514bdd81477e9248b9a0b4e769533 --- /dev/null +++ b/AlphaCLIP/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [Zeyi Sun] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/AlphaCLIP/MANIFEST.in b/AlphaCLIP/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..670a10cb3795e1c35c281a53e78b671eca9787de --- /dev/null +++ b/AlphaCLIP/MANIFEST.in @@ -0,0 +1 @@ +include alpha_clip/bpe_simple_vocab_16e6.txt.gz diff --git a/AlphaCLIP/alpha_clip/__init__.py b/AlphaCLIP/alpha_clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d5b643bb3da8fde1fcadedb6919a36fb544cf97 --- /dev/null +++ b/AlphaCLIP/alpha_clip/__init__.py @@ -0,0 +1 @@ +from .alpha_clip import * diff --git a/AlphaCLIP/alpha_clip/alpha_clip.py b/AlphaCLIP/alpha_clip/alpha_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..f9f79d147d9a056dc83b4169781e44651f22b012 --- /dev/null +++ b/AlphaCLIP/alpha_clip/alpha_clip.py @@ -0,0 +1,250 @@ +import hashlib +import os +import urllib +import warnings +from typing import Any, Union, List +from pkg_resources import packaging + +import torch +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from tqdm import tqdm + +from .model import build_model +from .simple_tokenizer import SimpleTokenizer as _Tokenizer + +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): + warnings.warn("PyTorch version 1.7.1 or higher is recommended") + + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = _Tokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", +} + + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load(name: str, alpha_vision_ckpt_pth="None", device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None, lora_adapt=False, rank=16): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + alpha_vision_ckpt_pth: str + only changed when inferencing model instead of training + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + with open(model_path, 'rb') as opened_file: + try: + # loading JIT archive + model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(opened_file, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict(), lora_adapt=lora_adapt, rank=rank).to(device) + if str(device) == "cpu": + model.float() + if alpha_vision_ckpt_pth != "None": + model.visual.load_state_dict(torch.load(alpha_vision_ckpt_pth)) + model.eval() # merge lora params if exists (for inference only) + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def _node_get(node: torch._C.Node, key: str): + """Gets attributes of a node which is polymorphic over return type. + + From https://github.com/pytorch/pytorch/pull/82628 + """ + sel = node.kindOf(key) + return getattr(node, sel)(key) + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if _node_get(inputs[i].node(), "value") == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + return model, _transform(model.input_resolution.item()) + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = True) -> Union[torch.IntTensor, torch.LongTensor]: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + else: + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/AlphaCLIP/alpha_clip/bpe_simple_vocab_16e6.txt.gz b/AlphaCLIP/alpha_clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/AlphaCLIP/alpha_clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/AlphaCLIP/alpha_clip/model.py b/AlphaCLIP/alpha_clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..3e4aaf0ccc32d84c88b63bd94a426123b05372c1 --- /dev/null +++ b/AlphaCLIP/alpha_clip/model.py @@ -0,0 +1,598 @@ +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +import loralib as lora +import math +import collections + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + return x.squeeze(0) + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.conv1_alpha = nn.Conv2d(in_channels=1, out_channels=width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x, alpha=None): + def stem(x): + x = self.relu1(self.bn1(self.conv1(x) + self.conv1_alpha(alpha))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + scaled_cosine=False, + scale_heads=False, + logit_scale_max=math.log(1. / 0.01), + attn_drop=0., + proj_drop=0., + lora_adapt=False, + rank=16 + ): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.logit_scale_max = logit_scale_max + + # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original + if lora_adapt: + print("!!!!!!!!!!using lora for qkv projection!!!!!!!!!!") + self.in_proj = lora.MergedLinear(dim, 3*dim, r=rank, enable_lora=[True, False, True]) + else: + self.in_proj = nn.Linear(dim, dim * 3) + # self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) + # if qkv_bias: + # self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) + # else: + # self.in_proj_bias = None + + if self.scaled_cosine: + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + else: + self.logit_scale = None + self.attn_drop = nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) + else: + self.head_scale = None + self.out_proj = nn.Linear(dim, dim) if not lora_adapt else lora.Linear(dim, dim, r=rank) + self.out_drop = nn.Dropout(proj_drop) + + def forward(self, x, attn_mask = None): + L, N, C = x.shape + q, k, v = self.in_proj(x).chunk(3, dim=-1) + q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + + if self.logit_scale is not None: + attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) + logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() + attn = attn.view(N, self.num_heads, L, L) * logit_scale + attn = attn.view(-1, L, L) + else: + q = q * self.scale + attn = torch.bmm(q, k.transpose(-2, -1)) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + attn += attn_mask + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = torch.bmm(attn, v) + if self.head_scale is not None: + x = x.view(N, self.num_heads, L, C) * self.head_scale + x = x.view(-1, L, C) + x = x.transpose(0, 1).reshape(L, N, C) + x = self.out_proj(x) + x = self.out_drop(x) + return x, attn + + +class CustomResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, lora_adapt=False, rank=16): + super().__init__() + + self.attn = Attention(d_model, n_head, lora_adapt=lora_adapt, rank=rank) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4) if not lora_adapt else lora.Linear(d_model, d_model*4, r=rank)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model) if not lora_adapt else lora.Linear(d_model*4, d_model, r=rank)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, attn_mask=self.attn_mask) + + def forward(self, x: torch.Tensor, return_attn=False): + attn_out, attn = self.attention(self.ln_1(x)) + x = x + attn_out + x = x + self.mlp(self.ln_2(x)) + if return_attn: + return x, attn + else: + return x + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + +class CustomTransformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, lora_adapt=False, rank=16): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[CustomResidualAttentionBlock(width, heads, attn_mask, lora_adapt=lora_adapt, rank=rank) for _ in range(layers)]) + + def forward(self, x: torch.Tensor, return_attn=False): + if return_attn: + for i, block in enumerate(self.resblocks): + if i == len(self.resblocks) - 1: + return block(x, return_attn=True) + else: + x = block(x) + assert False + return self.resblocks(x) + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, lora_adapt=False, rank=16): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + self.conv1_alpha = nn.Conv2d(in_channels=1, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = CustomTransformer(width, layers, heads, lora_adapt=lora_adapt, rank=rank) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor, alpha=None, return_attn=False): + x = self.conv1(x) # shape = [*, width, grid, grid] + # ASSUME alpha is always not None! + x = x + self.conv1_alpha(alpha) + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + if return_attn: + x, attn_last = self.transformer(x, return_attn=True) + else: + x = self.transformer(x, return_attn=False) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + if return_attn: + return x, attn_last + else: + return x + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int, + lora_adapt = False, + rank = 16, + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim, + lora_adapt=lora_adapt, + rank=rank + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + if not hasattr(self.visual, "conv1"): + return self.visual.module.conv1.weight.dtype + return self.visual.conv1.weight.dtype + + def encode_image(self, image, alpha): + assert alpha is not None + return self.visual(image.type(self.dtype), alpha.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text, alpha): + image_features = self.encode_image(image, alpha) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict, lora_adapt=False, rank=16): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) + + # always load lora version + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, + lora_adapt=lora_adapt, rank=rank, + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + # para_wb to linear + new_state_dict = collections.OrderedDict() + for k, v in state_dict.items(): + if 'visual' in k: + if 'in_proj_weight' in k: + new_state_dict[k.replace('in_proj_weight', 'in_proj.weight')] = v + elif 'in_proj_bias' in k: + new_state_dict[k.replace('in_proj_bias', 'in_proj.bias')] = v + else: + new_state_dict[k] = v + else: + new_state_dict[k] = v + + state_dict = new_state_dict + # add rgba_conv_weight + if 'visual.conv1_alpha.weight' not in state_dict.keys(): # zero initialization on alpha channel + rgb_weight = state_dict['visual.conv1.weight'].clone().detach() + rgba_weigth = torch.zeros_like(rgb_weight)[:, 0:1, :, :] + state_dict['visual.conv1_alpha.weight'] = rgba_weigth + convert_weights(model) + model.load_state_dict(state_dict, strict=False) + return model.eval() diff --git a/AlphaCLIP/alpha_clip/simple_tokenizer.py b/AlphaCLIP/alpha_clip/simple_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0a66286b7d5019c6e221932a813768038f839c91 --- /dev/null +++ b/AlphaCLIP/alpha_clip/simple_tokenizer.py @@ -0,0 +1,132 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text diff --git a/AlphaCLIP/eval/README.md b/AlphaCLIP/eval/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d72f99fd1fdfba3e17ad4ca6e0e035e6f7b38055 --- /dev/null +++ b/AlphaCLIP/eval/README.md @@ -0,0 +1,6 @@ +# Alpha-CLIP evaluation +## Zero-Shot Classification on ImageNet-S +checkout [imagenet_s_zs_test](https://github.com/SunzeY/AlphaCLIP/tree/eval-dev/eval/imagenet_s_zs_test) + +## Zero-Shot Referring Expression Comprehension on RefCOCO +checkout [rec_zs_test](https://github.com/SunzeY/AlphaCLIP/tree/eval-dev/eval/rec_zs_test) diff --git a/AlphaCLIP/eval/imagenet_s_zs_test/.gitignore b/AlphaCLIP/eval/imagenet_s_zs_test/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..735c18ad2ca128cbff4e55f5a88facd5ecffe0f8 --- /dev/null +++ b/AlphaCLIP/eval/imagenet_s_zs_test/.gitignore @@ -0,0 +1,2 @@ +*.json +data/* \ No newline at end of file diff --git a/AlphaCLIP/eval/imagenet_s_zs_test/README.md b/AlphaCLIP/eval/imagenet_s_zs_test/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bb4bd1329ab96cca10e0d1fe0f5c2a525b5b6050 --- /dev/null +++ b/AlphaCLIP/eval/imagenet_s_zs_test/README.md @@ -0,0 +1,21 @@ +# Alpha-CLIP evaluation +## Zero-Shot Classification on ImageNet-S + +1.prepare [imagenet-s](https://github.com/LUSSeg/ImageNet-S) dataset, only `validation` raw image is needed. + +2.download [imagenet_919.json](https://download.openxlab.org.cn/models/SunzeY/AlphaCLIP/weight/imagenet_919.json) we provide as data annotation (generated from imagenet-s annotation). The folder should be structured like + +``` +├── imagenet_s_zs_test +│ ├── data +│ │ ├── imagenet_919.json +│ │ └── ImageNetS919 +│ │ └── validation +``` + +3.run test script. + +``` +cd eval/imagenet_s_zs_test +python imagenet_s_zs_test.py +``` diff --git a/AlphaCLIP/eval/imagenet_s_zs_test/imagenet_s.py b/AlphaCLIP/eval/imagenet_s_zs_test/imagenet_s.py new file mode 100644 index 0000000000000000000000000000000000000000..c5d9ec466140d9b6a3f5b8afccc9e9392d3fb146 --- /dev/null +++ b/AlphaCLIP/eval/imagenet_s_zs_test/imagenet_s.py @@ -0,0 +1,149 @@ +import json +import os +import random +from tqdm import tqdm +from torch.utils.data import Dataset +from pycocotools.coco import COCO +from pycocotools import mask as maskUtils +from PIL import Image +import cv2 +import random +from torchvision import transforms +from tqdm import tqdm + +import pickle +import torch +import numpy as np +import copy +import sys +import shutil +from PIL import Image +from nltk.corpus import wordnet + +PIXEL_MEAN = (0.48145466, 0.4578275, 0.40821073) +MASK_FILL = [int(255 * c) for c in PIXEL_MEAN] + + +clip_standard_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Resize((224, 224), interpolation=Image.BICUBIC), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), +]) + +hi_clip_standard_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Resize((336, 336), interpolation=Image.BICUBIC), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), +]) + +res_clip_standard_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Resize((336, 336), interpolation=Image.BICUBIC), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), +]) + +mask_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Resize((224, 224)), + transforms.Normalize(0.5, 0.26) +]) + +hi_mask_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Resize((336, 336)), + transforms.Normalize(0.5, 0.26) +]) + +res_mask_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Resize((336, 336)), + transforms.Normalize(0.5, 0.26) +]) + +def crop_center(img, croph, cropw): + h, w = img.shape[:2] + starth = h//2 - (croph//2) + startw = w//2 - (cropw//2) + return img[starth:starth+croph, startw:startw+cropw, :] + +class Imagenet_S(Dataset): + def __init__(self, ann_file='data/imagenet_919.json', hi_res=False, all_one=False): + self.anns = json.load(open(ann_file, 'r')) + self.root_pth = 'data/' + cats = [] + for ann in self.anns: + if ann['category_word'] not in cats: + cats.append(ann['category_word']) + ann['cat_index'] = len(cats) - 1 + self.classes = [] + for cat_word in cats: + synset = wordnet.synset_from_pos_and_offset('n', int(cat_word[1:])) + synonyms = [x.name() for x in synset.lemmas()] + self.classes.append(synonyms[0]) + + self.choice = "center_crop" + if hi_res: + self.mask_transform = res_mask_transform + self.clip_standard_transform = res_clip_standard_transform + else: + self.mask_transform = mask_transform + self.clip_standard_transform = clip_standard_transform + + self.all_one = all_one + + def __len__(self): + return len(self.anns) + + def __getitem__(self, index): + ann = self.anns[index] + image = cv2.imread(os.path.join(self.root_pth, ann['image_pth'])) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + mask = maskUtils.decode(ann['mask']) + # image[mask==0] = MASK_FILL + rgba = np.concatenate((image, np.expand_dims(mask, axis=-1)), axis=-1) + h, w = rgba.shape[:2] + + if self.choice == "padding": + if max(h, w) == w: + pad = (w - h) // 2 + l, r = pad, w - h - pad + rgba = np.pad(rgba, ((l, r), (0, 0), (0, 0)), 'constant', constant_values=0) + else: + pad = (h - w) // 2 + l, r = pad, h - w - pad + rgba = np.pad(rgba, ((0, 0), (l, r), (0, 0)), 'constant', constant_values=0) + else: + if min(h, w) == h: + rgba = crop_center(rgba, h, h) + else: + rgba = crop_center(rgba, w, w) + rgb = rgba[:, :, :-1] + mask = rgba[:, :, -1] + image_torch = self.clip_standard_transform(rgb) + # using box: bounding-box compute + # bi_mask = mask == 1 + # h, w = bi_mask.shape[-2:] + # in_height = np.max(bi_mask, axis=-1) + # in_height_coords = np.max(bi_mask, axis=-1) * np.arange(h) + # b_e = in_height_coords.max() + # in_height_coords = in_height_coords + h * (~in_height) + # t_e = in_height_coords.min() + # in_width = np.max(bi_mask, axis=-2) + # in_width_coords = np.max(bi_mask, axis=-2) * np.arange(w) + # r_e = in_width_coords.max() + # in_width_coords = in_width_coords + w * (~in_width) + # l_e = in_width_coords.min() + # box = np.zeros_like(mask) + # box[t_e: b_e, l_e:r_e] = 1 + # mask = box + if self.all_one: + mask_torch = self.mask_transform(np.ones_like(mask) * 255) + else: + mask_torch = self.mask_transform(mask * 255) + return image_torch, mask_torch, ann['cat_index'] + +if __name__ == "__main__": + data = Imagenet_S() + for i in tqdm(range(data.__len__())): + data.__getitem__(i) \ No newline at end of file diff --git a/AlphaCLIP/eval/imagenet_s_zs_test/imagenet_s_zs_test.py b/AlphaCLIP/eval/imagenet_s_zs_test/imagenet_s_zs_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c71101360b275f16d8b83128a9ad45f9b979f5ad --- /dev/null +++ b/AlphaCLIP/eval/imagenet_s_zs_test/imagenet_s_zs_test.py @@ -0,0 +1,66 @@ +import torch +import alpha_clip +from tqdm import tqdm +from imagenet_s import Imagenet_S + +model, preprocess = alpha_clip.load("ViT-L/14@336px", alpha_vision_ckpt_pth="../../clip_l14@336_grit_20m_4xe.pth") + +def zeroshot_classifier(classnames, templates): + with torch.no_grad(): + zeroshot_weights = [] + for classname in tqdm(classnames): + texts = [template.format(classname) for template in templates] #format with class + texts = alpha_clip.tokenize(texts).cuda() #tokenize + class_embeddings = model.encode_text(texts) #embed with text encoder + class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) + class_embedding = class_embeddings.mean(dim=0) + class_embedding /= class_embedding.norm() + zeroshot_weights.append(class_embedding) + zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda() + return zeroshot_weights + +dataset = Imagenet_S(hi_res=True) +loader = torch.utils.data.DataLoader(dataset, batch_size=64, num_workers=2) + +imagenet_templates = [ + 'a photo of a {}.' +] + +zeroshot_weights = zeroshot_classifier(dataset.classes, imagenet_templates) +temp_corr_dict = dict() + +with torch.no_grad(): + for i, (images, alpha, target) in enumerate(tqdm(loader)): + images = images.cuda() + alpha = alpha.cuda() + target = target.cuda() + # predict + image_features = model.encode_image(images, alpha) + image_features /= image_features.norm(dim=-1, keepdim=True) + score = 100. * image_features @ zeroshot_weights + + pred = score.topk(1, dim=1)[1].squeeze(dim=1) + pred_5 = score.topk(5, dim=1)[1].squeeze(dim=1) + + for i in range(target.shape[0]): + if target[i].item() not in temp_corr_dict: + temp_corr_dict[target[i].item()] = [0, 0, 0] + temp_corr_dict[target[i].item()][0] += 1 + if target[i].item() == pred[i].item(): + temp_corr_dict[target[i].item()][1] += 1 + if target[i].item() in pred_5[i].tolist(): + temp_corr_dict[target[i].item()][2] += 1 + +acc1 = 0.0 +acc5 = 0.0 +num_class = 0 +for v in temp_corr_dict.values(): + if v[0] == 0: continue + acc1 += v[1] / v[0] + acc5 += v[2] / v[0] + num_class += 1 +acc1 = acc1 / num_class * 100 +acc5 = acc5 / num_class * 100 + +print(f"Top-1 accuracy: {acc1:.2f}") +print(f"Top-5 accuracy: {acc5:.2f}") diff --git a/AlphaCLIP/eval/rec_zs_test/LICENSE.md b/AlphaCLIP/eval/rec_zs_test/LICENSE.md new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/AlphaCLIP/eval/rec_zs_test/LICENSE.md @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/AlphaCLIP/eval/rec_zs_test/README.md b/AlphaCLIP/eval/rec_zs_test/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b959649394886d33d7d8d7f9fa9f58d1f8f0d1b0 --- /dev/null +++ b/AlphaCLIP/eval/rec_zs_test/README.md @@ -0,0 +1,74 @@ +## Zero-Shot Referring Expression Comprehension on RefCOCO + +**Preparing Data** + +1.Download [images for RefCOCO/g/+](http://images.cocodataset.org/zips/train2014.zip). Put downloaded dataset(train2014) to eval/rec_zs_test/data/. + +2.Download preprocessed data files via `gsutil cp gs://reclip-sanjays/reclip_data.tar.gz` and `cd rec_zs_test`, and then extract the data using `tar -xvzf reclip_data.tar.gz`. + +**Preparing model** + +3.Download [SAM](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth) (vit-h), [Alpha-CLIP](https://github.com/SunzeY/AlphaCLIP/blob/main/model-zoo.md) model, and put them in ./eval/rec_zs_test/ckpt. + +``` +├── eval +│ ├── rec_zs_test +│ │ ├── data +│ │ └── train2014 +│ │ ├── reclip_data +│ │ └── refcoco_val.jsonl +│ │ └── refcoco_dets_dict.json +│ │ ... +│ │ ├── ckpt +│ │ └── sam_vit_h_4b8939.pth +│ │ └── grit1m +│ │ └── clip_b16_grit+mim_fultune_4xe.pth +│ │ └── clip_l14_grit+mim_fultune_6xe.pth +│ │ ├── methods +│ │ ├── cache +│ │ ├── output +│ │ ├── main.py +│ │ ├── executor.py +│ │ ├── run.sh +│ │ ├── ... +``` + +4.run test script. + +``` +cd eval/rec_zs_test +``` +``` +bash run.sh +``` +or + +``` +python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_representation_method full,blur --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --detector_file reclip_data/refcoco+_dets_dict.json --cache_path ./cache +``` +(We recommend using `cache_path` to reduce time to generate mask by SAM for a image repeatedly.`) + +For multi-gpus testing, try: + +``` +bash run_multi_gpus.sh +python cal_acc.py refcoco_val +``` + + +**Acknowledgement** + +We test our model based on the wonderful work [ReCLIP](https://github.com/allenai/reclip/tree/main). We simply replace CLIP with Alpha-CLIP; and skip the image-cropping operation. + + + +**Experiment results** + +| Method | RefCOCO | | | RefCOCO+ | | | RefCOCOg | | +|----------------|---------|------|------|----------|------|------|----------|------| +| | Val | TestA| TestB| Val | TestA| TestB| Val | Test | +| CPT [67] | 32.2 | 36.1 | 30.3 | 31.9 | 35.2 | 28.8 | 36.7 | 36.5 | +| ReCLIP [54] | 45.8 | 46.1 | 47.1 | 47.9 | 50.1 | 45.1 | 59.3 | 59.0 | +| Red Circle [52]| 49.8 | 58.6 | 39.9 | 55.3 | 63.9 | 45.4 | 59.4 | 58.9 | +| Alpha-CLIP | 55.7 | 61.1 | 50.3 | 55.6 | 62.7 | 46.4 | 61.2 | 62.0 | + diff --git a/AlphaCLIP/eval/rec_zs_test/cache/.gitkeep b/AlphaCLIP/eval/rec_zs_test/cache/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/AlphaCLIP/eval/rec_zs_test/cal_acc.py b/AlphaCLIP/eval/rec_zs_test/cal_acc.py new file mode 100644 index 0000000000000000000000000000000000000000..51640609f5d180b06e114606b334fbc99d9b7f22 --- /dev/null +++ b/AlphaCLIP/eval/rec_zs_test/cal_acc.py @@ -0,0 +1,21 @@ +import json +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('name', type=str, default='refcoco_val') + +args = parser.parse_args() + +name = args.name +print(name) +count = 0 +all_count = 0 +for i in range(8): + pth = f'output/{name}_count_{i}.json' + acc = json.load(open(pth, 'r')) + a_list = acc.split() + a, b = a_list[0], a_list[1] + count += int(a) + all_count += int(b) + +print(float(count) / float(all_count)) \ No newline at end of file diff --git a/AlphaCLIP/eval/rec_zs_test/ckpt/.gitkeep b/AlphaCLIP/eval/rec_zs_test/ckpt/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/AlphaCLIP/eval/rec_zs_test/data/.gitkeep b/AlphaCLIP/eval/rec_zs_test/data/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/AlphaCLIP/eval/rec_zs_test/entity_extraction.py b/AlphaCLIP/eval/rec_zs_test/entity_extraction.py new file mode 100644 index 0000000000000000000000000000000000000000..008122f414e1463b7e6bd50e39f9e89d3143c5c4 --- /dev/null +++ b/AlphaCLIP/eval/rec_zs_test/entity_extraction.py @@ -0,0 +1,142 @@ +from typing import Dict, Any, Callable, List, Tuple, NamedTuple, Text, Optional +import numpy as np +from spacy.tokens.token import Token +from spacy.tokens.span import Span + +from lattice import Product as L + +from heuristics import Heuristics + +Rel = Tuple[List[Token], "Entity"] +Sup = List[Token] + +DEFAULT_HEURISTICS = Heuristics() + + +def find_superlatives(tokens, heuristics) -> List[Sup]: + """Modify and return a list of superlative tokens.""" + for heuristic in heuristics.superlatives: + if any(tok.text in heuristic.keywords for tok in tokens): + tokens.sort(key=lambda tok: tok.i) + return [tokens] + return [] + +def expand_chunks(doc, chunks): + expanded = {} + for key in chunks: + chunk = chunks[key] + start = chunk.start + end = chunk.end + for i in range(chunk.start-1, -1, -1): + if any(doc[j].is_ancestor(doc[i]) for j in range(chunk.start, chunk.end)): + if not any(any(doc[i].is_ancestor(doc[j]) for j in range(chunks[key2].start, chunks[key2].end)) for key2 in chunks if key != key2): + start = i + for i in range(chunk.end, len(doc)): + if any(doc[j].is_ancestor(doc[i]) for j in range(chunk.start, chunk.end)): + if not any(any(doc[i].is_ancestor(doc[j]) or i == j for j in range(chunks[key2].start, chunks[key2].end)) for key2 in chunks if key != key2): + end = i+1 + else: + break + expanded[key] = Span(doc=doc, start=start, end=end) + return expanded + +class Entity(NamedTuple): + """Represents an entity with locative constraints extracted from the parse.""" + + head: Span + relations: List[Rel] + superlatives: List[Sup] + + @classmethod + def extract(cls, head, chunks, heuristics: Optional[Heuristics] = None) -> "Entity": + """Extract entities from a spacy parse. + + Jointly recursive with `_get_rel_sups`.""" + if heuristics is None: + heuristics = DEFAULT_HEURISTICS + + if head.i not in chunks: + # Handles predicative cases. + children = list(head.children) + if children and children[0].i in chunks: + head = children[0] + # TODO: Also extract predicative relations. + else: + return None + hchunk = chunks[head.i] + rels, sups = cls._get_rel_sups(head, head, [], chunks, heuristics) + return cls(hchunk, rels, sups) + + @classmethod + def _get_rel_sups(cls, token, head, tokens, chunks, heuristics) -> Tuple[List[Rel], List[Sup]]: + hchunk = chunks[head.i] + is_keyword = any(token.text in h.keywords for h in heuristics.relations) + is_keyword |= token.text in heuristics.null_keywords + + # Found another entity head. + if token.i in chunks and chunks[token.i] is not hchunk and not is_keyword: + tchunk = chunks[token.i] + tokens.sort(key=lambda tok: tok.i) + subhead = cls.extract(token, chunks, heuristics) + return [(tokens, subhead)], [] + + # End of a chain of modifiers. + n_children = len(list(token.children)) + if n_children == 0: + return [], find_superlatives(tokens + [token], heuristics) + + relations = [] + superlatives = [] + is_keyword |= any(token.text in h.keywords for h in heuristics.superlatives) + for child in token.children: + if token.i in chunks and child.i in chunks and chunks[token.i] is chunks[child.i]: + if not any(child.text in h.keywords for h in heuristics.superlatives): + if n_children == 1: + # Catches "the goat on the left" + sups = find_superlatives(tokens + [token], heuristics) + superlatives.extend(sups) + continue + new_tokens = tokens + [token] if token.i not in chunks or is_keyword else tokens + subrel, subsup = cls._get_rel_sups(child, head, new_tokens, chunks, heuristics) + relations.extend(subrel) + superlatives.extend(subsup) + return relations, superlatives + + def expand(self, span: Span = None): + tokens = [token for token in self.head] + if span is None: + span = [None] + for target_token in span: + include = False + stack = [token for token in self.head] + while len(stack) > 0: + token = stack.pop() + if token == target_token: + token2 = target_token.head + while token2.head != token2: + tokens.append(token2) + token2 = token2.head + tokens.append(token2) + stack = [] + include = True + if target_token is None or include: + tokens.append(token) + for child in token.children: + stack.append(child) + tokens = list(set(tokens)) + tokens = sorted(tokens, key=lambda x: x.i) + return ' '.join([token.text for token in tokens]) + + def __eq__(self, other: "Entity") -> bool: + if self.text != other.text: + return False + if self.relations != other.relations: + return False + if self.superlatives != other.superlatives: + return False + return True + + @property + def text(self) -> Text: + """Get the text predicate associated with this entity.""" + return self.head.text diff --git a/AlphaCLIP/eval/rec_zs_test/executor.py b/AlphaCLIP/eval/rec_zs_test/executor.py new file mode 100644 index 0000000000000000000000000000000000000000..e4d4b763e73eb0e0a81cc9fd324da4690441c5ea --- /dev/null +++ b/AlphaCLIP/eval/rec_zs_test/executor.py @@ -0,0 +1,401 @@ +from typing import List, Dict, Union, Tuple + +from PIL import Image, ImageDraw, ImageFilter, ImageOps, ImageEnhance +import spacy +import hashlib +import os + +import torch +import torchvision +import torchvision.transforms as transforms +import clip +from transformers import BertTokenizer, RobertaTokenizerFast +import ruamel.yaml as yaml +import copy + +from interpreter import Box + +import pycocotools.mask as mask_utils +import alpha_clip +from segment_anything import sam_model_registry, SamPredictor +import numpy as np +import cv2 +import matplotlib.pyplot as plt + +import pickle + +class Executor: + def __init__(self, device: str = "cpu", box_representation_method: str = "crop", method_aggregator: str = "max", enlarge_boxes: int = 0, expand_position_embedding: bool = False, square_size: bool = False, blur_std_dev: int = 100, cache_path: str = None, input_file: str = None) -> None: + IMPLEMENTED_METHODS = ["blur", "full", "gray"] + if any(m not in IMPLEMENTED_METHODS for m in box_representation_method.split(",")): + raise NotImplementedError + IMPLEMENTED_AGGREGATORS = ["max", "sum"] + if method_aggregator not in IMPLEMENTED_AGGREGATORS: + raise NotImplementedError + self.box_representation_method = box_representation_method + self.method_aggregator = method_aggregator + self.enlarge_boxes = enlarge_boxes + self.device = device + self.expand_position_embedding = expand_position_embedding + self.square_size = square_size + self.blur_std_dev = blur_std_dev + self.cache_path = cache_path + + def preprocess_image(self, image: Image) -> List[torch.Tensor]: + return [preprocess(image) for preprocess in self.preprocesses] + + def preprocess_mask(self, mask: Image) -> List[torch.Tensor]: + preprocess = self.preprocesses[0] + return preprocess.transforms[1](preprocess.transforms[0](mask)) + + def preprocess_text(self, text: str) -> torch.Tensor: + raise NotImplementedError + + def call_model(self, model: torch.nn.Module, images: torch.Tensor, text: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor: + raise NotImplementedError + + def tensorize_inputs(self, caption: str, image: Image, boxes: List[Box], image_name: str = None, image_pth: str = None) -> Tuple[List[torch.Tensor], torch.Tensor]: + images = [] + for preprocess in self.preprocesses: + images.append([]) + + if 'aclip' in self.clip_type: + self.all_masks = [] + read_save = False + if self.mask_path is not None: # load mask if cached + file_name = image_pth.split('/')[-1].split('.')[0]+'.pkl' + if os.path.exists(os.path.join(self.mask_path, file_name)): + all_rles = pickle.load(open(os.path.join(self.mask_path, file_name),'rb')) + for rle in all_rles: + mask = np.array(mask_utils.decode(rle), dtype=bool) + self.all_masks.append(mask) + read_save = True + if not read_save: + # use SAM to generate masks + self.predictor.set_image(np.array(image.convert('RGB'))) + all_rles = [] + for i in range(len(boxes)): + box = [ + max(boxes[i].left-self.enlarge_boxes, 0), + max(boxes[i].top-self.enlarge_boxes, 0), + min(boxes[i].right+self.enlarge_boxes, image.width), + min(boxes[i].bottom+self.enlarge_boxes, image.height) + ] # box prompt + input_box = np.array(box) + masks, _, _ = self.predictor.predict( + point_coords=None, + point_labels=None, + box=input_box[None, :], + multimask_output=False, + ) + self.all_masks.append(masks[0]) + rle = mask_utils.encode(np.array(masks[0][:, :, None], order='F', dtype="uint8"))[0] + rle["counts"] = rle["counts"].decode("utf-8") + all_rles.append(rle) + if self.mask_path is not None: # save mask + os.makedirs(self.mask_path, exist_ok=True) + pickle.dump(all_rles, open(os.path.join(self.mask_path, file_name),'wb')) + + if self.cache_path is None or any([not os.path.exists(os.path.join(self.cache_path, "refcoco_val", model_name, "image", image_name, method_name+".pt")) for model_name in self.model_names for method_name in self.box_representation_method.split(',')]): + if "full" in self.box_representation_method: # original full image with alpha-map + for i in range(len(boxes)): + image_i = image.copy() + preprocessed_images = self.preprocess_image(image_i) + for j, img in enumerate(preprocessed_images): + images[j].append(img.to(self.device)) + if "blur" in self.box_representation_method: + for i in range(len(boxes)): + image_i = image.copy() + + mask = Image.new('L', image_i.size, 0) + draw = ImageDraw.Draw(mask) + box = ( + max(boxes[i].left-self.enlarge_boxes, 0), + max(boxes[i].top-self.enlarge_boxes, 0), + min(boxes[i].right+self.enlarge_boxes, image_i.width), + min(boxes[i].bottom+self.enlarge_boxes, image_i.height) + ) + if 'aclip' in self.clip_type: + width, height = image.size + for y in range(height): + for x in range(width): + if self.all_masks[i][y][x] == 1: + draw.point((x, y), fill=255) + else: + draw.rectangle([box[:2], box[2:]], fill=255) + blurred = image_i.filter(ImageFilter.GaussianBlur(self.blur_std_dev)) + blurred.paste(image_i, mask=mask) + preprocessed_images = self.preprocess_image(blurred) + + for j, img in enumerate(preprocessed_images): + images[j].append(img.to(self.device)) + if "gray" in self.box_representation_method: + for i in range(len(boxes)): + image_i = image.copy() + mask_i = self.all_masks[i] + width, height = image.size + + pixels = image_i.load() + for y in range(height): + for x in range(width): + if mask_i[y][x] == 0: + pixel_value = pixels[x, y] + gray_value = int(0.2989 * pixel_value[0] + 0.5870 * pixel_value[1] + 0.1140 * pixel_value[2]) + pixels[x, y] = (gray_value, gray_value, gray_value) + preprocessed_images = self.preprocess_image(image_i) + for j, img in enumerate(preprocessed_images): + images[j].append(img.to(self.device)) + + imgs = [torch.stack(image_list) for image_list in images] + else: + imgs = [[] for _ in self.models] + text_tensor = self.preprocess_text(caption.lower()).to(self.device) + return imgs, text_tensor + + @torch.no_grad() + def __call__(self, caption: str, image: Image, boxes: List[Box], image_name: str = None, image_pth=None) -> torch.Tensor: + images, text_tensor = self.tensorize_inputs(caption, image, boxes, image_name, image_pth) + all_logits_per_image = [] + all_logits_per_text = [] + box_representation_methods = self.box_representation_method.split(',') + caption_hash = hashlib.md5(caption.encode('utf-8')).hexdigest() + for model, images_t, model_name in zip(self.models, images, self.model_names): + self.image_feat_path = "" + if self.cache_path is not None: + text_cache_path = os.path.join(self.cache_path, "refcoco_val", model_name, "text"+("_shade" if self.box_representation_method == "shade" else "")) + image_feat_path = os.path.join(self.cache_path, "refcoco_val", model_name, "image", image_name) + self.image_feat_path = image_feat_path + image_features = None + text_features = None + if self.cache_path is not None and os.path.exists(os.path.join(self.cache_path, "refcoco_val", model_name)): + if os.path.exists(os.path.join(text_cache_path, caption_hash+".pt")): + text_features = torch.load(os.path.join(text_cache_path, caption_hash+".pt"), map_location=self.device) + if os.path.exists(image_feat_path): + if all([os.path.exists(os.path.join(image_feat_path, method_name+".pt")) for method_name in box_representation_methods]): + image_features = [] + for method_name in box_representation_methods: + features = torch.load(os.path.join(image_feat_path, method_name+".pt"), map_location=self.device) + image_features.append(torch.stack([ + features[(box.x, box.y, box.w, box.h)] + for box in boxes + ])) + image_features = torch.stack(image_features) + image_features = image_features.view(-1, image_features.shape[-1]) + logits_per_image, logits_per_text, image_features, text_features = self.call_model(model, images_t, text_tensor, image_features=image_features, text_features=text_features, boxes=boxes, image_pth=image_pth) + all_logits_per_image.append(logits_per_image) + all_logits_per_text.append(logits_per_text) + if self.cache_path is not None and image_name is not None and image_features is not None: + image_features = image_features.view(len(box_representation_methods), len(boxes), image_features.shape[-1]) + if not os.path.exists(image_feat_path): + os.makedirs(image_feat_path) + for i in range(image_features.shape[0]): + method_name = box_representation_methods[i] + if not os.path.exists(os.path.join(image_feat_path, method_name+".pt")): + image_features_dict = {(box.x, box.y, box.w, box.h): image_features[i,j,:].cpu() for j, box in enumerate(boxes)} + torch.save(image_features_dict, os.path.join(image_feat_path, method_name+".pt")) + if self.cache_path is not None and not os.path.exists(os.path.join(text_cache_path, caption_hash+".pt")) and text_features is not None: + assert text_features.shape[0] == 1 + if not os.path.exists(text_cache_path): + os.makedirs(text_cache_path) + torch.save(text_features.cpu(), os.path.join(text_cache_path, caption_hash+".pt")) + + all_logits_per_image = torch.stack(all_logits_per_image).sum(0) + all_logits_per_text = torch.stack(all_logits_per_text).sum(0) + if self.method_aggregator == "max": + all_logits_per_text = all_logits_per_text.view(-1, len(boxes)).max(dim=0, keepdim=True)[0] + elif self.method_aggregator == "sum": + all_logits_per_text = all_logits_per_text.view(-1, len(boxes)).sum(dim=0, keepdim=True) + return all_logits_per_text.view(-1) + +class ClipExecutor(Executor): + def __init__(self, clip_model: str = "ViT-B/32", device: str = "cpu", box_representation_method: str = "crop", method_aggregator: str = "max", enlarge_boxes: int = 0, expand_position_embedding: bool = False, square_size: bool = False, blur_std_dev: int = 100, cache_path: str = None, input_file: str = None, clip_type: str=None) -> None: + super().__init__(device, box_representation_method, method_aggregator, enlarge_boxes, expand_position_embedding, square_size, blur_std_dev, cache_path) + self.clip_models = clip_model.split(",") + self.model_names = [model_name.replace("/", "_") for model_name in self.clip_models] + self.models = [] + self.preprocesses = [] + self.data_name = input_file.split('/')[-1].split('.')[0] + self.mask_path = None + self.clip_type = clip_type + if self.cache_path is not None: + self.mask_path = os.path.join(self.cache_path, "refcoco_val", 'det_masks') + sam_checkpoint = "./ckpt/sam_vit_h_4b8939.pth" + model_type = "vit_h" + sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) + sam.to(device=device) + self.predictor = SamPredictor(sam) + for model_name in self.clip_models: + if 'aclip' in self.clip_type:#using alpha-clip + self.mask_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Resize((224, 224)), + transforms.Normalize(0.5, 0.26) + ]) + if model_name == 'ViT-B/16': + model, preprocess = alpha_clip.load("ViT-B/16", alpha_vision_ckpt_pth="./ckpt/grit1m/clip_b16_grit+mim_fultune_4xe.pth", device=device) + elif model_name == 'ViT-L/14': + model, preprocess = alpha_clip.load("ViT-L/14", alpha_vision_ckpt_pth="./ckpt/grit1m/clip_l14_grit+mim_fultune_6xe.pth", device=device) + + else: model, preprocess = clip.load(model_name, device=device, jit=False) + self.models.append(model) + if self.square_size: + print("Square size!") + preprocess.transforms[0] = transforms.Resize((model.visual.input_resolution, model.visual.input_resolution), interpolation=transforms.InterpolationMode.BICUBIC) + self.preprocesses.append(preprocess) + self.models = torch.nn.ModuleList(self.models) + + def preprocess_text(self, text: str) -> torch.Tensor: + if "aclip" in self.box_representation_method: + return alpha_clip.tokenize([text.lower()]) + if "shade" in self.box_representation_method: + return clip.tokenize([text.lower()+" is in red color."]) + return clip.tokenize(["a photo of "+text.lower()]) + + def call_model(self, model: torch.nn.Module, images: torch.Tensor, text: torch.Tensor, image_features: torch.Tensor = None, text_features: torch.Tensor = None, boxes=None, image_pth=None) -> torch.Tensor: + if image_features is None: + print('computing image features') + if 'aclip' not in self.clip_type: + image_features = model.encode_image(images) + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + else: + image_features = [] + if 'full' in self.box_representation_method: + aclip_images = images[:len(boxes)] + alphas = [] + + if os.path.exists(os.path.join(self.image_feat_path, 'full.pt')): + features = torch.load(os.path.join(self.image_feat_path, 'full.pt'), map_location=self.device) + aclip_image_features = torch.stack([ + features[(box.x, box.y, box.w, box.h)] + for box in boxes + ]) + else: + for i in range(len(self.all_masks)): + binary_mask = self.all_masks[i] + alpha = self.mask_transform((binary_mask * 255).astype(np.uint8)) + alpha = alpha.half().cuda().unsqueeze(dim=0) + alphas.append(alpha) + + alphas = torch.cat(alphas, dim=0) + aclip_images = aclip_images.half() + aclip_image_features = model.visual(aclip_images, alphas) # using alpha channels + images = images[len(boxes):] + image_features.append(aclip_image_features) + + if 'blur' in self.box_representation_method: + if os.path.exists(os.path.join(self.image_feat_path, 'blur.pt')): + features = torch.load(os.path.join(self.image_feat_path, 'blur.pt'), map_location=self.device) + ablur_images_features = torch.stack([ + features[(box.x, box.y, box.w, box.h)] + for box in boxes + ]) + else: + ablur_images = images[:len(boxes)] + alphas = [] + for i in range(len(self.all_masks)): + binary_mask = self.all_masks[i] + alpha = self.mask_transform((binary_mask * 255).astype(np.uint8)) + alpha = alpha.half().cuda().unsqueeze(dim=0) + alphas.append(alpha) + alphas = torch.cat(alphas, dim=0) + ablur_images = ablur_images.half() + ablur_images_features = model.visual(ablur_images, alphas) + images = images[len(boxes):] + image_features.append(ablur_images_features) + + if 'gray' in self.box_representation_method: + if os.path.exists(os.path.join(self.image_feat_path, 'gray.pt')): + features = torch.load(os.path.join(self.image_feat_path, 'gray.pt'), map_location=self.device) + gray_images_features = torch.stack([ + features[(box.x, box.y, box.w, box.h)] + for box in boxes + ]) + else: + gray_images = images[:len(boxes)] + alphas = [] + for i in range(len(self.all_masks)): + binary_mask = self.all_masks[i] + alpha = self.mask_transform((binary_mask * 255).astype(np.uint8)) + alpha = alpha.half().cuda().unsqueeze(dim=0) + alphas.append(alpha) + alphas = torch.cat(alphas, dim=0) + gray_images = gray_images.half() + gray_images_features = model.visual(gray_images, alphas) + images = images[len(boxes):] + image_features.append(gray_images_features) + + + image_features = torch.cat(image_features, dim=0) + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + + if text_features is None: + print('computing text features') + text_features = model.encode_text(text) + # normalized features + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = model.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + return logits_per_image, logits_per_text, image_features, text_features + + def __call__(self, caption: str, image: Image, boxes: List[Box], image_name: str = None, image_pth=None) -> torch.Tensor: + if self.expand_position_embedding: + original_preprocesses = self.preprocesses + new_preprocesses = [] + original_position_embeddings = [] + for model_name, model, preprocess in zip(self.clip_models, self.models, self.preprocesses): + if "RN" in model_name: + model_spatial_dim = int((model.visual.attnpool.positional_embedding.shape[0]-1)**0.5) + patch_size = model.visual.input_resolution // model_spatial_dim + original_positional_embedding = model.visual.attnpool.positional_embedding.clone() + model.visual.attnpool.positional_embedding = torch.nn.Parameter(torch.nn.functional.interpolate( + model.visual.attnpool.positional_embedding[1:,:].permute(1, 0).view(1, -1, model_spatial_dim, model_spatial_dim), + size=(image.height // patch_size, image.width // patch_size), + mode='bicubic', + align_corners=False + ).squeeze(0).permute(1, 2, 0).view(-1, original_positional_embedding.shape[-1])) + model.visual.attnpool.positional_embedding = torch.nn.Parameter(torch.cat(( + original_positional_embedding[:1,:], + model.visual.attnpool.positional_embedding + ), dim=0)) + transform = transforms.Compose([ + transforms.Resize(((image.height // patch_size)*patch_size, (image.width // patch_size)*patch_size), interpolation=Image.BICUBIC), + lambda image: image.convert("RGB"), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + else: + model_spatial_dim = int((model.visual.positional_embedding.shape[0]-1)**0.5) + patch_size = model.visual.input_resolution // model_spatial_dim + original_positional_embedding = model.visual.positional_embedding.clone() + model.visual.positional_embedding = torch.nn.Parameter(torch.nn.functional.interpolate( + model.visual.positional_embedding[1:,:].permute(1, 0).view(1, -1, model_spatial_dim, model_spatial_dim), + size=(image.height // patch_size, image.width // patch_size), + mode='bicubic', + align_corners=False + ).squeeze(0).permute(1, 2, 0).view(-1, original_positional_embedding.shape[-1])) + model.visual.positional_embedding = torch.nn.Parameter(torch.cat(( + original_positional_embedding[:1,:], + model.visual.positional_embedding + ), dim=0)) + transform = transforms.Compose([ + transforms.Resize(((image.height // patch_size)*patch_size, (image.width // patch_size)*patch_size), interpolation=Image.BICUBIC), + lambda image: image.convert("RGB"), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + new_preprocesses.append(transform) + original_position_embeddings.append(original_positional_embedding) + self.preprocesses = new_preprocesses + result = super().__call__(caption, image, boxes, image_name, image_pth) + if self.expand_position_embedding: + self.preprocesses = original_preprocesses + for model, model_name, pos_embedding in zip(self.models, self.clip_models, original_position_embeddings): + if "RN" in model_name: + model.visual.attnpool.positional_embedding = torch.nn.Parameter(pos_embedding) + else: + model.visual.positional_embedding = torch.nn.Parameter(pos_embedding) + return result + diff --git a/AlphaCLIP/eval/rec_zs_test/generic_clip_pairs.py b/AlphaCLIP/eval/rec_zs_test/generic_clip_pairs.py new file mode 100644 index 0000000000000000000000000000000000000000..98e32822f3926b220e16aa138f907b68b40d1b0f --- /dev/null +++ b/AlphaCLIP/eval/rec_zs_test/generic_clip_pairs.py @@ -0,0 +1,107 @@ +import os +import clip +import json +import argparse +import ruamel.yaml as yaml + +from PIL import Image +import torch +import torchvision.transforms as transforms +from tqdm import tqdm + +from albef.utils import * +from executor import AlbefExecutor + +parser = argparse.ArgumentParser() +parser.add_argument("--input_path", type=str, help="Path to input JSON file") +parser.add_argument("--image_root", type=str, help="Path to directory containing images") +parser.add_argument("--albef_path", type=str, default=None, help="Path to ALBEF model/config/etc. if the goal is to use ALBEF") +parser.add_argument("--albef_itc", action="store_true", help="Use ITC output of ALBEF") +parser.add_argument("--clip_model", type=str, help="CLIP model to use") +parser.add_argument("--gpu", type=int, default=-1, help="Which gpu to use") +parser.add_argument("--batch_size", type=int, default=32, help="Batch size for running CLIP") + +args = parser.parse_args() + +if args.albef_path is not None: + executor = AlbefExecutor(checkpoint_path = os.path.join(args.albef_path, "checkpoint.pth"), config_path = os.path.join(args.albef_path, "config.yaml"), device = "cpu" if args.gpu < 0 else "cuda:"+str(args.gpu)) + model = executor.models[0] + preprocess = executor.preprocesses[0] + model = model.eval() +else: + model, preprocess = clip.load(args.clip_model, jit=False, device="cuda:"+str(args.gpu)) + preprocess.transforms[0] == transforms.Resize((model.visual.input_resolution, model.visual.input_resolution), transforms.InterpolationMode.BICUBIC) + model = model.eval() +input_file = open(args.input_path) +data = json.load(input_file) +input_file.close() +correct = 0 +for i in tqdm(range(0, len(data), args.batch_size)): + batch_images = [] + batch_text = [] + for datum in data[i:min(i+args.batch_size, len(data))]: + img = Image.open(os.path.join(args.image_root, datum["image_filename"])).convert('RGB') + batch_images.append(preprocess(img)) + if "text2" in datum: + if args.albef_path is None: + datum["text1"] = "a photo of "+datum["text1"] + datum["text2"] = "a photo of "+datum["text2"] + batch_text.append(datum["text1"]) + batch_text.append(datum["text2"]) + else: + img2 = Image.open(os.path.join(args.image_root, datum["image_filename2"])).convert('RGB') + batch_images.append(preprocess(img2)) + batch_text.append(datum["text1"]) + batch_images = torch.stack(batch_images).to("cuda:"+str(args.gpu)) + if args.albef_path is None: + batch_text = clip.tokenize(batch_text).to("cuda:"+str(args.gpu)) + else: + modified_text = [pre_caption(txt, executor.max_words) for txt in batch_text] + batch_text = executor.tokenizer(modified_text, padding='longest', return_tensors="pt") + for key in batch_text: + batch_text[key] = batch_text[key].to(batch_images.device) + + with torch.no_grad(): + if args.albef_path is None: + logits_per_image, logits_per_text = model(batch_images, batch_text) + else: + if not args.albef_itc: + if batch_images.shape[0]*2 == batch_text.input_ids.shape[0]: + batch_images = batch_images.unsqueeze(1).repeat(1, 2, 1, 1, 1).view(batch_images.shape[0]*2, batch_images.shape[1], batch_images.shape[2], batch_images.shape[3]) + else: + assert batch_images.shape[0] ==2*batch_text.input_ids.shape[0] + batch_text.input_ids = batch_text.input_ids.unsqueeze(1).repeat(1, 2, 1).view(batch_images.shape[0], -1) + batch_text.attention_mask = batch_text.attention_mask.unsqueeze(1).repeat(1, 2, 1).view(batch_images.shape[0], -1) + image_embeds = model.visual_encoder(batch_images) + image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(batch_images.device) + output = model.text_encoder( + batch_text.input_ids, + attention_mask = batch_text.attention_mask, + encoder_hidden_states = image_embeds, + encoder_attention_mask = image_atts, + return_dict = True, + ) + vl_embeddings = output.last_hidden_state[:,0,:] + vl_output = model.itm_head(vl_embeddings) + logits_per_image = vl_output[:,1:2].view(-1, 2) + else: + image_embeds = model.visual_encoder(batch_images) + image_feat = torch.nn.functional.normalize(model.vision_proj(image_embeds[:,0,:]),dim=-1) + text_output = model.text_encoder(batch_text.input_ids, attention_mask = batch_text.attention_mask, + return_dict = True, mode = 'text') + text_embeds = text_output.last_hidden_state + text_feat = torch.nn.functional.normalize(model.text_proj(text_embeds[:,0,:]),dim=-1) + sim = image_feat@text_feat.t()/model.temp + logits_per_image = sim + if args.albef_path is None or args.albef_itc: + if logits_per_image.shape[0]*2 == logits_per_image.shape[1]: + for j in range(logits_per_image.shape[0]): + correct += 1 if logits_per_image[j,2*j].item() > logits_per_image[j,2*j+1].item() else 0 + else: + assert logits_per_image.shape[0] == 2*logits_per_image.shape[1] + for j in range(logits_per_image.shape[1]): + correct += 1 if logits_per_image[2*j,j].item() > logits_per_image[2*j+1,j].item() else 0 + else: + correct += (logits_per_image[:,0] > logits_per_image[:,1]).long().sum().item() + +print("Accuracy:", correct/len(data)) diff --git a/AlphaCLIP/eval/rec_zs_test/heuristics.py b/AlphaCLIP/eval/rec_zs_test/heuristics.py new file mode 100644 index 0000000000000000000000000000000000000000..531cf115795f8c8d73c8ffb37978895324cb8677 --- /dev/null +++ b/AlphaCLIP/eval/rec_zs_test/heuristics.py @@ -0,0 +1,68 @@ +"""Heuristic rules used to extract and execute entity parses.""" + +from typing import Callable, List, NamedTuple +from argparse import Namespace +import numpy as np + + +class RelHeuristic(NamedTuple): + keywords: List[str] + callback: Callable[["Environment"], np.ndarray] + + +class Heuristics: + """A class defining heuristics that can be enabled/disabled.""" + + RELATIONS = [ + RelHeuristic(["left", "west"], lambda env: env.left_of()), + RelHeuristic(["right", "east"], lambda env: env.right_of()), + RelHeuristic(["above", "north", "top", "back", "behind"], lambda env: env.above()), + RelHeuristic(["below", "south", "under", "front"], lambda env: env.below()), + RelHeuristic(["bigger", "larger", "closer"], lambda env: env.bigger_than()), + RelHeuristic(["smaller", "tinier", "further"], lambda env: env.smaller_than()), + RelHeuristic(["inside", "within", "contained"], lambda env: env.within()), + ] + + TERNARY_RELATIONS = [ + RelHeuristic(["between"], lambda env: env.between()), + ] + + SUPERLATIVES = [ + RelHeuristic(["left", "west", "leftmost", "western"], lambda env: env.left_of()), + RelHeuristic(["right", "rightmost", "east", "eastern"], lambda env: env.right_of()), + RelHeuristic(["above", "north", "top"], lambda env: env.above()), + RelHeuristic(["below", "south", "underneath", "front"], lambda env: env.below()), + RelHeuristic(["bigger", "biggest", "larger", "largest", "closer", "closest"], lambda env: env.bigger_than()), + RelHeuristic(["smaller", "smallest", "tinier", "tiniest", "further", "furthest"], lambda env: env.smaller_than()), + ] + OPPOSITES = {0: 1, 1: 0, 2: 3, 3: 2, 4: 5, 5: 4} + + NULL_KEYWORDS = ["part", "image", "side", "picture", "half", "region", "section"] + + EMPTY = [] + + def __init__(self, args: Namespace = None): + self.enable_relations = not args or not args.no_rel + self.enable_superlatives = not args or not args.no_sup + self.enable_nulls = not args or not args.no_null + self.enable_ternary = not args or args.ternary + + @property + def relations(self) -> List[RelHeuristic]: + return self.RELATIONS if self.enable_relations else self.EMPTY + + @property + def ternary_relations(self) -> List[RelHeuristic]: + return self.TERNARY_RELATIONS if self.enable_ternary else self.EMPTY + + @property + def superlatives(self) -> List[RelHeuristic]: + return self.SUPERLATIVES if self.enable_superlatives else self.EMPTY + + @property + def opposites(self): + return self.OPPOSITES + + @property + def null_keywords(self) -> List[str]: + return self.NULL_KEYWORDS if self.enable_nulls else self.EMPTY diff --git a/AlphaCLIP/eval/rec_zs_test/interpreter.py b/AlphaCLIP/eval/rec_zs_test/interpreter.py new file mode 100644 index 0000000000000000000000000000000000000000..6464bfd33d673349f90bb0616532675e1a61b0de --- /dev/null +++ b/AlphaCLIP/eval/rec_zs_test/interpreter.py @@ -0,0 +1,212 @@ +from typing import NamedTuple, List, Callable +import sys +import re +import numpy as np +import torch +from numpy.linalg import norm +from itertools import product, groupby +from PIL import Image + + +# Do two line segments intersect? Copied from +# https://stackoverflow.com/questions/3838329/how-can-i-check-if-two-segments-intersect + + +def ccw(A, B, C): + return (C.y - A.y) * (B.x - A.x) > (B.y - A.y) * (C.x - A.x) + + +def intersect(A, B, C, D): + """Do line segments AB and CD intersect?""" + return ccw(A, C, D) != ccw(B, C, D) and ccw(A, B, C) != ccw(A, B, D) + + +class Box(NamedTuple): + x: int + y: int + w: int = 0 + h: int = 0 + + @property + def left(self): + return self.x + + @property + def right(self): + return self.x + self.w + + @property + def top(self): + return self.y + + @property + def bottom(self): + return self.y + self.h + + @property + def center(self): + return Box(self.x + self.w // 2, self.y + self.h // 2) + + def corners(self): + yield Box(self.x, self.y) + yield Box(self.x + self.w, self.y) + yield Box(self.x + self.w, self.y + self.h) + yield Box(self.x, self.y + self.h) + + @property + def area(self): + return self.w * self.h + + def intersect(self, other: "Box") -> "Box": + x1 = max(self.x, other.x) + x2 = max(x1, min(self.x+self.w, other.x+other.w)) + y1 = max(self.y, other.y) + y2 = max(y1, min(self.y+self.h, other.y+other.h)) + return Box(x=x1, y=y1, w=x2-x1, h=y2-y1) + + def min_bounding(self, other: "Box") -> "Box": + corners = list(self.corners()) + corners.extend(other.corners()) + min_x = min_y = float("inf") + max_x = max_y = -float("inf") + + for item in corners: + min_x = min(min_x, item.x) + min_y = min(min_y, item.y) + max_x = max(max_x, item.x) + max_y = max(max_y, item.y) + + return Box(min_x, min_y, max_x - min_x, max_y - min_y) + + def expand(self, growth: float = .1) -> "Box": + factor = 1 + growth + w = factor * self.w + h = factor * self.h + return Box(min_x - (w - self.w) / 2, min_y - (h - self.h) / 2, w, h) + + +def iou(box1, box2): + x1 = max(box1.x, box2.x) + x2 = max(x1, min(box1.x+box1.w, box2.x+box2.w)) + y1 = max(box1.y, box2.y) + y2 = max(y1, min(box1.y+box1.h, box2.y+box2.h)) + intersection = Box(x=x1, y=y1, w=x2-x1, h=y2-y1) + intersection_area = intersection.area + union_area = box1.area+box2.area-intersection_area + return intersection_area / union_area + + +def all_equal(iterable): + """Are all elements the same?""" + g = groupby(iterable) + return next(g, True) and not next(g, False) + + +class spatial: + """A decorator that converts a predicate over boxes to a function that returns a tensor over all boxes.""" + + def __init__(self, arity: int = 2, enforce_antisymmetry: bool = False): + self.arity = arity + self.enforce_antisymmetry = enforce_antisymmetry # Zero out any entries where two boxes are the same. + + def __call__(self, predicate: Callable[[Box], float]) -> Callable[["Environment"], np.ndarray]: + def _rel(env): + n_boxes = len(env.boxes) + tensor = np.empty([n_boxes for _ in range(self.arity)]) + enum_boxes = list(enumerate(env.boxes)) + for pairs in product(*[enum_boxes for _ in range(self.arity)]): + indices, boxes = zip(*pairs) + if self.enforce_antisymmetry and len(set(indices)) < len(indices): + tensor[indices] = 0. + else: + tensor[indices] = predicate(*boxes) + return tensor + return _rel + + +class Environment: + def __init__(self, image: Image, boxes: List[Box], executor: "Executor" = None, freeform_boxes: bool = False, image_name: str = None, image_pth: str=None): + self.image = image + self.boxes = boxes + self.executor = executor # An object or callback that can query CLIP with captions/images. + self.freeform_boxes = freeform_boxes + self.image_name = image_name + self.image_pth=image_pth + + def uniform(self) -> np.ndarray: + n_boxes = len(self.boxes) + return 1 / n_boxes * np.ones(n_boxes) + + def filter(self, + caption: str, + temperature: float = 1., + area_threshold: float = 0.0, + softmax: bool = False, + expand: float = None + ) -> np.ndarray: + """Return a new distribution reflecting the likelihood that `caption` describes the content of each box.""" + area_filtered_dist = torch.from_numpy(self.filter_area(area_threshold)).to(self.executor.device) + candidate_indices = [i for i in range(len(self.boxes)) if float(area_filtered_dist[i]) > 0.0] + boxes = [self.boxes[i] for i in candidate_indices] + if len(boxes) == 0: + boxes = self.boxes + candidate_indices = list(range(len(boxes))) + if expand is not None: + boxes = [box.expand(expand) for box in boxes] + result_partial = self.executor(caption, self.image, boxes, image_name=self.image_name, image_pth=self.image_pth) + if self.freeform_boxes: + result_partial, boxes = result_partial + self.boxes = [Box(x=boxes[i,0].item(), y=boxes[i,1].item(), w=boxes[i,2].item()-boxes[i,0].item(), h=boxes[i,3].item()-boxes[i,1].item()) for i in range(boxes.shape[0])] + candidate_indices = list(range(len(self.boxes))) + result_partial = result_partial.float() + if not softmax: + result_partial = (result_partial-result_partial.mean()) / (result_partial.std() + 1e-9) + result_partial = (temperature * result_partial).sigmoid() + result = torch.zeros((len(self.boxes))).to(result_partial.device) + result[candidate_indices] = result_partial + else: + result = torch.zeros((len(self.boxes))).to(result_partial.device) + result[candidate_indices] = result_partial.softmax(dim=-1) #softmax结果 + return result.cpu().numpy() + + def filter_area(self, area_threshold: float) -> np.ndarray: + """Return a new distribution in which all boxes whose area as a fraction of the image is less than the threshold.""" + image_area = self.image.width*self.image.height + return np.array([1 if self.boxes[i].area/image_area > area_threshold else 0 for i in range(len(self.boxes))]) + + @spatial() + def left_of(b1, b2): + return (b1.right+b1.left) / 2 < (b2.right+b2.left) / 2 + + @spatial() + def right_of(b1, b2): + return (b1.right+b1.left) / 2 > (b2.right+b2.left) / 2 + + @spatial() + def above(b1, b2): + return (b1.bottom+b1.top) < (b2.bottom+b2.top) + + @spatial() + def below(b1, b2): + return (b1.bottom+b1.top) > (b2.bottom+b2.top) + + @spatial() + def bigger_than(b1, b2): + return b1.area > b2.area + + @spatial() + def smaller_than(b1, b2): + return b1.area < b2.area + + @spatial(enforce_antisymmetry=False) + def within(box1, box2): + """Return percent of box1 inside box2.""" + intersection = box1.intersect(box2) + return intersection.area / box1.area + + @spatial(arity=3, enforce_antisymmetry=True) + def between(box1, box2, box3): + """How much of box1 lies in min bounding box over box2 and box3?""" + min_bounding = box2.min_bounding(box3) + intersect = box1.intersect(min_bounding) + return intersect.area / box1.area diff --git a/AlphaCLIP/eval/rec_zs_test/lattice.py b/AlphaCLIP/eval/rec_zs_test/lattice.py new file mode 100644 index 0000000000000000000000000000000000000000..ca9cf9060ce6842c5fdbfeaa841bd1208f3ec43a --- /dev/null +++ b/AlphaCLIP/eval/rec_zs_test/lattice.py @@ -0,0 +1,70 @@ +"""Implement lattice interface.""" + +from overrides import overrides +import numpy as np +from abc import ABCMeta, abstractmethod + + +class Lattice(metaclass=ABCMeta): + + """Abstract base class representing a complemented lattice.""" + + @classmethod + @abstractmethod + def join(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray: + return NotImplemented + + @classmethod + @abstractmethod + def meet(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray: + return NotImplemented + + @classmethod + @abstractmethod + def join_reduce(cls, probs: np.ndarray) -> np.ndarray: + return NotImplemented + + @classmethod + @abstractmethod + def meet_reduce(cls, probs: np.ndarray) -> np.ndarray: + return NotImplemented + + +class Product(Lattice): + """Lattice where meet=prod and sum is defined accordingly. + + Equivalent to assuming independence, more or less. + """ + + eps = 1e-9 + + @classmethod + @overrides + def join(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray: + return probs1 + probs2 - cls.meet(probs1, probs2) + + @classmethod + @overrides + def meet(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray: + return probs1 * probs2 + + @classmethod + @overrides + def join_reduce(cls, probs: np.ndarray) -> np.ndarray: + """Assumes disjoint events.""" + # return cls.comp(cls.meet_reduce(cls.comp(probs))) + return np.sum(probs, axis=-1) + + @classmethod + @overrides + def meet_reduce(cls, probs: np.ndarray) -> np.ndarray: + return np.prod(probs, axis=-1) + + @classmethod + def comp(cls, probs): + return 1 - probs + + @classmethod + def normalize(cls, probs): + """Normalize a distribution by dividing by the total mass.""" + return probs / np.sum(probs + cls.eps, axis=-1) diff --git a/AlphaCLIP/eval/rec_zs_test/main.py b/AlphaCLIP/eval/rec_zs_test/main.py new file mode 100644 index 0000000000000000000000000000000000000000..ec2641a8e0846b3d2a1295606b6d6d02b724b75e --- /dev/null +++ b/AlphaCLIP/eval/rec_zs_test/main.py @@ -0,0 +1,200 @@ +from collections import defaultdict +import json +import argparse +import os +import random + +import torch +from PIL import Image +from tqdm import tqdm + +from interpreter import * +from executor import * +from methods import * + +METHODS_MAP = { + "baseline": Baseline, + "random": Random, + "parse": Parse, +} + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input_file", type=str, help="input file with expressions and annotations in jsonlines format") + parser.add_argument("--image_root", type=str, help="path to images (train2014 directory of COCO)") + parser.add_argument("--clip_model", type=str, default="RN50x16,ViT-B/32", help="which clip model to use (should use RN50x4, ViT-B/32, or both separated by a comma") + parser.add_argument("--clip_type", type=str, default="aclip", help="which clip model to use (should use RN50x4, ViT-B/32, or both separated by a comma") + parser.add_argument("--albef_path", type=str, default=None, help="to use ALBEF (instead of CLIP), specify the path to the ALBEF checkpoint") + parser.add_argument("--method", type=str, default="parse", help="method to solve expressions") + parser.add_argument("--box_representation_method", type=str, default="crop,blur", help="method of representing boxes as individual images (crop, blur, or both separated by a comma)") + parser.add_argument("--box_method_aggregator", type=str, default="sum", help="method of combining box representation scores") + parser.add_argument("--box_area_threshold", type=float, default=0.0, help="minimum area (as a proportion of image area) for a box to be considered as the answer") + parser.add_argument("--output_file", type=str, default=None, help="(optional) output path to save results") + parser.add_argument("--detector_file", type=str, default=None, help="(optional) file containing object detections. if not provided, the gold object boxes will be used.") + parser.add_argument("--mock", action="store_true", help="(optional) mock CLIP execution.") + parser.add_argument("--device", type=int, default=0, help="CUDA device to use.") + parser.add_argument("--shuffle_words", action="store_true", help="If true, shuffle words in the sentence") + parser.add_argument("--gradcam_alpha", type=float, nargs='+', help="alpha value to use for gradcam method") + parser.add_argument("--enlarge_boxes", type=float, default=0.0, help="(optional) whether to enlarge boxes when passing them to the model") + parser.add_argument("--part", type=str, default=None, help="(optional) specify how many parts to divide the dataset into and which part to run in the format NUM_PARTS,PART_NUM") + parser.add_argument("--batch_size", type=int, default=1, help="number of instances to process in one model call (only supported for baseline model)") + parser.add_argument("--baseline_head", action="store_true", help="For baseline, controls whether model is called on both full expression and head noun chunk of expression") + parser.add_argument("--mdetr", type=str, default=None, help="to use MDETR as the executor model, specify the name of the MDETR model") + parser.add_argument("--albef_block_num", type=int, default=8, help="block num for ALBEF gradcam") + parser.add_argument("--albef_mode", type=str, choices=["itm", "itc"], default="itm") + parser.add_argument("--expand_position_embedding",action="store_true") + parser.add_argument("--gradcam_background", action="store_true") + parser.add_argument("--mdetr_given_bboxes", action="store_true") + parser.add_argument("--mdetr_use_token_mapping", action="store_true") + parser.add_argument("--non_square_size", action="store_true") + parser.add_argument("--blur_std_dev", type=int, default=100, help="standard deviation of Gaussian blur") + parser.add_argument("--gradcam_ensemble_before", action="store_true", help="Average gradcam maps of different models before summing over the maps") + parser.add_argument("--cache_path", type=str, default=None, help="cache features") + # Arguments related to Parse method. + parser.add_argument("--no_rel", action="store_true", help="Disable relation extraction.") + parser.add_argument("--no_sup", action="store_true", help="Disable superlative extraction.") + parser.add_argument("--no_null", action="store_true", help="Disable null keyword heuristics.") + parser.add_argument("--ternary", action="store_true", help="Disable ternary relation extraction.") + parser.add_argument("--baseline_threshold", type=float, default=float("inf"), help="(Parse) Threshold to use relations/superlatives.") + parser.add_argument("--temperature", type=float, default=1., help="(Parse) Sigmoid temperature.") + parser.add_argument("--superlative_head_only", action="store_true", help="(Parse) Superlatives only quanntify head predicate.") + parser.add_argument("--sigmoid", action="store_true", help="(Parse) Use sigmoid, not softmax.") + parser.add_argument("--no_possessive", action="store_true", help="(Parse) Model extraneous relations as possessive relations.") + parser.add_argument("--expand_chunks", action="store_true", help="(Parse) Expand noun chunks to include descendant tokens that aren't ancestors of tokens in other chunks") + parser.add_argument("--parse_no_branch", action="store_true", help="(Parse) Only do the parsing procedure if some relation/superlative keyword is in the expression") + parser.add_argument("--possessive_no_expand", action="store_true", help="(Parse) Expand ent2 in possessive case") + args = parser.parse_args() + + with open(args.input_file) as f: + lines = f.readlines() + data = [json.loads(line) for line in lines] + + device = f"cuda:{args.device}" if torch.cuda.is_available() and args.device >= 0 else "cpu" + gradcam = args.method == "gradcam" + + executor = ClipExecutor(clip_model=args.clip_model, box_representation_method=args.box_representation_method, method_aggregator=args.box_method_aggregator, device=device, square_size=not args.non_square_size, expand_position_embedding=args.expand_position_embedding, blur_std_dev=args.blur_std_dev, cache_path=args.cache_path, input_file=args.input_file, clip_type=args.clip_type) + + method = METHODS_MAP[args.method](args) + correct_count = 0 + total_count = 0 + if args.output_file: + output_file = open(args.output_file, "w") + if args.detector_file: + detector_file = open(args.detector_file) + detections_list = json.load(detector_file) + if isinstance(detections_list, dict): + detections_map = {int(image_id): detections_list[image_id] for image_id in detections_list} + else: + detections_map = defaultdict(list) + for detection in detections_list: + detections_map[detection["image_id"]].append(detection["box"]) + + part = 0 + if args.part is not None: # for multi-gpu test / part-data test + num_parts = int(args.part.split(",")[0]) + part = int(args.part.split(",")[1]) + data = data[int(len(data)*part/num_parts):int(len(data)*(part+1)/num_parts)] + + batch_count = 0 + batch_boxes = [] + batch_gold_boxes = [] + batch_gold_index = [] + batch_file_names = [] + batch_sentences = [] + for datum in tqdm(data): + if "coco" in datum["file_name"].lower(): + file_name = "_".join(datum["file_name"].split("_")[:-1])+".jpg" + else: + file_name = datum["file_name"] + img_path = os.path.join(args.image_root, file_name) + img = Image.open(img_path).convert('RGB') + gold_boxes = [Box(x=ann["bbox"][0], y=ann["bbox"][1], w=ann["bbox"][2], h=ann["bbox"][3]) for ann in datum["anns"]] + if isinstance(datum["ann_id"], int) or isinstance(datum["ann_id"], str): + datum["ann_id"] = [datum["ann_id"]] + assert isinstance(datum["ann_id"], list) + gold_index = [i for i in range(len(datum["anns"])) if datum["anns"][i]["id"] in datum["ann_id"]] + if args.detector_file: + boxes = [Box(x=box[0], y=box[1], w=box[2], h=box[3]) for box in detections_map[int(datum["image_id"])]] + if len(boxes) == 0: + boxes = [Box(x=0, y=0, w=img.width, h=img.height)] + else: + boxes = gold_boxes + for sentence in datum["sentences"]: + env = Environment(img, boxes, executor, (args.mdetr is not None and not args.mdetr_given_bboxes), str(datum["image_id"]), img_path) + if args.shuffle_words: + words = sentence["raw"].lower().split() + random.shuffle(words) + result = method.execute(" ".join(words), env) + else: + result = method.execute(sentence["raw"].lower(), env) + boxes = env.boxes + print(sentence["raw"].lower()) + correct = False + for g_index in gold_index: + if iou(boxes[result["pred"]], gold_boxes[g_index]) > 0.5: + correct = True + break + if correct: + result["correct"] = 1 + correct_count += 1 + else: + result["correct"] = 0 + if args.detector_file: + argmax_ious = [] + max_ious = [] + for g_index in gold_index: + ious = [iou(box, gold_boxes[g_index]) for box in boxes] + argmax_iou = -1 + max_iou = 0 + if max(ious) >= 0.5: + for index, value in enumerate(ious): + if value > max_iou: + max_iou = value + argmax_iou = index + argmax_ious.append(argmax_iou) + max_ious.append(max_iou) + argmax_iou = -1 + max_iou = 0 + if max(max_ious) >= 0.5: + for index, value in zip(argmax_ious, max_ious): + if value > max_iou: + max_iou = value + argmax_iou = index + result["gold_index"] = argmax_iou + else: + result["gold_index"] = gold_index + result["bboxes"] = [[box.left, box.top, box.right, box.bottom] for box in boxes] + result["file_name"] = file_name + result["probabilities"] = result["probs"] + result["text"] = sentence["raw"].lower() + if args.output_file: + # Serialize numpy arrays for JSON. + for key in result: + if isinstance(result[key], np.ndarray): + result[key] = result[key].tolist() + if isinstance(result[key], np.int64): + result[key] = result[key].item() + output_file.write(json.dumps(result)+"\n") + total_count += 1 + print(f"est_acc: {100 * correct_count / total_count:.3f}") + + if args.output_file: + output_file.close() + print(f"acc: {100 * correct_count / total_count:.3f}") + acc = 100 * correct_count / total_count + + result = {} + result['acc'] = acc + json.dump(acc, open(os.path.join('./output', args.input_file.split('/')[-1].split('.')[0] + '_acc_' + str(part)+'.json'),'w')) + json.dump(str(correct_count)+' '+str(total_count), open(os.path.join('./output', args.input_file.split('/')[-1].split('.')[0] + '_count_' + str(part)+'.json'),'w')) + stats = method.get_stats() + if stats: + pairs = sorted(list(stats.items()), key=lambda tup: tup[0]) + for key, value in pairs: + result[key] = value + if isinstance(value, float): + print(f"{key}: {value:.5f}") + else: + print(f"{key}: {value}") + + json.dump(result, open(os.path.join('./output', args.input_file.split('/')[-1].split('.')[0] + '_' + str(part)+'.json'),'w')) \ No newline at end of file diff --git a/AlphaCLIP/eval/rec_zs_test/methods/__init__.py b/AlphaCLIP/eval/rec_zs_test/methods/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..74361a515ad864653b1159d9ad37e682953a9b23 --- /dev/null +++ b/AlphaCLIP/eval/rec_zs_test/methods/__init__.py @@ -0,0 +1,3 @@ +from .baseline import Baseline +from .random_method import Random +from .parse import Parse diff --git a/AlphaCLIP/eval/rec_zs_test/methods/baseline.py b/AlphaCLIP/eval/rec_zs_test/methods/baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..dd533ebb7f74655c6e349c433f0ebcc145701ef1 --- /dev/null +++ b/AlphaCLIP/eval/rec_zs_test/methods/baseline.py @@ -0,0 +1,57 @@ +"""A naive baseline method: just pass the full expression to CLIP.""" + +from overrides import overrides +from typing import Dict, Any, List +import numpy as np +import torch +import spacy +from argparse import Namespace + +from .ref_method import RefMethod +from lattice import Product as L + + +class Baseline(RefMethod): + """CLIP-only baseline where each box is evaluated with the full expression.""" + + nlp = spacy.load('en_core_web_sm') + + def __init__(self, args: Namespace): + self.args = args + self.box_area_threshold = args.box_area_threshold + self.batch_size = args.batch_size + self.batch = [] + + @overrides + def execute(self, caption: str, env: "Environment") -> Dict[str, Any]: + chunk_texts = self.get_chunk_texts(caption) + probs = env.filter(caption, area_threshold = self.box_area_threshold, softmax=True) + if self.args.baseline_head: + probs2 = env.filter(chunk_texts[0], area_threshold = self.box_area_threshold, softmax=True) + probs = L.meet(probs, probs2) + pred = np.argmax(probs) + return { + "probs": probs, + "pred": pred, + "box": env.boxes[pred], + } + + def get_chunk_texts(self, expression: str) -> List: + doc = self.nlp(expression) + head = None + for token in doc: + if token.head.i == token.i: + head = token + break + head_chunk = None + chunk_texts = [] + for chunk in doc.noun_chunks: + if head.i >= chunk.start and head.i < chunk.end: + head_chunk = chunk.text + chunk_texts.append(chunk.text) + if head_chunk is None: + if len(list(doc.noun_chunks)) > 0: + head_chunk = list(doc.noun_chunks)[0].text + else: + head_chunk = expression + return [head_chunk] + [txt for txt in chunk_texts if txt != head_chunk] diff --git a/AlphaCLIP/eval/rec_zs_test/methods/parse.py b/AlphaCLIP/eval/rec_zs_test/methods/parse.py new file mode 100644 index 0000000000000000000000000000000000000000..5f0dbb8cc672e0f41ced359e0e8016e252c5a0c4 --- /dev/null +++ b/AlphaCLIP/eval/rec_zs_test/methods/parse.py @@ -0,0 +1,239 @@ +"""Use spatial relations extracted from the parses.""" + +from typing import Dict, Any, Callable, List, Tuple, NamedTuple +from numbers import Number +from collections import defaultdict +from overrides import overrides +import numpy as np +import spacy +from spacy.tokens.token import Token +from spacy.tokens.span import Span +from argparse import Namespace + +from .ref_method import RefMethod +from lattice import Product as L +from heuristics import Heuristics +from entity_extraction import Entity, expand_chunks + + +def get_conjunct(ent, chunks, heuristics: Heuristics) -> Entity: + """If an entity represents a conjunction of two entities, pull them apart.""" + head = ent.head.root # Not ...root.head. Confusing names here. + if not any(child.text == "and" for child in head.children): + return None + for child in head.children: + if child.i in chunks and head.i is not child.i: + return Entity.extract(child, chunks, heuristics) + return None + + +class Parse(RefMethod): + """An REF method that extracts and composes predicates, relations, and superlatives from a dependency parse. + + The process is as follows: + 1. Use spacy to parse the document. + 2. Extract a semantic entity tree from the parse. + 3. Execute the entity tree to yield a distribution over boxes.""" + + nlp = spacy.load('en_core_web_sm') + + def __init__(self, args: Namespace = None): + self.args = args + self.box_area_threshold = args.box_area_threshold + self.baseline_threshold = args.baseline_threshold + self.temperature = args.temperature + self.superlative_head_only = args.superlative_head_only + self.expand_chunks = args.expand_chunks + self.branch = not args.parse_no_branch + self.possessive_expand = not args.possessive_no_expand + + # Lists of keyword heuristics to use. + self.heuristics = Heuristics(args) + + # Metrics for debugging relation extraction behavor. + self.counts = defaultdict(int) + + @overrides + def execute(self, caption: str, env: "Environment") -> Dict[str, Any]: + """Construct an `Entity` tree from the parse and execute it to yield a distribution over boxes.""" + # Start by using the full caption, as in Baseline. + probs = env.filter(caption, area_threshold=self.box_area_threshold, softmax=True) + ori_probs = probs + + # Extend the baseline using parse stuff. + doc = self.nlp(caption) + head = self.get_head(doc) + chunks = self.get_chunks(doc) + if self.expand_chunks: + chunks = expand_chunks(doc, chunks) + entity = Entity.extract(head, chunks, self.heuristics) + + # If no head noun is found, take the first one. + if entity is None and len(list(doc.noun_chunks)) > 0: + head = list(doc.noun_chunks)[0] + entity = Entity.extract(head.root.head, chunks, self.heuristics) + self.counts["n_0th_noun"] += 1 + + # If we have found some head noun, filter based on it. + if entity is not None and (any(any(token.text in h.keywords for h in self.heuristics.relations+self.heuristics.superlatives) for token in doc) or not self.branch): + ent_probs, texts = self.execute_entity(entity, env, chunks) + probs = L.meet(probs, ent_probs) + else: + texts = [caption] + self.counts["n_full_expr"] += 1 + + if len(ori_probs) == 1: + probs = ori_probs + + self.counts["n_total"] += 1 + pred = np.argmax(probs) + return { + "probs": probs, + "pred": pred, + "box": env.boxes[pred], + "texts": texts + } + + def execute_entity(self, + ent: Entity, + env: "Environment", + chunks: Dict[int, Span], + root: bool = True, + ) -> np.ndarray: + """Execute an `Entity` tree recursively, yielding a distribution over boxes.""" + self.counts["n_rec"] += 1 + probs = [1, 1] + head_probs = probs + + # Only use relations if the head baseline isn't certain. + if len(probs) == 1 or len(env.boxes) == 1: + return probs, [ent.text] + + m1, m2 = probs[:2] # probs[(-probs).argsort()[:2]] + text = ent.text + rel_probs = [] + if self.baseline_threshold == float("inf") or m1 < self.baseline_threshold * m2: + self.counts["n_rec_rel"] += 1 + for tokens, ent2 in ent.relations: + self.counts["n_rel"] += 1 + rel = None + # Heuristically decide which spatial relation is represented. + for heuristic in self.heuristics.relations: + if any(tok.text in heuristic.keywords for tok in tokens): + rel = heuristic.callback(env) + self.counts[f"n_rel_{heuristic.keywords[0]}"] += 1 + break + # Filter and normalize by the spatial relation. + if rel is not None: + probs2 = self.execute_entity(ent2, env, chunks, root=False) + events = L.meet(np.expand_dims(probs2, axis=0), rel) + new_probs = L.join_reduce(events) + rel_probs.append((ent2.text, new_probs, probs2)) + continue + + # This case specifically handles "between", which takes two noun arguments. + rel = None + for heuristic in self.heuristics.ternary_relations: + if any(tok.text in heuristic.keywords for tok in tokens): + rel = heuristic.callback(env) + self.counts[f"n_rel_{heuristic.keywords[0]}"] += 1 + break + if rel is not None: + ent3 = get_conjunct(ent2, chunks, self.heuristics) + if ent3 is not None: + probs2 = self.execute_entity(ent2, env, chunks, root=False) + probs2 = np.expand_dims(probs2, axis=[0, 2]) + probs3 = self.execute_entity(ent3, env, chunks, root=False) + probs3 = np.expand_dims(probs3, axis=[0, 1]) + events = L.meet(L.meet(probs2, probs3), rel) + new_probs = L.join_reduce(L.join_reduce(events)) + probs = L.meet(probs, new_probs) + continue + # Otherwise, treat the relation as a possessive relation. + if not self.args.no_possessive: + if self.possessive_expand: + text = ent.expand(ent2.head) + else: + text += f' {" ".join(tok.text for tok in tokens)} {ent2.text}' + #poss_probs = self._filter(text, env, root=root, expand=.3) + probs = self._filter(text, env, root=root) + texts = [text] + return_probs = [(probs.tolist(), probs.tolist())] + for (ent2_text, new_probs, ent2_only_probs) in rel_probs: + probs = L.meet(probs, new_probs) + probs /= probs.sum() + texts.append(ent2_text) + return_probs.append((probs.tolist(), ent2_only_probs.tolist())) + + # Only use superlatives if thresholds work out. + m1, m2 = probs[(-probs).argsort()[:2]] + if m1 < self.baseline_threshold * m2: + self.counts["n_rec_sup"] += 1 + for tokens in ent.superlatives: + self.counts["n_sup"] += 1 + sup = None + for heuristic_index, heuristic in enumerate(self.heuristics.superlatives): + if any(tok.text in heuristic.keywords for tok in tokens): + texts.append('sup:'+' '.join([tok.text for tok in tokens if tok.text in heuristic.keywords])) + sup = heuristic.callback(env) + self.counts[f"n_sup_{heuristic.keywords[0]}"] += 1 + break + if sup is not None: + # Could use `probs` or `head_probs` here? + precond = head_probs if self.superlative_head_only else probs + probs = L.meet(np.expand_dims(precond, axis=1)*np.expand_dims(precond, axis=0), sup).sum(axis=1) + probs = probs / probs.sum() + return_probs.append((probs.tolist(), None)) + + if root: + assert len(texts) == len(return_probs) + return probs, (texts, return_probs, tuple(str(chunk) for chunk in chunks.values())) + return probs + + def get_head(self, doc) -> Token: + """Return the token that is the head of the dependency parse. """ + for token in doc: + if token.head.i == token.i: + return token + return None + + def get_chunks(self, doc) -> Dict[int, Any]: + """Return a dictionary mapping sentence indices to their noun chunk.""" + chunks = {} + for chunk in doc.noun_chunks: + for idx in range(chunk.start, chunk.end): + chunks[idx] = chunk + return chunks + + @overrides + def get_stats(self) -> Dict[str, Number]: + """Summary statistics that have been tracked on this object.""" + stats = dict(self.counts) + n_rel_caught = sum(v for k, v in stats.items() if k.startswith("n_rel_")) + n_sup_caught = sum(v for k, v in stats.items() if k.startswith("n_sup_")) + stats.update({ + "p_rel_caught": n_rel_caught / (self.counts["n_rel"] + 1e-9), + "p_sup_caught": n_sup_caught / (self.counts["n_sup"] + 1e-9), + "p_rec_rel": self.counts["n_rec_rel"] / (self.counts["n_rec"] + 1e-9), + "p_rec_sup": self.counts["n_rec_sup"] / (self.counts["n_rec"] + 1e-9), + "p_0th_noun": self.counts["n_0th_noun"] / (self.counts["n_total"] + 1e-9), + "p_full_expr": self.counts["n_full_expr"] / (self.counts["n_total"] + 1e-9), + "avg_rec": self.counts["n_rec"] / self.counts["n_total"], + }) + return stats + + def _filter(self, + caption: str, + env: "Environment", + root: bool = False, + expand: float = None, + ) -> np.ndarray: + """Wrap a filter call in a consistent way for all recursions.""" + kwargs = { + "softmax": not self.args.sigmoid, + "temperature": self.args.temperature, + } + if root: + return env.filter(caption, area_threshold=self.box_area_threshold, **kwargs) + else: + return env.filter(caption, **kwargs) diff --git a/AlphaCLIP/eval/rec_zs_test/methods/random_method.py b/AlphaCLIP/eval/rec_zs_test/methods/random_method.py new file mode 100644 index 0000000000000000000000000000000000000000..9882ddf64b027a1e7795b79466f32ea175df0d1a --- /dev/null +++ b/AlphaCLIP/eval/rec_zs_test/methods/random_method.py @@ -0,0 +1,30 @@ +"""A naive baseline method: just pass the full expression to CLIP.""" + +from overrides import overrides +from typing import Dict, Any +import random +from argparse import Namespace + +import numpy as np + +from .ref_method import RefMethod + + +class Random(RefMethod): + """CLIP-only baseline where each box is evaluated with the full expression.""" + + def __init__(self, args: Namespace): + self.box_area_threshold = args.box_area_threshold + + @overrides + def execute(self, caption: str, env: "Environment") -> Dict[str, Any]: + probs = env.filter_area(self.box_area_threshold)*env.uniform() + random_ordering = list(range(len(env.boxes))) + random.shuffle(random_ordering) + random_ordering = np.array(random_ordering) + pred = np.argmax(probs*random_ordering) + return { + "probs": probs.tolist(), + "pred": int(pred), + "text": caption.lower() + } diff --git a/AlphaCLIP/eval/rec_zs_test/methods/ref_method.py b/AlphaCLIP/eval/rec_zs_test/methods/ref_method.py new file mode 100644 index 0000000000000000000000000000000000000000..0d0a3eac1faf307e6d230c5adb05763370086a41 --- /dev/null +++ b/AlphaCLIP/eval/rec_zs_test/methods/ref_method.py @@ -0,0 +1,13 @@ +"""Base class for a method for doing referring expressions.""" + +from typing import Dict, Any +from abc import ABCMeta, abstractmethod + + +class RefMethod(metaclass=ABCMeta): + @abstractmethod + def execute(self, caption: str, env: "Environment") -> Dict[str, Any]: + return NotImplemented + + def get_stats(self) -> Dict[str, Any]: + return {} diff --git a/AlphaCLIP/eval/rec_zs_test/output/.gitkeep b/AlphaCLIP/eval/rec_zs_test/output/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/AlphaCLIP/eval/rec_zs_test/requirements.txt b/AlphaCLIP/eval/rec_zs_test/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..2fee0bc1d956715882aae4b5c8c2417cdc7b6fba --- /dev/null +++ b/AlphaCLIP/eval/rec_zs_test/requirements.txt @@ -0,0 +1,53 @@ +attrs==21.2.0 +blis==0.7.4 +catalogue==2.0.4 +certifi==2021.5.30 +chardet==4.0.0 +click==7.1.2 +cymem==2.0.5 +en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0-py3-none-any.whl +filelock==3.0.12 +ftfy==6.0.3 +huggingface-hub==0.0.12 +idna==2.10 +iniconfig==1.1.1 +itsdangerous==2.0.1 +joblib==1.0.1 +MarkupSafe==2.0.1 +murmurhash==1.0.5 +numpy==1.21.0 +overrides==6.1.0 +packaging==21.0 +pathy==0.6.0 +Pillow==8.2.0 +pluggy==0.13.1 +preshed==3.0.5 +py==1.10.0 +pydantic==1.7.4 +pyparsing==2.4.7 +pytest==6.2.4 +PyYAML==5.4.1 +regex==2021.7.6 +requests==2.25.1 +ruamel.yaml==0.17.10 +ruamel.yaml.clib==0.2.6 +sacremoses==0.0.45 +scipy==1.7.0 +six==1.16.0 +smart-open==5.1.0 +spacy==3.0.6 +spacy-legacy==3.0.7 +srsly==2.4.1 +thinc==8.0.7 +timm==0.4.12 +tokenizers==0.10.3 +toml==0.10.2 +tqdm==4.61.2 +transformers==4.9.0 +typer==0.3.2 +typing-extensions==3.10.0.0 +typing-utils==0.1.0 +urllib3==1.26.6 +wasabi==0.8.2 +wcwidth==0.2.5 +Werkzeug==2.0.1 diff --git a/AlphaCLIP/eval/rec_zs_test/run.sh b/AlphaCLIP/eval/rec_zs_test/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..e4622b0e1ff670a4ead33a102ff2dddc528252b8 --- /dev/null +++ b/AlphaCLIP/eval/rec_zs_test/run.sh @@ -0,0 +1 @@ +CUDA_VISIBLE_DEVICES=0 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_representation_method full,blur --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache diff --git a/AlphaCLIP/eval/rec_zs_test/run_multi_gpus.sh b/AlphaCLIP/eval/rec_zs_test/run_multi_gpus.sh new file mode 100644 index 0000000000000000000000000000000000000000..a6c5c50e12adbeec852154abb4063528e9f330e0 --- /dev/null +++ b/AlphaCLIP/eval/rec_zs_test/run_multi_gpus.sh @@ -0,0 +1,15 @@ +CUDA_VISIBLE_DEVICES=0 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,0" & + +CUDA_VISIBLE_DEVICES=1 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,1" & + +CUDA_VISIBLE_DEVICES=2 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,2" & + +CUDA_VISIBLE_DEVICES=3 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,3" & + +CUDA_VISIBLE_DEVICES=4 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,4" & + +CUDA_VISIBLE_DEVICES=5 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,5" & + +CUDA_VISIBLE_DEVICES=6 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,6" & + +CUDA_VISIBLE_DEVICES=7 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,7" \ No newline at end of file diff --git a/AlphaCLIP/hubconf.py b/AlphaCLIP/hubconf.py new file mode 100644 index 0000000000000000000000000000000000000000..d27d2ae444c623b384f344549b4c1318a9537440 --- /dev/null +++ b/AlphaCLIP/hubconf.py @@ -0,0 +1,42 @@ +from alpha_clip.alpha_clip import tokenize as _tokenize, load as _load, available_models as _available_models +import re +import string + +dependencies = ["torch", "torchvision", "ftfy", "regex", "tqdm"] + +# For compatibility (cannot include special characters in function name) +model_functions = { model: re.sub(f'[{string.punctuation}]', '_', model) for model in _available_models()} + +def _create_hub_entrypoint(model): + def entrypoint(**kwargs): + return _load(model, **kwargs) + + entrypoint.__doc__ = f"""Loads the {model} CLIP model + + Parameters + ---------- + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" + + Returns + ------- + model : torch.nn.Module + The {model} CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + return entrypoint + +def tokenize(): + return _tokenize + +_entrypoints = {model_functions[model]: _create_hub_entrypoint(model) for model in _available_models()} + +globals().update(_entrypoints) \ No newline at end of file diff --git a/AlphaCLIP/requirements.txt b/AlphaCLIP/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..6b98c33f3a0e09ddf982606430472de3061c6e9f --- /dev/null +++ b/AlphaCLIP/requirements.txt @@ -0,0 +1,5 @@ +ftfy +regex +tqdm +torch +torchvision diff --git a/AlphaCLIP/setup.py b/AlphaCLIP/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..95cea76c6614ce032ab6f6f1db2a68eb6f2da76e --- /dev/null +++ b/AlphaCLIP/setup.py @@ -0,0 +1,21 @@ +import os + +import pkg_resources +from setuptools import setup, find_packages + +setup( + name="alpha_clip", + py_modules=["alpha_clip"], + version="1.0", + description="", + author="OpenAI&ZeyiSun", + packages=find_packages(exclude=["tests*"]), + install_requires=[ + str(r) + for r in pkg_resources.parse_requirements( + open(os.path.join(os.path.dirname(__file__), "requirements.txt")) + ) + ], + include_package_data=True, + extras_require={'dev': ['pytest']}, +) diff --git a/README.md b/README.md index 92914e3e1f2109801bd4aa698c1f286cd4e12fdf..d613341922a01092569682b6e5f38668ae675ada 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ emoji: 🏢 colorFrom: green colorTo: red sdk: gradio -sdk_version: 4.36.1 +sdk_version: 3.48.0 app_file: app.py pinned: false license: mit diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..e9c1ca6076714d4f50992396d7c197a4d54c23e4 --- /dev/null +++ b/app.py @@ -0,0 +1,113 @@ +import gradio as gr +import sys +import torch +from omegaconf import OmegaConf +from PIL import Image +from diffusers import StableDiffusionInpaintPipeline +from model.clip_away import CLIPAway +import cv2 +import numpy as np +import argparse + +# Parse command line arguments +parser = argparse.ArgumentParser() +parser.add_argument("--config", type=str, default="config/inference_config.yaml", help="Path to the config file") +parser.add_argument("--share", action="store_true", help="Share the interface if provided") +args = parser.parse_args() + +# Load configuration and models +config = OmegaConf.load(args.config) +sd_pipeline = StableDiffusionInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-inpainting", safety_checker=None, torch_dtype=torch.float32 +) +clipaway = CLIPAway( + sd_pipe=sd_pipeline, + image_encoder_path=config.image_encoder_path, + ip_ckpt=config.ip_adapter_ckpt_path, + alpha_clip_path=config.alpha_clip_ckpt_pth, + config=config, + alpha_clip_id=config.alpha_clip_id, + device=config.device, + num_tokens=4 +) + +def dilate_mask(mask, kernel_size=5, iterations=5): + mask = mask.convert("L") + kernel = np.ones((kernel_size, kernel_size), np.uint8) + mask = cv2.dilate(np.array(mask), kernel, iterations=iterations) + return Image.fromarray(mask) + +def combine_masks(uploaded_mask, sketched_mask): + if uploaded_mask is not None: + return uploaded_mask + elif sketched_mask is not None: + return sketched_mask + else: + raise ValueError("Please provide a mask") + +def remove_obj(image, uploaded_mask, seed): + image_pil, sketched_mask = image["image"], image["mask"] + mask = dilate_mask(combine_masks(uploaded_mask, sketched_mask)) + seed = int(seed) + latents = torch.randn((1, 4, 64, 64), generator=torch.Generator().manual_seed(seed)).to("cuda") + final_image = clipaway.generate( + prompt=[""], scale=1, seed=seed, + pil_image=[image_pil], alpha=[mask], strength=1, latents=latents + )[0] + return final_image + +# Define example data +examples = [ + ["assets/gradio_examples/images/1.jpg", "assets/gradio_examples/masks/1.png", 42], + ["assets/gradio_examples/images/2.jpg", "assets/gradio_examples/masks/2.png", 42], + ["assets/gradio_examples/images/3.jpg", "assets/gradio_examples/masks/3.png", 464], + ["assets/gradio_examples/images/4.jpg", "assets/gradio_examples/masks/4.png", 2024], +] + +# Define the Gradio interface +with gr.Blocks() as demo: + gr.Markdown("