from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding from datasets import load_dataset, load_metric import evaluate from torch.utils.data import DataLoader import torch import numpy as np import pandas as pd model_selector = '0' queries_selector = '0' prompt_selector = '0' metric_selector = '0' prediction_strategy_selector = ['1','2.0','2.1','2.2','2.3','3.0','3.1','3.2','3.3'] models = { '0' : 'joeddav/xlm-roberta-large-xnli', '1' : '', } queries = { '0' : 'SDG_Titles.csv', '1' : 'SDG_Headlines.csv', '2' : 'SDG_Subjects.csv', '3' : 'SDG_Targets.csv', '4' : 'SDG_Numbers.csv', } prompts = { '0' : '', '1' : 'This is ', '2' : 'The subject is ', '3' : 'The Sustainable Development Goal is ' } metrics = { '0' : 'gorkaartola/metric_for_tp_fp_samples', } prediction_strategy_options = { '1': ["argmax_max"], '2.0': ["threshold", 0.05], '2.1': ["threshold", 0.25], '2.2': ["threshold", 0.5], '2.3': ["threshold", 0.75], '3.0': ["topk", 9], '3.1': ["topk", 7], '3.2': ["topk", 5], '3.3': ["topk", 3], } saved_inference_tables_path = 'Reports/ZS inference tables/' #Load test dataset___________________________ test_dataset = load_dataset('gorkaartola/SC-ZS-test_AURORA-Gold-SDG_True-Positives-and-False-Positives')['test'] #Load queries________________________________ queries_data_files = {'queries': queries[queries_selector]} queries_dataset = load_dataset('gorkaartola/SDG_queries', data_files = queries_data_files)['queries'] #Load prompt_________________________________ prompt = prompts[prompt_selector] #Load prediction strategias__________________ prediction_strategies = [prediction_strategy_options[x] for x in prediction_strategy_selector] #Calculate and save predictions_______________________ #''' def tokenize_function(example, prompt = '', query = ''): queries = [] for i in range(len(example['title'])): queries.append(prompt + query) tokenize = tokenizer(example['title'], queries, truncation='only_first') #tokenize['query'] = queries return tokenize model = AutoModelForSequenceClassification.from_pretrained(models[model_selector]) #device = torch.device("cuda") #model.to(device) tokenizer = AutoTokenizer.from_pretrained(models[model_selector]) data_collator = DataCollatorWithPadding(tokenizer=tokenizer) results_test = pd.DataFrame() for query_data in queries_dataset: query = query_data['SDGquery'] tokenized_test_dataset = test_dataset.map(tokenize_function, batched = True, fn_kwargs = {'prompt' : prompt, 'query' : query}) columns_to_remove = test_dataset.column_names for column_name in ['label_ids', 'nli_label']: columns_to_remove.remove(column_name) tokenized_test_dataset_for_inference = tokenized_test_dataset.remove_columns(columns_to_remove) tokenized_test_dataset_for_inference.set_format('torch') dataloader = DataLoader( tokenized_test_dataset_for_inference, batch_size=8, collate_fn = data_collator, ) values = [] labels = [] nli_labels =[] for batch in dataloader: #data = {k: v.to(device) for k, v in batch.items() if k not in ['labels', 'nli_label']} data = {k: v for k, v in batch.items() if k not in ['labels', 'nli_label']} with torch.no_grad(): outputs = model(**data) logits = outputs.logits entail_contradiction_logits = logits[:,[0,2]] probs = entail_contradiction_logits.softmax(dim=1) predictions = probs[:,1].tolist() label_ids = batch['labels'].tolist() nli_label_ids = batch['nli_label'].tolist() for prediction, label, nli_label in zip(predictions, label_ids, nli_label_ids): values.append(prediction) labels.append(label) nli_labels.append(nli_label) results_test['dataset_labels'] = labels results_test['nli_labels'] = nli_labels results_test[query] = values results_test.to_csv(saved_inference_tables_path + 'ZS-inference-table_Model-' + model_selector + '_Queries-' + queries_selector + '_Prompt-' + prompt_selector + '.csv', index = False) #''' #Load saved predictions____________________________ ''' results_test = pd.read_csv(saved_inference_tables_path + 'ZS-inference-table_Model-' + model_selector + '_Queries-' + queries_selector + '_Prompt-' + prompt_selector + '.csv') ''' #Analize predictions_______________________________ def logits_labels(raw): raw_logits = raw.iloc[:,2:] logits = np.zeros(shape=(len(raw_logits.index),17)) for i in range(17): queries = queries_dataset.filter(lambda x: x['label_ids'] == i)['SDGquery'] logits[:,i]=raw_logits[queries].max(axis=1) labels = raw[["dataset_labels","nli_labels"]] labels = np.array(labels).astype(int) return logits, labels predictions, references = logits_labels(results_test) prediction_strategies = [prediction_strategy_options[x] for x in prediction_strategy_selector] metric = evaluate.load(metrics['0']) metric.add_batch(predictions = predictions, references = references) results = metric.compute(prediction_strategies = prediction_strategies) with open('Reports/report-Model-' + model_selector + '_Queries-' + queries_selector + '_Prompt-' + prompt_selector + '.csv', 'a') as results_file: for result in results: results[result].to_csv(results_file, mode='a', index_label = result) print(results[result], '\n')