Spaces:
Runtime error
Runtime error
add mask
Browse files
app.py
CHANGED
@@ -63,6 +63,10 @@ def caption_next_word(latent, model, tokenizer, prefix='a photo of a'):
|
|
63 |
x_l = cross_attn2(x_l, latent)
|
64 |
|
65 |
pred = model.to_logits(x_l)
|
|
|
|
|
|
|
|
|
66 |
next_word = pred.argmax(dim=-1)[0, -1]
|
67 |
next_word = tokenizer.decode(next_word)
|
68 |
|
@@ -73,6 +77,8 @@ def caption(max_len, latent, model, tokenizer, prefix='a photo of a'):
|
|
73 |
while len(words) < max_len:
|
74 |
next_word = caption_next_word(latent, model, tokenizer, prefix=' '.join(words))
|
75 |
words.append(next_word)
|
|
|
|
|
76 |
return ' '.join(words)
|
77 |
|
78 |
|
|
|
63 |
x_l = cross_attn2(x_l, latent)
|
64 |
|
65 |
pred = model.to_logits(x_l)
|
66 |
+
pred[:, :, 103] = -100
|
67 |
+
pred[:, :, 101] = -100
|
68 |
+
pred[:, :, 100] = -100
|
69 |
+
pred[:, :, 0] = -100
|
70 |
next_word = pred.argmax(dim=-1)[0, -1]
|
71 |
next_word = tokenizer.decode(next_word)
|
72 |
|
|
|
77 |
while len(words) < max_len:
|
78 |
next_word = caption_next_word(latent, model, tokenizer, prefix=' '.join(words))
|
79 |
words.append(next_word)
|
80 |
+
if next_word == '[SEP]':
|
81 |
+
break
|
82 |
return ' '.join(words)
|
83 |
|
84 |
|