Spaces:
Runtime error
Runtime error
Update inference.py
Browse files- inference.py +6 -3
inference.py
CHANGED
@@ -11,11 +11,14 @@ text_pipeline = lambda x: vocab(tokenizer(x))
|
|
11 |
num_class = 11
|
12 |
vocab_size = len(vocab)
|
13 |
embed_size = 300
|
14 |
-
lr = 0.4
|
15 |
|
16 |
model = TextClassifierModel(vocab_size, embed_size, num_class)
|
17 |
-
optimizer = torch.optim.SGD(model_test.parameters(), lr=0.4)
|
18 |
|
|
|
19 |
|
20 |
-
|
|
|
|
|
|
|
|
|
21 |
|
|
|
11 |
num_class = 11
|
12 |
vocab_size = len(vocab)
|
13 |
embed_size = 300
|
|
|
14 |
|
15 |
model = TextClassifierModel(vocab_size, embed_size, num_class)
|
|
|
16 |
|
17 |
+
model = load_state_dict(model, model_trained, vocab)
|
18 |
|
19 |
+
def predict(text, model=model, text_pipeline=text_pipeline):
|
20 |
+
with torch.no_grad()
|
21 |
+
model.eval()
|
22 |
+
text_tensor = torch.tensor(text_pipeline(text))
|
23 |
+
return model(text_tensor, torch.tensor([0])).argmax(1).item()
|
24 |
|