nickgardner commited on
Commit
25e0f62
1 Parent(s): c654b20

add search type radio

Browse files
Files changed (1) hide show
  1. app.py +18 -13
app.py CHANGED
@@ -33,7 +33,7 @@ model.load_state_dict(torch.load(hf_hub_download(repo_id="nickgardner/chatbot",
33
  filename="alpaca_train_400_epoch.pt"), map_location=device))
34
  model.eval()
35
 
36
- def respond(input):
37
  model.eval()
38
  src = torch.tensor(text_pipeline(input), dtype=torch.int64).unsqueeze(0).to(device)
39
  src_mask = ((src != pad_token) & (src != unknown_token)).unsqueeze(-2).to(device)
@@ -46,18 +46,23 @@ def respond(input):
46
  trg_mask = torch.autograd.Variable(torch.from_numpy(trg_mask) == 0).to(device)
47
 
48
  out = model.out(model.decoder(outputs[:i].unsqueeze(0), e_outputs, src_mask, trg_mask))
49
- out = torch.nn.functional.softmax(out, dim=-1)[:, -1].squeeze().detach().numpy()
50
- print(out.shape)
51
- print(np.sum(out))
52
- ix = np.random.choice(np.arange(len(out)), 1, p=out)
53
- # val, ix = out[:, -1].data.topk(1)
54
-
55
- # outputs[i] = ix[0][0]
56
- outputs[i] = ix[0]
57
- # if ix[0][0] == vocab_token_dict['<eos>']:
58
- if ix[0] == vocab_token_dict['<eos>']:
59
- break
 
 
 
60
  return ' '.join([indices_to_tokens[ix] for ix in outputs[1:i]])
61
 
62
- iface = gr.Interface(fn=respond, inputs="text", outputs="text")
 
 
63
  iface.launch()
 
33
  filename="alpaca_train_400_epoch.pt"), map_location=device))
34
  model.eval()
35
 
36
+ def respond(search_type, input):
37
  model.eval()
38
  src = torch.tensor(text_pipeline(input), dtype=torch.int64).unsqueeze(0).to(device)
39
  src_mask = ((src != pad_token) & (src != unknown_token)).unsqueeze(-2).to(device)
 
46
  trg_mask = torch.autograd.Variable(torch.from_numpy(trg_mask) == 0).to(device)
47
 
48
  out = model.out(model.decoder(outputs[:i].unsqueeze(0), e_outputs, src_mask, trg_mask))
49
+ if search_type == "Greedy":
50
+ out = torch.nn.functional.softmax(out, dim=-1)
51
+ val, ix = out[:, -1].data.topk(1)
52
+
53
+ outputs[i] = ix[0][0]
54
+ if ix[0][0] == vocab_token_dict['<eos>']:
55
+ break
56
+ else:
57
+ out = torch.nn.functional.softmax(out, dim=-1)[:, -1].squeeze().detach().numpy()
58
+ ix = np.random.choice(np.arange(len(out)), 1, p=out)
59
+
60
+ outputs[i] = ix[0]
61
+ if ix[0] == vocab_token_dict['<eos>']:
62
+ break
63
  return ' '.join([indices_to_tokens[ix] for ix in outputs[1:i]])
64
 
65
+ iface = gr.Interface(fn=respond,
66
+ inputs=[gr.Radio(["Greedy", "Probabilistic"], label="Search Type"), "text"],
67
+ outputs="text")
68
  iface.launch()