Spaces:
Runtime error
Runtime error
Update inference.py
Browse files- inference.py +12 -2
inference.py
CHANGED
@@ -1,11 +1,21 @@
|
|
1 |
import torch
|
2 |
from torchtext.data.utils import get_tokenizer
|
3 |
-
from model_arch import TextClassifierModel
|
4 |
|
5 |
-
|
6 |
vocab = torch.load('vocab.pt')
|
7 |
tokenizer = get_tokenizer("spacy", language="es")
|
8 |
|
9 |
text_pipeline = lambda x: vocab(tokenizer(x))
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
|
|
1 |
import torch
|
2 |
from torchtext.data.utils import get_tokenizer
|
3 |
+
from model_arch import TextClassifierModel, load_state_dict
|
4 |
|
5 |
+
model_trained = torch.load('model_checkpoint.pth')
|
6 |
vocab = torch.load('vocab.pt')
|
7 |
tokenizer = get_tokenizer("spacy", language="es")
|
8 |
|
9 |
text_pipeline = lambda x: vocab(tokenizer(x))
|
10 |
|
11 |
+
num_class = 11
|
12 |
+
vocab_size = len(vocab)
|
13 |
+
embed_size = 300
|
14 |
+
lr = 0.4
|
15 |
+
|
16 |
+
model = TextClassifierModel(vocab_size, embed_size, num_class)
|
17 |
+
optimizer = torch.optim.SGD(model_test.parameters(), lr=0.4)
|
18 |
+
|
19 |
+
|
20 |
+
model, optimizer = load_state_dict(model, optimizer, model_trained, vocab)
|
21 |
|