from typing import Tuple import streamlit as st from torch.quantization import quantize_dynamic from torch import nn, qint8 from torch.nn import Parameter from transformers import PreTrainedModel, PreTrainedTokenizer from optimum.bettertransformer import BetterTransformer from mbart_amr.constraints.constraints import AMRLogitsProcessor from mbart_amr.data.tokenization import AMRMBartTokenizer from transformers import MBartForConditionalGeneration st_hash_funcs = {PreTrainedModel: lambda model: model.name_or_path, PreTrainedTokenizer: lambda tokenizer: tokenizer.name_or_path, Parameter: lambda param: param.data} @st.cache(show_spinner=False, hash_funcs=st_hash_funcs, allow_output_mutation=True) def get_resources(multilingual: bool, quantize: bool = True) -> Tuple[MBartForConditionalGeneration, AMRMBartTokenizer, AMRLogitsProcessor]: """Get the relevant model, tokenizer and logits_processor. The loaded model depends on whether the multilingual model is requested, or not. If not, an English-only model is loaded. The model can be optionally quantized for better performance. :param multilingual: whether or not to load the multilingual model. If not, loads the English-only model :param quantize: whether to quantize the model with PyTorch's 'quantize_dynamic' :return: the loaded model, tokenizer, and logits processor """ if multilingual: # Tokenizer src_lang is reset during translation to the right language tokenizer = AMRMBartTokenizer.from_pretrained("BramVanroy/mbart-en-es-nl-to-amr", src_lang="nl_XX") model = MBartForConditionalGeneration.from_pretrained("BramVanroy/mbart-en-es-nl-to-amr") else: tokenizer = AMRMBartTokenizer.from_pretrained("BramVanroy/mbart-en-to-amr", src_lang="en_XX") model = MBartForConditionalGeneration.from_pretrained("BramVanroy/mbart-en-to-amr") model = BetterTransformer.transform(model, keep_original_model=False) model.resize_token_embeddings(len(tokenizer)) if quantize: model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8) logits_processor = AMRLogitsProcessor(tokenizer, model.config.max_length) return model, tokenizer, logits_processor @st.cache(show_spinner=False, hash_funcs=st_hash_funcs) def translate(text: str, src_lang: str, model: MBartForConditionalGeneration, tokenizer: AMRMBartTokenizer, **gen_kwargs) -> str: """Translates a given text of a given source language with a given model and tokenizer. The generation is guided by potential keyword-arguments, which can include arguments such as max length, logits processors, etc. :param text: source text to translate :param src_lang: source language :param model: MBART model :param tokenizer: MBART tokenizer :param gen_kwargs: potential keyword arguments for the generation process :return: the translation (linearized AMR graph) """ tokenizer.src_lang = LANGUAGES[src_lang] encoded = tokenizer(text, return_tensors="pt") generated = model.generate(**encoded, **gen_kwargs) return tokenizer.decode_and_fix(generated)[0] LANGUAGES = { "English": "en_XX", "Dutch": "nl_XX", "Spanish": "es_XX", }