gorkaartola commited on
Commit
a3533fc
1 Parent(s): f63d4de

Upload run.py

Browse files
Files changed (1) hide show
  1. run.py +1 -1
run.py CHANGED
@@ -92,7 +92,7 @@ def tp_tf_test(metric_selector, test_dataset, model_selector, queries_selector,
92
  raw_logits = raw.iloc[:,2:]
93
  logits = np.zeros(shape=(len(raw_logits.index),len(classes)))
94
  for i in range(len(classes)):
95
- queries = queries_dataset.filter(lambda x: x['label_ids'] == i)['SDGquery']
96
  logits[:,i]=raw_logits[queries].max(axis=1)
97
  labels = raw[["dataset_labels","nli_labels"]]
98
  labels = np.array(labels).astype(int)
 
92
  raw_logits = raw.iloc[:,2:]
93
  logits = np.zeros(shape=(len(raw_logits.index),len(classes)))
94
  for i in range(len(classes)):
95
+ queries = queries_dataset.filter(lambda x: x['label_ids'] == i)['query']
96
  logits[:,i]=raw_logits[queries].max(axis=1)
97
  labels = raw[["dataset_labels","nli_labels"]]
98
  labels = np.array(labels).astype(int)