File size: 2,029 Bytes
afa42f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from PIL import Image
import torch

from transformers import (
    AutoModelForImageClassification,
    AutoImageProcessor,
    Pipeline,
)

import numpy as np
from typing import Union

class SiglipTaggerPipe(Pipeline):
    def __init__(self,**kwargs):
      self.processor = AutoImageProcessor.from_pretrained("p1atdev/siglip-tagger-test-3")
      if "torch_dtype" not in kwargs :
        kwargs["torch_dtype"] = torch.bfloat16
      Pipeline.__init__(self,**kwargs)
    def _sanitize_parameters(self, **kwargs):
      postprocess_kwargs = {}
      if "threshold" in kwargs :
        # if threshold parameter is present
        # we pass it to the postprocess method
        postprocess_kwargs["threshold"] = kwargs["threshold"]
      if "return_scores" in kwargs :
        postprocess_kwargs["return_scores"] = kwargs["return_scores"]
      return {},{},postprocess_kwargs

    def preprocess(self,inputs: Union[str,Image.Image,np.ndarray]):
      if isinstance(inputs,str) :
          img =  Image.open(inputs)
      elif isinstance(inputs,Image.Image) :
          img = inputs
      else :
        # TODO: double check this implementation
        # consider adding try except
        # maybe add url checker too
        img = Image.fromarray(inputs)

      inputs = self.processor(img, return_tensors="pt").to(self.model.device, self.model.dtype)
      return inputs

    def _forward(self,inputs):
      logits = self.model(**inputs).logits.detach().cpu().float()[0]
      logits = np.clip(logits, 0.0, 1.0)
      return logits
    def postprocess(self,logits,threshold:float=0,return_scores=False):
      results = {
          self.model.config.id2label[i]: logit for i, logit in enumerate(logits) if logit > 0
      }
      results = sorted(results.items(), key=lambda x: x[1], reverse=True)
      out = {}
      for tag, score in results:
        if score >= threshold :
          out[tag] = f"{score*100:.2f}"
      if return_scores == True : 
        return out
      else : 
        return ", ".join(list(out.keys()))