gorkaartola commited on
Commit
6bbf43b
1 Parent(s): 3928d85

Delete run.py

Browse files
Files changed (1) hide show
  1. run.py +0 -123
run.py DELETED
@@ -1,123 +0,0 @@
1
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding
2
- from datasets import load_dataset, load_metric
3
- import evaluate
4
- from torch.utils.data import DataLoader
5
- import torch
6
- import numpy as np
7
- import pandas as pd
8
- import options as op
9
-
10
- def tp_tf_test(model_selector, queries_selector, prompt_selector, metric_selector, prediction_strategy_selector):
11
-
12
- #Load test dataset___________________________
13
- test_dataset = load_dataset('gorkaartola/SC-ZS-test_AURORA-Gold-SDG_True-Positives-and-False-Positives')['test']
14
-
15
- #Load queries________________________________
16
- queries_data_files = {'queries': queries_selector}
17
- queries_dataset = load_dataset('gorkaartola/SDG_queries', data_files = queries_data_files)['queries']
18
-
19
- #Load prompt_________________________________
20
- prompt = prompt_selector
21
-
22
- #Load prediction strategias__________________
23
- prediction_strategies = prediction_strategy_selector
24
-
25
- #Calculate and save predictions_______________________
26
- #'''
27
- def tokenize_function(example, prompt = '', query = ''):
28
- queries = []
29
- for i in range(len(example['title'])):
30
- queries.append(prompt + query)
31
- tokenize = tokenizer(example['title'], queries, truncation='only_first')
32
- #tokenize['query'] = queries
33
- return tokenize
34
-
35
- model = AutoModelForSequenceClassification.from_pretrained(model_selector)
36
- if torch.cuda.is_available():
37
- device = torch.device("cuda")
38
- model.to(device)
39
- tokenizer = AutoTokenizer.from_pretrained(model_selector)
40
- data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
41
-
42
- results_test = pd.DataFrame()
43
- for query_data in queries_dataset:
44
- query = query_data['SDGquery']
45
- tokenized_test_dataset = test_dataset.map(tokenize_function, batched = True, fn_kwargs = {'prompt' : prompt, 'query' : query})
46
- columns_to_remove = test_dataset.column_names
47
- for column_name in ['label_ids', 'nli_label']:
48
- columns_to_remove.remove(column_name)
49
- tokenized_test_dataset_for_inference = tokenized_test_dataset.remove_columns(columns_to_remove)
50
- tokenized_test_dataset_for_inference.set_format('torch')
51
- dataloader = DataLoader(
52
- tokenized_test_dataset_for_inference,
53
- batch_size=8,
54
- collate_fn = data_collator,
55
- )
56
- values = []
57
- labels = []
58
- nli_labels =[]
59
- for batch in dataloader:
60
- if torch.cuda.is_available():
61
- data = {k: v.to(device) for k, v in batch.items() if k not in ['labels', 'nli_label']}
62
- else:
63
- data = {k: v for k, v in batch.items() if k not in ['labels', 'nli_label']}
64
- with torch.no_grad():
65
- outputs = model(**data)
66
- logits = outputs.logits
67
- entail_contradiction_logits = logits[:,[0,2]]
68
- probs = entail_contradiction_logits.softmax(dim=1)
69
- predictions = probs[:,1].tolist()
70
- label_ids = batch['labels'].tolist()
71
- nli_label_ids = batch['nli_label'].tolist()
72
- for prediction, label, nli_label in zip(predictions, label_ids, nli_label_ids):
73
- values.append(prediction)
74
- labels.append(label)
75
- nli_labels.append(nli_label)
76
- results_test['dataset_labels'] = labels
77
- results_test['nli_labels'] = nli_labels
78
- results_test[query] = values
79
-
80
- results_test.to_csv('Reports/ZS inference tables/ZS-inference-table_Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_selector] + '_Prompt-' + op.prompts[prompt_selector] + '.csv', index = False)
81
- #'''
82
- #Load saved predictions____________________________
83
- '''
84
- results_test = pd.read_csv('Reports/ZS inference tables/ZS-inference-table_Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_selector] + '_Prompt-' + op.prompts[prompt_selector] + '.csv')
85
- '''
86
- #Analize predictions_______________________________
87
- def logits_labels(raw):
88
- raw_logits = raw.iloc[:,2:]
89
- logits = np.zeros(shape=(len(raw_logits.index),17))
90
- for i in range(17):
91
- queries = queries_dataset.filter(lambda x: x['label_ids'] == i)['SDGquery']
92
- logits[:,i]=raw_logits[queries].max(axis=1)
93
- labels = raw[["dataset_labels","nli_labels"]]
94
- labels = np.array(labels).astype(int)
95
- return logits, labels
96
-
97
- predictions, references = logits_labels(results_test)
98
- prediction_strategies = [op.prediction_strategy_options[x] for x in prediction_strategy_selector]
99
-
100
- metric = evaluate.load(metric_selector)
101
- metric.add_batch(predictions = predictions, references = references)
102
- results = metric.compute(prediction_strategies = prediction_strategies)
103
- prediction_strategies_names = '-'.join(prediction_strategy_selector).replace(" ", "")
104
- output_filename = 'Reports/report-Model-' + op.models[model_selector] + '_Queries-' + op.queries[queries_selector] + '_Prompt-' + op.prompts[prompt_selector] + '_Strategies-'+ prediction_strategies_names +'.csv'
105
- with open(output_filename, 'a') as results_file:
106
- for result in results:
107
- results[result].to_csv(results_file, mode='a', index_label = result)
108
- print(results[result], '\n')
109
- return output_filename
110
-
111
-
112
-
113
-
114
-
115
-
116
-
117
-
118
-
119
-
120
-
121
-
122
-
123
-