not-lain commited on
Commit
afa42f2
1 Parent(s): 63aa2d2

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. CustomPipe.py +59 -0
  2. config.json +10 -0
  3. model.safetensors +2 -2
CustomPipe.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+
4
+ from transformers import (
5
+ AutoModelForImageClassification,
6
+ AutoImageProcessor,
7
+ Pipeline,
8
+ )
9
+
10
+ import numpy as np
11
+ from typing import Union
12
+
13
+ class SiglipTaggerPipe(Pipeline):
14
+ def __init__(self,**kwargs):
15
+ self.processor = AutoImageProcessor.from_pretrained("p1atdev/siglip-tagger-test-3")
16
+ if "torch_dtype" not in kwargs :
17
+ kwargs["torch_dtype"] = torch.bfloat16
18
+ Pipeline.__init__(self,**kwargs)
19
+ def _sanitize_parameters(self, **kwargs):
20
+ postprocess_kwargs = {}
21
+ if "threshold" in kwargs :
22
+ # if threshold parameter is present
23
+ # we pass it to the postprocess method
24
+ postprocess_kwargs["threshold"] = kwargs["threshold"]
25
+ if "return_scores" in kwargs :
26
+ postprocess_kwargs["return_scores"] = kwargs["return_scores"]
27
+ return {},{},postprocess_kwargs
28
+
29
+ def preprocess(self,inputs: Union[str,Image.Image,np.ndarray]):
30
+ if isinstance(inputs,str) :
31
+ img = Image.open(inputs)
32
+ elif isinstance(inputs,Image.Image) :
33
+ img = inputs
34
+ else :
35
+ # TODO: double check this implementation
36
+ # consider adding try except
37
+ # maybe add url checker too
38
+ img = Image.fromarray(inputs)
39
+
40
+ inputs = self.processor(img, return_tensors="pt").to(self.model.device, self.model.dtype)
41
+ return inputs
42
+
43
+ def _forward(self,inputs):
44
+ logits = self.model(**inputs).logits.detach().cpu().float()[0]
45
+ logits = np.clip(logits, 0.0, 1.0)
46
+ return logits
47
+ def postprocess(self,logits,threshold:float=0,return_scores=False):
48
+ results = {
49
+ self.model.config.id2label[i]: logit for i, logit in enumerate(logits) if logit > 0
50
+ }
51
+ results = sorted(results.items(), key=lambda x: x[1], reverse=True)
52
+ out = {}
53
+ for tag, score in results:
54
+ if score >= threshold :
55
+ out[tag] = f"{score*100:.2f}"
56
+ if return_scores == True :
57
+ return out
58
+ else :
59
+ return ", ".join(list(out.keys()))
config.json CHANGED
@@ -4,6 +4,16 @@
4
  "auto_map": {
5
  "AutoModelForImageClassification": "modeling_siglip.SiglipForImageClassification"
6
  },
 
 
 
 
 
 
 
 
 
 
7
  "attention_dropout": 0.0,
8
  "hidden_act": "gelu_pytorch_tanh",
9
  "hidden_size": 1152,
 
4
  "auto_map": {
5
  "AutoModelForImageClassification": "modeling_siglip.SiglipForImageClassification"
6
  },
7
+ "custom_pipelines": {
8
+ "image-classification": {
9
+ "impl": "CustomPipe.SiglipTaggerPipe",
10
+ "pt": [
11
+ "AutoModelForImageClassification"
12
+ ],
13
+ "tf": [],
14
+ "type": "image"
15
+ }
16
+ },
17
  "attention_dropout": 0.0,
18
  "hidden_act": "gelu_pytorch_tanh",
19
  "hidden_size": 1152,
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ce46ef29aec79fcf0fbe8280521acb10381ef5000706af5f870c08af781fb3eb
3
- size 878455682
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c57dce403a3fbb0b10dd311cd84cc12ecbf884ae444f54aa6f941f5fb3e06f7
3
+ size 1756853084