File size: 2,693 Bytes
fc75f91
238ab50
 
 
1fb2550
 
238ab50
 
fc75f91
1fb2550
 
 
238ab50
 
 
8d7ed76
238ab50
 
25b8856
238ab50
 
 
 
 
 
 
 
 
 
 
 
952f87e
bf67bcc
238ab50
 
25e0f62
238ab50
22a141b
238ab50
 
 
 
 
 
 
 
 
 
25e0f62
 
 
 
 
 
 
 
 
 
 
 
 
 
25b8856
238ab50
25e0f62
6144666
25e0f62
fc75f91
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
import torch
from torchtext.data.utils import get_tokenizer
import numpy as np
import subprocess

from huggingface_hub import hf_hub_download
from transformer import Transformer

model_url = "https://huggingface.co/spacy/en_core_web_sm/resolve/main/en_core_web_sm-any-py3-none-any.whl"
subprocess.run(["pip", "install", model_url])

MAX_LEN = 350

tokenizer = get_tokenizer('spacy', language='en_core_web_sm')
vocab = torch.load(hf_hub_download(repo_id="nickgardner/chatbot",
                                   filename="vocab.pth"))
vocab_token_dict = vocab.get_stoi()
indices_to_tokens = vocab.get_itos()
pad_token = vocab_token_dict['<pad>']
unknown_token = vocab_token_dict['<unk>']
sos_token = vocab_token_dict['<sos>']
eos_token = vocab_token_dict['<eos>']
text_pipeline = lambda x: vocab(tokenizer(x))

d_model = 512
heads = 8
N = 6
src_vocab = len(vocab)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Transformer(len(vocab), len(vocab), d_model, N, heads).to(device)
model.load_state_dict(torch.load(hf_hub_download(repo_id="nickgardner/chatbot",
                                      filename="alpaca_train_440_epoch.pt"), map_location=device))
model.eval()

def respond(search_type, input):
    model.eval()
    src = torch.tensor(text_pipeline(input), dtype=torch.int64).unsqueeze(0).to(device)
    src_mask = ((src != pad_token) & (src != unknown_token)).unsqueeze(-2).to(device)
    e_outputs = model.encoder(src, src_mask)

    outputs = torch.zeros(MAX_LEN).type_as(src.data).to(device)
    outputs[0] = torch.tensor([vocab.get_stoi()['<sos>']])
    for i in range(1, MAX_LEN):
        trg_mask = np.triu(np.ones([1, i, i]), k=1).astype('uint8')
        trg_mask = torch.autograd.Variable(torch.from_numpy(trg_mask) == 0).to(device)

        out = model.out(model.decoder(outputs[:i].unsqueeze(0), e_outputs, src_mask, trg_mask))
        if search_type == "Greedy":
            out = torch.nn.functional.softmax(out, dim=-1)
            val, ix = out[:, -1].data.topk(1)

            outputs[i] = ix[0][0]
            if ix[0][0] == vocab_token_dict['<eos>']:
                break
        else:
            out = torch.nn.functional.softmax(out, dim=-1)[:, -1].squeeze().detach().numpy()
            ix = np.random.choice(np.arange(len(out)), 1, p=out)

            outputs[i] = ix[0]
            if ix[0] == vocab_token_dict['<eos>']:
                break
    return ' '.join([indices_to_tokens[ix] for ix in outputs[1:i]])

iface = gr.Interface(fn=respond,
                     inputs=[gr.Radio(["Greedy", "Probabilistic"], value="Greedy", label="Search Type"), "text"],
                     outputs="text")
iface.launch()