stupidog04 commited on
Commit
100aae6
1 Parent(s): d6c86e9

commit files to HF hub

Browse files
README.md CHANGED
@@ -11,11 +11,11 @@ model-index:
11
  results:
12
  - task:
13
  name: Image Classification
14
- type: image-classification
15
  metrics:
16
  - name: Accuracy
17
  type: accuracy
18
- value: 0.375
19
  ---
20
 
21
  # krenzcolor_chkpt_classifier
 
11
  results:
12
  - task:
13
  name: Image Classification
14
+ type: pair-classification
15
  metrics:
16
  - name: Accuracy
17
  type: accuracy
18
+ value: 0.3928571343421936
19
  ---
20
 
21
  # krenzcolor_chkpt_classifier
config.json CHANGED
@@ -4,6 +4,17 @@
4
  "ViTForImageClassification"
5
  ],
6
  "attention_probs_dropout_prob": 0.0,
 
 
 
 
 
 
 
 
 
 
 
7
  "encoder_stride": 16,
8
  "hidden_act": "gelu",
9
  "hidden_dropout_prob": 0.0,
 
4
  "ViTForImageClassification"
5
  ],
6
  "attention_probs_dropout_prob": 0.0,
7
+ "custom_pipelines": {
8
+ "pair-classification": {
9
+ "impl": "pair_classification.PairClassificationPipeline",
10
+ "pt": [
11
+ "ViTForImageClassification"
12
+ ],
13
+ "tf": [
14
+ "TFViTForImageClassification"
15
+ ]
16
+ }
17
+ },
18
  "encoder_stride": 16,
19
  "hidden_act": "gelu",
20
  "hidden_dropout_prob": 0.0,
pair_classification.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+ from transformers import ImageClassificationPipeline
3
+ import torch
4
+
5
+
6
+ class PairClassificationPipeline(ImageClassificationPipeline):
7
+ pipe_to_tensor = transforms.ToTensor()
8
+ pipe_to_pil = transforms.ToPILImage()
9
+
10
+ def preprocess(self, image):
11
+ left_image, right_image = self.horizontal_split_image(image)
12
+ model_inputs = self.extract_split_feature(left_image, right_image)
13
+ # model_inputs = super().preprocess(image)
14
+ # print(model_inputs['pixel_values'].shape)
15
+ return model_inputs
16
+
17
+ def horizontal_split_image(self, image):
18
+ # image = image.resize((448,224))
19
+ w, h = image.size
20
+ half_w = w//2
21
+ left_image = image.crop([0,0,half_w,h])
22
+ right_image = image.crop([half_w,0,2*half_w,h])
23
+ return left_image, right_image
24
+
25
+ def extract_split_feature(self, left_image, right_image):
26
+ model_inputs = self.feature_extractor(images=left_image, return_tensors=self.framework)
27
+ right_inputs = self.feature_extractor(images=right_image, return_tensors=self.framework)
28
+ model_inputs['pixel_values'] = torch.cat([model_inputs['pixel_values'],right_inputs['pixel_values']], dim=1)
29
+ return model_inputs
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5c850dd17c2fb514bd822185575af28eb23a38a41eeff521fca00ad8d9f43efc
3
  size 345635761
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5233e44bcfb172d9200f5f0c63a00ac4411a1783bf6dd42b48dce384a468da1
3
  size 345635761
runs/events.out.tfevents.1666962937.sa103.22908.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a878e370777ce85d35b80fa02f4354af89b306a44916d4b592109528e5dad9ed
3
+ size 1939