zamborg commited on
Commit
b80df5c
1 Parent(s): e1cdab4

added sub_prompting

Browse files
Files changed (2) hide show
  1. app.py +3 -3
  2. model.py +8 -1
app.py CHANGED
@@ -5,9 +5,9 @@ import time
5
  import json
6
  sys.path.append("./virtex/")
7
 
8
- def gen_show_caption():
9
  with st.spinner("Generating Caption"):
10
- subreddit, caption = virtexModel.predict(image_dict)
11
  st.header("Predicted Caption:\n\n")
12
  st.subheader(f"r/{subreddit}:\t{caption}\n")
13
 
@@ -68,7 +68,7 @@ else:
68
  show = st.image(image)
69
  show.image(image, "Your Image")
70
 
71
- gen_show_caption()
72
 
73
  image.close()
74
 
5
  import json
6
  sys.path.append("./virtex/")
7
 
8
+ def gen_show_caption(sub_prompt=None):
9
  with st.spinner("Generating Caption"):
10
+ subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt)
11
  st.header("Predicted Caption:\n\n")
12
  st.subheader(f"r/{subreddit}:\t{caption}\n")
13
 
68
  show = st.image(image)
69
  show.image(image, "Your Image")
70
 
71
+ gen_show_caption(sub)
72
 
73
  image.close()
74
 
model.py CHANGED
@@ -54,7 +54,14 @@ class VirTexModel():
54
  if sub_prompt is None:
55
  subreddit_tokens = torch.tensor([self.model.sos_index], device=self.device).long()
56
  else:
57
- subreddit_tokens = torch.tensor([self.tokenizer.token_to_id(sub_prompt)], device=self.device).long()
 
 
 
 
 
 
 
58
  predictions: List[Dict[str, Any]] = []
59
 
60
  is_valid_subreddit = False
54
  if sub_prompt is None:
55
  subreddit_tokens = torch.tensor([self.model.sos_index], device=self.device).long()
56
  else:
57
+ subreddit_tokens = " ".join(ws.segment(ws.clean(sub_prompt)))
58
+ subreddit_tokens = (
59
+ [self.model.sos_index] +
60
+ self.tokenizer.encode(subreddit_tokens) +
61
+ [tokenizer.token_to_id("[SEP]")]
62
+ )
63
+
64
+
65
  predictions: List[Dict[str, Any]] = []
66
 
67
  is_valid_subreddit = False