BramVanroy commited on
Commit
b4c04ac
β€’
1 Parent(s): e0b12db

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +2 -0
utils.py CHANGED
@@ -10,6 +10,7 @@ from torch.quantization import quantize_dynamic
10
  from torch import nn, qint8
11
  from transformers import MBartForConditionalGeneration, AutoConfig
12
 
 
13
 
14
  @st.cache_resource(show_spinner=False)
15
  def get_resources(multilingual: bool, src_lang: str, quantize: bool = True, no_cuda: bool = False) -> Tuple[MBartForConditionalGeneration, AMRTokenizerWrapper]:
@@ -57,6 +58,7 @@ def get_resources(multilingual: bool, src_lang: str, quantize: bool = True, no_c
57
  return model, tok_wrapper
58
 
59
 
 
60
  def translate(texts: List[str], src_lang: str, model: MBartForConditionalGeneration, tok_wrapper: AMRTokenizerWrapper, **gen_kwargs) -> Dict[str, List[Union[penman.Graph, ParsedStatus]]]:
61
  """Translates a given text of a given source language with a given model and tokenizer. The generation is guided by
62
  potential keyword-arguments, which can include arguments such as max length, logits processors, etc.
 
10
  from torch import nn, qint8
11
  from transformers import MBartForConditionalGeneration, AutoConfig
12
 
13
+ import spaces
14
 
15
  @st.cache_resource(show_spinner=False)
16
  def get_resources(multilingual: bool, src_lang: str, quantize: bool = True, no_cuda: bool = False) -> Tuple[MBartForConditionalGeneration, AMRTokenizerWrapper]:
 
58
  return model, tok_wrapper
59
 
60
 
61
+ @spaces.GPU
62
  def translate(texts: List[str], src_lang: str, model: MBartForConditionalGeneration, tok_wrapper: AMRTokenizerWrapper, **gen_kwargs) -> Dict[str, List[Union[penman.Graph, ParsedStatus]]]:
63
  """Translates a given text of a given source language with a given model and tokenizer. The generation is guided by
64
  potential keyword-arguments, which can include arguments such as max length, logits processors, etc.