mohdelgaar commited on
Commit
9c663aa
1 Parent(s): bd1ded7

fix hparams

Browse files
Files changed (1) hide show
  1. app.py +17 -2
app.py CHANGED
@@ -56,9 +56,24 @@ parser.add_argument('--use_crf', type=bool)
56
  parser.add_argument('--print_spans', action='store_true')
57
  args = parser.parse_args()
58
 
59
- args.num_labels = args.num_decs
60
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
 
62
  tokenizer = load_tokenizer(args.model_name)
63
  model = load_model(args, device)[0]
64
  model.eval()
 
56
  parser.add_argument('--print_spans', action='store_true')
57
  args = parser.parse_args()
58
 
59
+ if args.task == 'seq' and args.pheno_id is not None:
60
+ args.num_labels = 1
61
+ elif args.task == 'seq':
62
+ args.num_labels = args.num_phenos
63
+ elif args.task == 'token':
64
+ if args.use_umls:
65
+ args.num_labels = args.num_umls_tags
66
+ else:
67
+ args.num_labels = args.num_decs
68
+ if args.label_encoding == 'multiclass':
69
+ args.num_labels = args.num_labels * 2 + 1
70
+ elif args.label_encoding == 'bo':
71
+ args.num_labels *= 2
72
+ elif args.label_encoding == 'boe':
73
+ args.num_labels *= 3
74
+
75
 
76
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
77
  tokenizer = load_tokenizer(args.model_name)
78
  model = load_model(args, device)[0]
79
  model.eval()