| import torch | |
| import torch.nn as nn | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| class FrozenCLIPEmbedder(nn.Module): | |
| def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, freeze=True, layer="last", layer_idx=None): | |
| super().__init__() | |
| self.tokenizer = CLIPTokenizer.from_pretrained(version) | |
| self.transformer = CLIPTextModel.from_pretrained(version) | |
| self.device = device | |
| self.max_length = max_length | |
| self.layer = layer | |
| self.layer_idx = layer_idx | |
| if freeze: | |
| self.transformer = self.transformer.eval() | |
| for p in self.parameters(): | |
| p.requires_grad = False | |
| def forward(self, text): | |
| enc = self.tokenizer( | |
| text, truncation=True, max_length=self.max_length, | |
| return_length=True, return_overflowing_tokens=False, | |
| padding="max_length", return_tensors="pt" | |
| ) | |
| tokens = enc["input_ids"].to(next(self.transformer.parameters()).device) | |
| out = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden") | |
| if self.layer == "last": | |
| return out.last_hidden_state | |
| if self.layer == "pooled": | |
| return out.pooler_output[:, None, :] | |
| return out.hidden_states[self.layer_idx] | |
| def encode(self, text): | |
| return self(text) | |