Spaces:
Runtime error
Runtime error
from PIL import Image | |
from transformers import VisionEncoderDecoderModel,ViTFeatureExtractor,PreTrainedTokenizerFast,GPT2Tokenizer,AutoModelForCausalLM,AutoTokenizer | |
import requests | |
import gradio as gr | |
import torch | |
from transformers import pipeline | |
import re | |
description = "<h4>Just upload an image, and generate a story for the image! GPT-2 is not perfect but it's fun to play with.</h4>" | |
title = "<h1>Story generator from images using ViT and GPT2</h1>" | |
model = VisionEncoderDecoderModel.from_pretrained("gagan3012/ViTGPT2_vizwiz").to('cpu') | |
vit_feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") | |
tokenizer = PreTrainedTokenizerFast.from_pretrained("distilgpt2") | |
story_gpt = AutoModelForCausalLM.from_pretrained("pranavpsv/gpt2-genre-story-generator") | |
st_tokenizer = AutoTokenizer.from_pretrained("pranavpsv/gpt2-genre-story-generator") | |
inputs = [ | |
gr.inputs.Image(type="pil", label="Original Image") | |
] | |
outputs = [ | |
gr.outputs.Textbox(label = 'Story') | |
] | |
examples = [['img_1.jpg','img_2.jpg']] | |
def get_output_senten(img): | |
pixel_values = vit_feature_extractor(images=img, return_tensors="pt").pixel_values.to('cpu') | |
encoder_outputs = model.generate(pixel_values.to('cpu'),num_beams=7) | |
generated_sentences = tokenizer.batch_decode(encoder_outputs) | |
senten = generated_sentences[0][generated_sentences[0][2:].index('>')+1:] | |
senten = senten.replace('>','') | |
senten = senten.replace('|','') | |
res = senten.split('.')[0][0:75] | |
res = res[0:res.rindex(' ')] | |
print(res) | |
tokenized_text=st_tokenizer.encode(res) | |
input_ids=torch.tensor(tokenized_text).view(-1,len(tokenized_text)) | |
outputs=story_gpt.generate(input_ids,max_length=100,num_beams=5,no_repeat_ngram_size=2,early_stopping=True) | |
generated_story = st_tokenizer.batch_decode(outputs) | |
print(len(generated_story)) | |
ans = generated_story[0] | |
ans = str(ans) | |
ind = ans.rindex('.') | |
ans = ans[0:ind+1] | |
return ans | |
gr.Interface( | |
get_output_senten, | |
inputs, | |
outputs, | |
examples = examples, | |
title=title, | |
description=description, | |
theme="huggingface", | |
).launch(enable_queue=True) |