tennant commited on
Commit
450eb04
1 Parent(s): 71e1a05
Files changed (1) hide show
  1. app.py +6 -0
app.py CHANGED
@@ -63,6 +63,10 @@ def caption_next_word(latent, model, tokenizer, prefix='a photo of a'):
63
  x_l = cross_attn2(x_l, latent)
64
 
65
  pred = model.to_logits(x_l)
 
 
 
 
66
  next_word = pred.argmax(dim=-1)[0, -1]
67
  next_word = tokenizer.decode(next_word)
68
 
@@ -73,6 +77,8 @@ def caption(max_len, latent, model, tokenizer, prefix='a photo of a'):
73
  while len(words) < max_len:
74
  next_word = caption_next_word(latent, model, tokenizer, prefix=' '.join(words))
75
  words.append(next_word)
 
 
76
  return ' '.join(words)
77
 
78
 
 
63
  x_l = cross_attn2(x_l, latent)
64
 
65
  pred = model.to_logits(x_l)
66
+ pred[:, :, 103] = -100
67
+ pred[:, :, 101] = -100
68
+ pred[:, :, 100] = -100
69
+ pred[:, :, 0] = -100
70
  next_word = pred.argmax(dim=-1)[0, -1]
71
  next_word = tokenizer.decode(next_word)
72
 
 
77
  while len(words) < max_len:
78
  next_word = caption_next_word(latent, model, tokenizer, prefix=' '.join(words))
79
  words.append(next_word)
80
+ if next_word == '[SEP]':
81
+ break
82
  return ' '.join(words)
83
 
84