Joshua Lochner commited on
Commit
8f0e2d8
·
1 Parent(s): c78435b

Assign default model for predictions

Browse files
Files changed (1) hide show
  1. src/predict.py +9 -7
src/predict.py CHANGED
@@ -17,13 +17,14 @@ from transformers import HfArgumentParser
17
  from transformers.trainer_utils import get_last_checkpoint
18
  from dataclasses import dataclass, field
19
  import logging
 
20
 
21
 
22
  @dataclass
23
  class TrainingOutputArguments:
24
 
25
  model_path: str = field(
26
- default=None,
27
  metadata={
28
  'help': 'Path to pretrained model used for prediction'
29
  }
@@ -36,12 +37,13 @@ class TrainingOutputArguments:
36
  if self.model_path is not None:
37
  return
38
 
39
- last_checkpoint = get_last_checkpoint(self.output_dir)
40
- if last_checkpoint is not None:
41
- self.model_path = last_checkpoint
42
- else:
43
- raise Exception(
44
- 'Unable to find model, explicitly set `--model_path`')
 
45
 
46
 
47
  @dataclass
 
17
  from transformers.trainer_utils import get_last_checkpoint
18
  from dataclasses import dataclass, field
19
  import logging
20
+ import os
21
 
22
 
23
  @dataclass
24
  class TrainingOutputArguments:
25
 
26
  model_path: str = field(
27
+ default='Xenova/sponsorblock-small',
28
  metadata={
29
  'help': 'Path to pretrained model used for prediction'
30
  }
 
37
  if self.model_path is not None:
38
  return
39
 
40
+ if os.path.exists(self.output_dir):
41
+ last_checkpoint = get_last_checkpoint(self.output_dir)
42
+ if last_checkpoint is not None:
43
+ self.model_path = last_checkpoint
44
+ return
45
+
46
+ raise Exception('Unable to find model, explicitly set `--model_path`')
47
 
48
 
49
  @dataclass