Spaces:
Running
Running
CUDA or not CUDA
Browse files- model_translation.py +6 -3
model_translation.py
CHANGED
@@ -25,7 +25,7 @@ model_names = {
|
|
25 |
# Registry for all loaded bilingual models
|
26 |
tokenizer_model_registry = {}
|
27 |
|
28 |
-
device = '
|
29 |
|
30 |
def get_tokenizer_model_for_src_lang(src_lang: str) -> (AutoTokenizer, AutoModelForSeq2SeqLM):
|
31 |
"""
|
@@ -47,7 +47,7 @@ def get_tokenizer_model_for_src_lang(src_lang: str) -> (AutoTokenizer, AutoModel
|
|
47 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
48 |
if model.config.torch_dtype != torch.float16:
|
49 |
model = model.half()
|
50 |
-
model
|
51 |
tokenizer_model_registry[src_lang] = (tokenizer, model)
|
52 |
|
53 |
return (tokenizer, model)
|
@@ -65,5 +65,8 @@ model_MADLAD_name = "google/madlad400-3b-mt"
|
|
65 |
#model_MADLAD_name = "google/madlad400-7b-mt-bt"
|
66 |
tokenizer_multilingual = AutoTokenizer.from_pretrained(model_MADLAD_name, use_fast=True)
|
67 |
model_multilingual = AutoModelForSeq2SeqLM.from_pretrained(
|
68 |
-
model_MADLAD_name,
|
|
|
|
|
|
|
69 |
|
|
|
25 |
# Registry for all loaded bilingual models
|
26 |
tokenizer_model_registry = {}
|
27 |
|
28 |
+
device = 'cpu'
|
29 |
|
30 |
def get_tokenizer_model_for_src_lang(src_lang: str) -> (AutoTokenizer, AutoModelForSeq2SeqLM):
|
31 |
"""
|
|
|
47 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
48 |
if model.config.torch_dtype != torch.float16:
|
49 |
model = model.half()
|
50 |
+
model.to(device)
|
51 |
tokenizer_model_registry[src_lang] = (tokenizer, model)
|
52 |
|
53 |
return (tokenizer, model)
|
|
|
65 |
#model_MADLAD_name = "google/madlad400-7b-mt-bt"
|
66 |
tokenizer_multilingual = AutoTokenizer.from_pretrained(model_MADLAD_name, use_fast=True)
|
67 |
model_multilingual = AutoModelForSeq2SeqLM.from_pretrained(
|
68 |
+
model_MADLAD_name,
|
69 |
+
device_map="auto",
|
70 |
+
torch_dtype=torch.float16
|
71 |
+
low_cpu_mem_usage=True)
|
72 |
|