zamborg commited on
Commit
c838395
1 Parent(s): 54fa80d

cap_prompt attempts

Browse files
Files changed (2) hide show
  1. app.py +3 -6
  2. model.py +3 -0
app.py CHANGED
@@ -5,9 +5,9 @@ import time
5
  import json
6
  sys.path.append("./virtex/")
7
 
8
- def gen_show_caption(sub_prompt=None, cap_prompt = None):
9
  with st.spinner("Generating Caption"):
10
- subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt)
11
  st.header("Predicted Caption:\n\n")
12
  st.subheader(f"r/{subreddit}:\t{caption}\n")
13
 
@@ -44,12 +44,9 @@ sub = st.sidebar.selectbox(
44
  st.sidebar.title("Write a Custom Prompt")
45
  cap_prompt = st.sidebar.text_input(
46
  "Leave this blank for an unbiased caption",
47
- value=None
48
  )
49
 
50
- if cap_prompt is None:
51
- st.write("HAHA")
52
-
53
  if st.sidebar.button("Random Sample Image"):
54
  random_image = get_rand_img(sample_images)
55
  sample_image = None
 
5
  import json
6
  sys.path.append("./virtex/")
7
 
8
+ def gen_show_caption(sub_prompt=None, cap_prompt = ""):
9
  with st.spinner("Generating Caption"):
10
+ subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt = cap_prompt)
11
  st.header("Predicted Caption:\n\n")
12
  st.subheader(f"r/{subreddit}:\t{caption}\n")
13
 
 
44
  st.sidebar.title("Write a Custom Prompt")
45
  cap_prompt = st.sidebar.text_input(
46
  "Leave this blank for an unbiased caption",
47
+ value=""
48
  )
49
 
 
 
 
50
  if st.sidebar.button("Random Sample Image"):
51
  random_image = get_rand_img(sample_images)
52
  sample_image = None
model.py CHANGED
@@ -62,6 +62,9 @@ class VirTexModel():
62
  )
63
  subreddit_tokens = torch.tensor(subreddit_tokens, device=self.device).long()
64
 
 
 
 
65
 
66
  predictions: List[Dict[str, Any]] = []
67
 
 
62
  )
63
  subreddit_tokens = torch.tensor(subreddit_tokens, device=self.device).long()
64
 
65
+ if prompt is not "":
66
+ cap_tokens = self.tokenizer.encode(subreddit_tokens)
67
+ subreddit_tokens = torch.cat((subreddit_tokens, cap_tokens))
68
 
69
  predictions: List[Dict[str, Any]] = []
70