stupidog04's picture
Update pipeline.py
151724b
raw
history blame
No virus
2.48 kB
from torchvision import transforms
from transformers import ViTFeatureExtractor, ViTForImageClassification
from transformers import ImageClassificationPipeline
import torch
class PreTrainedPipeline():
def __init__(self, path):
"""
Initialize model
"""
# self.processor = feature_extractor = ViTFeatureExtractor.from_pretrained(model_flag)
model_flag = 'google/vit-base-patch16-224-in21k'
# model_flag = 'google/vit-base-patch16-384'
self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_flag)
self.model = ViTForImageClassification.from_pretrained(path)
self.pipe = PairClassificationPipeline(self.model, feature_extractor=self.feature_extractor)
def __call__(self, inputs):
"""
Args:
inputs (:obj:`np.array`):
The raw waveform of audio received. By default at 16KHz.
Return:
A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
the detected text from the input audio.
"""
# input_values = self.processor(inputs, return_tensors="pt", sampling_rate=self.sampling_rate).input_values # Batch size 1
# logits = self.model(input_values).logits.cpu().detach().numpy()[0]
return self.pipe(inputs)
class PairClassificationPipeline(ImageClassificationPipeline):
pipe_to_tensor = transforms.ToTensor()
pipe_to_pil = transforms.ToPILImage()
def preprocess(self, image):
left_image, right_image = self.horizontal_split_image(image)
model_inputs = self.extract_split_feature(left_image, right_image)
# model_inputs = super().preprocess(image)
# print(model_inputs['pixel_values'].shape)
return model_inputs
def horizontal_split_image(self, image):
# image = image.resize((448,224))
w, h = image.size
half_w = w//2
left_image = image.crop([0,0,half_w,h])
right_image = image.crop([half_w,0,2*half_w,h])
return left_image, right_image
def extract_split_feature(self, left_image, right_image):
model_inputs = self.feature_extractor(images=left_image, return_tensors=self.framework)
right_inputs = self.feature_extractor(images=right_image, return_tensors=self.framework)
model_inputs['pixel_values'] = torch.cat([model_inputs['pixel_values'],right_inputs['pixel_values']], dim=1)
return model_inputs