|
from typing import List, Optional, Union |
|
from torchvision import transforms |
|
from PIL import Image |
|
from transformers.image_processing_utils import BaseImageProcessor |
|
from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoImageProcessor, AutoModel |
|
import os |
|
from huggingface_hub import hf_hub_download |
|
import torch |
|
import torch.nn as nn |
|
from transformers.pipelines import PIPELINE_REGISTRY |
|
from transformers.utils import add_end_docstrings |
|
from transformers.pipelines.base import Pipeline, build_pipeline_init_args |
|
class SscdImageProcessor(BaseImageProcessor): |
|
def __init__( |
|
self, |
|
do_resize: bool = True, |
|
size: int = 288, |
|
image_mean: Optional[Union[float, List[float]]] = None, |
|
image_std: Optional[Union[float, List[float]]] = None, |
|
do_convert_rgb: bool = True, |
|
**kwargs, |
|
) -> None: |
|
super().__init__(**kwargs) |
|
self.size = size |
|
self.image_mean = image_mean if image_mean is not None else [0.485, 0.456, 0.406] |
|
self.image_std = image_std if image_std is not None else [0.229, 0.224, 0.225] |
|
self.do_convert_rgb = do_convert_rgb |
|
self.do_resize = do_resize |
|
|
|
def preprocess( |
|
self, |
|
image: Image, |
|
do_resize: bool = None, |
|
**kwargs, |
|
): |
|
size_transforms = [ |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=self.image_mean, std=self.image_std, |
|
), |
|
] |
|
if do_resize is None: |
|
do_resize = self.do_resize |
|
if do_resize: |
|
size_transforms.append(transforms.Resize(self.size)) |
|
preprocess = transforms.Compose([ |
|
transforms.Resize(self.size), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=self.image_mean, std=self.image_std, |
|
), |
|
]) |
|
if self.do_convert_rgb: |
|
image = image.convert('RGB') |
|
return preprocess(image).unsqueeze(0) |
|
|
|
|
|
class SscdConfig(PretrainedConfig): |
|
model_type = 'sscd-copy-detection' |
|
|
|
def __init__(self, model_path: str = None, **kwargs): |
|
if model_path is None: |
|
model_path = 'sscd_disc_mixup.torchscript.pt' |
|
super().__init__(model_path=model_path, **kwargs) |
|
|
|
|
|
class SscdModel(PreTrainedModel): |
|
config_class = SscdConfig |
|
|
|
def __init__(self, config, model_path: str = None): |
|
super().__init__(config) |
|
self.dummy_param = nn.Parameter(torch.zeros(0)) |
|
if model_path is None: |
|
model_path = config.model_path |
|
is_local = os.path.isdir(config.name_or_path) |
|
if is_local: |
|
config.base_path = config.name_or_path |
|
else: |
|
file_path = hf_hub_download(repo_id=config.name_or_path, filename=model_path) |
|
config.base_path = os.path.dirname(file_path) |
|
model_path = config.base_path + '/' + model_path |
|
if model_path is not None: |
|
self.model = torch.jit.load(model_path) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
return cls(AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)) |
|
|
|
def forward(self, inputs): |
|
return self.model(inputs)[0, :] |
|
|
|
|
|
|
|
@add_end_docstrings(build_pipeline_init_args(has_image_processor=True)) |
|
class SscdPipeline(Pipeline): |
|
def __init__(self, model, **kwargs): |
|
self.device_id = kwargs['device'] |
|
super().__init__(model=model, **kwargs) |
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
return {}, {}, {} |
|
|
|
def preprocess(self, input): |
|
return self.image_processor.preprocess(input) |
|
|
|
def _forward(self, inputs): |
|
return self.model(inputs) |
|
|
|
def postprocess(self, model_outputs): |
|
return model_outputs |
|
|
|
|
|
AutoConfig.register('sscd-copy-detection', SscdConfig) |
|
AutoModel.register(SscdConfig, SscdModel) |
|
AutoImageProcessor.register(SscdConfig, slow_image_processor_class=SscdImageProcessor) |
|
models = AutoModel.from_pretrained('m3/sscd-copy-detection') |
|
|
|
PIPELINE_REGISTRY.register_pipeline( |
|
task='sscd-copy-detection', |
|
pipeline_class=SscdPipeline, |
|
pt_model=SscdModel |
|
) |
|
|