gorkaartola commited on
Commit
22ef6fa
1 Parent(s): a72cebb

Upload run.py

Browse files
Files changed (1) hide show
  1. run.py +9 -5
run.py CHANGED
@@ -13,8 +13,12 @@ def tp_tf_test(test_dataset, metric_selector, model_selector, queries_selector,
13
  test_dataset = load_dataset(test_dataset)['test']
14
 
15
  #Load queries________________________________
16
- queries_data_files = {'queries': queries_selector}
17
- queries_dataset = load_dataset('gorkaartola/SDG_queries', data_files = queries_data_files)['queries']
 
 
 
 
18
 
19
  #Load prompt_________________________________
20
  prompt = prompt_selector
@@ -78,11 +82,11 @@ def tp_tf_test(test_dataset, metric_selector, model_selector, queries_selector,
78
  results_test['nli_labels'] = nli_labels
79
  results_test[query] = values
80
 
81
- results_test.to_csv('Reports/ZS inference tables/ZS-inference-table_Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_selector] + '_Prompt-' + op.prompts[prompt_selector] + '.csv', index = False)
82
  #'''
83
  #Load saved predictions____________________________
84
  '''
85
- results_test = pd.read_csv('Reports/ZS inference tables/ZS-inference-table_Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_selector] + '_Prompt-' + op.prompts[prompt_selector] + '.csv')
86
  '''
87
  #Analize predictions_______________________________
88
  def logits_labels(raw):
@@ -102,7 +106,7 @@ def tp_tf_test(test_dataset, metric_selector, model_selector, queries_selector,
102
  metric.add_batch(predictions = predictions, references = references)
103
  results = metric.compute(prediction_strategies = prediction_strategies)
104
  prediction_strategies_names = '-'.join(prediction_strategy_selector).replace(" ", "")
105
- output_filename = 'Reports/report-Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_selector] + '_Prompt-' + op.prompts[prompt_selector] + '_Strategies-'+ prediction_strategies_names +'.csv'
106
  with open(output_filename, 'a') as results_file:
107
  for result in results:
108
  results[result].to_csv(results_file, mode='a', index_label = result)
 
13
  test_dataset = load_dataset(test_dataset)['test']
14
 
15
  #Load queries________________________________
16
+ queries_data_file = queries_selector.split('-')[-1]
17
+ queries_dataset_path = queries_selector.replace('-'+queries_data_file, '')
18
+ queries_dataset_split = {'queries': queries_data_file}
19
+ queries_dataset = load_dataset(queries_dataset_path, data_files = queries_dataset_split)['queries']
20
+ #queries_data_files = {'queries': queries_selector}
21
+ #queries_dataset = load_dataset('gorkaartola/SDG_queries', data_files = queries_data_files)['queries']
22
 
23
  #Load prompt_________________________________
24
  prompt = prompt_selector
 
82
  results_test['nli_labels'] = nli_labels
83
  results_test[query] = values
84
 
85
+ results_test.to_csv('Reports/ZS inference tables/ZS-inference-table_Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_dataset_path][queries_data_file] + '_Prompt-' + op.prompts[prompt_selector] + '.csv', index = False)
86
  #'''
87
  #Load saved predictions____________________________
88
  '''
89
+ results_test = pd.read_csv('Reports/ZS inference tables/ZS-inference-table_Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_dataset_path][queries_data_file] + '_Prompt-' + op.prompts[prompt_selector] + '.csv')
90
  '''
91
  #Analize predictions_______________________________
92
  def logits_labels(raw):
 
106
  metric.add_batch(predictions = predictions, references = references)
107
  results = metric.compute(prediction_strategies = prediction_strategies)
108
  prediction_strategies_names = '-'.join(prediction_strategy_selector).replace(" ", "")
109
+ output_filename = 'Reports/report-Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_dataset_path][queries_data_file] + '_Prompt-' + op.prompts[prompt_selector] + '_Strategies-'+ prediction_strategies_names +'.csv'
110
  with open(output_filename, 'a') as results_file:
111
  for result in results:
112
  results[result].to_csv(results_file, mode='a', index_label = result)