vesteinn commited on
Commit
b83ef5e
2 Parent(s): a058c06 a29a069

Merge branch 'main' of https://huggingface.co/vesteinn/vit-mae-inat21

Browse files
Files changed (1) hide show
  1. README.md +171 -0
README.md ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Note that this model does not work directly with HF, a modification that does mean pooling before the layernorm and classification head is needed.
2
+
3
+
4
+ ```python
5
+ from transformers import (
6
+ ViTForImageClassification,
7
+ pipeline,
8
+ AutoImageProcessor,
9
+ ViTConfig,
10
+ ViTModel,
11
+ )
12
+
13
+ from transformers.modeling_outputs import (
14
+ ImageClassifierOutput,
15
+ BaseModelOutputWithPooling,
16
+ )
17
+
18
+ from PIL import Image
19
+ import torch
20
+ from torch import nn
21
+ from typing import Optional, Union, Tuple
22
+
23
+
24
+ class CustomViTModel(ViTModel):
25
+ def forward(
26
+ self,
27
+ pixel_values: Optional[torch.Tensor] = None,
28
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
29
+ head_mask: Optional[torch.Tensor] = None,
30
+ output_attentions: Optional[bool] = None,
31
+ output_hidden_states: Optional[bool] = None,
32
+ interpolate_pos_encoding: Optional[bool] = None,
33
+ return_dict: Optional[bool] = None,
34
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
35
+ r"""
36
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
37
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
38
+ """
39
+ output_attentions = (
40
+ output_attentions
41
+ if output_attentions is not None
42
+ else self.config.output_attentions
43
+ )
44
+ output_hidden_states = (
45
+ output_hidden_states
46
+ if output_hidden_states is not None
47
+ else self.config.output_hidden_states
48
+ )
49
+ return_dict = (
50
+ return_dict if return_dict is not None else self.config.use_return_dict
51
+ )
52
+
53
+ if pixel_values is None:
54
+ raise ValueError("You have to specify pixel_values")
55
+
56
+ # Prepare head mask if needed
57
+ # 1.0 in head_mask indicate we keep the head
58
+ # attention_probs has shape bsz x n_heads x N x N
59
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
60
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
61
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
62
+
63
+ # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
64
+ expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
65
+ if pixel_values.dtype != expected_dtype:
66
+ pixel_values = pixel_values.to(expected_dtype)
67
+
68
+ embedding_output = self.embeddings(
69
+ pixel_values,
70
+ bool_masked_pos=bool_masked_pos,
71
+ interpolate_pos_encoding=interpolate_pos_encoding,
72
+ )
73
+
74
+ encoder_outputs = self.encoder(
75
+ embedding_output,
76
+ head_mask=head_mask,
77
+ output_attentions=output_attentions,
78
+ output_hidden_states=output_hidden_states,
79
+ return_dict=return_dict,
80
+ )
81
+ sequence_output = encoder_outputs[0]
82
+ sequence_output = sequence_output[:, 1:, :].mean(dim=1)
83
+
84
+ sequence_output = self.layernorm(sequence_output)
85
+ pooled_output = (
86
+ self.pooler(sequence_output) if self.pooler is not None else None
87
+ )
88
+
89
+ if not return_dict:
90
+ head_outputs = (
91
+ (sequence_output, pooled_output)
92
+ if pooled_output is not None
93
+ else (sequence_output,)
94
+ )
95
+ return head_outputs + encoder_outputs[1:]
96
+
97
+ return BaseModelOutputWithPooling(
98
+ last_hidden_state=sequence_output,
99
+ pooler_output=pooled_output,
100
+ hidden_states=encoder_outputs.hidden_states,
101
+ attentions=encoder_outputs.attentions,
102
+ )
103
+
104
+
105
+ class CustomViTForImageClassification(ViTForImageClassification):
106
+ def __init__(self, config: ViTConfig) -> None:
107
+ super().__init__(config)
108
+
109
+ self.num_labels = config.num_labels
110
+ self.vit = CustomViTModel(config, add_pooling_layer=False)
111
+
112
+ # Classifier head
113
+ self.classifier = (
114
+ nn.Linear(config.hidden_size, config.num_labels)
115
+ if config.num_labels > 0
116
+ else nn.Identity()
117
+ )
118
+
119
+ # Initialize weights and apply final processing
120
+ self.post_init()
121
+
122
+ def forward(
123
+ self,
124
+ pixel_values: Optional[torch.Tensor] = None,
125
+ head_mask: Optional[torch.Tensor] = None,
126
+ labels: Optional[torch.Tensor] = None,
127
+ output_attentions: Optional[bool] = None,
128
+ output_hidden_states: Optional[bool] = None,
129
+ interpolate_pos_encoding: Optional[bool] = None,
130
+ return_dict: Optional[bool] = None,
131
+ ) -> Union[tuple, ImageClassifierOutput]:
132
+ r"""
133
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
134
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
135
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
136
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
137
+ """
138
+ return_dict = (
139
+ return_dict if return_dict is not None else self.config.use_return_dict
140
+ )
141
+
142
+ outputs = self.vit(
143
+ pixel_values,
144
+ head_mask=head_mask,
145
+ output_attentions=output_attentions,
146
+ output_hidden_states=output_hidden_states,
147
+ interpolate_pos_encoding=interpolate_pos_encoding,
148
+ return_dict=return_dict,
149
+ )
150
+
151
+ sequence_output = outputs[0]
152
+
153
+ logits = self.classifier(sequence_output)
154
+
155
+ loss = None
156
+
157
+ return ImageClassifierOutput(
158
+ loss=loss,
159
+ logits=logits,
160
+ hidden_states=outputs.hidden_states,
161
+ attentions=outputs.attentions,
162
+ )
163
+
164
+
165
+ model = CustomViTForImageClassification.from_pretrained("vesteinn/vit-mae-cub")
166
+ image_processor = AutoImageProcessor.from_pretrained("vesteinn/vit-mae-cub")
167
+
168
+ classifier = pipeline(
169
+ "image-classification", model=model, image_processor=image_processor
170
+ )
171
+ ```