gagan3012 commited on
Commit
759ce14
1 Parent(s): 8630fe8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from transformers import (AutoTokenizer, VisionEncoderDecoderModel,
4
+ ViTFeatureExtractor)
5
+ from data_loaders import modify_dataset
6
+ import pandas as pd
7
+ from tqdm import tqdm
8
+ import gradio as gr
9
+
10
+ if torch.cuda.is_available():
11
+ device = "cuda"
12
+ else:
13
+ device = "cpu"
14
+
15
+ encoder_checkpoint = "google/vit-base-patch16-224-in21k"
16
+ decoder_checkpoint = "distilgpt2"
17
+ model_checkpoint = "gagan3012/ViTGPT2_vizwiz"
18
+ feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
19
+ tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
20
+ model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
21
+
22
+ def predict(image):
23
+ clean_text = lambda x: x.replace("<|endoftext|>", "").split("\n")[0]
24
+ sample = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
25
+ caption_ids = model.generate(sample, max_length=50)[0]
26
+ caption_text = clean_text(tokenizer.decode(caption_ids))
27
+ return caption_text
28
+
29
+ inputs = [
30
+ gr.inputs.Image(type="pil", label="Original Image")
31
+ ]
32
+
33
+ outputs = [
34
+ gr.outputs.Textbox(label = 'Caption')
35
+ ]
36
+
37
+ title = "Image Captioning using ViT + GPT2"
38
+ description = "ViT and GPT2 are used to generate Image Caption for the uploaded images"
39
+ article = " <a href='https://huggingface.co/gagan3012/ViTGPT2_vizwiz'>Model Repo on Hugging Face Model Hub</a>"
40
+ examples = [
41
+ ["people-walking-street-pedestrian-crossing-traffic-light-city.jpeg"],
42
+ ["elonmusk.jpeg"]
43
+ ]
44
+
45
+ gr.Interface(
46
+ predict,
47
+ inputs,
48
+ outputs,
49
+ title=title,
50
+ description=description,
51
+ article=article,
52
+ examples=examples,
53
+ theme="huggingface",
54
+ ).launch(debug=True, enable_queue=True)