alexrods commited on
Commit
7064c13
1 Parent(s): 3e4e0da

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +12 -2
inference.py CHANGED
@@ -1,11 +1,21 @@
1
  import torch
2
  from torchtext.data.utils import get_tokenizer
3
- from model_arch import TextClassifierModel
4
 
5
- model_loaded = torch.load('model_checkpoint.pth')
6
  vocab = torch.load('vocab.pt')
7
  tokenizer = get_tokenizer("spacy", language="es")
8
 
9
  text_pipeline = lambda x: vocab(tokenizer(x))
10
 
 
 
 
 
 
 
 
 
 
 
11
 
 
1
  import torch
2
  from torchtext.data.utils import get_tokenizer
3
+ from model_arch import TextClassifierModel, load_state_dict
4
 
5
+ model_trained = torch.load('model_checkpoint.pth')
6
  vocab = torch.load('vocab.pt')
7
  tokenizer = get_tokenizer("spacy", language="es")
8
 
9
  text_pipeline = lambda x: vocab(tokenizer(x))
10
 
11
+ num_class = 11
12
+ vocab_size = len(vocab)
13
+ embed_size = 300
14
+ lr = 0.4
15
+
16
+ model = TextClassifierModel(vocab_size, embed_size, num_class)
17
+ optimizer = torch.optim.SGD(model_test.parameters(), lr=0.4)
18
+
19
+
20
+ model, optimizer = load_state_dict(model, optimizer, model_trained, vocab)
21