File size: 4,806 Bytes
05de9a6
f8b0e70
05de9a6
 
1e0a2f8
e10ccfa
b818293
 
 
 
05de9a6
1e0a2f8
 
b818293
05de9a6
f8b0e70
 
 
 
05de9a6
 
f8b0e70
51df785
05de9a6
f8b0e70
05de9a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11f8e5f
1e0a2f8
 
51df785
 
 
1e0a2f8
 
05de9a6
1e0a2f8
11f8e5f
05de9a6
f8b0e70
 
 
05de9a6
f8b0e70
 
05de9a6
f8b0e70
 
 
05de9a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11f8e5f
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from typing import Tuple, Union, Dict, List

from multi_amr.data.postprocessing_graph import ParsedStatus
from multi_amr.data.tokenization import AMRTokenizerWrapper
from optimum.bettertransformer import BetterTransformer
import penman
import streamlit as st
import torch
from torch.quantization import quantize_dynamic
from torch import nn, qint8
from transformers import MBartForConditionalGeneration, AutoConfig


@st.cache_resource(show_spinner=False)
def get_resources(multilingual: bool, src_lang: str, quantize: bool = True, no_cuda: bool = False) -> Tuple[MBartForConditionalGeneration, AMRTokenizerWrapper]:
    """Get the relevant model, tokenizer and logits_processor. The loaded model depends on whether the multilingual
    model is requested, or not. If not, an English-only model is loaded. The model can be optionally quantized
    for better performance.

    :param multilingual: whether to load the multilingual model or not
    :param src_lang: source language
    :param quantize: whether to quantize the model with PyTorch's 'quantize_dynamic'
    :param no_cuda: whether to disable CUDA, even if it is available
    :return: the loaded model, and tokenizer wrapper
    """
    model_name = "BramVanroy/mbart-large-cc25-ft-amr30-en_es_nl"
    if not multilingual:
        if src_lang == "English":
            model_name = "BramVanroy/mbart-large-cc25-ft-amr30-en"
        elif src_lang == "Spanish":
            model_name = "BramVanroy/mbart-large-cc25-ft-amr30-es"
        elif src_lang == "Dutch":
            model_name = "BramVanroy/mbart-large-cc25-ft-amr30-nl"
        else:
            raise ValueError(f"Language {src_lang} not supported")

    # Tokenizer src_lang is reset during translation to the right language
    tok_wrapper = AMRTokenizerWrapper.from_pretrained(model_name, src_lang="en_XX")

    config = AutoConfig.from_pretrained(model_name)
    config.decoder_start_token_id = tok_wrapper.amr_token_id

    model = MBartForConditionalGeneration.from_pretrained(model_name, config=config)
    model.eval()

    embedding_size = model.get_input_embeddings().weight.shape[0]
    if len(tok_wrapper.tokenizer) > embedding_size:
        model.resize_token_embeddings(len(tok_wrapper.tokenizer))

    model = BetterTransformer.transform(model, keep_original_model=False)

    if torch.cuda.is_available() and not no_cuda:
        model = model.to("cuda")
    elif quantize:  # Quantization not supported on CUDA
        model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8)

    return model, tok_wrapper


def translate(texts: List[str], src_lang: str, model: MBartForConditionalGeneration, tok_wrapper: AMRTokenizerWrapper, **gen_kwargs) -> Dict[str, List[Union[penman.Graph, ParsedStatus]]]:
    """Translates a given text of a given source language with a given model and tokenizer. The generation is guided by
    potential keyword-arguments, which can include arguments such as max length, logits processors, etc.

    :param texts: source text to translate (potentially a batch)
    :param src_lang: source language
    :param model: MBART model
    :param tok_wrapper: MBART tokenizer wrapper
    :param gen_kwargs: potential keyword arguments for the generation process
    :return: the translation (linearized AMR graph)
    """
    if isinstance(texts, str):
        texts = [texts]

    tok_wrapper.src_lang = LANGUAGES[src_lang]
    encoded = tok_wrapper(texts, return_tensors="pt").to(model.device)
    with torch.no_grad():
        generated = model.generate(**encoded, output_scores=True, return_dict_in_generate=True, **gen_kwargs)

    generated["sequences"] = generated["sequences"].cpu()
    generated["sequences_scores"] = generated["sequences_scores"].cpu()
    best_scoring_results = {"graph": [], "status": []}
    beam_size = gen_kwargs["num_beams"]

    # Select the best item from the beam: the sequence with best status and highest score
    for sample_idx in range(0, len(generated["sequences_scores"]), beam_size):
        sequences = generated["sequences"][sample_idx: sample_idx + beam_size]
        scores = generated["sequences_scores"][sample_idx: sample_idx + beam_size].tolist()
        outputs = tok_wrapper.batch_decode_amr_ids(sequences)
        statuses = outputs["status"]
        graphs = outputs["graph"]
        zipped = zip(statuses, scores, graphs)
        # Lowest status first (OK=0, FIXED=1, BACKOFF=2), highest score second
        best = sorted(zipped, key=lambda item: (item[0].value, -item[1]))[0]
        best_scoring_results["graph"].append(best[2])
        best_scoring_results["status"].append(best[0])

    # Returns dictionary with "graph" and "status" keys
    return best_scoring_results


LANGUAGES = {
    "English": "en_XX",
    "Dutch": "nl_XX",
    "Spanish": "es_XX",
}