from typing import Dict, List, Any from scipy.special import softmax from utils import clean_str, clean_str_nopunct import torch from transformers import BertTokenizer from utils import MultiHeadModel, BertInputBuilder, get_num_words MODEL_CHECKPOINT='ddemszky/uptake-model' class EndpointHandler(): def __init__(self, path="."): print("Loading models...") self.device = "cuda" if torch.cuda.is_available() else "cpu" self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") self.input_builder = BertInputBuilder(tokenizer=self.tokenizer) self.model = MultiHeadModel.from_pretrained(path, head2size={"nsp": 2}) self.model.to(self.device) self.max_length = 120 def get_clean_text(self, text, remove_punct=False): if remove_punct: return clean_str_nopunct(text) return clean_str(text) def get_prediction(self, instance): instance["attention_mask"] = [[1] * len(instance["input_ids"])] for key in ["input_ids", "token_type_ids", "attention_mask"]: instance[key] = torch.tensor(instance[key]).unsqueeze(0) # Batch size = 1 instance[key].to(self.device) output = self.model(input_ids=instance["input_ids"], attention_mask=instance["attention_mask"], token_type_ids=instance["token_type_ids"], return_pooler_output=False) return output def get_uptake_score(self, textA, textB): textA = self.get_clean_text(textA, remove_punct=False) textB = self.get_clean_text(textB, remove_punct=False) instance = self.input_builder.build_inputs([textA], textB, max_length=self.max_length, input_str=True) output = self.get_prediction(instance) uptake_score = softmax(output["nsp_logits"][0].tolist())[1] return uptake_score def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: utterances (:obj: `list`) parameters (:obj: `dict`) Return: A :obj:`list` | `dict`: will be serialized and returned """ # get inputs utterances = data.pop("inputs", data) params = data.pop("parameters", None) print("EXAMPLES") for utt in utterances[:3]: print("speaker %s: %s" % (utt["speaker"], utt["text"])) print("Running inference on %d examples..." % len(utterances)) self.model.eval() prev_num_words = 0 prev_text = "" uptake_scores = {} with torch.no_grad(): for i, utt in enumerate(utterances): if utt["speaker"] == params["speaker_2"] and (prev_num_words >= params["speaker_1_min_num_words"]): uptake_scores[str(utt["id"])] = self.get_uptake_score(textA=prev_text, textB=utt["text"]) prev_num_words = get_num_words(utt["text"]) prev_text = utt["text"] return uptake_scores