Spaces:
Runtime error
Runtime error
nickgardner
commited on
Commit
•
4117229
1
Parent(s):
304e451
this should not be so hard
Browse files
app.py
CHANGED
@@ -47,11 +47,13 @@ def respond(custom_string):
|
|
47 |
|
48 |
out = model.out(model.decoder(outputs[:i].unsqueeze(0), e_outputs, src_mask, trg_mask))
|
49 |
out = torch.nn.functional.softmax(out, dim=-1).detach()
|
50 |
-
ix = np.random.choice(np.arange(out[:, -1]), 1, p=out[:, -1])
|
51 |
# val, ix = out[:, -1].data.topk(1)
|
52 |
|
53 |
-
outputs[i] = ix[0][0]
|
54 |
-
|
|
|
|
|
55 |
break
|
56 |
return ' '.join([indices_to_tokens[ix] for ix in outputs[1:i]])
|
57 |
|
|
|
47 |
|
48 |
out = model.out(model.decoder(outputs[:i].unsqueeze(0), e_outputs, src_mask, trg_mask))
|
49 |
out = torch.nn.functional.softmax(out, dim=-1).detach()
|
50 |
+
ix = np.random.choice(np.arange(out[:, -1].data), 1, p=out[:, -1].data)
|
51 |
# val, ix = out[:, -1].data.topk(1)
|
52 |
|
53 |
+
# outputs[i] = ix[0][0]
|
54 |
+
outputs[i] = ix
|
55 |
+
# if ix[0][0] == vocab_token_dict['<eos>']:
|
56 |
+
if ix == vocab_token_dict['<eos>']:
|
57 |
break
|
58 |
return ' '.join([indices_to_tokens[ix] for ix in outputs[1:i]])
|
59 |
|