fbrynpk commited on
Commit
70d2bcf
1 Parent(s): 6437669

Update user interface to use trained models

Browse files
Files changed (1) hide show
  1. app.py +41 -0
app.py CHANGED
@@ -3,9 +3,50 @@ import os
3
  import streamlit as st
4
  import requests
5
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  st.title('Image-Caption-Generator')
8
  img_url = st.text_input(label='Enter an Image URL')
9
 
 
 
 
 
 
 
 
 
 
10
  st.markdown('<center style="opacity: 70%">OR</center>', unsafe_allow_html=True)
11
  img_upload = st.file_uploader(label='Upload Image', type=['jpg', 'png', 'jpeg'])
 
 
 
 
 
 
 
 
 
 
 
3
  import streamlit as st
4
  import requests
5
  from PIL import Image
6
+ from model import get_caption_model, generate_caption
7
+
8
+ @st.cache(allow_output_mutation=True)
9
+ def get_model():
10
+ return get_caption_model()
11
+
12
+ caption_model = get_model()
13
+
14
+ def predict():
15
+ captions = []
16
+ pred_caption = generate_caption('tmp.jpg', caption_model)
17
+
18
+ st.markdown('#### Predicted Captions:')
19
+ captions.append(pred_caption)
20
+
21
+ for _ in range(4):
22
+ pred_caption = generate_caption('tmp.jpg', caption_model, add_noise=True)
23
+ if pred_caption not in captions:
24
+ captions.append(pred_caption)
25
+
26
+ for c in captions:
27
+ st.write(c)
28
 
29
  st.title('Image-Caption-Generator')
30
  img_url = st.text_input(label='Enter an Image URL')
31
 
32
+ if (img_url != "") and (img_url != None):
33
+ img = Image.open(requests.get(img_url, stream=True).raw)
34
+ img = img.convert('RGB')
35
+ st.image(img)
36
+ img.save('tmp.jpg')
37
+ predict()
38
+ os.remove('tmp.jpg')
39
+
40
+
41
  st.markdown('<center style="opacity: 70%">OR</center>', unsafe_allow_html=True)
42
  img_upload = st.file_uploader(label='Upload Image', type=['jpg', 'png', 'jpeg'])
43
+
44
+ if img_upload != None:
45
+ img = img_upload.read()
46
+ img = Image.open(io.BytesIO(img))
47
+ img = img.convert('RGB')
48
+ img.save('tmp.jpg')
49
+ st.image(img)
50
+ predict()
51
+ os.remove('tmp.jpg')
52
+