File size: 12,721 Bytes
5d1f0ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 |
import os
from typing import Union, List
from pkg_resources import packaging
import torch
import numpy as np
from AnomalyCLIP_lib.simple_tokenizer import SimpleTokenizer as _Tokenizer
# from open_clip import tokenizer
# simple_tokenizer = tokenizer.SimpleTokenizer()
from copy import deepcopy
import torch.nn as nn
_tokenizer = _Tokenizer()
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
truncate: bool
Whether to truncate the text in case its encoding is longer than the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
"""
if isinstance(texts, str):
texts = [texts]
sot_token = _tokenizer.encoder["<|startoftext|>"]
eot_token = _tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
else:
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate:
tokens = tokens[:context_length]
tokens[-1] = eot_token
else:
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
result[i, :len(tokens)] = torch.tensor(tokens)
return result
def encode_text_with_prompt_ensemble(model, texts, device):
prompt_normal = ['{}', 'flawless {}', 'perfect {}', 'unblemished {}', '{} without flaw', '{} without defect', '{} without damage']
prompt_abnormal = ['damaged {}', 'broken {}', '{} with flaw', '{} with defect', '{} with damage']
prompt_state = [prompt_normal, prompt_abnormal]
prompt_templates = ['a bad photo of a {}.', 'a low resolution photo of the {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a bright photo of a {}.', 'a dark photo of the {}.', 'a photo of my {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a photo of one {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'a low resolution photo of a {}.', 'a photo of a large {}.', 'a blurry photo of a {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a photo of the small {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'a dark photo of a {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'there is a {} in the scene.', 'there is the {} in the scene.', 'this is a {} in the scene.', 'this is the {} in the scene.', 'this is one {} in the scene.']
text_features = []
for i in range(len(prompt_state)):
prompted_state = [state.format(texts[0]) for state in prompt_state[i]]
prompted_sentence = []
for s in prompted_state:
for template in prompt_templates:
prompted_sentence.append(template.format(s))
prompted_sentence = tokenize(prompted_sentence)
class_embeddings = model.encode_text(prompted_sentence.to(device))
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
text_features.append(class_embedding)
text_features = torch.stack(text_features, dim=1).to(device).t()
return text_features
def _get_clones(module, N):
return nn.ModuleList([deepcopy(module) for i in range(N)])
class AnomalyCLIP_PromptLearner(nn.Module):
def __init__(self, clip_model, design_details):
super().__init__()
classnames = ["object"]
self.n_cls = len(classnames)
self.n_ctx = design_details["Prompt_length"]
n_ctx_pos = self.n_ctx
n_ctx_neg = self.n_ctx
self.text_encoder_n_ctx = design_details["learnabel_text_embedding_length"]
ctx_init_pos = ""
ctx_init_neg = ""
dtype = clip_model.transformer.get_cast_dtype()
ctx_dim = clip_model.ln_final.weight.shape[0]
self.classnames = classnames
self.state_normal_list = [
"{}",
]
self.state_anomaly_list = [
"damaged {}",
]
normal_num = len(self.state_normal_list)
anormaly_num = len(self.state_anomaly_list)
self.normal_num = normal_num
self.anormaly_num = anormaly_num
if ctx_init_pos and ctx_init_neg:
# use given words to initialize context vectors
ctx_init_pos = ctx_init_pos.replace("_", " ")
ctx_init_neg = ctx_init_neg.replace("_", " ")
n_ctx_pos = len(ctx_init_pos.split(" "))
n_ctx_neg = len(ctx_init_neg.split(" "))
#初始化text成bpd编码
prompt_pos = tokenize(ctx_init_pos)
prompt_neg = tokenize(ctx_init_neg)
with torch.no_grad():
#生成相应的text embedding
embedding_pos = clip_model.token_embedding(prompt_pos).type(dtype)
embedding_neg = clip_model.token_embedding(prompt_neg).type(dtype)
#这些是去除出来EOS 和 # CLS, EOS, 获得可学习的textual prompt
ctx_vectors_pos = embedding_pos[0, 1: 1 + n_ctx_pos, :]
ctx_vectors_neg = embedding_neg[0, 1: 1 + n_ctx_neg, :]
prompt_prefix_pos = ctx_init_pos
prompt_prefix_neg = ctx_init_neg
if True:
ctx_vectors_pos_ = []
ctx_vectors_neg_ = []
for _ in range(self.n_cls):
ctx_vectors_pos_.append(deepcopy(ctx_vectors_pos))
ctx_vectors_neg_.append(deepcopy(ctx_vectors_neg))
ctx_vectors_pos = torch.stack(ctx_vectors_pos_, dim=0)
ctx_vectors_neg = torch.stack(ctx_vectors_neg_, dim=0)
else:
# Random Initialization
if True:
print("Initializing class-specific contexts")
#这里是cls是类的个数,n_ctx_pos代表learnable token的长度,ctx_dim表示prompt的dimension
ctx_vectors_pos = torch.empty(self.n_cls, self.normal_num, n_ctx_pos, ctx_dim, dtype=dtype)
ctx_vectors_neg = torch.empty(self.n_cls, self.anormaly_num, n_ctx_neg, ctx_dim, dtype=dtype)
else:
print("Initializing a generic context")
ctx_vectors_pos = torch.empty(n_ctx_pos, ctx_dim, dtype=dtype)
ctx_vectors_neg = torch.empty(n_ctx_neg, ctx_dim, dtype=dtype)
nn.init.normal_(ctx_vectors_pos, std=0.02)
nn.init.normal_(ctx_vectors_neg, std=0.02)
prompt_prefix_pos = " ".join(["X"] * n_ctx_pos)
prompt_prefix_neg = " ".join(["X"] * n_ctx_neg)
self.compound_prompts_depth = design_details["learnabel_text_embedding_depth"]
self.compound_prompts_text = nn.ParameterList([nn.Parameter(torch.empty(self.text_encoder_n_ctx, ctx_dim))
for _ in range(self.compound_prompts_depth - 1)])
for single_para in self.compound_prompts_text:
print("single_para", single_para.shape)
nn.init.normal_(single_para, std=0.02)
single_layer = nn.Linear(ctx_dim, 896)
self.compound_prompt_projections = _get_clones(single_layer, self.compound_prompts_depth - 1)
self.ctx_pos = nn.Parameter(ctx_vectors_pos) # to be optimized
self.ctx_neg = nn.Parameter(ctx_vectors_neg) # to be optimized
classnames = [name.replace("_", " ") for name in classnames]
name_lens = [len(_tokenizer.encode(name)) for name in classnames]
prompts_pos = [prompt_prefix_pos + " " + template.format(name)+ "." for template in self.state_normal_list for name in classnames]
prompts_neg = [prompt_prefix_neg + " " + template.format(name)+ "." for template in self.state_anomaly_list for name in classnames]
tokenized_prompts_pos = []
tokenized_prompts_neg = []
for p_pos in prompts_pos:
tokenized_prompts_pos.append(tokenize(p_pos))
for p_neg in prompts_neg:
tokenized_prompts_neg.append(tokenize(p_neg))
tokenized_prompts_pos = torch.cat(tokenized_prompts_pos)
tokenized_prompts_neg = torch.cat(tokenized_prompts_neg)
#生成相应的text embedding
with torch.no_grad():
embedding_pos = clip_model.token_embedding(tokenized_prompts_pos).type(dtype)
embedding_neg = clip_model.token_embedding(tokenized_prompts_neg).type(dtype)
n, l, d = embedding_pos.shape
print("embedding_pos", embedding_pos.shape)
embedding_pos = embedding_pos.reshape(normal_num, self.n_cls, l, d).permute(1, 0, 2, 3)
embedding_neg = embedding_neg.reshape(anormaly_num, self.n_cls, l, d).permute(1, 0, 2, 3)
self.register_buffer("token_prefix_pos", embedding_pos[:, :, :1, :] )
self.register_buffer("token_suffix_pos", embedding_pos[:, :,1 + n_ctx_pos:, :])
self.register_buffer("token_prefix_neg", embedding_neg[:,:, :1, :])
self.register_buffer("token_suffix_neg", embedding_neg[:, :, 1 + n_ctx_neg:, :])
n, d = tokenized_prompts_pos.shape
tokenized_prompts_pos = tokenized_prompts_pos.reshape(normal_num, self.n_cls, d).permute(1, 0, 2)
n, d = tokenized_prompts_neg.shape
tokenized_prompts_neg = tokenized_prompts_neg.reshape(anormaly_num, self.n_cls, d).permute(1, 0, 2)
self.n_ctx_pos = n_ctx_pos
self.n_ctx_neg = n_ctx_neg
# tokenized_prompts = torch.cat([tokenized_prompts_pos, tokenized_prompts_neg], dim=0) # torch.Tensor
self.register_buffer("tokenized_prompts_pos", tokenized_prompts_pos)
self.register_buffer("tokenized_prompts_neg", tokenized_prompts_neg)
print("tokenized_prompts shape", self.tokenized_prompts_pos.shape, self.tokenized_prompts_neg.shape)
def forward(self, cls_id =None):
ctx_pos = self.ctx_pos
ctx_neg = self.ctx_neg
ctx_pos = self.ctx_pos
ctx_neg = self.ctx_neg
# print("shape", self.ctx_pos[0:1].shape, ctx_pos.shape)
prefix_pos = self.token_prefix_pos
prefix_neg = self.token_prefix_neg
suffix_pos = self.token_suffix_pos
suffix_neg = self.token_suffix_neg
# print(prefix_pos.shape, prefix_neg.shape)
prompts_pos = torch.cat(
[
# N(the number of template), 1, dim
prefix_pos, # (n_cls, 1, dim)
ctx_pos, # (n_cls, n_ctx, dim)
suffix_pos, # (n_cls, *, dim)
],
dim=2,
)
prompts_neg = torch.cat(
[
prefix_neg, # (n_cls, 1, dim)
ctx_neg, # (n_cls, n_ctx, dim)
suffix_neg, # (n_cls, *, dim)
],
dim=2,
)
_, _, l, d = prompts_pos.shape
prompts_pos = prompts_pos.reshape(-1, l, d)
_, _, l, d = prompts_neg.shape
prompts_neg = prompts_neg.reshape(-1, l, d)
prompts = torch.cat([prompts_pos, prompts_neg], dim=0)
_, l, d = self.tokenized_prompts_pos.shape
tokenized_prompts_pos = self.tokenized_prompts_pos.reshape(-1, d)
_, l, d = self.tokenized_prompts_neg.shape
tokenized_prompts_neg = self.tokenized_prompts_neg.reshape(-1, d)
tokenized_prompts = torch.cat((tokenized_prompts_pos, tokenized_prompts_neg), dim = 0)
return prompts, tokenized_prompts, self.compound_prompts_text |