siglip-tagger-test-3 / CustomPipe.py
not-lain's picture
Upload folder using huggingface_hub
afa42f2 verified
raw history blame
No virus
2.03 kB
from PIL import Image
import torch
from transformers import (
AutoModelForImageClassification,
AutoImageProcessor,
Pipeline,
)
import numpy as np
from typing import Union
class SiglipTaggerPipe(Pipeline):
def __init__(self,**kwargs):
self.processor = AutoImageProcessor.from_pretrained("p1atdev/siglip-tagger-test-3")
if "torch_dtype" not in kwargs :
kwargs["torch_dtype"] = torch.bfloat16
Pipeline.__init__(self,**kwargs)
def _sanitize_parameters(self, **kwargs):
postprocess_kwargs = {}
if "threshold" in kwargs :
# if threshold parameter is present
# we pass it to the postprocess method
postprocess_kwargs["threshold"] = kwargs["threshold"]
if "return_scores" in kwargs :
postprocess_kwargs["return_scores"] = kwargs["return_scores"]
return {},{},postprocess_kwargs
def preprocess(self,inputs: Union[str,Image.Image,np.ndarray]):
if isinstance(inputs,str) :
img = Image.open(inputs)
elif isinstance(inputs,Image.Image) :
img = inputs
else :
# TODO: double check this implementation
# consider adding try except
# maybe add url checker too
img = Image.fromarray(inputs)
inputs = self.processor(img, return_tensors="pt").to(self.model.device, self.model.dtype)
return inputs
def _forward(self,inputs):
logits = self.model(**inputs).logits.detach().cpu().float()[0]
logits = np.clip(logits, 0.0, 1.0)
return logits
def postprocess(self,logits,threshold:float=0,return_scores=False):
results = {
self.model.config.id2label[i]: logit for i, logit in enumerate(logits) if logit > 0
}
results = sorted(results.items(), key=lambda x: x[1], reverse=True)
out = {}
for tag, score in results:
if score >= threshold :
out[tag] = f"{score*100:.2f}"
if return_scores == True :
return out
else :
return ", ".join(list(out.keys()))