nishantguvvada commited on
Commit
2ec9390
·
1 Parent(s): f2edbd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -102
app.py CHANGED
@@ -1,53 +1,17 @@
1
  import streamlit as st
2
- import pickle
3
  import tensorflow as tf
4
- import cv2
5
  import numpy as np
6
- from PIL import Image, ImageOps
7
- import imageio.v3 as iio
8
- import time
9
- from textwrap import wrap
10
- import matplotlib.pylab as plt
11
- from tensorflow.keras import Input
12
- from tensorflow.keras.layers import (
13
- GRU,
14
- Add,
15
- AdditiveAttention,
16
- Attention,
17
- Concatenate,
18
- Dense,
19
- Embedding,
20
- LayerNormalization,
21
- Reshape,
22
- StringLookup,
23
- TextVectorization,
24
- )
25
 
26
- MAX_CAPTION_LEN = 64
27
- MINIMUM_SENTENCE_LENGTH = 5
28
- IMG_HEIGHT = 299
29
- IMG_WIDTH = 299
30
- IMG_CHANNELS = 3
31
- ATTENTION_DIM = 512 # size of dense layer in Attention
32
- VOCAB_SIZE = 20000
33
- FEATURES_SHAPE = (8, 8, 1536)
34
 
35
- @st.cache_resource()
36
- def load_image_model():
37
- image_model=tf.keras.models.load_model('./image_caption_model.h5')
38
- return image_model
39
 
40
- @st.cache_resource()
41
- def load_decoder_model():
42
- decoder_model=tf.keras.models.load_model('./decoder_pred_model.h5')
43
- return decoder_model
44
-
45
- @st.cache_resource()
46
- def load_encoder_model():
47
- encoder=tf.keras.models.load_model('./encoder_model.h5')
48
- return encoder
49
-
50
-
51
  st.title(":blue[Nishant Guvvada's] :red[AI Journey] Image Caption Generation")
52
  image = Image.open('./title.jpg')
53
  st.image(image)
@@ -56,74 +20,34 @@ st.write("""
56
  """
57
  )
58
 
59
- file = st.file_uploader("Upload any image and the model will try to provide a caption to it!", type= ['png', 'jpg'])
60
-
61
-
62
-
63
- # We will override the default standardization of TextVectorization to preserve
64
- # "<>" characters, so we preserve the tokens for the <start> and <end>.
65
- def standardize(inputs):
66
- inputs = tf.strings.lower(inputs)
67
- return tf.strings.regex_replace(
68
- inputs, r"[!\"#$%&\(\)\*\+.,-/:;=?@\[\\\]^_`{|}~]?", ""
69
- )
70
-
71
- # Choose the most frequent words from the vocabulary & remove punctuation etc.
72
- vocab = open('./tokenizer_vocab.txt', 'rb')
73
- tokenizer = pickle.load(vocab)
74
-
75
-
76
- # Lookup table: Word -> Index
77
- word_to_index = StringLookup(
78
- mask_token="", vocabulary=tokenizer
79
- )
80
-
81
-
82
- ## Probabilistic prediction using the trained model
83
- def predict_caption(file):
84
- filename = Image.open(file)
85
- image = filename.convert('RGB')
86
- image = np.array(image)
87
- gru_state = tf.zeros((1, ATTENTION_DIM))
88
 
89
- resize = tf.image.resize(image, (IMG_HEIGHT, IMG_WIDTH))
90
- img = resize/255
91
-
92
- encoder = load_encoder_model()
93
- features = encoder(tf.expand_dims(img, axis=0))
94
- dec_input = tf.expand_dims([word_to_index("<start>")], 1)
95
- result = []
96
- decoder_pred_model = load_decoder_model()
97
- for i in range(MAX_CAPTION_LEN):
98
- predictions, gru_state = decoder_pred_model(
99
- [dec_input, gru_state, features]
100
- )
101
 
102
- # draws from log distribution given by predictions
103
- top_probs, top_idxs = tf.math.top_k(
104
- input=predictions[0][0], k=10, sorted=False
105
- )
106
- chosen_id = tf.random.categorical([top_probs], 1)[0].numpy()
107
- predicted_id = top_idxs.numpy()[chosen_id][0]
108
 
109
- result.append(tokenizer[predicted_id])
 
110
 
111
- if predicted_id == word_to_index("<end>"):
112
- return img, result
113
 
114
- dec_input = tf.expand_dims([predicted_id], 1)
 
 
115
 
116
- return img, result
117
 
118
  def on_click():
119
  if file is None:
120
  st.text("Please upload an image file")
121
  else:
122
- image = Image.open(file)
123
- st.image(image, use_column_width=True)
124
- for i in range(5):
125
- image, caption = predict_caption(file)
126
- #print(" ".join(caption[:-1]) + ".")
127
- st.write(" ".join(caption[:-1]) + ".")
128
 
129
  st.button('Generate', on_click=on_click)
 
1
  import streamlit as st
 
2
  import tensorflow as tf
 
3
  import numpy as np
4
+ from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
5
+ import torch
6
+ from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
9
+ feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
10
+ tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
 
 
 
 
 
11
 
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ model.to(device)
 
 
14
 
 
 
 
 
 
 
 
 
 
 
 
15
  st.title(":blue[Nishant Guvvada's] :red[AI Journey] Image Caption Generation")
16
  image = Image.open('./title.jpg')
17
  st.image(image)
 
20
  """
21
  )
22
 
23
+ file = st.file_uploader("Upload an image to generate captions!", type= ['png', 'jpg'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ max_length = 16
26
+ num_beams = 4
27
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
28
+ def predict_step(image_paths):
29
+ images = []
30
+ for image_path in image_paths:
31
+ i_image = Image.open(image_path)
32
+ if i_image.mode != "RGB":
33
+ i_image = i_image.convert(mode="RGB")
 
 
 
34
 
35
+ images.append(i_image)
 
 
 
 
 
36
 
37
+ pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
38
+ pixel_values = pixel_values.to(device)
39
 
40
+ output_ids = model.generate(pixel_values, **gen_kwargs)
 
41
 
42
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
43
+ preds = [pred.strip() for pred in preds]
44
+ return preds
45
 
 
46
 
47
  def on_click():
48
  if file is None:
49
  st.text("Please upload an image file")
50
  else:
51
+ predict_step([file])
 
 
 
 
 
52
 
53
  st.button('Generate', on_click=on_click)