nickgardner commited on
Commit
22a141b
1 Parent(s): ab4aa34

maybe this works??

Browse files
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -33,9 +33,9 @@ model.load_state_dict(torch.load(hf_hub_download(repo_id="nickgardner/chatbot",
33
  filename="alpaca_train_400_epoch.pt"), map_location=device))
34
  model.eval()
35
 
36
- def respond(custom_string):
37
  model.eval()
38
- src = torch.tensor(text_pipeline(custom_string), dtype=torch.int64).unsqueeze(0).to(device)
39
  src_mask = ((src != pad_token) & (src != unknown_token)).unsqueeze(-2).to(device)
40
  e_outputs = model.encoder(src, src_mask)
41
 
@@ -47,8 +47,7 @@ def respond(custom_string):
47
 
48
  out = model.out(model.decoder(outputs[:i].unsqueeze(0), e_outputs, src_mask, trg_mask))
49
  out = torch.nn.functional.softmax(out, dim=-1).detach()
50
- print(out[:, -1].data.shape)
51
- ix = np.random.choice(np.arange(out[:, -1].data), 1, p=out[:, -1].data)
52
  # val, ix = out[:, -1].data.topk(1)
53
 
54
  # outputs[i] = ix[0][0]
 
33
  filename="alpaca_train_400_epoch.pt"), map_location=device))
34
  model.eval()
35
 
36
+ def respond(input):
37
  model.eval()
38
+ src = torch.tensor(text_pipeline(input), dtype=torch.int64).unsqueeze(0).to(device)
39
  src_mask = ((src != pad_token) & (src != unknown_token)).unsqueeze(-2).to(device)
40
  e_outputs = model.encoder(src, src_mask)
41
 
 
47
 
48
  out = model.out(model.decoder(outputs[:i].unsqueeze(0), e_outputs, src_mask, trg_mask))
49
  out = torch.nn.functional.softmax(out, dim=-1).detach()
50
+ ix = np.random.choice(np.arange(out[:, -1].data[0]), 1, p=out[:, -1].data[0])
 
51
  # val, ix = out[:, -1].data.topk(1)
52
 
53
  # outputs[i] = ix[0][0]