Spaces:
Runtime error
Runtime error
File size: 2,126 Bytes
6b849d7 e0ab090 6b849d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
from PIL import Image
from transformers import VisionEncoderDecoderModel,ViTFeatureExtractor,PreTrainedTokenizerFast,GPT2Tokenizer,AutoModelForCausalLM,AutoTokenizer
import requests
import gradio as gr
import torch
from transformers import pipeline
import re
description = "Just upload an image, and generate a story for the image! GPT-2 ins't perfect but it's fun to play with."
title = "Story generator from images using ViT and GPT2"
model = VisionEncoderDecoderModel.from_pretrained("gagan3012/ViTGPT2_vizwiz").to('cuda')
vit_feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
tokenizer = PreTrainedTokenizerFast.from_pretrained("distilgpt2")
story_gpt = AutoModelForCausalLM.from_pretrained("pranavpsv/gpt2-genre-story-generator")
st_tokenizer = AutoTokenizer.from_pretrained("pranavpsv/gpt2-genre-story-generator")
inputs = [
gr.inputs.Image(type="pil", label="Original Image")
]
outputs = [
gr.outputs.Textbox(label = 'Story')
]
examples = [['img_1.jpg','img_2.jpg']]
def get_output_senten(img):
pixel_values = vit_feature_extractor(images=img, return_tensors="pt").pixel_values.to('cuda')
encoder_outputs = model.generate(pixel_values.to('cuda'),num_beams=7)
generated_sentences = tokenizer.batch_decode(encoder_outputs)
senten = generated_sentences[0][generated_sentences[0][2:].index('>')+1:]
senten = senten.replace('>','')
senten = senten.replace('|','')
res = senten.split('.')[0][0:75]
res = res[0:res.rindex(' ')]
print(res)
tokenized_text=st_tokenizer.encode(res)
input_ids=torch.tensor(tokenized_text).view(-1,len(tokenized_text))
outputs=story_gpt.generate(input_ids,max_length=100,num_beams=5,no_repeat_ngram_size=2,early_stopping=True)
generated_story = st_tokenizer.batch_decode(outputs)
print(len(generated_story))
ans = generated_story[0]
ans = str(ans)
ind = ans.rindex('.')
ans = ans[0:ind+1]
return ans
gr.Interface(
get_output_senten,
inputs,
outputs,
examples = examples,
title=title,
description=description,
theme="huggingface",
).launch(enable_queue=True) |