nickgardner commited on
Commit
4117229
1 Parent(s): 304e451

this should not be so hard

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -47,11 +47,13 @@ def respond(custom_string):
47
 
48
  out = model.out(model.decoder(outputs[:i].unsqueeze(0), e_outputs, src_mask, trg_mask))
49
  out = torch.nn.functional.softmax(out, dim=-1).detach()
50
- ix = np.random.choice(np.arange(out[:, -1]), 1, p=out[:, -1])
51
  # val, ix = out[:, -1].data.topk(1)
52
 
53
- outputs[i] = ix[0][0]
54
- if ix[0][0] == vocab_token_dict['<eos>']:
 
 
55
  break
56
  return ' '.join([indices_to_tokens[ix] for ix in outputs[1:i]])
57
 
 
47
 
48
  out = model.out(model.decoder(outputs[:i].unsqueeze(0), e_outputs, src_mask, trg_mask))
49
  out = torch.nn.functional.softmax(out, dim=-1).detach()
50
+ ix = np.random.choice(np.arange(out[:, -1].data), 1, p=out[:, -1].data)
51
  # val, ix = out[:, -1].data.topk(1)
52
 
53
+ # outputs[i] = ix[0][0]
54
+ outputs[i] = ix
55
+ # if ix[0][0] == vocab_token_dict['<eos>']:
56
+ if ix == vocab_token_dict['<eos>']:
57
  break
58
  return ' '.join([indices_to_tokens[ix] for ix in outputs[1:i]])
59