Image_to_story / app.py
Aanisha's picture
Update app.py
745910c
from PIL import Image
from transformers import VisionEncoderDecoderModel,ViTFeatureExtractor,PreTrainedTokenizerFast,GPT2Tokenizer,AutoModelForCausalLM,AutoTokenizer
import requests
import gradio as gr
import torch
from transformers import pipeline
import re
description = "Just upload an image, and generate a short story for the image.\n PS: GPT-2 is not perfect but it's fun to play with.May take a minute for the output to generate. Enjoyy!!!"
title = "Story generator from images using ViT and GPT2"
model = VisionEncoderDecoderModel.from_pretrained("gagan3012/ViTGPT2_vizwiz").to('cpu')
vit_feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
tokenizer = PreTrainedTokenizerFast.from_pretrained("distilgpt2")
story_gpt = AutoModelForCausalLM.from_pretrained("pranavpsv/gpt2-genre-story-generator")
st_tokenizer = AutoTokenizer.from_pretrained("pranavpsv/gpt2-genre-story-generator")
inputs = [
gr.inputs.Image(type="pil", label="Original Image")
]
outputs = [
gr.outputs.Textbox(label = 'Story')
]
examples = [['img_1.jpg'],['img_2.jpg']]
def get_output_senten(img):
pixel_values = vit_feature_extractor(images=img, return_tensors="pt").pixel_values.to('cpu')
encoder_outputs = model.generate(pixel_values.to('cpu'),num_beams=7)
generated_sentences = tokenizer.batch_decode(encoder_outputs)
senten = generated_sentences[0][generated_sentences[0][2:].index('>')+1:]
senten = senten.replace('>','')
senten = senten.replace('|','')
res = senten.split('.')[0][0:75]
res = res[0:res.rindex(' ')]
print(res)
tokenized_text=st_tokenizer.encode(res)
input_ids=torch.tensor(tokenized_text).view(-1,len(tokenized_text))
outputs=story_gpt.generate(input_ids,max_length=100,num_beams=5,no_repeat_ngram_size=2,early_stopping=True)
generated_story = st_tokenizer.batch_decode(outputs)
print(len(generated_story))
ans = generated_story[0]
ans = str(ans)
ind = ans.rindex('.')
ans = ans[0:ind+1]
return ans
gr.Interface(
get_output_senten,
inputs,
outputs,
examples = examples,
title=title,
description=description,
theme="huggingface",
).launch(enable_queue=True)