Spaces:
Runtime error
Runtime error
mohdelgaar
commited on
Commit
•
9c663aa
1
Parent(s):
bd1ded7
fix hparams
Browse files
app.py
CHANGED
@@ -56,9 +56,24 @@ parser.add_argument('--use_crf', type=bool)
|
|
56 |
parser.add_argument('--print_spans', action='store_true')
|
57 |
args = parser.parse_args()
|
58 |
|
59 |
-
args.
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
|
|
62 |
tokenizer = load_tokenizer(args.model_name)
|
63 |
model = load_model(args, device)[0]
|
64 |
model.eval()
|
|
|
56 |
parser.add_argument('--print_spans', action='store_true')
|
57 |
args = parser.parse_args()
|
58 |
|
59 |
+
if args.task == 'seq' and args.pheno_id is not None:
|
60 |
+
args.num_labels = 1
|
61 |
+
elif args.task == 'seq':
|
62 |
+
args.num_labels = args.num_phenos
|
63 |
+
elif args.task == 'token':
|
64 |
+
if args.use_umls:
|
65 |
+
args.num_labels = args.num_umls_tags
|
66 |
+
else:
|
67 |
+
args.num_labels = args.num_decs
|
68 |
+
if args.label_encoding == 'multiclass':
|
69 |
+
args.num_labels = args.num_labels * 2 + 1
|
70 |
+
elif args.label_encoding == 'bo':
|
71 |
+
args.num_labels *= 2
|
72 |
+
elif args.label_encoding == 'boe':
|
73 |
+
args.num_labels *= 3
|
74 |
+
|
75 |
|
76 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
77 |
tokenizer = load_tokenizer(args.model_name)
|
78 |
model = load_model(args, device)[0]
|
79 |
model.eval()
|