File size: 4,678 Bytes
77d2229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ef3524
 
 
77d2229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ef3524
 
 
 
77d2229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ef3524
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding
from datasets import load_dataset, load_metric
import evaluate
from torch.utils.data import DataLoader
import torch
import numpy as np
import pandas as pd
import options as op

def tp_tf_test(model_selector, queries_selector, prompt_selector, metric_selector, prediction_strategy_selector):
	
	#Load test dataset___________________________
	test_dataset = load_dataset('gorkaartola/SC-ZS-test_AURORA-Gold-SDG_True-Positives-and-False-Positives')['test']
	
	#Load queries________________________________
	queries_data_files = {'queries': queries_selector}
	queries_dataset = load_dataset('gorkaartola/SDG_queries', data_files = queries_data_files)['queries']
	
	#Load prompt_________________________________
	prompt = prompt_selector
	
	#Load prediction strategias__________________
	prediction_strategies = prediction_strategy_selector
	
	#Calculate and save predictions_______________________
	#'''
	def tokenize_function(example, prompt = '', query = ''):
		queries = []
		for i in range(len(example['title'])):
			queries.append(prompt + query)
		tokenize = tokenizer(example['title'], queries, truncation='only_first')
		#tokenize['query'] = queries
		return tokenize
	
	model = AutoModelForSequenceClassification.from_pretrained(model_selector)
	if torch.cuda.is_available():
		device = torch.device("cuda")
		model.to(device)
	tokenizer = AutoTokenizer.from_pretrained(model_selector)	
	data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
	
	results_test = pd.DataFrame()
	for query_data in queries_dataset:
		query = query_data['SDGquery']
		tokenized_test_dataset = test_dataset.map(tokenize_function, batched = True, fn_kwargs = {'prompt' : prompt, 'query' : query})
		columns_to_remove = test_dataset.column_names
		for column_name in ['label_ids', 'nli_label']:
			columns_to_remove.remove(column_name)	
		tokenized_test_dataset_for_inference = tokenized_test_dataset.remove_columns(columns_to_remove)
		tokenized_test_dataset_for_inference.set_format('torch')
		dataloader = DataLoader(
				tokenized_test_dataset_for_inference,
				batch_size=8,
				collate_fn = data_collator,
				)
		values = []
		labels = []
		nli_labels =[]
		for batch in dataloader:
			if torch.cuda.is_available():
				data = {k: v.to(device) for k, v in batch.items() if k not in ['labels', 'nli_label']}
			else:
				data = {k: v for k, v in batch.items() if k not in ['labels', 'nli_label']}
			with torch.no_grad():
				outputs = model(**data)
			logits = outputs.logits			
			entail_contradiction_logits = logits[:,[0,2]]
			probs = entail_contradiction_logits.softmax(dim=1)
			predictions = probs[:,1].tolist()
			label_ids = batch['labels'].tolist()
			nli_label_ids = batch['nli_label'].tolist()
			for prediction, label, nli_label in zip(predictions, label_ids, nli_label_ids):
				values.append(prediction)
				labels.append(label)
				nli_labels.append(nli_label)
		results_test['dataset_labels'] = labels
		results_test['nli_labels'] = nli_labels
		results_test[query] = values
	
	results_test.to_csv('Reports/ZS inference tables/ZS-inference-table_Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_selector] + '_Prompt-' + op.prompts[prompt_selector] + '.csv', index = False)
	#'''
	#Load saved predictions____________________________
	'''
	results_test = pd.read_csv('Reports/ZS inference tables/ZS-inference-table_Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_selector] + '_Prompt-' + op.prompts[prompt_selector] + '.csv')
	'''
	#Analize predictions_______________________________
	def logits_labels(raw): 
		raw_logits = raw.iloc[:,2:]
		logits = np.zeros(shape=(len(raw_logits.index),17))
		for i in range(17):
			queries = queries_dataset.filter(lambda x: x['label_ids'] == i)['SDGquery']
			logits[:,i]=raw_logits[queries].max(axis=1)
		labels = raw[["dataset_labels","nli_labels"]]
		labels = np.array(labels).astype(int)
		return logits, labels
	
	predictions, references = logits_labels(results_test)
	prediction_strategies = [op.prediction_strategy_options[x] for x in prediction_strategy_selector]
	
	metric = evaluate.load(metric_selector)
	metric.add_batch(predictions = predictions, references = references)
	results = metric.compute(prediction_strategies = prediction_strategies)
	with open('Reports/report-Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_selector] + '_Prompt-' + op.prompts[prompt_selector] + '.csv', 'a') as results_file:
		for result in results:
			results[result].to_csv(results_file, mode='a', index_label = result)
			print(results[result], '\n')
	return results