zamborg commited on
Commit
4dab50d
1 Parent(s): 7d1df38

updated code

Browse files
Files changed (2) hide show
  1. app.py +56 -49
  2. model.py +5 -4
app.py CHANGED
@@ -1,59 +1,66 @@
1
  import streamlit as st
2
  import io
3
 
4
- # x = st.slider("Select a value")
5
- # st.write(x, "squared is", x * x)
6
-
7
- st.title("Image Captioning Demo from Redcaps")
8
- st.sidebar.markdown(
9
- """
10
- Image Captioning Model from VirTex trained on Redcaps
11
- """
12
- )
13
-
14
- with st.spinner("Loading Model"):
15
- from model import *
16
- sample_images = glob.glob("./samples/*.jpg")
17
- download_files()
18
- virtexModel = VirTexModel()
19
- imageLoader = ImageLoader()
20
-
21
- random_image = get_rand_img(sample_images)
22
-
23
- st.sidebar.title("Select a sample image")
24
- sample_image = st.sidebar.selectbox(
25
- "",
26
- sample_images
27
- )
28
-
29
- if st.sidebar.button("Random Sample Image"):
30
- random_image = get_rand_img(sample_images)
31
- sample_image = None
32
 
33
- uploaded_image = None
34
- with st.sidebar.form("file-uploader-form", clear_on_submit=True):
35
- uploaded_file = st.file_uploader("Choose a file")
36
- submitted = st.form_submit_button("Submit")
37
- if uploaded_file is not None and submitted:
38
- uploaded_image = Image.open(io.BytesIO(uploaded_file.get_values()))
39
-
40
- if uploaded_image is None and submitted:
41
- st.write("Please select a file to upload")
42
-
43
- else:
44
- image_file = sample_image if sample_image is not None else random_image
45
 
46
- image = uploaded_image if uploaded_image is not None else Image.open()
47
 
48
- image_dict = imageLoader.transform(image)
49
 
50
- show.image(st.image(image_dict["image"]), "Target Image")
51
 
52
- with st.spinner("Generating Caption"):
53
- subreddit, caption = virtexModel.predict(image_dict)
54
- st.header("Predicted Caption:\n\n")
55
- st.subheader(f"Subreddit: {subreddit}\n")
56
- st.subheader(f"Caption: {caption}\n")
57
 
58
- image.close()
 
 
 
 
 
59
 
 
 
 
 
 
1
  import streamlit as st
2
  import io
3
 
4
+ # st.title("Image Captioning Demo from Redcaps")
5
+ # st.sidebar.markdown(
6
+ # """
7
+ # Image Captioning Model from VirTex trained on Redcaps
8
+ # """
9
+ # )
10
+
11
+ # with st.spinner("Loading Model"):
12
+ # from model import *
13
+ # sample_images = glob.glob("./samples/*.jpg")
14
+ # download_files()
15
+ # virtexModel = VirTexModel()
16
+ # imageLoader = ImageLoader()
17
+
18
+ # random_image = get_rand_img(sample_images)
19
+
20
+ # st.sidebar.title("Select a sample image")
21
+ # sample_image = st.sidebar.selectbox(
22
+ # "",
23
+ # sample_images
24
+ # )
25
+
26
+ # if st.sidebar.button("Random Sample Image"):
27
+ # random_image = get_rand_img(sample_images)
28
+ # sample_image = None
 
 
 
29
 
30
+ # uploaded_image = None
31
+ # with st.sidebar.form("file-uploader-form", clear_on_submit=True):
32
+ # uploaded_file = st.file_uploader("Choose a file")
33
+ # submitted = st.form_submit_button("Submit")
34
+ # if uploaded_file is not None and submitted:
35
+ # uploaded_image = Image.open(io.BytesIO(uploaded_file.get_values()))
36
+
37
+ # if uploaded_image is None and submitted:
38
+ # st.write("Please select a file to upload")
39
+
40
+ # else:
41
+ # image_file = sample_image if sample_image is not None else random_image
42
 
43
+ # image = uploaded_image if uploaded_image is not None else Image.open()
44
 
45
+ # image_dict = imageLoader.transform(image)
46
 
47
+ # show.image(st.image(image_dict["image"]), "Target Image")
48
 
49
+ # with st.spinner("Generating Caption"):
50
+ # subreddit, caption = virtexModel.predict(image_dict)
51
+ # st.header("Predicted Caption:\n\n")
52
+ # st.subheader(f"Subreddit: {subreddit}\n")
53
+ # st.subheader(f"Caption: {caption}\n")
54
 
55
+ # image.close()
56
+
57
+ from model import *
58
+ download_files()
59
+ sample_images = get_samples()
60
+ v, il = VirTexModel(), ImageLoader()
61
 
62
+ for s in sample_images:
63
+ subreddit, caption = v.predict(il.load(s))
64
+ print("=====================")
65
+ print(subreddit)
66
+ print(caption)
model.py CHANGED
@@ -24,12 +24,12 @@ class ImageLoader():
24
  self.transformer = torchvision.transforms.Compose([torchvision.transforms.Resize(256),
25
  torchvision.transforms.CenterCrop(224),
26
  torchvision.transforms.ToTensor()])
27
- def load(self, im_path, prompt = ""):
28
  im = torch.FloatTensor(self.transformer(Image.open(im_path))).unsqueeze(0)
29
- return {"image": im, "decode_prompt": prompt}
30
- def transform(self, image, prompt = ""):
31
  im = torch.FloatTensor(self.transformer(image)).unsqueeze(0)
32
- return {"image": im, "decode_prompt": prompt}
33
 
34
  class VirTexModel():
35
  def __init__(self):
@@ -51,6 +51,7 @@ class VirTexModel():
51
 
52
  is_valid_subreddit = False
53
  subreddit, rest_of_caption = "", ""
 
54
  while not is_valid_subreddit:
55
 
56
  with torch.no_grad():
24
  self.transformer = torchvision.transforms.Compose([torchvision.transforms.Resize(256),
25
  torchvision.transforms.CenterCrop(224),
26
  torchvision.transforms.ToTensor()])
27
+ def load(self, im_path):
28
  im = torch.FloatTensor(self.transformer(Image.open(im_path))).unsqueeze(0)
29
+ return {"image": im}
30
+ def transform(self, image):
31
  im = torch.FloatTensor(self.transformer(image)).unsqueeze(0)
32
+ return {"image": im}
33
 
34
  class VirTexModel():
35
  def __init__(self):
51
 
52
  is_valid_subreddit = False
53
  subreddit, rest_of_caption = "", ""
54
+ image_dict["decode_prompt"] = subreddit_tokens
55
  while not is_valid_subreddit:
56
 
57
  with torch.no_grad():