import os import torch import gradio as gr from transformers import VisionEncoderDecoderModel, AutoFeatureExtractor, AutoTokenizer def create_caption_transformer(img): """ create_caption_transformer() create a caption for an image using a transformer model that was trained on 'Flickr image dataset' :param img: a numpy array of the image :return: a string of the image caption """ sample = feature_extractor(img, return_tensors="pt").pixel_values.to('cpu') caption_ids = model.generate(sample, max_length=15)[0] # TODO: take care of the caption length caption_text = tokenizer.decode(caption_ids, skip_special_tokens=True) caption_text = caption_text.split('.')[0] return caption_text IMAGES_EXAMPLES_FOLDER = 'examples/' images = os.listdir(IMAGES_EXAMPLES_FOLDER) IMAGES_EXAMPLES = [IMAGES_EXAMPLES_FOLDER + img for img in images] model = VisionEncoderDecoderModel.from_pretrained(os.getcwd()).to('cpu') feature_extractor = AutoFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k') tokenizer = AutoTokenizer.from_pretrained('gpt2') iface = gr.Interface(fn=create_caption_transformer, inputs="image", outputs='text', examples=IMAGES_EXAMPLES ).launch(share=True)