Spaces:
No application file
No application file
File size: 3,442 Bytes
d08dd00 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
"""
"""
import logging
import json
import os
import pickle
import scipy.spatial as sp
from filelock import FileLock
import numpy as np
import torch
from .base import BaseModule, create_trainer
logger = logging.getLogger(__name__)
class XSentRetrieval(BaseModule):
mode = 'base'
output_mode = 'classification'
example_type = 'text'
def __init__(self, hparams):
self.test_results_fpath = 'test_results'
if os.path.exists(self.test_results_fpath):
os.remove(self.test_results_fpath)
super().__init__(hparams)
def forward(self, **inputs):
outputs = self.model(**inputs)
last_hidden = outputs[0]
mean_pooled = torch.mean(last_hidden, 1)
return mean_pooled
def test_dataloader_en(self):
test_features = self.load_features('en')
dataloader = self.make_loader(test_features, self.hparams['eval_batch_size'])
return dataloader
def test_dataloader_in(self):
test_features = self.load_features('in')
dataloader = self.make_loader(test_features, self.hparams['eval_batch_size'])
return dataloader
def test_step(self, batch, batch_idx):
inputs = {'input_ids': batch[0], 'token_type_ids': batch[2],
'attention_mask': batch[1]}
labels = batch[3].detach().cpu().numpy()
sentvecs = self(**inputs)
sentvecs = sentvecs.detach().cpu().numpy()
sentvecs = np.hstack([labels[:, None], sentvecs])
return {'sentvecs': sentvecs}
def test_epoch_end(self, outputs):
all_sentvecs = np.vstack([x['sentvecs'] for x in outputs])
with FileLock(self.test_results_fpath + '.lock'):
if os.path.exists(self.test_results_fpath):
with open(self.test_results_fpath, 'rb') as fp:
data = pickle.load(fp)
data = np.vstack([data, all_sentvecs])
else:
data = all_sentvecs
with open(self.test_results_fpath, 'wb') as fp:
pickle.dump(data, fp)
return {'sentvecs': all_sentvecs}
@staticmethod
def add_model_specific_args(parser, root_dir):
return parser
def run_module(self):
self.eval()
self.freeze()
trainer = create_trainer(self, self.hparams)
trainer.test(self, self.test_dataloader_en())
sentvecs1 = pickle.load(open(self.test_results_fpath, 'rb'))
os.remove(self.test_results_fpath)
trainer.test(self, self.test_dataloader_in())
sentvecs2 = pickle.load(open(self.test_results_fpath, 'rb'))
os.remove(self.test_results_fpath)
sentvecs1 = sentvecs1[sentvecs1[:, 0].argsort()][:, 1:]
sentvecs2 = sentvecs2[sentvecs2[:, 0].argsort()][:, 1:]
result_path = os.path.join(self.hparams['output_dir'], 'test_results.txt')
with open(result_path, 'w') as fp:
metrics = {'test_acc': precision_at_10(sentvecs1, sentvecs2)}
json.dump(metrics, fp)
def precision_at_10(sentvecs1, sentvecs2):
n = sentvecs1.shape[0]
# mean centering
sentvecs1 = sentvecs1 - np.mean(sentvecs1, axis=0)
sentvecs2 = sentvecs2 - np.mean(sentvecs2, axis=0)
sim = sp.distance.cdist(sentvecs1, sentvecs2, 'cosine')
actual = np.array(range(n))
preds = sim.argsort(axis=1)[:, :10]
matches = np.any(preds == actual[:, None], axis=1)
return matches.mean()
|