Spaces:
Running
Running
File size: 3,265 Bytes
f8b0e70 1e0a2f8 f8b0e70 11f8e5f 1e0a2f8 11f8e5f f8b0e70 11f8e5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
from typing import Tuple
import streamlit as st
from torch.quantization import quantize_dynamic
from torch import nn, qint8
from torch.nn import Parameter
from transformers import PreTrainedModel, PreTrainedTokenizer
from optimum.bettertransformer import BetterTransformer
from mbart_amr.constraints.constraints import AMRLogitsProcessor
from mbart_amr.data.tokenization import AMRMBartTokenizer
from transformers import MBartForConditionalGeneration
st_hash_funcs = {PreTrainedModel: lambda model: model.name_or_path,
PreTrainedTokenizer: lambda tokenizer: tokenizer.name_or_path,
Parameter: lambda param: param.data}
@st.cache(show_spinner=False, hash_funcs=st_hash_funcs, allow_output_mutation=True)
def get_resources(multilingual: bool, quantize: bool = True) -> Tuple[MBartForConditionalGeneration, AMRMBartTokenizer, AMRLogitsProcessor]:
"""Get the relevant model, tokenizer and logits_processor. The loaded model depends on whether the multilingual
model is requested, or not. If not, an English-only model is loaded. The model can be optionally quantized
for better performance.
:param multilingual: whether or not to load the multilingual model. If not, loads the English-only model
:param quantize: whether to quantize the model with PyTorch's 'quantize_dynamic'
:return: the loaded model, tokenizer, and logits processor
"""
if multilingual:
# Tokenizer src_lang is reset during translation to the right language
tokenizer = AMRMBartTokenizer.from_pretrained("BramVanroy/mbart-en-es-nl-to-amr", src_lang="nl_XX")
model = MBartForConditionalGeneration.from_pretrained("BramVanroy/mbart-en-es-nl-to-amr")
else:
tokenizer = AMRMBartTokenizer.from_pretrained("BramVanroy/mbart-en-to-amr", src_lang="en_XX")
model = MBartForConditionalGeneration.from_pretrained("BramVanroy/mbart-en-to-amr")
model = BetterTransformer.transform(model, keep_original_model=False)
model.resize_token_embeddings(len(tokenizer))
if quantize:
model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8)
logits_processor = AMRLogitsProcessor(tokenizer, model.config.max_length)
return model, tokenizer, logits_processor
@st.cache(show_spinner=False, hash_funcs=st_hash_funcs)
def translate(text: str, src_lang: str, model: MBartForConditionalGeneration, tokenizer: AMRMBartTokenizer, **gen_kwargs) -> str:
"""Translates a given text of a given source language with a given model and tokenizer. The generation is guided by
potential keyword-arguments, which can include arguments such as max length, logits processors, etc.
:param text: source text to translate
:param src_lang: source language
:param model: MBART model
:param tokenizer: MBART tokenizer
:param gen_kwargs: potential keyword arguments for the generation process
:return: the translation (linearized AMR graph)
"""
tokenizer.src_lang = LANGUAGES[src_lang]
encoded = tokenizer(text, return_tensors="pt")
generated = model.generate(**encoded, **gen_kwargs)
return tokenizer.decode_and_fix(generated)[0]
LANGUAGES = {
"English": "en_XX",
"Dutch": "nl_XX",
"Spanish": "es_XX",
}
|