Update macbert/infer_all.py
Browse files- macbert/infer_all.py +4 -4
macbert/infer_all.py
CHANGED
|
@@ -88,10 +88,10 @@ def test_epoch(travel_model, name_model, epoch, dataloader, tokenizer):
|
|
| 88 |
return travel_predictions, travel_probs, name_predictions, name_probs, sms_ids
|
| 89 |
|
| 90 |
def inference():
|
| 91 |
-
travel_checkpoint_file = '
|
| 92 |
-
name_checkpoint_file = '
|
| 93 |
-
ann_file_test = r'
|
| 94 |
-
output_file = r'/
|
| 95 |
cache_dir = 'cache'
|
| 96 |
|
| 97 |
model_cfg = {
|
|
|
|
| 88 |
return travel_predictions, travel_probs, name_predictions, name_probs, sms_ids
|
| 89 |
|
| 90 |
def inference():
|
| 91 |
+
travel_checkpoint_file = 'checkpoints/saved_checkpoints/travel_checkpoint15_train8000.pth.tar'
|
| 92 |
+
name_checkpoint_file = 'checkpoints/saved_checkpoints/name_checkpoint17_train9000.pth.tar'
|
| 93 |
+
ann_file_test = r'dataset/datagame_sms_stage1(in).csv'
|
| 94 |
+
output_file = r'/both_macbertBase_20250731_2.csv'
|
| 95 |
cache_dir = 'cache'
|
| 96 |
|
| 97 |
model_cfg = {
|