from PIL import Image from transformers import VisionEncoderDecoderModel,ViTFeatureExtractor,PreTrainedTokenizerFast,GPT2Tokenizer,AutoModelForCausalLM,AutoTokenizer import requests import gradio as gr import torch from transformers import pipeline import re description = "Just upload an image, and generate a short story for the image.\n PS: GPT-2 is not perfect but it's fun to play with.May take a minute for the output to generate. Enjoyy!!!" title = "Story generator from images using ViT and GPT2" model = VisionEncoderDecoderModel.from_pretrained("gagan3012/ViTGPT2_vizwiz").to('cpu') vit_feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") tokenizer = PreTrainedTokenizerFast.from_pretrained("distilgpt2") story_gpt = AutoModelForCausalLM.from_pretrained("pranavpsv/gpt2-genre-story-generator") st_tokenizer = AutoTokenizer.from_pretrained("pranavpsv/gpt2-genre-story-generator") inputs = [ gr.inputs.Image(type="pil", label="Original Image") ] outputs = [ gr.outputs.Textbox(label = 'Story') ] examples = [['img_1.jpg'],['img_2.jpg']] def get_output_senten(img): pixel_values = vit_feature_extractor(images=img, return_tensors="pt").pixel_values.to('cpu') encoder_outputs = model.generate(pixel_values.to('cpu'),num_beams=7) generated_sentences = tokenizer.batch_decode(encoder_outputs) senten = generated_sentences[0][generated_sentences[0][2:].index('>')+1:] senten = senten.replace('>','') senten = senten.replace('|','') res = senten.split('.')[0][0:75] res = res[0:res.rindex(' ')] print(res) tokenized_text=st_tokenizer.encode(res) input_ids=torch.tensor(tokenized_text).view(-1,len(tokenized_text)) outputs=story_gpt.generate(input_ids,max_length=100,num_beams=5,no_repeat_ngram_size=2,early_stopping=True) generated_story = st_tokenizer.batch_decode(outputs) print(len(generated_story)) ans = generated_story[0] ans = str(ans) ind = ans.rindex('.') ans = ans[0:ind+1] return ans gr.Interface( get_output_senten, inputs, outputs, examples = examples, title=title, description=description, theme="huggingface", ).launch(enable_queue=True)