IAB_VIDEO_AD_CLASSIFIER / image_caption.py
arjunanand13's picture
Upload 10 files
6a6f954 verified
raw
history blame
No virus
2.03 kB
import argparse
from pathlib import Path
import os
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
import torch
from PIL import Image
class Caption:
def __init__(self):
self.model = VisionEncoderDecoderModel.from_pretrained(
"nlpconnect/vit-gpt2-image-captioning"
)
self.feature_extractor = ViTImageProcessor.from_pretrained(
"nlpconnect/vit-gpt2-image-captioning"
)
self.tokenizer = AutoTokenizer.from_pretrained(
"nlpconnect/vit-gpt2-image-captioning"
)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = torch.device("cpu")
self.model.to(self.device)
self.max_length = 16
self.num_beams = 4
self.gen_kwargs = {"max_length": self.max_length, "num_beams": self.num_beams}
def predict_step(self,image_paths):
images = []
for image_path in image_paths:
i_image = Image.open(image_path)
if i_image.mode != "RGB":
i_image = i_image.convert(mode="RGB")
images.append(i_image)
pixel_values = self.feature_extractor(images=images, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(self.device)
output_ids = self.model.generate(pixel_values, **self.gen_kwargs)
preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
return preds
def get_args(self):
parser = argparse.ArgumentParser()
parser.add_argument( "-i",
"--input_img_paths",
type=str,
default="farmer.jpg",
help="img for caption")
args = parser.parse_args()
return args
if __name__ == "__main__":
model = Caption()
args = model.get_args()
image_paths = []
image_paths.append(args.input_img_paths)
print(model.predict_step(image_paths))