text-to-amr / utils.py
BramVanroy's picture
Update utils.py
b4c04ac
raw
history blame
4.83 kB
from typing import Tuple, Union, Dict, List
from multi_amr.data.postprocessing_graph import ParsedStatus
from multi_amr.data.tokenization import AMRTokenizerWrapper
from optimum.bettertransformer import BetterTransformer
import penman
import streamlit as st
import torch
from torch.quantization import quantize_dynamic
from torch import nn, qint8
from transformers import MBartForConditionalGeneration, AutoConfig
import spaces
@st.cache_resource(show_spinner=False)
def get_resources(multilingual: bool, src_lang: str, quantize: bool = True, no_cuda: bool = False) -> Tuple[MBartForConditionalGeneration, AMRTokenizerWrapper]:
"""Get the relevant model, tokenizer and logits_processor. The loaded model depends on whether the multilingual
model is requested, or not. If not, an English-only model is loaded. The model can be optionally quantized
for better performance.
:param multilingual: whether to load the multilingual model or not
:param src_lang: source language
:param quantize: whether to quantize the model with PyTorch's 'quantize_dynamic'
:param no_cuda: whether to disable CUDA, even if it is available
:return: the loaded model, and tokenizer wrapper
"""
model_name = "BramVanroy/mbart-large-cc25-ft-amr30-en_es_nl"
if not multilingual:
if src_lang == "English":
model_name = "BramVanroy/mbart-large-cc25-ft-amr30-en"
elif src_lang == "Spanish":
model_name = "BramVanroy/mbart-large-cc25-ft-amr30-es"
elif src_lang == "Dutch":
model_name = "BramVanroy/mbart-large-cc25-ft-amr30-nl"
else:
raise ValueError(f"Language {src_lang} not supported")
# Tokenizer src_lang is reset during translation to the right language
tok_wrapper = AMRTokenizerWrapper.from_pretrained(model_name, src_lang="en_XX")
config = AutoConfig.from_pretrained(model_name)
config.decoder_start_token_id = tok_wrapper.amr_token_id
model = MBartForConditionalGeneration.from_pretrained(model_name, config=config)
model.eval()
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tok_wrapper.tokenizer) > embedding_size:
model.resize_token_embeddings(len(tok_wrapper.tokenizer))
model = BetterTransformer.transform(model, keep_original_model=False)
if torch.cuda.is_available() and not no_cuda:
model = model.to("cuda")
elif quantize: # Quantization not supported on CUDA
model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8)
return model, tok_wrapper
@spaces.GPU
def translate(texts: List[str], src_lang: str, model: MBartForConditionalGeneration, tok_wrapper: AMRTokenizerWrapper, **gen_kwargs) -> Dict[str, List[Union[penman.Graph, ParsedStatus]]]:
"""Translates a given text of a given source language with a given model and tokenizer. The generation is guided by
potential keyword-arguments, which can include arguments such as max length, logits processors, etc.
:param texts: source text to translate (potentially a batch)
:param src_lang: source language
:param model: MBART model
:param tok_wrapper: MBART tokenizer wrapper
:param gen_kwargs: potential keyword arguments for the generation process
:return: the translation (linearized AMR graph)
"""
if isinstance(texts, str):
texts = [texts]
tok_wrapper.src_lang = LANGUAGES[src_lang]
encoded = tok_wrapper(texts, return_tensors="pt").to(model.device)
with torch.no_grad():
generated = model.generate(**encoded, output_scores=True, return_dict_in_generate=True, **gen_kwargs)
generated["sequences"] = generated["sequences"].cpu()
generated["sequences_scores"] = generated["sequences_scores"].cpu()
best_scoring_results = {"graph": [], "status": []}
beam_size = gen_kwargs["num_beams"]
# Select the best item from the beam: the sequence with best status and highest score
for sample_idx in range(0, len(generated["sequences_scores"]), beam_size):
sequences = generated["sequences"][sample_idx: sample_idx + beam_size]
scores = generated["sequences_scores"][sample_idx: sample_idx + beam_size].tolist()
outputs = tok_wrapper.batch_decode_amr_ids(sequences)
statuses = outputs["status"]
graphs = outputs["graph"]
zipped = zip(statuses, scores, graphs)
# Lowest status first (OK=0, FIXED=1, BACKOFF=2), highest score second
best = sorted(zipped, key=lambda item: (item[0].value, -item[1]))[0]
best_scoring_results["graph"].append(best[2])
best_scoring_results["status"].append(best[0])
# Returns dictionary with "graph" and "status" keys
return best_scoring_results
LANGUAGES = {
"English": "en_XX",
"Dutch": "nl_XX",
"Spanish": "es_XX",
}