| from transformers import PreTrainedModel, PretrainedConfig | |
| from .module import ConditionalViT | |
| from sentence_transformers import SentenceTransformer | |
| class CondViTConfig(PretrainedConfig): | |
| model_type = "condvit" | |
| def __init__( | |
| self, | |
| input_resolution: int = 224, | |
| patch_size: int = 16, | |
| width: int = 768, | |
| layers: int = 12, | |
| heads: int = 12, | |
| output_dim: int = 512, | |
| n_categories: int = 10, | |
| lm_backbone: str = "sentence-transformers/sentence-t5-xl", | |
| lm_revision: str = "e0976ba9afd18be963c22c680367a3928c44fd22", | |
| device: str = "cpu", | |
| **kwargs | |
| ): | |
| self.input_resolution = input_resolution | |
| self.patch_size = patch_size | |
| self.width = width | |
| self.layers = layers | |
| self.heads = heads | |
| self.output_dim = output_dim | |
| self.n_categories = n_categories | |
| self.lm_backbone = lm_backbone | |
| self.lm_revision = lm_revision | |
| self.device = device | |
| super().__init__(**kwargs) | |
| class CondViTForEmbedding(PreTrainedModel): | |
| config_class = CondViTConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.condvit = ConditionalViT( | |
| input_resolution=config.input_resolution, | |
| patch_size=config.patch_size, | |
| width=config.width, | |
| layers=config.layers, | |
| heads=config.heads, | |
| output_dim=config.output_dim, | |
| ) | |
| if config.device: | |
| self.condvit.to(config.device) | |
| self.lm = SentenceTransformer( | |
| config.lm_backbone, revision=config.lm_revision, device=config.device | |
| ) | |
| def forward(self, pixel_values, texts=None): | |
| if texts is not None: | |
| text_embeddings = self.lm.encode( | |
| texts, | |
| convert_to_tensor=True, | |
| convert_to_numpy=False, | |
| ) | |
| else: | |
| text_embeddings = None | |
| return self.condvit(imgs=pixel_values, c=text_embeddings) | |