zamborg commited on
Commit
b58ad35
1 Parent(s): 65193db

default to pics for subreddit sampling

Browse files
Files changed (2) hide show
  1. app.py +62 -62
  2. model.py +13 -9
app.py CHANGED
@@ -9,7 +9,7 @@ from model import *
9
  def gen_show_caption(sub_prompt=None, cap_prompt = ""):
10
  with st.spinner("Generating Caption"):
11
  if sub_prompt is None and cap_prompt is not "":
12
- st.write("Without a specified subreddit, caption prompts will skip subreddit prediction")
13
  subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt = cap_prompt)
14
  st.header("Predicted Caption:\n\n")
15
  # st.subheader(f"r/{subreddit}:\t{caption}\n")
@@ -31,7 +31,7 @@ st.sidebar.markdown(
31
  You can also generate captions as if they are from specific subreddits,
32
  as if they start with a particular prompt, or even both.
33
 
34
- Feel free to share your results on twitter with #redcaps or with a friend.
35
  """
36
  )
37
 
@@ -39,91 +39,91 @@ with st.spinner("Loading Model"):
39
  virtexModel, imageLoader, sample_images, valid_subs = create_objects()
40
 
41
 
42
- staggered = st.sidebar.checkbox("Iteratively Generate Captions")
43
 
44
- if staggered:
45
- pass
46
- else:
47
-
48
- select_idx = None
49
 
50
- st.sidebar.title("Select a sample image")
51
 
52
- if st.sidebar.button("Random Sample Image"):
53
- select_idx = get_rand_idx(sample_images)
54
 
55
- sample_image = sample_images[0 if select_idx is None else select_idx]
56
 
57
 
58
- uploaded_image = None
59
- with st.sidebar.form("file-uploader-form", clear_on_submit=True):
60
- uploaded_file = st.file_uploader("Choose a file")
61
- submitted = st.form_submit_button("Submit")
62
- if uploaded_file is not None and submitted:
63
- uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
64
- select_idx = None # set this to help rewrite the cache
65
 
66
- # class OnChange():
67
- # def __init__(self, idx):
68
- # self.idx = idx
69
 
70
- # def __call__(self):
71
- # st.write(f"the idx is: {self.idx}")
72
- # st.write(f"the sample_image is {sample_image}")
73
 
74
- # sample_image = st.sidebar.selectbox(
75
- # "",
76
- # sample_images,
77
- # index = 0 if select_idx is None else select_idx,
78
- # on_change=OnChange(0 if select_idx is None else select_idx)
79
- # )
80
 
81
- st.sidebar.title("Select a Subreddit")
82
- sub = st.sidebar.selectbox(
83
- "Select None for a Predicted Subreddit",
84
- valid_subs
85
- )
86
 
87
- st.sidebar.title("Write a Custom Prompt")
88
- cap_prompt = st.sidebar.text_input(
89
- "Leave this blank for an unbiased caption",
90
- value=""
91
- )
 
 
92
 
93
- _ = st.sidebar.button("Regenerate Caption")
94
-
95
  # advanced = st.sidebar.checkbox("Advanced Options")
96
-
97
  # if advanced:
98
  # nuc_size = st.sidebar.slider("")
99
 
100
- if uploaded_image is None and submitted:
101
- st.write("Please select a file to upload")
102
 
103
- else:
104
- image_file = sample_image
105
 
106
- # LOAD AND CACHE THE IMAGE
107
- if uploaded_image is not None:
108
- image = uploaded_image
109
- elif select_idx is None and 'image' in st.session_state:
110
- image = st.session_state['image']
111
- else:
112
- image = Image.open(image_file)
113
 
114
- image = image.convert("RGB")
115
 
116
- st.session_state['image'] = image
117
 
118
 
119
- image_dict = imageLoader.transform(image)
120
 
121
- show_image = imageLoader.show_resize(image)
122
 
123
- show = st.image(show_image)
124
- show.image(show_image, "Your Image")
125
 
126
- gen_show_caption(sub, imageLoader.text_transform(cap_prompt))
127
 
128
  # from model import *
129
  # sample_images = get_samples()
9
  def gen_show_caption(sub_prompt=None, cap_prompt = ""):
10
  with st.spinner("Generating Caption"):
11
  if sub_prompt is None and cap_prompt is not "":
12
+ st.write("Without a specified subreddit we default to /r/pics")
13
  subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt = cap_prompt)
14
  st.header("Predicted Caption:\n\n")
15
  # st.subheader(f"r/{subreddit}:\t{caption}\n")
31
  You can also generate captions as if they are from specific subreddits,
32
  as if they start with a particular prompt, or even both.
33
 
34
+ Share your results on twitter with #redcaps or with a friend.
35
  """
36
  )
37
 
