Bram Vanroy commited on
Commit
51df785
β€’
1 Parent(s): 1184fa3

add CUDA support

Browse files
Files changed (1) hide show
  1. utils.py +11 -5
utils.py CHANGED
@@ -2,8 +2,9 @@ from typing import Tuple
2
 
3
  import streamlit as st
4
 
 
5
  from torch.quantization import quantize_dynamic
6
- from torch import nn, qint8
7
  from torch.nn import Parameter
8
  from transformers import PreTrainedModel, PreTrainedTokenizer
9
  from optimum.bettertransformer import BetterTransformer
@@ -14,17 +15,19 @@ from transformers import MBartForConditionalGeneration
14
 
15
  st_hash_funcs = {PreTrainedModel: lambda model: model.name_or_path,
16
  PreTrainedTokenizer: lambda tokenizer: tokenizer.name_or_path,
17
- Parameter: lambda param: param.data}
 
18
 
19
 
20
  @st.cache(show_spinner=False, hash_funcs=st_hash_funcs, allow_output_mutation=True)
21
- def get_resources(multilingual: bool, quantize: bool = True) -> Tuple[MBartForConditionalGeneration, AMRMBartTokenizer, AMRLogitsProcessor]:
22
  """Get the relevant model, tokenizer and logits_processor. The loaded model depends on whether the multilingual
23
  model is requested, or not. If not, an English-only model is loaded. The model can be optionally quantized
24
  for better performance.
25
 
26
  :param multilingual: whether or not to load the multilingual model. If not, loads the English-only model
27
  :param quantize: whether to quantize the model with PyTorch's 'quantize_dynamic'
 
28
  :return: the loaded model, tokenizer, and logits processor
29
  """
30
  if multilingual:
@@ -38,7 +41,9 @@ def get_resources(multilingual: bool, quantize: bool = True) -> Tuple[MBartForCo
38
  model = BetterTransformer.transform(model, keep_original_model=False)
39
  model.resize_token_embeddings(len(tokenizer))
40
 
41
- if quantize:
 
 
42
  model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8)
43
 
44
  logits_processor = AMRLogitsProcessor(tokenizer, model.config.max_length)
@@ -60,7 +65,8 @@ def translate(text: str, src_lang: str, model: MBartForConditionalGeneration, to
60
  """
61
  tokenizer.src_lang = LANGUAGES[src_lang]
62
  encoded = tokenizer(text, return_tensors="pt")
63
- generated = model.generate(**encoded, **gen_kwargs)
 
64
  return tokenizer.decode_and_fix(generated)[0]
65
 
66
 
 
2
 
3
  import streamlit as st
4
 
5
+ import torch
6
  from torch.quantization import quantize_dynamic
7
+ from torch import nn, qint8, Tensor
8
  from torch.nn import Parameter
9
  from transformers import PreTrainedModel, PreTrainedTokenizer
10
  from optimum.bettertransformer import BetterTransformer
 
15
 
16
  st_hash_funcs = {PreTrainedModel: lambda model: model.name_or_path,
17
  PreTrainedTokenizer: lambda tokenizer: tokenizer.name_or_path,
18
+ Parameter: lambda parameter: parameter.data,
19
+ Tensor: lambda tensor: tensor.cpu()}
20
 
21
 
22
  @st.cache(show_spinner=False, hash_funcs=st_hash_funcs, allow_output_mutation=True)
23
+ def get_resources(multilingual: bool, quantize: bool = True, no_cuda: bool = False) -> Tuple[MBartForConditionalGeneration, AMRMBartTokenizer, AMRLogitsProcessor]:
24
  """Get the relevant model, tokenizer and logits_processor. The loaded model depends on whether the multilingual
25
  model is requested, or not. If not, an English-only model is loaded. The model can be optionally quantized
26
  for better performance.
27
 
28
  :param multilingual: whether or not to load the multilingual model. If not, loads the English-only model
29
  :param quantize: whether to quantize the model with PyTorch's 'quantize_dynamic'
30
+ :param no_cuda: whether to disable CUDA, even if it is available
31
  :return: the loaded model, tokenizer, and logits processor
32
  """
33
  if multilingual:
 
41
  model = BetterTransformer.transform(model, keep_original_model=False)
42
  model.resize_token_embeddings(len(tokenizer))
43
 
44
+ if torch.cuda.is_available() and not no_cuda:
45
+ model = model.to("cuda")
46
+ elif quantize: # Quantization not supported on CUDA
47
  model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8)
48
 
49
  logits_processor = AMRLogitsProcessor(tokenizer, model.config.max_length)
 
65
  """
66
  tokenizer.src_lang = LANGUAGES[src_lang]
67
  encoded = tokenizer(text, return_tensors="pt")
68
+ encoded = {k: v.to(model.device) for k, v in encoded.items()}
69
+ generated = model.generate(**encoded, **gen_kwargs).cpu()
70
  return tokenizer.decode_and_fix(generated)[0]
71
 
72