Spaces:
Runtime error
Runtime error
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding | |
from datasets import load_dataset, load_metric | |
import evaluate | |
from torch.utils.data import DataLoader | |
import torch | |
import numpy as np | |
import pandas as pd | |
import options as op | |
def tp_tf_test(model_selector, test_dataset, queries_selector, prompt_selector, metric_selector, prediction_strategy_selector): | |
#Load test dataset___________________________ | |
test_dataset = load_dataset(test_dataset)['test'] | |
#Load queries________________________________ | |
queries_data_files = {'queries': queries_selector} | |
queries_dataset = load_dataset('gorkaartola/SDG_queries', data_files = queries_data_files)['queries'] | |
#Load prompt_________________________________ | |
prompt = prompt_selector | |
#Load prediction strategias__________________ | |
prediction_strategies = prediction_strategy_selector | |
#Load model, tokenizer and collator__________ | |
model = AutoModelForSequenceClassification.from_pretrained(model_selector) | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
model.to(device) | |
tokenizer = AutoTokenizer.from_pretrained(model_selector) | |
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
#Calculate and save predictions______________ | |
#''' | |
def tokenize_function(example, prompt = '', query = ''): | |
queries = [] | |
for i in range(len(example['title'])): | |
queries.append(prompt + query) | |
tokenize = tokenizer(example['title'], queries, truncation='only_first') | |
#tokenize['query'] = queries | |
return tokenize | |
results_test = pd.DataFrame() | |
for query_data in queries_dataset: | |
query = query_data['SDGquery'] | |
tokenized_test_dataset = test_dataset.map(tokenize_function, batched = True, fn_kwargs = {'prompt' : prompt, 'query' : query}) | |
columns_to_remove = test_dataset.column_names | |
for column_name in ['label_ids', 'nli_label']: | |
columns_to_remove.remove(column_name) | |
tokenized_test_dataset_for_inference = tokenized_test_dataset.remove_columns(columns_to_remove) | |
tokenized_test_dataset_for_inference.set_format('torch') | |
dataloader = DataLoader( | |
tokenized_test_dataset_for_inference, | |
batch_size=8, | |
collate_fn = data_collator, | |
) | |
values = [] | |
labels = [] | |
nli_labels =[] | |
for batch in dataloader: | |
if torch.cuda.is_available(): | |
data = {k: v.to(device) for k, v in batch.items() if k not in ['labels', 'nli_label']} | |
else: | |
data = {k: v for k, v in batch.items() if k not in ['labels', 'nli_label']} | |
with torch.no_grad(): | |
outputs = model(**data) | |
logits = outputs.logits | |
entail_contradiction_logits = logits[:,[0,2]] | |
probs = entail_contradiction_logits.softmax(dim=1) | |
predictions = probs[:,1].tolist() | |
label_ids = batch['labels'].tolist() | |
nli_label_ids = batch['nli_label'].tolist() | |
for prediction, label, nli_label in zip(predictions, label_ids, nli_label_ids): | |
values.append(prediction) | |
labels.append(label) | |
nli_labels.append(nli_label) | |
results_test['dataset_labels'] = labels | |
results_test['nli_labels'] = nli_labels | |
results_test[query] = values | |
results_test.to_csv('Reports/ZS inference tables/ZS-inference-table_Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_selector] + '_Prompt-' + op.prompts[prompt_selector] + '.csv', index = False) | |
#''' | |
#Load saved predictions____________________________ | |
''' | |
results_test = pd.read_csv('Reports/ZS inference tables/ZS-inference-table_Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_selector] + '_Prompt-' + op.prompts[prompt_selector] + '.csv') | |
''' | |
#Analize predictions_______________________________ | |
def logits_labels(raw): | |
raw_logits = raw.iloc[:,2:] | |
logits = np.zeros(shape=(len(raw_logits.index),17)) | |
for i in range(17): | |
queries = queries_dataset.filter(lambda x: x['label_ids'] == i)['SDGquery'] | |
logits[:,i]=raw_logits[queries].max(axis=1) | |
labels = raw[["dataset_labels","nli_labels"]] | |
labels = np.array(labels).astype(int) | |
return logits, labels | |
predictions, references = logits_labels(results_test) | |
prediction_strategies = [op.prediction_strategy_options[x] for x in prediction_strategy_selector] | |
metric = evaluate.load(metric_selector) | |
metric.add_batch(predictions = predictions, references = references) | |
results = metric.compute(prediction_strategies = prediction_strategies) | |
prediction_strategies_names = '-'.join(prediction_strategy_selector).replace(" ", "") | |
output_filename = 'Reports/report-Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_selector] + '_Prompt-' + op.prompts[prompt_selector] + '_Strategies-'+ prediction_strategies_names +'.csv' | |
with open(output_filename, 'a') as results_file: | |
for result in results: | |
results[result].to_csv(results_file, mode='a', index_label = result) | |
print(results[result], '\n') | |
return output_filename | |