Spaces:
Runtime error
Runtime error
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from transformers import VisionEncoderDecoderModel,ViTFeatureExtractor,PreTrainedTokenizerFast,GPT2Tokenizer,AutoModelForCausalLM,AutoTokenizer
|
3 |
+
import requests
|
4 |
+
import gradio as gr
|
5 |
+
import torch
|
6 |
+
from transformers import pipeline
|
7 |
+
import re
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
description = "Just upload an image, and generate a story for the image! GPT-2 ins't perfect but it's fun to play with."
|
12 |
+
title = "Story generator from images using ViT and GPT2"
|
13 |
+
|
14 |
+
|
15 |
+
model = VisionEncoderDecoderModel.from_pretrained("gagan3012/ViTGPT2_vizwiz").to('cuda')
|
16 |
+
vit_feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
|
17 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained("distilgpt2")
|
18 |
+
story_gpt = AutoModelForCausalLM.from_pretrained("pranavpsv/gpt2-genre-story-generator")
|
19 |
+
st_tokenizer = AutoTokenizer.from_pretrained("pranavpsv/gpt2-genre-story-generator")
|
20 |
+
|
21 |
+
inputs = [
|
22 |
+
gr.inputs.Image(type="pil", label="Original Image")
|
23 |
+
]
|
24 |
+
|
25 |
+
outputs = [
|
26 |
+
gr.outputs.Textbox(label = 'Story')
|
27 |
+
]
|
28 |
+
|
29 |
+
examples = [['img_1.jpg','img_2.jpg']]
|
30 |
+
|
31 |
+
def get_output_senten(img):
|
32 |
+
pixel_values = vit_feature_extractor(images=img, return_tensors="pt").pixel_values.to('cuda')
|
33 |
+
encoder_outputs = model.generate(pixel_values.to('cuda'),num_beams=7)
|
34 |
+
generated_sentences = tokenizer.batch_decode(encoder_outputs)
|
35 |
+
senten = generated_sentences[0][generated_sentences[0][2:].index('>')+1:]
|
36 |
+
|
37 |
+
senten = senten.replace('>','')
|
38 |
+
senten = senten.replace('|','')
|
39 |
+
res = senten.split('.')[0][0:75]
|
40 |
+
res = res[0:res.rindex(' ')]
|
41 |
+
|
42 |
+
print(res)
|
43 |
+
|
44 |
+
tokenized_text=st_tokenizer.encode(res)
|
45 |
+
input_ids=torch.tensor(tokenized_text).view(-1,len(tokenized_text))
|
46 |
+
outputs=story_gpt.generate(input_ids,max_length=100,num_beams=5,no_repeat_ngram_size=2,early_stopping=True)
|
47 |
+
|
48 |
+
generated_story = st_tokenizer.batch_decode(outputs)
|
49 |
+
|
50 |
+
print(len(generated_story))
|
51 |
+
ans = generated_story[0]
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
ans = str(ans)
|
56 |
+
ind = ans.rindex('.')
|
57 |
+
ans = ans[0:ind+1]
|
58 |
+
return ans
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
gr.Interface(
|
63 |
+
get_output_senten,
|
64 |
+
inputs,
|
65 |
+
outputs,
|
66 |
+
examples = examples
|
67 |
+
title=title,
|
68 |
+
description=description,
|
69 |
+
theme="huggingface",
|
70 |
+
).launch(enable_queue=True)
|