gorkaartola commited on
Commit
ec0be50
1 Parent(s): 54ef3a8

Delete run.py

Browse files
Files changed (1) hide show
  1. run.py +0 -163
run.py DELETED
@@ -1,163 +0,0 @@
1
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding
2
- from datasets import load_dataset, load_metric
3
- import evaluate
4
- from torch.utils.data import DataLoader
5
- import torch
6
- import numpy as np
7
- import pandas as pd
8
-
9
- model_selector = '0'
10
- queries_selector = '0'
11
- prompt_selector = '0'
12
- metric_selector = '0'
13
- prediction_strategy_selector = ['1','2.0','2.1','2.2','2.3','3.0','3.1','3.2','3.3']
14
-
15
- models = {
16
- '0' : 'joeddav/xlm-roberta-large-xnli',
17
- '1' : '',
18
- }
19
-
20
- queries = {
21
- '0' : 'SDG_Titles.csv',
22
- '1' : 'SDG_Headlines.csv',
23
- '2' : 'SDG_Subjects.csv',
24
- '3' : 'SDG_Targets.csv',
25
- '4' : 'SDG_Numbers.csv',
26
- }
27
-
28
- prompts = {
29
- '0' : '',
30
- '1' : 'This is ',
31
- '2' : 'The subject is ',
32
- '3' : 'The Sustainable Development Goal is ',
33
- }
34
-
35
- metrics = {
36
- '0' : 'gorkaartola/metric_for_tp_fp_samples',
37
- }
38
-
39
- prediction_strategy_options = {
40
- '1': ["argmax_max"],
41
- '2.0': ["threshold", 0.05],
42
- '2.1': ["threshold", 0.25],
43
- '2.2': ["threshold", 0.5],
44
- '2.3': ["threshold", 0.75],
45
- '3.0': ["topk", 9],
46
- '3.1': ["topk", 7],
47
- '3.2': ["topk", 5],
48
- '3.3': ["topk", 3],
49
- }
50
-
51
- saved_inference_tables_path = 'Reports/ZS inference tables/'
52
-
53
- #Load test dataset___________________________
54
- test_dataset = load_dataset('gorkaartola/SC-ZS-test_AURORA-Gold-SDG_True-Positives-and-False-Positives')['test']
55
-
56
- #Load queries________________________________
57
- queries_data_files = {'queries': queries[queries_selector]}
58
- queries_dataset = load_dataset('gorkaartola/SDG_queries', data_files = queries_data_files)['queries']
59
-
60
- #Load prompt_________________________________
61
- prompt = prompts[prompt_selector]
62
-
63
- #Load prediction strategias__________________
64
- prediction_strategies = [prediction_strategy_options[x] for x in prediction_strategy_selector]
65
-
66
- #Calculate and save predictions_______________________
67
- #'''
68
- def tokenize_function(example, prompt = '', query = ''):
69
- queries = []
70
- for i in range(len(example['title'])):
71
- queries.append(prompt + query)
72
- tokenize = tokenizer(example['title'], queries, truncation='only_first')
73
- #tokenize['query'] = queries
74
- return tokenize
75
-
76
- model = AutoModelForSequenceClassification.from_pretrained(models[model_selector])
77
- #device = torch.device("cuda")
78
- #model.to(device)
79
- tokenizer = AutoTokenizer.from_pretrained(models[model_selector])
80
- data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
81
-
82
- results_test = pd.DataFrame()
83
- for query_data in queries_dataset:
84
- query = query_data['SDGquery']
85
- tokenized_test_dataset = test_dataset.map(tokenize_function, batched = True, fn_kwargs = {'prompt' : prompt, 'query' : query})
86
- columns_to_remove = test_dataset.column_names
87
- for column_name in ['label_ids', 'nli_label']:
88
- columns_to_remove.remove(column_name)
89
- tokenized_test_dataset_for_inference = tokenized_test_dataset.remove_columns(columns_to_remove)
90
- tokenized_test_dataset_for_inference.set_format('torch')
91
- dataloader = DataLoader(
92
- tokenized_test_dataset_for_inference,
93
- batch_size=8,
94
- collate_fn = data_collator,
95
- )
96
- values = []
97
- labels = []
98
- nli_labels =[]
99
- for batch in dataloader:
100
- #data = {k: v.to(device) for k, v in batch.items() if k not in ['labels', 'nli_label']}
101
- data = {k: v for k, v in batch.items() if k not in ['labels', 'nli_label']}
102
- with torch.no_grad():
103
- outputs = model(**data)
104
- logits = outputs.logits
105
- entail_contradiction_logits = logits[:,[0,2]]
106
- probs = entail_contradiction_logits.softmax(dim=1)
107
- predictions = probs[:,1].tolist()
108
- label_ids = batch['labels'].tolist()
109
- nli_label_ids = batch['nli_label'].tolist()
110
- for prediction, label, nli_label in zip(predictions, label_ids, nli_label_ids):
111
- values.append(prediction)
112
- labels.append(label)
113
- nli_labels.append(nli_label)
114
- results_test['dataset_labels'] = labels
115
- results_test['nli_labels'] = nli_labels
116
- results_test[query] = values
117
-
118
- results_test.to_csv(saved_inference_tables_path + 'ZS-inference-table_Model-' + model_selector + '_Queries-' + queries_selector + '_Prompt-' + prompt_selector + '.csv', index = False)
119
- #'''
120
- #Load saved predictions____________________________
121
- '''
122
- results_test = pd.read_csv(saved_inference_tables_path + 'ZS-inference-table_Model-' + model_selector + '_Queries-' + queries_selector + '_Prompt-' + prompt_selector + '.csv')
123
- '''
124
- #Analize predictions_______________________________
125
- def logits_labels(raw):
126
- raw_logits = raw.iloc[:,2:]
127
- logits = np.zeros(shape=(len(raw_logits.index),17))
128
- for i in range(17):
129
- queries = queries_dataset.filter(lambda x: x['label_ids'] == i)['SDGquery']
130
- logits[:,i]=raw_logits[queries].max(axis=1)
131
- labels = raw[["dataset_labels","nli_labels"]]
132
- labels = np.array(labels).astype(int)
133
- return logits, labels
134
-
135
- predictions, references = logits_labels(results_test)
136
- prediction_strategies = [prediction_strategy_options[x] for x in prediction_strategy_selector]
137
-
138
- metric = evaluate.load(metrics['0'])
139
- metric.add_batch(predictions = predictions, references = references)
140
- results = metric.compute(prediction_strategies = prediction_strategies)
141
- with open('Reports/report-Model-' + model_selector + '_Queries-' + queries_selector + '_Prompt-' + prompt_selector + '.csv', 'a') as results_file:
142
- for result in results:
143
- results[result].to_csv(results_file, mode='a', index_label = result)
144
- print(results[result], '\n')
145
-
146
-
147
-
148
-
149
-
150
-
151
-
152
-
153
-
154
-
155
-
156
-
157
-
158
-
159
-
160
-
161
-
162
-
163
-