nickgardner commited on
Commit
c9c480e
1 Parent(s): fd58d47

maybe progress

Browse files
Files changed (1) hide show
  1. app.py +3 -5
app.py CHANGED
@@ -46,12 +46,10 @@ def respond(input):
46
  trg_mask = torch.autograd.Variable(torch.from_numpy(trg_mask) == 0).to(device)
47
 
48
  out = model.out(model.decoder(outputs[:i].unsqueeze(0), e_outputs, src_mask, trg_mask))
49
- out = torch.nn.functional.softmax(out, dim=-1).detach()
50
  print(out.shape)
51
- print(out[:, -1].data[0])
52
- print(out[:, -1].data[0].shape)
53
- print(np.sum(out[:, -1].data[0]))
54
- ix = np.random.choice(np.arange(len(out[:, -1].data[0])), 1, p=out[:, -1].data[0])
55
  # val, ix = out[:, -1].data.topk(1)
56
 
57
  # outputs[i] = ix[0][0]
 
46
  trg_mask = torch.autograd.Variable(torch.from_numpy(trg_mask) == 0).to(device)
47
 
48
  out = model.out(model.decoder(outputs[:i].unsqueeze(0), e_outputs, src_mask, trg_mask))
49
+ out = torch.nn.functional.softmax(out, dim=-1).squeeze().detach().numpy()
50
  print(out.shape)
51
+ print(np.sum(out))
52
+ ix = np.random.choice(np.arange(len(out)), 1, p=out)
 
 
53
  # val, ix = out[:, -1].data.topk(1)
54
 
55
  # outputs[i] = ix[0][0]