File size: 3,265 Bytes
f8b0e70
 
1e0a2f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8b0e70
 
 
 
 
 
 
 
 
11f8e5f
 
 
 
 
 
 
 
1e0a2f8
 
 
 
 
 
 
 
 
11f8e5f
 
 
f8b0e70
 
 
 
 
 
 
 
 
 
 
11f8e5f
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from typing import Tuple

import streamlit as st

from torch.quantization import quantize_dynamic
from torch import nn, qint8
from torch.nn import Parameter
from transformers import PreTrainedModel, PreTrainedTokenizer
from optimum.bettertransformer import BetterTransformer
from mbart_amr.constraints.constraints import AMRLogitsProcessor
from mbart_amr.data.tokenization import AMRMBartTokenizer
from transformers import MBartForConditionalGeneration


st_hash_funcs = {PreTrainedModel: lambda model: model.name_or_path,
                 PreTrainedTokenizer: lambda tokenizer: tokenizer.name_or_path,
                 Parameter: lambda param: param.data}


@st.cache(show_spinner=False, hash_funcs=st_hash_funcs, allow_output_mutation=True)
def get_resources(multilingual: bool, quantize: bool = True) -> Tuple[MBartForConditionalGeneration, AMRMBartTokenizer, AMRLogitsProcessor]:
    """Get the relevant model, tokenizer and logits_processor. The loaded model depends on whether the multilingual
    model is requested, or not. If not, an English-only model is loaded. The model can be optionally quantized
    for better performance.

    :param multilingual: whether or not to load the multilingual model. If not, loads the English-only model
    :param quantize: whether to quantize the model with PyTorch's 'quantize_dynamic'
    :return: the loaded model, tokenizer, and logits processor
    """
    if multilingual:
        # Tokenizer src_lang is reset during translation to the right language
        tokenizer = AMRMBartTokenizer.from_pretrained("BramVanroy/mbart-en-es-nl-to-amr", src_lang="nl_XX")
        model = MBartForConditionalGeneration.from_pretrained("BramVanroy/mbart-en-es-nl-to-amr")
    else:
        tokenizer = AMRMBartTokenizer.from_pretrained("BramVanroy/mbart-en-to-amr", src_lang="en_XX")
        model = MBartForConditionalGeneration.from_pretrained("BramVanroy/mbart-en-to-amr")

    model = BetterTransformer.transform(model, keep_original_model=False)
    model.resize_token_embeddings(len(tokenizer))

    if quantize:
        model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8)

    logits_processor = AMRLogitsProcessor(tokenizer, model.config.max_length)

    return model, tokenizer, logits_processor


@st.cache(show_spinner=False, hash_funcs=st_hash_funcs)
def translate(text: str, src_lang: str, model: MBartForConditionalGeneration, tokenizer: AMRMBartTokenizer, **gen_kwargs) -> str:
    """Translates a given text of a given source language with a given model and tokenizer. The generation is guided by
    potential keyword-arguments, which can include arguments such as max length, logits processors, etc.

    :param text: source text to translate
    :param src_lang: source language
    :param model: MBART model
    :param tokenizer: MBART tokenizer
    :param gen_kwargs: potential keyword arguments for the generation process
    :return: the translation (linearized AMR graph)
    """
    tokenizer.src_lang = LANGUAGES[src_lang]
    encoded = tokenizer(text, return_tensors="pt")
    generated = model.generate(**encoded, **gen_kwargs)
    return tokenizer.decode_and_fix(generated)[0]


LANGUAGES = {
    "English": "en_XX",
    "Dutch": "nl_XX",
    "Spanish": "es_XX",
}