File size: 2,126 Bytes
6b849d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0ab090
6b849d7
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from PIL import Image
from transformers import VisionEncoderDecoderModel,ViTFeatureExtractor,PreTrainedTokenizerFast,GPT2Tokenizer,AutoModelForCausalLM,AutoTokenizer
import requests
import gradio as gr
import torch
from transformers import pipeline
import re



description = "Just upload an image, and generate a story for the image! GPT-2 ins't perfect but it's fun to play with."
title = "Story generator from images using ViT and GPT2"


model = VisionEncoderDecoderModel.from_pretrained("gagan3012/ViTGPT2_vizwiz").to('cuda')
vit_feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
tokenizer = PreTrainedTokenizerFast.from_pretrained("distilgpt2")
story_gpt = AutoModelForCausalLM.from_pretrained("pranavpsv/gpt2-genre-story-generator")
st_tokenizer = AutoTokenizer.from_pretrained("pranavpsv/gpt2-genre-story-generator")

inputs = [
    gr.inputs.Image(type="pil", label="Original Image")
]

outputs = [
    gr.outputs.Textbox(label = 'Story')
]

examples = [['img_1.jpg','img_2.jpg']]

def get_output_senten(img):
  pixel_values = vit_feature_extractor(images=img, return_tensors="pt").pixel_values.to('cuda')
  encoder_outputs = model.generate(pixel_values.to('cuda'),num_beams=7)
  generated_sentences = tokenizer.batch_decode(encoder_outputs)
  senten = generated_sentences[0][generated_sentences[0][2:].index('>')+1:]

  senten = senten.replace('>','')
  senten = senten.replace('|','')
  res = senten.split('.')[0][0:75]
  res = res[0:res.rindex(' ')]
  
  print(res)

  tokenized_text=st_tokenizer.encode(res)
  input_ids=torch.tensor(tokenized_text).view(-1,len(tokenized_text))
  outputs=story_gpt.generate(input_ids,max_length=100,num_beams=5,no_repeat_ngram_size=2,early_stopping=True)

  generated_story = st_tokenizer.batch_decode(outputs)

  print(len(generated_story))
  ans = generated_story[0]

  

  ans = str(ans)
  ind = ans.rindex('.')
  ans = ans[0:ind+1]
  return ans



gr.Interface(
    get_output_senten,
    inputs,
    outputs,
    examples = examples,
    title=title,
    description=description,
    theme="huggingface",
).launch(enable_queue=True)