Continual-Mega commited on
Commit
940092e
·
verified ·
1 Parent(s): 7caf841

Upload CoOp.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. CoOp.py +205 -0
CoOp.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from torch.nn import functional as F
7
+ from torch.cuda.amp import GradScaler, autocast
8
+ from CLIP.tokenizer import SimpleTokenizer,tokenize
9
+
10
+ from huggingface_hub import PyTorchModelHubMixin
11
+ class TextEncoder(nn.Module):
12
+ def __init__(self, clip_model):
13
+
14
+ super().__init__()
15
+
16
+ self.transformer = clip_model.transformer
17
+ self.positional_embedding = clip_model.positional_embedding
18
+ self.ln_final = clip_model.ln_final
19
+ self.text_projection = clip_model.text_projection
20
+
21
+
22
+ def forward(self, prompts, tokenized_prompts):
23
+
24
+ x = prompts + self.positional_embedding
25
+ x = x.permute(1, 0, 2) # NLD -> LND
26
+ x,_,_ = self.transformer(x)
27
+ x = x.permute(1, 0, 2) # LND -> NLD
28
+ x = self.ln_final(x)
29
+ x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
30
+ return x
31
+
32
+
33
+
34
+ class PromptLearner(nn.Module):
35
+ def __init__(self,
36
+ prompts,
37
+ n_ctx, # prompt max len
38
+ CSC, # True or False multi prompt
39
+ class_token_position, # cls position
40
+ clip_model):
41
+
42
+ super().__init__()
43
+
44
+ ctx_dim = clip_model.ln_final.weight.shape[0] #
45
+
46
+ self.ctx={}
47
+
48
+ for cls in prompts:
49
+ for position in class_token_position:
50
+ if CSC:
51
+ ctx_vectors = torch.empty(len(prompts[cls]), n_ctx, ctx_dim).to(clip_model.device)
52
+ else:
53
+ ctx_vectors = torch.empty(n_ctx, ctx_dim).to(clip_model.device)
54
+ nn.init.normal_(ctx_vectors, std=0.02)
55
+ self.ctx['{}_{}'.format(cls,position)]=nn.Parameter(ctx_vectors,requires_grad=True)
56
+
57
+ self.ctx = nn.ParameterDict(self.ctx) # to be optimized
58
+
59
+ prompt_prefix = " ".join(["X"] * n_ctx)
60
+
61
+ _tokenizer = SimpleTokenizer()
62
+
63
+ prompts_split={cls: [prompt.replace("_", " ") for prompt in prompts[cls]] for cls in prompts}
64
+
65
+ prompts_lens= {cls: [ len(_tokenizer.encode(prompt)) for prompt in prompts_split[cls]] for cls in prompts_split}
66
+
67
+ prompts_learnable_tokens = {cls:[prompt_prefix + " " + prompt + "." for prompt in prompts_split[cls]] for cls in prompts_split}
68
+
69
+ tokenized_prompts = {cls:torch.cat([tokenize(prompt) for prompt in prompts_learnable_tokens[cls]]).to(clip_model.device) for cls in prompts_learnable_tokens}
70
+
71
+ with torch.no_grad():
72
+ embeddings = {cls:clip_model.token_embedding(tokenized_prompts[cls]) for cls in tokenized_prompts}
73
+
74
+ self.register_embeddings={}
75
+
76
+ for cls in embeddings:
77
+ self.register_embeddings['{}_token_prefix'.format(cls)]=embeddings[cls][:, :1, :]
78
+ self.register_embeddings['{}_token_suffix'.format(cls)]=embeddings[cls][:, 1 + n_ctx :, :]
79
+
80
+ self.n_ctx = n_ctx
81
+ self.tokenized_prompts = tokenized_prompts
82
+ self.prompts_lens = prompts_lens
83
+ self.class_token_position = class_token_position
84
+
85
+
86
+ def forward(self):
87
+ cls_prompts={}
88
+
89
+ for cls in self.tokenized_prompts:
90
+
91
+ prefix = self.register_embeddings['{}_token_prefix'.format(cls)]
92
+ suffix = self.register_embeddings['{}_token_suffix'.format(cls)]
93
+
94
+ cls_prompts[cls]=[]
95
+
96
+ for position in self.class_token_position:
97
+
98
+ ctx = self.ctx['{}_{}'.format(cls,position)]
99
+ if ctx.dim() == 2:
100
+ ctx = ctx.unsqueeze(0).expand(len(self.prompts_lens[cls]), -1, -1)
101
+
102
+ if position == "end":
103
+ prompts = torch.cat(
104
+ [
105
+ prefix, # (n_cls, 1, dim)
106
+ ctx, # (n_cls, n_ctx, dim)
107
+ suffix, # (n_cls, *, dim)
108
+ ],
109
+ dim=1,
110
+ )
111
+
112
+ elif position == "middle":
113
+
114
+ half_n_ctx = self.n_ctx // 2
115
+ prompts = []
116
+
117
+ for i in range(len(self.prompts_lens[cls])):
118
+ p_len = self.prompts_lens[cls][i]
119
+
120
+ prefix_i = prefix[i : i + 1, :, :]
121
+ class_i = suffix[i : i + 1, :p_len, :]
122
+ suffix_i = suffix[i : i + 1, p_len:, :]
123
+ ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
124
+ ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
125
+
126
+ prompt = torch.cat(
127
+ [
128
+ prefix_i, # (1, 1, dim)
129
+ ctx_i_half1, # (1, n_ctx//2, dim)
130
+ class_i, # (1, name_len, dim)
131
+ ctx_i_half2, # (1, n_ctx//2, dim)
132
+ suffix_i, # (1, *, dim)
133
+ ],
134
+ dim=1,
135
+ )
136
+ prompts.append(prompt)
137
+ prompts = torch.cat(prompts, dim=0)
138
+
139
+ else :
140
+ assert position == "front"
141
+ prompts = []
142
+
143
+ for i in range(len(self.prompts_lens[cls])):
144
+ p_len = self.prompts_lens[cls][i]
145
+
146
+ prefix_i = prefix[i : i + 1, :, :]
147
+ class_i = suffix[i : i + 1, :p_len, :]
148
+ suffix_i = suffix[i : i + 1, p_len:, :]
149
+ ctx_i = ctx[i : i + 1, :, :]
150
+ prompt = torch.cat(
151
+ [
152
+ prefix_i, # (1, 1, dim)
153
+ class_i, # (1, name_len, dim)
154
+ ctx_i, # (1, n_ctx, dim)
155
+ suffix_i, # (1, *, dim)
156
+ ],
157
+ dim=1,
158
+ )
159
+ prompts.append(prompt)
160
+
161
+ prompts = torch.cat(prompts, dim=0)
162
+
163
+ cls_prompts[cls].append(prompts)
164
+ cls_prompts[cls]=torch.cat(cls_prompts[cls],dim=0)
165
+ return cls_prompts
166
+
167
+
168
+ class PromptMaker(nn.Module,
169
+ PyTorchModelHubMixin,
170
+ repo_url="https://github.com/Continual-Mega/Continual-Mega",
171
+ paper_url="https://arxiv.org/abs/2506.00956"):
172
+
173
+ def __init__(self,
174
+ prompts,
175
+ clip_model,
176
+ n_ctx: int=8, # prompt max len
177
+ CSC: bool= True, # True or False multi prompt
178
+ class_token_position: list=['end'], # cls position
179
+ ):
180
+
181
+ super().__init__()
182
+ assert 'normal' in prompts and 'abnormal' in prompts
183
+
184
+ for position in class_token_position:
185
+ assert position in ['end','middle','front']
186
+
187
+ self.prompt_learner = PromptLearner(prompts, n_ctx, CSC, class_token_position, clip_model)
188
+ self.tokenized_prompts = self.prompt_learner.tokenized_prompts
189
+
190
+ self.class_token_position = class_token_position
191
+ self.text_encoder = TextEncoder(clip_model)
192
+
193
+ def forward(self):
194
+ prompts = self.prompt_learner()
195
+ tokenized_prompts = self.tokenized_prompts
196
+ text_features=[]
197
+
198
+ for cls in prompts:
199
+ class_embedding = self.text_encoder(prompts[cls], tokenized_prompts[cls].repeat(len(self.class_token_position),1))
200
+ class_embedding = class_embedding.mean(dim=0)
201
+ class_embedding = class_embedding / class_embedding.norm()
202
+ text_features.append(class_embedding)
203
+ text_features = torch.stack(text_features, dim=1)
204
+
205
+ return text_features