zamborg commited on
Commit
ed768de
1 Parent(s): d92334f

fixing st writeouts

Browse files
Files changed (2) hide show
  1. app.py +2 -0
  2. model.py +2 -3
app.py CHANGED
@@ -7,6 +7,8 @@ sys.path.append("./virtex/")
7
 
8
  def gen_show_caption(sub_prompt=None, cap_prompt = ""):
9
  with st.spinner("Generating Caption"):
 
 
10
  subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt = cap_prompt)
11
  st.header("Predicted Caption:\n\n")
12
  st.subheader(f"r/{subreddit}:\t{caption}\n")
 
7
 
8
  def gen_show_caption(sub_prompt=None, cap_prompt = ""):
9
  with st.spinner("Generating Caption"):
10
+ if sub_prompt is None and cap_prompt is not "":
11
+ st.write("Without a specified subreddit, caption prompts will skip subreddit prediction")
12
  subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt = cap_prompt)
13
  st.header("Predicted Caption:\n\n")
14
  st.subheader(f"r/{subreddit}:\t{caption}\n")
model.py CHANGED
@@ -63,13 +63,12 @@ class VirTexModel():
63
  subreddit_tokens = torch.tensor(subreddit_tokens, device=self.device).long()
64
 
65
  if prompt is not "":
 
 
66
  if sub_prompt is not None:
67
  cap_tokens = self.tokenizer.encode(prompt)
68
  cap_tokens = torch.tensor(cap_tokens, device=self.device).long()
69
  subreddit_tokens = torch.cat([subreddit_tokens, cap_tokens])
70
- else:
71
- st.write("Without a specified subreddit, caption prompts will skip subreddit prediction")
72
- #TODO fix this
73
 
74
 
75
  predictions: List[Dict[str, Any]] = []
 
63
  subreddit_tokens = torch.tensor(subreddit_tokens, device=self.device).long()
64
 
65
  if prompt is not "":
66
+ # at present prompts without subreddits will break without this change
67
+ # TODO FIX
68
  if sub_prompt is not None:
69
  cap_tokens = self.tokenizer.encode(prompt)
70
  cap_tokens = torch.tensor(cap_tokens, device=self.device).long()
71
  subreddit_tokens = torch.cat([subreddit_tokens, cap_tokens])
 
 
 
72
 
73
 
74
  predictions: List[Dict[str, Any]] = []