gorkaartola's picture
Upload run.py
b27edec
raw
history blame
4.86 kB
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding
from datasets import load_dataset, load_metric
import evaluate
from torch.utils.data import DataLoader
import torch
import numpy as np
import pandas as pd
import options as op
def tp_tf_test(model_selector, test_dataset, queries_selector, prompt_selector, metric_selector, prediction_strategy_selector):
#Load test dataset___________________________
test_dataset = load_dataset(test_dataset)['test']
#Load queries________________________________
queries_data_files = {'queries': queries_selector}
queries_dataset = load_dataset('gorkaartola/SDG_queries', data_files = queries_data_files)['queries']
#Load prompt_________________________________
prompt = prompt_selector
#Load prediction strategias__________________
prediction_strategies = prediction_strategy_selector
#Load model, tokenizer and collator__________
model = AutoModelForSequenceClassification.from_pretrained(model_selector)
if torch.cuda.is_available():
device = torch.device("cuda")
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(model_selector)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
#Calculate and save predictions______________
#'''
def tokenize_function(example, prompt = '', query = ''):
queries = []
for i in range(len(example['title'])):
queries.append(prompt + query)
tokenize = tokenizer(example['title'], queries, truncation='only_first')
#tokenize['query'] = queries
return tokenize
results_test = pd.DataFrame()
for query_data in queries_dataset:
query = query_data['SDGquery']
tokenized_test_dataset = test_dataset.map(tokenize_function, batched = True, fn_kwargs = {'prompt' : prompt, 'query' : query})
columns_to_remove = test_dataset.column_names
for column_name in ['label_ids', 'nli_label']:
columns_to_remove.remove(column_name)
tokenized_test_dataset_for_inference = tokenized_test_dataset.remove_columns(columns_to_remove)
tokenized_test_dataset_for_inference.set_format('torch')
dataloader = DataLoader(
tokenized_test_dataset_for_inference,
batch_size=8,
collate_fn = data_collator,
)
values = []
labels = []
nli_labels =[]
for batch in dataloader:
if torch.cuda.is_available():
data = {k: v.to(device) for k, v in batch.items() if k not in ['labels', 'nli_label']}
else:
data = {k: v for k, v in batch.items() if k not in ['labels', 'nli_label']}
with torch.no_grad():
outputs = model(**data)
logits = outputs.logits
entail_contradiction_logits = logits[:,[0,2]]
probs = entail_contradiction_logits.softmax(dim=1)
predictions = probs[:,1].tolist()
label_ids = batch['labels'].tolist()
nli_label_ids = batch['nli_label'].tolist()
for prediction, label, nli_label in zip(predictions, label_ids, nli_label_ids):
values.append(prediction)
labels.append(label)
nli_labels.append(nli_label)
results_test['dataset_labels'] = labels
results_test['nli_labels'] = nli_labels
results_test[query] = values
results_test.to_csv('Reports/ZS inference tables/ZS-inference-table_Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_selector] + '_Prompt-' + op.prompts[prompt_selector] + '.csv', index = False)
#'''
#Load saved predictions____________________________
'''
results_test = pd.read_csv('Reports/ZS inference tables/ZS-inference-table_Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_selector] + '_Prompt-' + op.prompts[prompt_selector] + '.csv')
'''
#Analize predictions_______________________________
def logits_labels(raw):
raw_logits = raw.iloc[:,2:]
logits = np.zeros(shape=(len(raw_logits.index),17))
for i in range(17):
queries = queries_dataset.filter(lambda x: x['label_ids'] == i)['SDGquery']
logits[:,i]=raw_logits[queries].max(axis=1)
labels = raw[["dataset_labels","nli_labels"]]
labels = np.array(labels).astype(int)
return logits, labels
predictions, references = logits_labels(results_test)
prediction_strategies = [op.prediction_strategy_options[x] for x in prediction_strategy_selector]
metric = evaluate.load(metric_selector)
metric.add_batch(predictions = predictions, references = references)
results = metric.compute(prediction_strategies = prediction_strategies)
prediction_strategies_names = '-'.join(prediction_strategy_selector).replace(" ", "")
output_filename = 'Reports/report-Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_selector] + '_Prompt-' + op.prompts[prompt_selector] + '_Strategies-'+ prediction_strategies_names +'.csv'
with open(output_filename, 'a') as results_file:
for result in results:
results[result].to_csv(results_file, mode='a', index_label = result)
print(results[result], '\n')
return output_filename