m3's picture
chore: add readme
8e8cbdc
from typing import List, Optional, Union
from torchvision import transforms
from PIL import Image
from transformers.image_processing_utils import BaseImageProcessor
from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoImageProcessor, AutoModel
import os
from huggingface_hub import hf_hub_download
import torch
import torch.nn as nn
from transformers.pipelines import PIPELINE_REGISTRY
from transformers.utils import add_end_docstrings
from transformers.pipelines.base import Pipeline, build_pipeline_init_args
class SscdImageProcessor(BaseImageProcessor):
def __init__(
self,
do_resize: bool = True,
size: int = 288,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.size = size
self.image_mean = image_mean if image_mean is not None else [0.485, 0.456, 0.406]
self.image_std = image_std if image_std is not None else [0.229, 0.224, 0.225]
self.do_convert_rgb = do_convert_rgb
self.do_resize = do_resize
def preprocess(
self,
image: Image,
do_resize: bool = None,
**kwargs,
):
size_transforms = [
transforms.ToTensor(),
transforms.Normalize(
mean=self.image_mean, std=self.image_std,
),
]
if do_resize is None:
do_resize = self.do_resize
if do_resize:
size_transforms.append(transforms.Resize(self.size))
preprocess = transforms.Compose([
transforms.Resize(self.size),
transforms.ToTensor(),
transforms.Normalize(
mean=self.image_mean, std=self.image_std,
),
])
if self.do_convert_rgb:
image = image.convert('RGB')
return preprocess(image).unsqueeze(0)
class SscdConfig(PretrainedConfig):
model_type = 'sscd-copy-detection'
def __init__(self, model_path: str = None, **kwargs):
if model_path is None:
model_path = 'sscd_disc_mixup.torchscript.pt'
super().__init__(model_path=model_path, **kwargs)
class SscdModel(PreTrainedModel):
config_class = SscdConfig
def __init__(self, config, model_path: str = None):
super().__init__(config)
self.dummy_param = nn.Parameter(torch.zeros(0))
if model_path is None:
model_path = config.model_path
is_local = os.path.isdir(config.name_or_path)
if is_local:
config.base_path = config.name_or_path
else:
file_path = hf_hub_download(repo_id=config.name_or_path, filename=model_path)
config.base_path = os.path.dirname(file_path)
model_path = config.base_path + '/' + model_path
if model_path is not None:
self.model = torch.jit.load(model_path)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
return cls(AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs))
def forward(self, inputs):
return self.model(inputs)[0, :]
@add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
class SscdPipeline(Pipeline):
def __init__(self, model, **kwargs):
self.device_id = kwargs['device']
super().__init__(model=model, **kwargs)
def _sanitize_parameters(self, **kwargs):
return {}, {}, {}
def preprocess(self, input):
return self.image_processor.preprocess(input)
def _forward(self, inputs):
return self.model(inputs)
def postprocess(self, model_outputs):
return model_outputs
AutoConfig.register('sscd-copy-detection', SscdConfig)
AutoModel.register(SscdConfig, SscdModel)
AutoImageProcessor.register(SscdConfig, slow_image_processor_class=SscdImageProcessor)
models = AutoModel.from_pretrained('m3/sscd-copy-detection')
PIPELINE_REGISTRY.register_pipeline(
task='sscd-copy-detection',
pipeline_class=SscdPipeline,
pt_model=SscdModel
)