39
  virtexModel, imageLoader, sample_images, valid_subs = create_objects()
40
 
41
 
42
+ # staggered = st.sidebar.checkbox("Iteratively Generate Captions")
43
 
44
+ # if staggered:
45
+ # pass
46
+ # else:
47
+
48
+ select_idx = None
49
 
50
+ st.sidebar.title("Select a sample image")
51
 
52
+ if st.sidebar.button("Random Sample Image"):
53
+ select_idx = get_rand_idx(sample_images)
54
 
55
+ sample_image = sample_images[0 if select_idx is None else select_idx]
56
 
57
 
58
+ uploaded_image = None
59
+ with st.sidebar.form("file-uploader-form", clear_on_submit=True):
60
+ uploaded_file = st.file_uploader("Choose a file")
61
+ submitted = st.form_submit_button("Submit")
62
+ if uploaded_file is not None and submitted:
63
+ uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
64
+ select_idx = None # set this to help rewrite the cache
65
 
66
+ # class OnChange():
67
+ # def __init__(self, idx):
68
+ # self.idx = idx
69
 
70
+ # def __call__(self):
71
+ # st.write(f"the idx is: {self.idx}")
72
+ # st.write(f"the sample_image is {sample_image}")
73
 
74
+ # sample_image = st.sidebar.selectbox(
75
+ # "",
76
+ # sample_images,
77
+ # index = 0 if select_idx is None else select_idx,
78
+ # on_change=OnChange(0 if select_idx is None else select_idx)
79
+ # )
80
 
81
+ st.sidebar.title("Select a Subreddit")
82
+ sub = st.sidebar.selectbox(
83
+ "Type below to condition on a subreddit. Select None for a predicted subreddit",
84
+ valid_subs
85
+ )
86
 
87
+ st.sidebar.title("Write a Custom Prompt")
88
+ cap_prompt = st.sidebar.text_input(
89
+ "Write the start of your caption below",
90
+ value=""
91
+ )
92
+
93
+ _ = st.sidebar.button("Regenerate Caption")
94
 
 
 
95
  # advanced = st.sidebar.checkbox("Advanced Options")
96
+
97
  # if advanced:
98
  # nuc_size = st.sidebar.slider("")
99
 
100
+ if uploaded_image is None and submitted:
101
+ st.write("Please select a file to upload")
102
 
103
+ else:
104
+ image_file = sample_image
105
 
106
+ # LOAD AND CACHE THE IMAGE
107
+ if uploaded_image is not None:
108
+ image = uploaded_image
109
+ elif select_idx is None and 'image' in st.session_state:
110
+ image = st.session_state['image']
111
+ else:
112
+ image = Image.open(image_file)
113
 
114
+ image = image.convert("RGB")
115
 
116
+ st.session_state['image'] = image
117
 
118
 
119
+ image_dict = imageLoader.transform(image)
120
 
121
+ show_image = imageLoader.show_resize(image)
122
 
123
+ show = st.image(show_image)
124
+ show.image(show_image, "Your Image")
125
 
126
+ gen_show_caption(sub, imageLoader.text_transform(cap_prompt))
127
 
128
  # from model import *
129
  # sample_images = get_samples()
model.py CHANGED
@@ -80,15 +80,19 @@ class VirTexModel():
80
  if prompt is not "":
81
  # at present prompts without subreddits will break without this change
82
  # TODO FIX
83
- if True: #sub_prompt is not None:
84
- cap_tokens = self.tokenizer.encode(prompt)
85
- cap_tokens = torch.tensor(cap_tokens, device=self.device).long()
86
- subreddit_tokens = torch.cat(
87
- [
88
- subreddit_tokens,
89
- torch.tensor([self.tokenizer.token_to_id("[SEP]")], device=self.device).long(),
90
- cap_tokens
91
- ])
 
 
 
 
92
 
93
 
94
  predictions: List[Dict[str, Any]] = []
80
  if prompt is not "":
81
  # at present prompts without subreddits will break without this change
82
  # TODO FIX
83
+ cap_tokens = self.tokenizer.encode(prompt)
84
+ cap_tokens = torch.tensor(cap_tokens, device=self.device).long()
85
+ subreddit_tokens = subreddit_tokens if sub_prompt is not None else torch.tensor([
86
+ [self.model.sos_index] +
87
+ self.tokenizer.encode("pics") +
88
+ [self.tokenizer.token_to_id("[SEP]")]
89
+ ])
90
+ subreddit_tokens = torch.cat(
91
+ [
92
+ subreddit_tokens,
93
+ torch.tensor([self.tokenizer.token_to_id("[SEP]")], device=self.device).long(),
94
+ cap_tokens
95
+ ])
96
 
97
 
98
  predictions: List[Dict[str, Any]] = []