File size: 4,856 Bytes
b27edec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding
from datasets import load_dataset, load_metric
import evaluate
from torch.utils.data import DataLoader
import torch
import numpy as np
import pandas as pd
import options as op

def tp_tf_test(model_selector, test_dataset, queries_selector, prompt_selector, metric_selector, prediction_strategy_selector):
	
	#Load test dataset___________________________
	test_dataset = load_dataset(test_dataset)['test']
	
	#Load queries________________________________
	queries_data_files = {'queries': queries_selector}
	queries_dataset = load_dataset('gorkaartola/SDG_queries', data_files = queries_data_files)['queries']
	
	#Load prompt_________________________________
	prompt = prompt_selector
	
	#Load prediction strategias__________________
	prediction_strategies = prediction_strategy_selector

	#Load model, tokenizer and collator__________
	model = AutoModelForSequenceClassification.from_pretrained(model_selector)
	if torch.cuda.is_available():
		device = torch.device("cuda")
		model.to(device)
	tokenizer = AutoTokenizer.from_pretrained(model_selector)	
	data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
	
	#Calculate and save predictions______________
	#'''
	def tokenize_function(example, prompt = '', query = ''):
		queries = []
		for i in range(len(example['title'])):
			queries.append(prompt + query)
		tokenize = tokenizer(example['title'], queries, truncation='only_first')
		#tokenize['query'] = queries
		return tokenize

	results_test = pd.DataFrame()
	for query_data in queries_dataset:
		query = query_data['SDGquery']
		tokenized_test_dataset = test_dataset.map(tokenize_function, batched = True, fn_kwargs = {'prompt' : prompt, 'query' : query})
		columns_to_remove = test_dataset.column_names
		for column_name in ['label_ids', 'nli_label']:
			columns_to_remove.remove(column_name)	
		tokenized_test_dataset_for_inference = tokenized_test_dataset.remove_columns(columns_to_remove)
		tokenized_test_dataset_for_inference.set_format('torch')
		dataloader = DataLoader(
				tokenized_test_dataset_for_inference,
				batch_size=8,
				collate_fn = data_collator,
				)
		values = []
		labels = []
		nli_labels =[]
		for batch in dataloader:
			if torch.cuda.is_available():
				data = {k: v.to(device) for k, v in batch.items() if k not in ['labels', 'nli_label']}
			else:
				data = {k: v for k, v in batch.items() if k not in ['labels', 'nli_label']}
			with torch.no_grad():
				outputs = model(**data)
			logits = outputs.logits			
			entail_contradiction_logits = logits[:,[0,2]]
			probs = entail_contradiction_logits.softmax(dim=1)
			predictions = probs[:,1].tolist()
			label_ids = batch['labels'].tolist()
			nli_label_ids = batch['nli_label'].tolist()
			for prediction, label, nli_label in zip(predictions, label_ids, nli_label_ids):
				values.append(prediction)
				labels.append(label)
				nli_labels.append(nli_label)
		results_test['dataset_labels'] = labels
		results_test['nli_labels'] = nli_labels
		results_test[query] = values
	
	results_test.to_csv('Reports/ZS inference tables/ZS-inference-table_Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_selector] + '_Prompt-' + op.prompts[prompt_selector] + '.csv', index = False)
	#'''
	#Load saved predictions____________________________
	'''
	results_test = pd.read_csv('Reports/ZS inference tables/ZS-inference-table_Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_selector] + '_Prompt-' + op.prompts[prompt_selector] + '.csv')
	'''
	#Analize predictions_______________________________
	def logits_labels(raw): 
		raw_logits = raw.iloc[:,2:]
		logits = np.zeros(shape=(len(raw_logits.index),17))
		for i in range(17):
			queries = queries_dataset.filter(lambda x: x['label_ids'] == i)['SDGquery']
			logits[:,i]=raw_logits[queries].max(axis=1)
		labels = raw[["dataset_labels","nli_labels"]]
		labels = np.array(labels).astype(int)
		return logits, labels
	
	predictions, references = logits_labels(results_test)
	prediction_strategies = [op.prediction_strategy_options[x] for x in prediction_strategy_selector]
	
	metric = evaluate.load(metric_selector)
	metric.add_batch(predictions = predictions, references = references)
	results = metric.compute(prediction_strategies = prediction_strategies)
	prediction_strategies_names = '-'.join(prediction_strategy_selector).replace(" ", "")	
	output_filename = 'Reports/report-Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_selector] + '_Prompt-' + op.prompts[prompt_selector] + '_Strategies-'+ prediction_strategies_names +'.csv'
	with open(output_filename, 'a') as results_file:
		for result in results:
			results[result].to_csv(results_file, mode='a', index_label = result)
			print(results[result], '\n')
	return output_filename