Update inference.py
Browse files- inference.py +5 -1
inference.py
CHANGED
@@ -22,11 +22,15 @@ class DebertaEvaluator(nn.Module):
|
|
22 |
|
23 |
return linear_output
|
24 |
|
25 |
-
def inference():
|
26 |
saved_model_path = './'
|
27 |
model = torch.load(saved_model_path + 'fine-tuned-model.pt', map_location=torch.device('cpu'))
|
28 |
tokenizer = torch.load(saved_model_path + 'fine-tuned-tokenizer.pt', map_location=torch.device('cpu'))
|
29 |
model.eval()
|
|
|
|
|
|
|
|
|
30 |
|
31 |
if __name__ == "__main__":
|
32 |
inference()
|
|
|
22 |
|
23 |
return linear_output
|
24 |
|
25 |
+
def inference(input_text):
|
26 |
saved_model_path = './'
|
27 |
model = torch.load(saved_model_path + 'fine-tuned-model.pt', map_location=torch.device('cpu'))
|
28 |
tokenizer = torch.load(saved_model_path + 'fine-tuned-tokenizer.pt', map_location=torch.device('cpu'))
|
29 |
model.eval()
|
30 |
+
input = tokenizer(input_text)
|
31 |
+
output = model(input_data['input_ids'].squeeze(1), input_data['attention_mask'])
|
32 |
+
|
33 |
+
return output.tolist()
|
34 |
|
35 |
if __name__ == "__main__":
|
36 |
inference()
|