shunk031 commited on
Commit
fd8303c
1 Parent(s): 5511dd3

Delete v2.py

Browse files
Files changed (1) hide show
  1. v2.py +0 -137
v2.py DELETED
@@ -1,137 +0,0 @@
1
- from collections import OrderedDict
2
- from typing import Dict, Final, Optional, Tuple, Union
3
-
4
- import torch
5
- import torch.nn as nn
6
- from transformers import CLIPVisionModelWithProjection, logging
7
- from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention
8
- from transformers.models.clip.configuration_clip import CLIPVisionConfig
9
-
10
- logging.set_verbosity_error()
11
-
12
- URLS_LINEAR: Final[Dict[str, str]] = {
13
- "sac+logos+ava1-l14-linearMSE": "https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/sac%2Blogos%2Bava1-l14-linearMSE.pth",
14
- "ava+logos-l14-linearMSE": "https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/ava%2Blogos-l14-linearMSE.pth",
15
- }
16
-
17
-
18
- URLS_RELU: Final[Dict[str, str]] = {
19
- "ava+logos-l14-reluMSE": "https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/ava%2Blogos-l14-reluMSE.pth",
20
- }
21
-
22
-
23
- class AestheticsPredictorV2Linear(CLIPVisionModelWithProjection):
24
- def __init__(self, config: CLIPVisionConfig) -> None:
25
- super().__init__(config)
26
- self.layers = nn.Sequential(
27
- nn.Linear(config.projection_dim, 1024),
28
- nn.Dropout(0.2),
29
- nn.Linear(1024, 128),
30
- nn.Dropout(0.2),
31
- nn.Linear(128, 64),
32
- nn.Dropout(0.1),
33
- nn.Linear(64, 16),
34
- nn.Linear(16, 1),
35
- )
36
- self.post_init()
37
-
38
- def forward(
39
- self,
40
- pixel_values: Optional[torch.FloatTensor] = None,
41
- output_attentions: Optional[bool] = None,
42
- output_hidden_states: Optional[bool] = None,
43
- labels: Optional[torch.Tensor] = None,
44
- return_dict: Optional[bool] = None,
45
- ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
46
- return_dict = (
47
- return_dict if return_dict is not None else self.config.use_return_dict
48
- )
49
-
50
- outputs = super().forward(
51
- pixel_values=pixel_values,
52
- output_attentions=output_attentions,
53
- output_hidden_states=output_hidden_states,
54
- return_dict=return_dict,
55
- )
56
- image_embeds = outputs[0] # image_embeds
57
- image_embeds /= image_embeds.norm(dim=-1, keepdim=True)
58
-
59
- prediction = self.layers(image_embeds)
60
-
61
- loss = None
62
- if labels is not None:
63
- loss_fct = nn.MSELoss()
64
- loss = loss_fct()
65
-
66
- if not return_dict:
67
- return (loss, prediction, image_embeds)
68
-
69
- return ImageClassifierOutputWithNoAttention(
70
- loss=loss,
71
- logits=prediction,
72
- hidden_states=image_embeds,
73
- )
74
-
75
-
76
- class AestheticsPredictorV2ReLU(AestheticsPredictorV2Linear):
77
- def __init__(self, config: CLIPVisionConfig):
78
- super().__init__(config)
79
- self.layers = nn.Sequential(
80
- nn.Linear(config.projection_dim, 1024),
81
- nn.ReLU(),
82
- nn.Dropout(0.2),
83
- nn.Linear(1024, 128),
84
- nn.ReLU(),
85
- nn.Dropout(0.2),
86
- nn.Linear(128, 64),
87
- nn.ReLU(),
88
- nn.Dropout(0.1),
89
- nn.Linear(64, 16),
90
- nn.ReLU(),
91
- nn.Linear(16, 1),
92
- )
93
- self.post_init()
94
-
95
-
96
- def convert_v2_linear_from_openai_clip(
97
- predictor_head_name: str,
98
- openai_model_name: str = "openai/clip-vit-large-patch14",
99
- ) -> AestheticsPredictorV2Linear:
100
- model = AestheticsPredictorV2Linear.from_pretrained(openai_model_name)
101
-
102
- state_dict = torch.hub.load_state_dict_from_url(
103
- URLS_LINEAR[predictor_head_name], map_location="cpu"
104
- )
105
- assert isinstance(state_dict, OrderedDict)
106
-
107
- # remove `layers.` from the key of the state_dict
108
- state_dict = OrderedDict(
109
- ((k.replace("layers.", ""), v) for k, v in state_dict.items())
110
- )
111
- model.layers.load_state_dict(state_dict)
112
-
113
- model.eval()
114
-
115
- return model
116
-
117
-
118
- def convert_v2_relu_from_openai_clip(
119
- predictor_head_name: str,
120
- openai_model_name: str = "openai/clip-vit-large-patch14",
121
- ) -> AestheticsPredictorV2ReLU:
122
- model = AestheticsPredictorV2ReLU.from_pretrained(openai_model_name)
123
-
124
- state_dict = torch.hub.load_state_dict_from_url(
125
- URLS_RELU[predictor_head_name], map_location="cpu"
126
- )
127
- assert isinstance(state_dict, OrderedDict)
128
-
129
- # remove `layers.` from the key of the state_dict
130
- state_dict = OrderedDict(
131
- ((k.replace("layers.", ""), v) for k, v in state_dict.items())
132
- )
133
- model.layers.load_state_dict(state_dict)
134
-
135
- model.eval()
136
-
137
- return model