shunk031 commited on
Commit
0818091
1 Parent(s): eb096d6

Upload AestheticsPredictorV1

Browse files
Files changed (2) hide show
  1. config.json +2 -2
  2. modeling_v1.py +9 -2
config.json CHANGED
@@ -1,10 +1,10 @@
1
  {
2
- "_name_or_path": "openai/clip-vit-base-patch32",
3
  "architectures": [
4
  "AestheticsPredictorV1"
5
  ],
6
  "attention_dropout": 0.0,
7
  "auto_map": {
 
8
  "AutoModel": "modeling_v1.AestheticsPredictorV1"
9
  },
10
  "dropout": 0.0,
@@ -15,7 +15,7 @@
15
  "initializer_range": 0.02,
16
  "intermediate_size": 3072,
17
  "layer_norm_eps": 1e-05,
18
- "model_type": "clip_vision_model",
19
  "num_attention_heads": 12,
20
  "num_channels": 3,
21
  "num_hidden_layers": 12,
 
1
  {
 
2
  "architectures": [
3
  "AestheticsPredictorV1"
4
  ],
5
  "attention_dropout": 0.0,
6
  "auto_map": {
7
+ "AutoConfig": "configuration_predictor.AestheticsPredictorConfig",
8
  "AutoModel": "modeling_v1.AestheticsPredictorV1"
9
  },
10
  "dropout": 0.0,
 
15
  "initializer_range": 0.02,
16
  "intermediate_size": 3072,
17
  "layer_norm_eps": 1e-05,
18
+ "model_type": "aesthetics_predictor",
19
  "num_attention_heads": 12,
20
  "num_channels": 3,
21
  "num_hidden_layers": 12,
modeling_v1.py CHANGED
@@ -54,8 +54,15 @@ class AestheticsPredictorV1(CLIPVisionModelWithProjection):
54
  )
55
 
56
 
57
- def convert_from_openai_clip(openai_model_name: str) -> AestheticsPredictorV1:
58
- model = AestheticsPredictorV1.from_pretrained(openai_model_name)
 
 
 
 
 
 
 
59
  state_dict = torch.hub.load_state_dict_from_url(URLS[openai_model_name])
60
  model.predictor.load_state_dict(state_dict)
61
  model.eval()
 
54
  )
55
 
56
 
57
+ def convert_from_openai_clip(
58
+ openai_model_name: str, config: Optional[AestheticsPredictorConfig] = None
59
+ ) -> AestheticsPredictorV1:
60
+ config = config or AestheticsPredictorConfig.from_pretrained(openai_model_name)
61
+ model = AestheticsPredictorV1(config)
62
+
63
+ clip_model = CLIPVisionModelWithProjection.from_pretrained(openai_model_name)
64
+ model.load_state_dict(clip_model.state_dict(), strict=False)
65
+
66
  state_dict = torch.hub.load_state_dict_from_url(URLS[openai_model_name])
67
  model.predictor.load_state_dict(state_dict)
68
  model.eval()