Aanisha commited on
Commit
6b849d7
1 Parent(s): 28aad11

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -0
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from transformers import VisionEncoderDecoderModel,ViTFeatureExtractor,PreTrainedTokenizerFast,GPT2Tokenizer,AutoModelForCausalLM,AutoTokenizer
3
+ import requests
4
+ import gradio as gr
5
+ import torch
6
+ from transformers import pipeline
7
+ import re
8
+
9
+
10
+
11
+ description = "Just upload an image, and generate a story for the image! GPT-2 ins't perfect but it's fun to play with."
12
+ title = "Story generator from images using ViT and GPT2"
13
+
14
+
15
+ model = VisionEncoderDecoderModel.from_pretrained("gagan3012/ViTGPT2_vizwiz").to('cuda')
16
+ vit_feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
17
+ tokenizer = PreTrainedTokenizerFast.from_pretrained("distilgpt2")
18
+ story_gpt = AutoModelForCausalLM.from_pretrained("pranavpsv/gpt2-genre-story-generator")
19
+ st_tokenizer = AutoTokenizer.from_pretrained("pranavpsv/gpt2-genre-story-generator")
20
+
21
+ inputs = [
22
+ gr.inputs.Image(type="pil", label="Original Image")
23
+ ]
24
+
25
+ outputs = [
26
+ gr.outputs.Textbox(label = 'Story')
27
+ ]
28
+
29
+ examples = [['img_1.jpg','img_2.jpg']]
30
+
31
+ def get_output_senten(img):
32
+ pixel_values = vit_feature_extractor(images=img, return_tensors="pt").pixel_values.to('cuda')
33
+ encoder_outputs = model.generate(pixel_values.to('cuda'),num_beams=7)
34
+ generated_sentences = tokenizer.batch_decode(encoder_outputs)
35
+ senten = generated_sentences[0][generated_sentences[0][2:].index('>')+1:]
36
+
37
+ senten = senten.replace('>','')
38
+ senten = senten.replace('|','')
39
+ res = senten.split('.')[0][0:75]
40
+ res = res[0:res.rindex(' ')]
41
+
42
+ print(res)
43
+
44
+ tokenized_text=st_tokenizer.encode(res)
45
+ input_ids=torch.tensor(tokenized_text).view(-1,len(tokenized_text))
46
+ outputs=story_gpt.generate(input_ids,max_length=100,num_beams=5,no_repeat_ngram_size=2,early_stopping=True)
47
+
48
+ generated_story = st_tokenizer.batch_decode(outputs)
49
+
50
+ print(len(generated_story))
51
+ ans = generated_story[0]
52
+
53
+
54
+
55
+ ans = str(ans)
56
+ ind = ans.rindex('.')
57
+ ans = ans[0:ind+1]
58
+ return ans
59
+
60
+
61
+
62
+ gr.Interface(
63
+ get_output_senten,
64
+ inputs,
65
+ outputs,
66
+ examples = examples
67
+ title=title,
68
+ description=description,
69
+ theme="huggingface",
70
+ ).launch(enable_queue=True)