from typing import Dict, Final, Optional, Tuple, Union import torch import torch.nn as nn from transformers import CLIPVisionModelWithProjection, logging from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention from .configuration_predictor import AestheticsPredictorConfig logging.set_verbosity_error() URLS: Final[Dict[str, str]] = { "openai/clip-vit-base-patch16": "https://github.com/LAION-AI/aesthetic-predictor/raw/main/sa_0_4_vit_b_16_linear.pth", "openai/clip-vit-base-patch32": "https://github.com/LAION-AI/aesthetic-predictor/raw/main/sa_0_4_vit_b_32_linear.pth", "openai/clip-vit-large-patch14": "https://github.com/LAION-AI/aesthetic-predictor/raw/main/sa_0_4_vit_l_14_linear.pth", } class AestheticsPredictorV1(CLIPVisionModelWithProjection): def __init__(self, config: AestheticsPredictorConfig) -> None: super().__init__(config) self.predictor = nn.Linear(config.projection_dim, 1) self.post_init() def forward( self, pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]: return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) outputs = super().forward( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) image_embeds = outputs[0] # image_embeds image_embeds /= image_embeds.norm(dim=-1, keepdim=True) prediction = self.predictor(image_embeds) if not return_dict: return (None, prediction, image_embeds) return ImageClassifierOutputWithNoAttention( loss=None, logits=prediction, hidden_states=image_embeds, ) def convert_from_openai_clip( openai_model_name: str, config: Optional[AestheticsPredictorConfig] = None ) -> AestheticsPredictorV1: config = config or AestheticsPredictorConfig.from_pretrained(openai_model_name) model = AestheticsPredictorV1(config) clip_model = CLIPVisionModelWithProjection.from_pretrained(openai_model_name) model.load_state_dict(clip_model.state_dict(), strict=False) state_dict = torch.hub.load_state_dict_from_url(URLS[openai_model_name]) model.predictor.load_state_dict(state_dict) model.eval() return model