|
from torchvision import transforms |
|
from transformers import ViTFeatureExtractor, ViTForImageClassification |
|
from transformers import ImageClassificationPipeline |
|
import torch |
|
|
|
|
|
class PreTrainedPipeline(): |
|
def __init__(self, path): |
|
""" |
|
Initialize model |
|
""" |
|
|
|
model_flag = 'google/vit-base-patch16-224-in21k' |
|
|
|
self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_flag) |
|
self.model = ViTForImageClassification.from_pretrained(path) |
|
self.pipe = PairClassificationPipeline(self.model, feature_extractor=self.feature_extractor) |
|
|
|
def __call__(self, inputs): |
|
""" |
|
Args: |
|
inputs (:obj:`np.array`): |
|
The raw waveform of audio received. By default at 16KHz. |
|
Return: |
|
A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing |
|
the detected text from the input audio. |
|
""" |
|
|
|
|
|
return self.pipe(inputs) |
|
|
|
|
|
class PairClassificationPipeline(ImageClassificationPipeline): |
|
pipe_to_tensor = transforms.ToTensor() |
|
pipe_to_pil = transforms.ToPILImage() |
|
|
|
def preprocess(self, image): |
|
left_image, right_image = self.horizontal_split_image(image) |
|
model_inputs = self.extract_split_feature(left_image, right_image) |
|
|
|
|
|
return model_inputs |
|
|
|
def horizontal_split_image(self, image): |
|
|
|
w, h = image.size |
|
half_w = w//2 |
|
left_image = image.crop([0,0,half_w,h]) |
|
right_image = image.crop([half_w,0,2*half_w,h]) |
|
return left_image, right_image |
|
|
|
def extract_split_feature(self, left_image, right_image): |
|
model_inputs = self.feature_extractor(images=left_image, return_tensors=self.framework) |
|
right_inputs = self.feature_extractor(images=right_image, return_tensors=self.framework) |
|
model_inputs['pixel_values'] = torch.cat([right_inputs['pixel_values'], model_inputs['pixel_values']], dim=1) |
|
return model_inputs |