File size: 2,257 Bytes
d8530c7
 
 
 
 
 
 
 
 
7d87cc1
d8530c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
from typing import List, Union

import torch
from torch import Tensor, nn

class ClipTextEncoder(nn.Module):
    def __init__(
            self,
            modelpath: str='openai/clip-vit-large-patch14', # clip-vit-base-patch32
            finetune: bool = False,
            **kwargs
        ) -> None:

        super().__init__()
        from transformers import logging
        from transformers import AutoModel, AutoTokenizer
        logging.set_verbosity_error()
        # Tokenizer
        os.environ["TOKENIZERS_PARALLELISM"] = "false"

        self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
        self.text_model = AutoModel.from_pretrained(modelpath)

        # Don't train the model
        if not finetune:
            self.text_model.training = False
            for p in self.text_model.parameters():
                p.requires_grad = False

        # Then configure the model
        self.max_length = self.tokenizer.model_max_length
        self.text_encoded_dim = self.text_model.config.text_config.hidden_size

    def forward(self, texts: List[str]):
        # get prompt text embeddings
        text_inputs = self.tokenizer(
            texts,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        text_input_ids = text_inputs.input_ids.to(self.text_model.device)
        txt_att_mask = text_inputs.attention_mask.to(self.text_model.device)
            # split into max length Clip can handle
        if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
            text_input_ids = text_input_ids[:, :self.tokenizer.
                                            model_max_length]

        # use pooled ouuput if latent dim is two-dimensional
        # pooled = 0 if self.latent_dim[0] == 1 else 1 # (bs, seq_len, text_encoded_dim) -> (bs, text_encoded_dim)
        # text encoder forward, clip must use get_text_features
        # (batch_Size, seq_length , text_encoded_dim)
        text_embeddings = self.text_model.text_model(text_input_ids,
                            # attention_mask=txt_att_mask
                            ).last_hidden_state

        return text_embeddings, txt_att_mask.bool()