alexrods commited on
Commit
8ab7bd5
1 Parent(s): 7064c13

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +6 -3
inference.py CHANGED
@@ -11,11 +11,14 @@ text_pipeline = lambda x: vocab(tokenizer(x))
11
  num_class = 11
12
  vocab_size = len(vocab)
13
  embed_size = 300
14
- lr = 0.4
15
 
16
  model = TextClassifierModel(vocab_size, embed_size, num_class)
17
- optimizer = torch.optim.SGD(model_test.parameters(), lr=0.4)
18
 
 
19
 
20
- model, optimizer = load_state_dict(model, optimizer, model_trained, vocab)
 
 
 
 
21
 
 
11
  num_class = 11
12
  vocab_size = len(vocab)
13
  embed_size = 300
 
14
 
15
  model = TextClassifierModel(vocab_size, embed_size, num_class)
 
16
 
17
+ model = load_state_dict(model, model_trained, vocab)
18
 
19
+ def predict(text, model=model, text_pipeline=text_pipeline):
20
+ with torch.no_grad()
21
+ model.eval()
22
+ text_tensor = torch.tensor(text_pipeline(text))
23
+ return model(text_tensor, torch.tensor([0])).argmax(1).item()
24