ydshieh commited on
Commit
d28411b
1 Parent(s): 3568832

use real predict method

Browse files
Files changed (2) hide show
  1. app.py +3 -3
  2. model.py +1 -0
app.py CHANGED
@@ -12,7 +12,7 @@ st.write('\n')
12
  #show = st.image(image, use_column_width=True)
13
  #show.image(image, 'Preloaded Image', use_column_width=True)
14
 
15
- with st.spinner('Loading ViT-GPT2 model ...'):
16
 
17
  from model import *
18
  st.sidebar.write(f'Vit-GPT2 model loaded :)')
@@ -29,14 +29,14 @@ sample_path = os.path.join(sample_dir, sample_name)
29
 
30
  image = Image.open(sample_path)
31
  show = st.image(image, use_column_width=True)
32
- show.image(image, 'Selected Image', use_column_width=True)
33
 
34
  # For newline
35
  st.sidebar.write('\n')
36
 
37
  with st.spinner('Generating image caption ...'):
38
 
39
- caption = predict_dummy(image)
40
  image.close()
41
  st.success(f'caption: {caption}')
42
 
 
12
  #show = st.image(image, use_column_width=True)
13
  #show.image(image, 'Preloaded Image', use_column_width=True)
14
 
15
+ with st.spinner('Loading and compiling ViT-GPT2 model ...'):
16
 
17
  from model import *
18
  st.sidebar.write(f'Vit-GPT2 model loaded :)')
 
29
 
30
  image = Image.open(sample_path)
31
  show = st.image(image, use_column_width=True)
32
+ show.image(image, '\nSelected Image', use_column_width=True)
33
 
34
  # For newline
35
  st.sidebar.write('\n')
36
 
37
  with st.spinner('Generating image caption ...'):
38
 
39
+ caption = predict(image)
40
  image.close()
41
  st.success(f'caption: {caption}')
42
 
model.py CHANGED
@@ -52,6 +52,7 @@ def predict(image):
52
 
53
  token_ids = np.array(generation.sequences)[0]
54
  caption = tokenizer.decode(token_ids)
 
55
 
56
  return caption
57
 
 
52
 
53
  token_ids = np.array(generation.sequences)[0]
54
  caption = tokenizer.decode(token_ids)
55
+ caption = caption.replace('<s>', '').replace('</s>', '').replace('<pad>', '')
56
 
57
  return caption
58