not-lain commited on
Commit
58426ee
1 Parent(s): 63aa2d2

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. CustomPipe.py +54 -0
  2. config.json +18 -6
  3. model.safetensors +2 -2
CustomPipe.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+
4
+ from transformers import (
5
+ AutoModelForImageClassification,
6
+ AutoImageProcessor,
7
+ Pipeline,
8
+ )
9
+
10
+ import numpy as np
11
+ from typing import Union
12
+
13
+ class SiglipTaggerPipe(Pipeline):
14
+ def __init__(self,**kwargs):
15
+ self.processor = AutoImageProcessor.from_pretrained("p1atdev/siglip-tagger-test-3")
16
+ if "torch_dtype" not in kwargs :
17
+ kwargs["torch_dtype"] = torch.bfloat16
18
+ Pipeline.__init__(self,**kwargs)
19
+ def _sanitize_parameters(self, **kwargs):
20
+ postprocess_kwargs = {}
21
+ if "threshold" in kwargs :
22
+ # if threshold parameter is present
23
+ # we pass it to the postprocess method
24
+ postprocess_kwargs["threshold"] = kwargs["threshold"]
25
+ return {},{},postprocess_kwargs
26
+
27
+ def preprocess(self,inputs: Union[str,Image.Image,np.ndarray]):
28
+ if isinstance(inputs,str) :
29
+ img = Image.open(inputs)
30
+ elif isinstance(inputs,Image.Image) :
31
+ img = inputs
32
+ else :
33
+ # TODO: double check this implementation
34
+ # consider adding try except
35
+ # maybe add url checker too
36
+ img = Image.fromarray(inputs)
37
+
38
+ inputs = self.processor(img, return_tensors="pt").to(self.model.device, self.model.dtype)
39
+ return inputs
40
+
41
+ def _forward(self,inputs):
42
+ logits = self.model(**inputs).logits.detach().cpu().float()[0]
43
+ logits = np.clip(logits, 0.0, 1.0)
44
+ return logits
45
+ def postprocess(self,logits,threshold:float=0):
46
+ results = {
47
+ self.model.config.id2label[i]: logit for i, logit in enumerate(logits) if logit > 0
48
+ }
49
+ results = sorted(results.items(), key=lambda x: x[1], reverse=True)
50
+ out = {}
51
+ for tag, score in results:
52
+ if score >= threshold :
53
+ out[tag] = f"{score*100:.2f}"
54
+ return out
config.json CHANGED
@@ -1,10 +1,22 @@
1
  {
2
- "_name_or_path": "google/siglip-so400m-patch14-384",
3
- "architectures": ["SiglipForImageClassification"],
 
 
 
4
  "auto_map": {
5
- "AutoModelForImageClassification": "modeling_siglip.SiglipForImageClassification"
 
 
 
 
 
 
 
 
 
 
6
  },
7
- "attention_dropout": 0.0,
8
  "hidden_act": "gelu_pytorch_tanh",
9
  "hidden_size": 1152,
10
  "id2label": {
@@ -19047,13 +19059,13 @@
19047
  "zzz": 9515,
19048
  "|_": 9516
19049
  },
19050
- "layer_norm_eps": 1e-6,
19051
  "model_type": "siglip_vision_model",
19052
  "num_attention_heads": 16,
19053
  "num_channels": 3,
19054
  "num_hidden_layers": 27,
19055
  "patch_size": 14,
19056
  "problem_type": "multi_label_classification",
19057
- "torch_dtype": "bfloat16",
19058
  "transformers_version": "4.37.2"
19059
  }
 
1
  {
2
+ "_name_or_path": "p1atdev/siglip-tagger-test-3",
3
+ "architectures": [
4
+ "SiglipForImageClassification"
5
+ ],
6
+ "attention_dropout": 0.0,
7
  "auto_map": {
8
+ "AutoModelForImageClassification": "p1atdev/siglip-tagger-test-3--modeling_siglip.SiglipForImageClassification"
9
+ },
10
+ "custom_pipelines": {
11
+ "image-classification": {
12
+ "impl": "CustomPipe.SiglipTaggerPipe",
13
+ "pt": [
14
+ "AutoModelForImageClassification"
15
+ ],
16
+ "tf": [],
17
+ "type": "image"
18
+ }
19
  },
 
20
  "hidden_act": "gelu_pytorch_tanh",
21
  "hidden_size": 1152,
22
  "id2label": {
 
19059
  "zzz": 9515,
19060
  "|_": 9516
19061
  },
19062
+ "layer_norm_eps": 1e-06,
19063
  "model_type": "siglip_vision_model",
19064
  "num_attention_heads": 16,
19065
  "num_channels": 3,
19066
  "num_hidden_layers": 27,
19067
  "patch_size": 14,
19068
  "problem_type": "multi_label_classification",
19069
+ "torch_dtype": "float32",
19070
  "transformers_version": "4.37.2"
19071
  }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ce46ef29aec79fcf0fbe8280521acb10381ef5000706af5f870c08af781fb3eb
3
- size 878455682
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c57dce403a3fbb0b10dd311cd84cc12ecbf884ae444f54aa6f941f5fb3e06f7
3
+ size 1756853084