Spaces:
Running
Running
File size: 4,832 Bytes
05de9a6 f8b0e70 05de9a6 1e0a2f8 e10ccfa b818293 05de9a6 1e0a2f8 b4c04ac 1e0a2f8 b818293 05de9a6 f8b0e70 05de9a6 f8b0e70 51df785 05de9a6 f8b0e70 05de9a6 11f8e5f 1e0a2f8 51df785 1e0a2f8 05de9a6 1e0a2f8 11f8e5f b4c04ac 05de9a6 f8b0e70 05de9a6 f8b0e70 05de9a6 f8b0e70 05de9a6 11f8e5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
from typing import Tuple, Union, Dict, List
from multi_amr.data.postprocessing_graph import ParsedStatus
from multi_amr.data.tokenization import AMRTokenizerWrapper
from optimum.bettertransformer import BetterTransformer
import penman
import streamlit as st
import torch
from torch.quantization import quantize_dynamic
from torch import nn, qint8
from transformers import MBartForConditionalGeneration, AutoConfig
import spaces
@st.cache_resource(show_spinner=False)
def get_resources(multilingual: bool, src_lang: str, quantize: bool = True, no_cuda: bool = False) -> Tuple[MBartForConditionalGeneration, AMRTokenizerWrapper]:
"""Get the relevant model, tokenizer and logits_processor. The loaded model depends on whether the multilingual
model is requested, or not. If not, an English-only model is loaded. The model can be optionally quantized
for better performance.
:param multilingual: whether to load the multilingual model or not
:param src_lang: source language
:param quantize: whether to quantize the model with PyTorch's 'quantize_dynamic'
:param no_cuda: whether to disable CUDA, even if it is available
:return: the loaded model, and tokenizer wrapper
"""
model_name = "BramVanroy/mbart-large-cc25-ft-amr30-en_es_nl"
if not multilingual:
if src_lang == "English":
model_name = "BramVanroy/mbart-large-cc25-ft-amr30-en"
elif src_lang == "Spanish":
model_name = "BramVanroy/mbart-large-cc25-ft-amr30-es"
elif src_lang == "Dutch":
model_name = "BramVanroy/mbart-large-cc25-ft-amr30-nl"
else:
raise ValueError(f"Language {src_lang} not supported")
# Tokenizer src_lang is reset during translation to the right language
tok_wrapper = AMRTokenizerWrapper.from_pretrained(model_name, src_lang="en_XX")
config = AutoConfig.from_pretrained(model_name)
config.decoder_start_token_id = tok_wrapper.amr_token_id
model = MBartForConditionalGeneration.from_pretrained(model_name, config=config)
model.eval()
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tok_wrapper.tokenizer) > embedding_size:
model.resize_token_embeddings(len(tok_wrapper.tokenizer))
model = BetterTransformer.transform(model, keep_original_model=False)
if torch.cuda.is_available() and not no_cuda:
model = model.to("cuda")
elif quantize: # Quantization not supported on CUDA
model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8)
return model, tok_wrapper
@spaces.GPU
def translate(texts: List[str], src_lang: str, model: MBartForConditionalGeneration, tok_wrapper: AMRTokenizerWrapper, **gen_kwargs) -> Dict[str, List[Union[penman.Graph, ParsedStatus]]]:
"""Translates a given text of a given source language with a given model and tokenizer. The generation is guided by
potential keyword-arguments, which can include arguments such as max length, logits processors, etc.
:param texts: source text to translate (potentially a batch)
:param src_lang: source language
:param model: MBART model
:param tok_wrapper: MBART tokenizer wrapper
:param gen_kwargs: potential keyword arguments for the generation process
:return: the translation (linearized AMR graph)
"""
if isinstance(texts, str):
texts = [texts]
tok_wrapper.src_lang = LANGUAGES[src_lang]
encoded = tok_wrapper(texts, return_tensors="pt").to(model.device)
with torch.no_grad():
generated = model.generate(**encoded, output_scores=True, return_dict_in_generate=True, **gen_kwargs)
generated["sequences"] = generated["sequences"].cpu()
generated["sequences_scores"] = generated["sequences_scores"].cpu()
best_scoring_results = {"graph": [], "status": []}
beam_size = gen_kwargs["num_beams"]
# Select the best item from the beam: the sequence with best status and highest score
for sample_idx in range(0, len(generated["sequences_scores"]), beam_size):
sequences = generated["sequences"][sample_idx: sample_idx + beam_size]
scores = generated["sequences_scores"][sample_idx: sample_idx + beam_size].tolist()
outputs = tok_wrapper.batch_decode_amr_ids(sequences)
statuses = outputs["status"]
graphs = outputs["graph"]
zipped = zip(statuses, scores, graphs)
# Lowest status first (OK=0, FIXED=1, BACKOFF=2), highest score second
best = sorted(zipped, key=lambda item: (item[0].value, -item[1]))[0]
best_scoring_results["graph"].append(best[2])
best_scoring_results["status"].append(best[0])
# Returns dictionary with "graph" and "status" keys
return best_scoring_results
LANGUAGES = {
"English": "en_XX",
"Dutch": "nl_XX",
"Spanish": "es_XX",
}
|