ssa-perin / model_wrapper.py
larkkin's picture
Add code
991f07c
raw
history blame
4.37 kB
import os
import json
import tempfile
import sys
import datetime
import re
import string
sys.path.append('mtool')
import torch
from model.model import Model
from data.dataset import Dataset
from config.params import Params
from utility.initialize import initialize
from data.batch import Batch
from mtool.main import main as mtool_main
from tqdm import tqdm
class PredictionModel(torch.nn.Module):
def __init__(self, checkpoint_path=os.path.join('models', 'checkpoint.bin'), default_mrp_path=os.path.join('models', 'default.mrp'), verbose=False):
super().__init__()
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.checkpoint = torch.load('./models/checkpoint.bin', map_location=torch.device('cpu'))
self.verbose = verbose
self.args = Params().load_state_dict(self.checkpoint['params'])
self.args.log_wandb = False
self.args.training_data = default_mrp_path
self.args.validation_data = default_mrp_path
self.args.test_data = default_mrp_path
self.args.only_train = False
self.args.encoder = os.path.join('models', 'encoder')
initialize(self.args, init_wandb=False)
self.dataset = Dataset(self.args, verbose=False)
self.model = Model(self.dataset, self.args).to(self.device)
self.model.load_state_dict(self.checkpoint["model"], strict=False)
self.model.eval()
def _mrp_to_text(self, mrp_list, graph_mode='labeled-edge'):
framework = 'norec'
with tempfile.NamedTemporaryFile(delete=False, mode='w') as output_text_file:
output_text_filename = output_text_file.name
with tempfile.NamedTemporaryFile(delete=False, mode='w') as mrp_file:
line = '\n'.join([json.dumps(entry) for entry in mrp_list])
mrp_file.write(line)
mrp_filename = mrp_file.name
if graph_mode == 'labeled-edge':
mtool_main([
'--strings',
'--ids',
'--read', 'mrp',
'--write', framework,
mrp_filename, output_text_filename
])
elif graph_mode == 'node-centric':
mtool_main([
'--node_centric',
'--strings',
'--ids',
'--read', 'mrp',
'--write', framework,
mrp_filename, output_text_filename
])
else:
raise Exception(f'Unknown graph mode: {graph_mode}')
with open(output_text_filename) as f:
texts = json.load(f)
os.unlink(output_text_filename)
os.unlink(mrp_filename)
return texts
def clean_texts(self, texts):
punctuation = ''.join([f'\\{s}' for s in string.punctuation])
texts = [re.sub(f'([{punctuation}])', ' \\1 ', t) for t in texts]
texts = [re.sub(r' +', ' ', t) for t in texts]
return texts
def _predict_to_mrp(self, texts, graph_mode='labeled-edge'):
texts = self.clean_texts(texts)
framework, language = self.args.framework, self.args.language
data = self.dataset.load_sentences(texts, self.args)
res_sentences = {f"{i}": {'input': sentence} for i, sentence in enumerate(texts)}
date_str = datetime.datetime.now().date().isoformat()
for key, value_dict in res_sentences.items():
value_dict['id'] = key
value_dict['time'] = date_str
value_dict['framework'], value_dict['language'] = framework, language
value_dict['nodes'], value_dict['edges'], value_dict['tops'] = [], [], []
for i, batch in enumerate(tqdm(data) if self.verbose else data):
with torch.no_grad():
predictions = self.model(Batch.to(batch, self.device), inference=True)
for prediction in predictions:
for key, value in prediction.items():
res_sentences[prediction['id']][key] = value
return res_sentences
def predict(self, text_list, graph_mode='labeled-edge', language='no'):
mrp_predictions = self._predict_to_mrp(text_list, graph_mode)
predictions = self._mrp_to_text(mrp_predictions.values(), graph_mode)
return predictions
def forward(self, x):
return self.predict(x)