krenzcolor_chkpt_classifier / pair_classification_pipeline.py
stupidog04's picture
commit files to HF hub
45c38e5
raw
history blame
No virus
2.48 kB
from torchvision import transforms
from transformers import ViTFeatureExtractor, ViTForImageClassification
from transformers import ImageClassificationPipeline
import torch
class PreTrainedPipeline():
def __init__(self, path):
"""
Initialize model
"""
# self.processor = feature_extractor = ViTFeatureExtractor.from_pretrained(model_flag)
model_flag = 'google/vit-base-patch16-224-in21k'
# model_flag = 'google/vit-base-patch16-384'
self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_flag)
self.model = ViTForImageClassification.from_pretrained(path)
self.pipe = PairClassificationPipeline(self.model, feature_extractor=self.feature_extractor)
def __call__(self, inputs):
"""
Args:
inputs (:obj:`np.array`):
The raw waveform of audio received. By default at 16KHz.
Return:
A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
the detected text from the input audio.
"""
# input_values = self.processor(inputs, return_tensors="pt", sampling_rate=self.sampling_rate).input_values # Batch size 1
# logits = self.model(input_values).logits.cpu().detach().numpy()[0]
return self.pipe(inputs)
class PairClassificationPipeline(ImageClassificationPipeline):
pipe_to_tensor = transforms.ToTensor()
pipe_to_pil = transforms.ToPILImage()
def preprocess(self, image):
left_image, right_image = self.horizontal_split_image(image)
model_inputs = self.extract_split_feature(left_image, right_image)
# model_inputs = super().preprocess(image)
# print(model_inputs['pixel_values'].shape)
return model_inputs
def horizontal_split_image(self, image):
# image = image.resize((448,224))
w, h = image.size
half_w = w//2
left_image = image.crop([0,0,half_w,h])
right_image = image.crop([half_w,0,2*half_w,h])
return left_image, right_image
def extract_split_feature(self, left_image, right_image):
model_inputs = self.feature_extractor(images=left_image, return_tensors=self.framework)
right_inputs = self.feature_extractor(images=right_image, return_tensors=self.framework)
model_inputs['pixel_values'] = torch.cat([model_inputs['pixel_values'],right_inputs['pixel_values']], dim=1)
return model_inputs