alexrods commited on
Commit
b9c93e0
1 Parent(s): 4da03a6

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +16 -4
inference.py CHANGED
@@ -2,6 +2,18 @@ import torch
2
  from torchtext.data.utils import get_tokenizer
3
  from model_arch import TextClassifierModel, load_state_dict
4
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  model_trained = torch.load('model_checkpoint.pth')
6
  vocab = torch.load('vocab.pt')
7
  tokenizer = get_tokenizer("spacy", language="es")
@@ -17,8 +29,8 @@ model = TextClassifierModel(vocab_size, embed_size, num_class)
17
  model = load_state_dict(model, model_trained, vocab)
18
 
19
  def predict(text, model=model, text_pipeline=text_pipeline):
20
- with torch.no_grad()
21
- model.eval()
22
- text_tensor = torch.tensor(text_pipeline(text))
23
- return model(text_tensor, torch.tensor([0])).argmax(1).item()
24
 
 
2
  from torchtext.data.utils import get_tokenizer
3
  from model_arch import TextClassifierModel, load_state_dict
4
 
5
+ labels = {0: 'messaging',
6
+ 1: 'calling',
7
+ 2: 'event',
8
+ 3: 'timer',
9
+ 4: 'music',
10
+ 5: 'weather',
11
+ 6: 'alarm',
12
+ 7: 'people',
13
+ 8: 'reminder',
14
+ 9: 'recipes',
15
+ 10: 'news'}
16
+
17
  model_trained = torch.load('model_checkpoint.pth')
18
  vocab = torch.load('vocab.pt')
19
  tokenizer = get_tokenizer("spacy", language="es")
 
29
  model = load_state_dict(model, model_trained, vocab)
30
 
31
  def predict(text, model=model, text_pipeline=text_pipeline):
32
+ with torch.no_grad():
33
+ model.eval()
34
+ text_tensor = torch.tensor(text_pipeline(text))
35
+ return labels[model(text_tensor, torch.tensor([0])).argmax(1).item()]
36