File size: 970 Bytes
d1aaae4
 
7064c13
d1aaae4
b9c93e0
 
 
 
 
 
 
 
 
 
 
 
7064c13
d1aaae4
 
 
 
 
7064c13
 
 
 
 
 
8ab7bd5
7064c13
8ab7bd5
b9c93e0
 
 
 
d1aaae4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
from torchtext.data.utils import get_tokenizer
from model_arch import TextClassifierModel, load_state_dict

labels = {0: 'messaging',
         1: 'calling',
         2: 'event',
         3: 'timer',
         4: 'music',
         5: 'weather',
         6: 'alarm',
         7: 'people',
         8: 'reminder',
         9: 'recipes',
         10: 'news'}
    
model_trained = torch.load('model_checkpoint.pth')
vocab = torch.load('vocab.pt')
tokenizer = get_tokenizer("spacy", language="es")

text_pipeline = lambda x: vocab(tokenizer(x))

num_class = 11
vocab_size = len(vocab)
embed_size = 300

model = TextClassifierModel(vocab_size, embed_size, num_class)

model = load_state_dict(model, model_trained, vocab)

def predict(text, model=model, text_pipeline=text_pipeline):
    with torch.no_grad():
        model.eval()
        text_tensor = torch.tensor(text_pipeline(text))
        return labels[model(text_tensor, torch.tensor([0])).argmax(1).item()]