kdevoe commited on
Commit
6fbe5f8
1 Parent(s): 3884ec7

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +5 -1
inference.py CHANGED
@@ -22,11 +22,15 @@ class DebertaEvaluator(nn.Module):
22
 
23
  return linear_output
24
 
25
- def inference():
26
  saved_model_path = './'
27
  model = torch.load(saved_model_path + 'fine-tuned-model.pt', map_location=torch.device('cpu'))
28
  tokenizer = torch.load(saved_model_path + 'fine-tuned-tokenizer.pt', map_location=torch.device('cpu'))
29
  model.eval()
 
 
 
 
30
 
31
  if __name__ == "__main__":
32
  inference()
 
22
 
23
  return linear_output
24
 
25
+ def inference(input_text):
26
  saved_model_path = './'
27
  model = torch.load(saved_model_path + 'fine-tuned-model.pt', map_location=torch.device('cpu'))
28
  tokenizer = torch.load(saved_model_path + 'fine-tuned-tokenizer.pt', map_location=torch.device('cpu'))
29
  model.eval()
30
+ input = tokenizer(input_text)
31
+ output = model(input_data['input_ids'].squeeze(1), input_data['attention_mask'])
32
+
33
+ return output.tolist()
34
 
35
  if __name__ == "__main__":
36
  inference()