Spaces:
Runtime error
Runtime error
gorkaartola
commited on
Commit
•
b27edec
1
Parent(s):
a65e727
Upload run.py
Browse files
run.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding
|
2 |
+
from datasets import load_dataset, load_metric
|
3 |
+
import evaluate
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
import options as op
|
9 |
+
|
10 |
+
def tp_tf_test(model_selector, test_dataset, queries_selector, prompt_selector, metric_selector, prediction_strategy_selector):
|
11 |
+
|
12 |
+
#Load test dataset___________________________
|
13 |
+
test_dataset = load_dataset(test_dataset)['test']
|
14 |
+
|
15 |
+
#Load queries________________________________
|
16 |
+
queries_data_files = {'queries': queries_selector}
|
17 |
+
queries_dataset = load_dataset('gorkaartola/SDG_queries', data_files = queries_data_files)['queries']
|
18 |
+
|
19 |
+
#Load prompt_________________________________
|
20 |
+
prompt = prompt_selector
|
21 |
+
|
22 |
+
#Load prediction strategias__________________
|
23 |
+
prediction_strategies = prediction_strategy_selector
|
24 |
+
|
25 |
+
#Load model, tokenizer and collator__________
|
26 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_selector)
|
27 |
+
if torch.cuda.is_available():
|
28 |
+
device = torch.device("cuda")
|
29 |
+
model.to(device)
|
30 |
+
tokenizer = AutoTokenizer.from_pretrained(model_selector)
|
31 |
+
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
32 |
+
|
33 |
+
#Calculate and save predictions______________
|
34 |
+
#'''
|
35 |
+
def tokenize_function(example, prompt = '', query = ''):
|
36 |
+
queries = []
|
37 |
+
for i in range(len(example['title'])):
|
38 |
+
queries.append(prompt + query)
|
39 |
+
tokenize = tokenizer(example['title'], queries, truncation='only_first')
|
40 |
+
#tokenize['query'] = queries
|
41 |
+
return tokenize
|
42 |
+
|
43 |
+
results_test = pd.DataFrame()
|
44 |
+
for query_data in queries_dataset:
|
45 |
+
query = query_data['SDGquery']
|
46 |
+
tokenized_test_dataset = test_dataset.map(tokenize_function, batched = True, fn_kwargs = {'prompt' : prompt, 'query' : query})
|
47 |
+
columns_to_remove = test_dataset.column_names
|
48 |
+
for column_name in ['label_ids', 'nli_label']:
|
49 |
+
columns_to_remove.remove(column_name)
|
50 |
+
tokenized_test_dataset_for_inference = tokenized_test_dataset.remove_columns(columns_to_remove)
|
51 |
+
tokenized_test_dataset_for_inference.set_format('torch')
|
52 |
+
dataloader = DataLoader(
|
53 |
+
tokenized_test_dataset_for_inference,
|
54 |
+
batch_size=8,
|
55 |
+
collate_fn = data_collator,
|
56 |
+
)
|
57 |
+
values = []
|
58 |
+
labels = []
|
59 |
+
nli_labels =[]
|
60 |
+
for batch in dataloader:
|
61 |
+
if torch.cuda.is_available():
|
62 |
+
data = {k: v.to(device) for k, v in batch.items() if k not in ['labels', 'nli_label']}
|
63 |
+
else:
|
64 |
+
data = {k: v for k, v in batch.items() if k not in ['labels', 'nli_label']}
|
65 |
+
with torch.no_grad():
|
66 |
+
outputs = model(**data)
|
67 |
+
logits = outputs.logits
|
68 |
+
entail_contradiction_logits = logits[:,[0,2]]
|
69 |
+
probs = entail_contradiction_logits.softmax(dim=1)
|
70 |
+
predictions = probs[:,1].tolist()
|
71 |
+
label_ids = batch['labels'].tolist()
|
72 |
+
nli_label_ids = batch['nli_label'].tolist()
|
73 |
+
for prediction, label, nli_label in zip(predictions, label_ids, nli_label_ids):
|
74 |
+
values.append(prediction)
|
75 |
+
labels.append(label)
|
76 |
+
nli_labels.append(nli_label)
|
77 |
+
results_test['dataset_labels'] = labels
|
78 |
+
results_test['nli_labels'] = nli_labels
|
79 |
+
results_test[query] = values
|
80 |
+
|
81 |
+
results_test.to_csv('Reports/ZS inference tables/ZS-inference-table_Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_selector] + '_Prompt-' + op.prompts[prompt_selector] + '.csv', index = False)
|
82 |
+
#'''
|
83 |
+
#Load saved predictions____________________________
|
84 |
+
'''
|
85 |
+
results_test = pd.read_csv('Reports/ZS inference tables/ZS-inference-table_Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_selector] + '_Prompt-' + op.prompts[prompt_selector] + '.csv')
|
86 |
+
'''
|
87 |
+
#Analize predictions_______________________________
|
88 |
+
def logits_labels(raw):
|
89 |
+
raw_logits = raw.iloc[:,2:]
|
90 |
+
logits = np.zeros(shape=(len(raw_logits.index),17))
|
91 |
+
for i in range(17):
|
92 |
+
queries = queries_dataset.filter(lambda x: x['label_ids'] == i)['SDGquery']
|
93 |
+
logits[:,i]=raw_logits[queries].max(axis=1)
|
94 |
+
labels = raw[["dataset_labels","nli_labels"]]
|
95 |
+
labels = np.array(labels).astype(int)
|
96 |
+
return logits, labels
|
97 |
+
|
98 |
+
predictions, references = logits_labels(results_test)
|
99 |
+
prediction_strategies = [op.prediction_strategy_options[x] for x in prediction_strategy_selector]
|
100 |
+
|
101 |
+
metric = evaluate.load(metric_selector)
|
102 |
+
metric.add_batch(predictions = predictions, references = references)
|
103 |
+
results = metric.compute(prediction_strategies = prediction_strategies)
|
104 |
+
prediction_strategies_names = '-'.join(prediction_strategy_selector).replace(" ", "")
|
105 |
+
output_filename = 'Reports/report-Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_selector] + '_Prompt-' + op.prompts[prompt_selector] + '_Strategies-'+ prediction_strategies_names +'.csv'
|
106 |
+
with open(output_filename, 'a') as results_file:
|
107 |
+
for result in results:
|
108 |
+
results[result].to_csv(results_file, mode='a', index_label = result)
|
109 |
+
print(results[result], '\n')
|
110 |
+
return output_filename
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
|