Spaces:
Runtime error
Runtime error
Commit
·
a3533fc
1
Parent(s):
f63d4de
Upload run.py
Browse files
run.py
CHANGED
@@ -92,7 +92,7 @@ def tp_tf_test(metric_selector, test_dataset, model_selector, queries_selector,
|
|
92 |
raw_logits = raw.iloc[:,2:]
|
93 |
logits = np.zeros(shape=(len(raw_logits.index),len(classes)))
|
94 |
for i in range(len(classes)):
|
95 |
-
queries = queries_dataset.filter(lambda x: x['label_ids'] == i)['
|
96 |
logits[:,i]=raw_logits[queries].max(axis=1)
|
97 |
labels = raw[["dataset_labels","nli_labels"]]
|
98 |
labels = np.array(labels).astype(int)
|
|
|
92 |
raw_logits = raw.iloc[:,2:]
|
93 |
logits = np.zeros(shape=(len(raw_logits.index),len(classes)))
|
94 |
for i in range(len(classes)):
|
95 |
+
queries = queries_dataset.filter(lambda x: x['label_ids'] == i)['query']
|
96 |
logits[:,i]=raw_logits[queries].max(axis=1)
|
97 |
labels = raw[["dataset_labels","nli_labels"]]
|
98 |
labels = np.array(labels).astype(int)
|