from torchvision import transforms from transformers import ViTFeatureExtractor, ViTForImageClassification from transformers import ImageClassificationPipeline import torch class PreTrainedPipeline(): def __init__(self, path): """ Initialize model """ # self.processor = feature_extractor = ViTFeatureExtractor.from_pretrained(model_flag) model_flag = 'google/vit-base-patch16-224-in21k' # model_flag = 'google/vit-base-patch16-384' self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_flag) self.model = ViTForImageClassification.from_pretrained(path) self.pipe = PairClassificationPipeline(self.model, feature_extractor=self.feature_extractor) def __call__(self, inputs): """ Args: inputs (:obj:`np.array`): The raw waveform of audio received. By default at 16KHz. Return: A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing the detected text from the input audio. """ # input_values = self.processor(inputs, return_tensors="pt", sampling_rate=self.sampling_rate).input_values # Batch size 1 # logits = self.model(input_values).logits.cpu().detach().numpy()[0] return self.pipe(inputs) class PairClassificationPipeline(ImageClassificationPipeline): pipe_to_tensor = transforms.ToTensor() pipe_to_pil = transforms.ToPILImage() def preprocess(self, image): left_image, right_image = self.horizontal_split_image(image) model_inputs = self.extract_split_feature(left_image, right_image) # model_inputs = super().preprocess(image) # print(model_inputs['pixel_values'].shape) return model_inputs def horizontal_split_image(self, image): # image = image.resize((448,224)) w, h = image.size half_w = w//2 left_image = image.crop([0,0,half_w,h]) right_image = image.crop([half_w,0,2*half_w,h]) return left_image, right_image def extract_split_feature(self, left_image, right_image): model_inputs = self.feature_extractor(images=left_image, return_tensors=self.framework) right_inputs = self.feature_extractor(images=right_image, return_tensors=self.framework) model_inputs['pixel_values'] = torch.cat([model_inputs['pixel_values'],right_inputs['pixel_values']], dim=1) return model_inputs