Spaces:
Runtime error
Runtime error
Commit
·
22ef6fa
1
Parent(s):
a72cebb
Upload run.py
Browse files
run.py
CHANGED
@@ -13,8 +13,12 @@ def tp_tf_test(test_dataset, metric_selector, model_selector, queries_selector,
|
|
13 |
test_dataset = load_dataset(test_dataset)['test']
|
14 |
|
15 |
#Load queries________________________________
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
18 |
|
19 |
#Load prompt_________________________________
|
20 |
prompt = prompt_selector
|
@@ -78,11 +82,11 @@ def tp_tf_test(test_dataset, metric_selector, model_selector, queries_selector,
|
|
78 |
results_test['nli_labels'] = nli_labels
|
79 |
results_test[query] = values
|
80 |
|
81 |
-
results_test.to_csv('Reports/ZS inference tables/ZS-inference-table_Model-' + op.models[model_selector] + '_Queries-' + op.queries[
|
82 |
#'''
|
83 |
#Load saved predictions____________________________
|
84 |
'''
|
85 |
-
results_test = pd.read_csv('Reports/ZS inference tables/ZS-inference-table_Model-' + op.models[model_selector] + '_Queries-' + op.queries[
|
86 |
'''
|
87 |
#Analize predictions_______________________________
|
88 |
def logits_labels(raw):
|
@@ -102,7 +106,7 @@ def tp_tf_test(test_dataset, metric_selector, model_selector, queries_selector,
|
|
102 |
metric.add_batch(predictions = predictions, references = references)
|
103 |
results = metric.compute(prediction_strategies = prediction_strategies)
|
104 |
prediction_strategies_names = '-'.join(prediction_strategy_selector).replace(" ", "")
|
105 |
-
output_filename = 'Reports/report-Model-' + op.models[model_selector] + '_Queries-' + op.queries[
|
106 |
with open(output_filename, 'a') as results_file:
|
107 |
for result in results:
|
108 |
results[result].to_csv(results_file, mode='a', index_label = result)
|
|
|
13 |
test_dataset = load_dataset(test_dataset)['test']
|
14 |
|
15 |
#Load queries________________________________
|
16 |
+
queries_data_file = queries_selector.split('-')[-1]
|
17 |
+
queries_dataset_path = queries_selector.replace('-'+queries_data_file, '')
|
18 |
+
queries_dataset_split = {'queries': queries_data_file}
|
19 |
+
queries_dataset = load_dataset(queries_dataset_path, data_files = queries_dataset_split)['queries']
|
20 |
+
#queries_data_files = {'queries': queries_selector}
|
21 |
+
#queries_dataset = load_dataset('gorkaartola/SDG_queries', data_files = queries_data_files)['queries']
|
22 |
|
23 |
#Load prompt_________________________________
|
24 |
prompt = prompt_selector
|
|
|
82 |
results_test['nli_labels'] = nli_labels
|
83 |
results_test[query] = values
|
84 |
|
85 |
+
results_test.to_csv('Reports/ZS inference tables/ZS-inference-table_Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_dataset_path][queries_data_file] + '_Prompt-' + op.prompts[prompt_selector] + '.csv', index = False)
|
86 |
#'''
|
87 |
#Load saved predictions____________________________
|
88 |
'''
|
89 |
+
results_test = pd.read_csv('Reports/ZS inference tables/ZS-inference-table_Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_dataset_path][queries_data_file] + '_Prompt-' + op.prompts[prompt_selector] + '.csv')
|
90 |
'''
|
91 |
#Analize predictions_______________________________
|
92 |
def logits_labels(raw):
|
|
|
106 |
metric.add_batch(predictions = predictions, references = references)
|
107 |
results = metric.compute(prediction_strategies = prediction_strategies)
|
108 |
prediction_strategies_names = '-'.join(prediction_strategy_selector).replace(" ", "")
|
109 |
+
output_filename = 'Reports/report-Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_dataset_path][queries_data_file] + '_Prompt-' + op.prompts[prompt_selector] + '_Strategies-'+ prediction_strategies_names +'.csv'
|
110 |
with open(output_filename, 'a') as results_file:
|
111 |
for result in results:
|
112 |
results[result].to_csv(results_file, mode='a', index_label = result)
|