Spaces:
Runtime error
Runtime error
Commit
•
ce81925
1
Parent(s):
22ef6fa
Upload run.py
Browse files
run.py
CHANGED
@@ -17,8 +17,6 @@ def tp_tf_test(test_dataset, metric_selector, model_selector, queries_selector,
|
|
17 |
queries_dataset_path = queries_selector.replace('-'+queries_data_file, '')
|
18 |
queries_dataset_split = {'queries': queries_data_file}
|
19 |
queries_dataset = load_dataset(queries_dataset_path, data_files = queries_dataset_split)['queries']
|
20 |
-
#queries_data_files = {'queries': queries_selector}
|
21 |
-
#queries_dataset = load_dataset('gorkaartola/SDG_queries', data_files = queries_data_files)['queries']
|
22 |
|
23 |
#Load prompt_________________________________
|
24 |
prompt = prompt_selector
|
@@ -89,10 +87,11 @@ def tp_tf_test(test_dataset, metric_selector, model_selector, queries_selector,
|
|
89 |
results_test = pd.read_csv('Reports/ZS inference tables/ZS-inference-table_Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_dataset_path][queries_data_file] + '_Prompt-' + op.prompts[prompt_selector] + '.csv')
|
90 |
'''
|
91 |
#Analize predictions_______________________________
|
92 |
-
def logits_labels(raw):
|
|
|
93 |
raw_logits = raw.iloc[:,2:]
|
94 |
-
logits = np.zeros(shape=(len(raw_logits.index),
|
95 |
-
for i in range(
|
96 |
queries = queries_dataset.filter(lambda x: x['label_ids'] == i)['SDGquery']
|
97 |
logits[:,i]=raw_logits[queries].max(axis=1)
|
98 |
labels = raw[["dataset_labels","nli_labels"]]
|
|
|
17 |
queries_dataset_path = queries_selector.replace('-'+queries_data_file, '')
|
18 |
queries_dataset_split = {'queries': queries_data_file}
|
19 |
queries_dataset = load_dataset(queries_dataset_path, data_files = queries_dataset_split)['queries']
|
|
|
|
|
20 |
|
21 |
#Load prompt_________________________________
|
22 |
prompt = prompt_selector
|
|
|
87 |
results_test = pd.read_csv('Reports/ZS inference tables/ZS-inference-table_Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_dataset_path][queries_data_file] + '_Prompt-' + op.prompts[prompt_selector] + '.csv')
|
88 |
'''
|
89 |
#Analize predictions_______________________________
|
90 |
+
def logits_labels(raw):
|
91 |
+
classes = raw["dataset_labels"].unique()
|
92 |
raw_logits = raw.iloc[:,2:]
|
93 |
+
logits = np.zeros(shape=(len(raw_logits.index),len(classes)))
|
94 |
+
for i in range(len(classes)):
|
95 |
queries = queries_dataset.filter(lambda x: x['label_ids'] == i)['SDGquery']
|
96 |
logits[:,i]=raw_logits[queries].max(axis=1)
|
97 |
labels = raw[["dataset_labels","nli_labels"]]
|