gorkaartola commited on
Commit
ce81925
1 Parent(s): 22ef6fa

Upload run.py

Browse files
Files changed (1) hide show
  1. run.py +4 -5
run.py CHANGED
@@ -17,8 +17,6 @@ def tp_tf_test(test_dataset, metric_selector, model_selector, queries_selector,
17
  queries_dataset_path = queries_selector.replace('-'+queries_data_file, '')
18
  queries_dataset_split = {'queries': queries_data_file}
19
  queries_dataset = load_dataset(queries_dataset_path, data_files = queries_dataset_split)['queries']
20
- #queries_data_files = {'queries': queries_selector}
21
- #queries_dataset = load_dataset('gorkaartola/SDG_queries', data_files = queries_data_files)['queries']
22
 
23
  #Load prompt_________________________________
24
  prompt = prompt_selector
@@ -89,10 +87,11 @@ def tp_tf_test(test_dataset, metric_selector, model_selector, queries_selector,
89
  results_test = pd.read_csv('Reports/ZS inference tables/ZS-inference-table_Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_dataset_path][queries_data_file] + '_Prompt-' + op.prompts[prompt_selector] + '.csv')
90
  '''
91
  #Analize predictions_______________________________
92
- def logits_labels(raw):
 
93
  raw_logits = raw.iloc[:,2:]
94
- logits = np.zeros(shape=(len(raw_logits.index),17))
95
- for i in range(17):
96
  queries = queries_dataset.filter(lambda x: x['label_ids'] == i)['SDGquery']
97
  logits[:,i]=raw_logits[queries].max(axis=1)
98
  labels = raw[["dataset_labels","nli_labels"]]
 
17
  queries_dataset_path = queries_selector.replace('-'+queries_data_file, '')
18
  queries_dataset_split = {'queries': queries_data_file}
19
  queries_dataset = load_dataset(queries_dataset_path, data_files = queries_dataset_split)['queries']
 
 
20
 
21
  #Load prompt_________________________________
22
  prompt = prompt_selector
 
87
  results_test = pd.read_csv('Reports/ZS inference tables/ZS-inference-table_Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_dataset_path][queries_data_file] + '_Prompt-' + op.prompts[prompt_selector] + '.csv')
88
  '''
89
  #Analize predictions_______________________________
90
+ def logits_labels(raw):
91
+ classes = raw["dataset_labels"].unique()
92
  raw_logits = raw.iloc[:,2:]
93
+ logits = np.zeros(shape=(len(raw_logits.index),len(classes)))
94
+ for i in range(len(classes)):
95
  queries = queries_dataset.filter(lambda x: x['label_ids'] == i)['SDGquery']
96
  logits[:,i]=raw_logits[queries].max(axis=1)
97
  labels = raw[["dataset_labels","nli_labels"]]