zamborg commited on
Commit
11ab28e
1 Parent(s): b58ad35

model jutsu

Browse files
Files changed (1) hide show
  1. model.py +8 -6
model.py CHANGED
@@ -82,11 +82,13 @@ class VirTexModel():
82
  # TODO FIX
83
  cap_tokens = self.tokenizer.encode(prompt)
84
  cap_tokens = torch.tensor(cap_tokens, device=self.device).long()
85
- subreddit_tokens = subreddit_tokens if sub_prompt is not None else torch.tensor([
86
- [self.model.sos_index] +
87
- self.tokenizer.encode("pics") +
88
- [self.tokenizer.token_to_id("[SEP]")]
89
- ])
 
 
90
  subreddit_tokens = torch.cat(
91
  [
92
  subreddit_tokens,
@@ -118,7 +120,7 @@ class VirTexModel():
118
  else:
119
  subreddit, rest_of_caption = "", caption
120
 
121
- is_valid_subreddit = True if sub_prompt is not None or prompt is not None else subreddit in self.valid_subs
122
 
123
 
124
  return subreddit, rest_of_caption
82
  # TODO FIX
83
  cap_tokens = self.tokenizer.encode(prompt)
84
  cap_tokens = torch.tensor(cap_tokens, device=self.device).long()
85
+ subreddit_tokens = subreddit_tokens if sub_prompt is not None else torch.tensor(
86
+ (
87
+ [self.model.sos_index] +
88
+ self.tokenizer.encode("pics") +
89
+ [self.tokenizer.token_to_id("[SEP]")]
90
+ ), device = self.device).long()
91
+
92
  subreddit_tokens = torch.cat(
93
  [
94
  subreddit_tokens,
120
  else:
121
  subreddit, rest_of_caption = "", caption
122
 
123
+ is_valid_subreddit = subreddit in self.valid_subs
124
 
125
 
126
  return subreddit, rest_of_caption