Didier commited on
Commit
efcd81a
·
1 Parent(s): 690d91f

CUDA or not CUDA

Browse files
Files changed (1) hide show
  1. model_translation.py +6 -3
model_translation.py CHANGED
@@ -25,7 +25,7 @@ model_names = {
25
  # Registry for all loaded bilingual models
26
  tokenizer_model_registry = {}
27
 
28
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
29
 
30
  def get_tokenizer_model_for_src_lang(src_lang: str) -> (AutoTokenizer, AutoModelForSeq2SeqLM):
31
  """
@@ -47,7 +47,7 @@ def get_tokenizer_model_for_src_lang(src_lang: str) -> (AutoTokenizer, AutoModel
47
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
48
  if model.config.torch_dtype != torch.float16:
49
  model = model.half()
50
- model = model.to(device)
51
  tokenizer_model_registry[src_lang] = (tokenizer, model)
52
 
53
  return (tokenizer, model)
@@ -65,5 +65,8 @@ model_MADLAD_name = "google/madlad400-3b-mt"
65
  #model_MADLAD_name = "google/madlad400-7b-mt-bt"
66
  tokenizer_multilingual = AutoTokenizer.from_pretrained(model_MADLAD_name, use_fast=True)
67
  model_multilingual = AutoModelForSeq2SeqLM.from_pretrained(
68
- model_MADLAD_name, device_map="auto", torch_dtype=torch.float16)
 
 
 
69
 
 
25
  # Registry for all loaded bilingual models
26
  tokenizer_model_registry = {}
27
 
28
+ device = 'cpu'
29
 
30
  def get_tokenizer_model_for_src_lang(src_lang: str) -> (AutoTokenizer, AutoModelForSeq2SeqLM):
31
  """
 
47
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
48
  if model.config.torch_dtype != torch.float16:
49
  model = model.half()
50
+ model.to(device)
51
  tokenizer_model_registry[src_lang] = (tokenizer, model)
52
 
53
  return (tokenizer, model)
 
65
  #model_MADLAD_name = "google/madlad400-7b-mt-bt"
66
  tokenizer_multilingual = AutoTokenizer.from_pretrained(model_MADLAD_name, use_fast=True)
67
  model_multilingual = AutoModelForSeq2SeqLM.from_pretrained(
68
+ model_MADLAD_name,
69
+ device_map="auto",
70
+ torch_dtype=torch.float16
71
+ low_cpu_mem_usage=True)
72