S3TVR-Demo / models /nllb.py
yalsaffar's picture
Updated Dockerfile for Hugging Face Spaces deployment
81b0f36
raw
history blame
3.17 kB
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
import os
# Disable numba caching
os.environ["NUMBA_CACHE_DIR"] = "/tmp/numba_cache"
os.environ["NUMBA_DISABLE_JIT"] = "1"
def nllb():
"""
Load and return the NLLB (No Language Left Behind) model and tokenizer.
This function loads the NLLB-200-distilled-1.3B model and tokenizer from Hugging Face's Transformers library.
The model is configured to use a GPU if available, otherwise it defaults to CPU.
Returns:
tuple: A tuple containing the loaded model and tokenizer.
- model (transformers.AutoModelForSeq2SeqLM): The loaded NLLB model.
- tokenizer (transformers.AutoTokenizer): The loaded tokenizer.
Example usage:
model, tokenizer = nllb()
"""
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the tokenizer and model
# Set Hugging Face cache directory
# Ensure the cache directory exists and has the correct permissions
os.environ['HF_HOME'] = '/app/cache/huggingface'
os.environ['TRANSFORMERS_CACHE'] = '/app/cache/huggingface'
# Load models
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-1.3B")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-1.3B").to(device)
return model, tokenizer
def nllb_translate(model, tokenizer, article, language):
"""
Translate an article using the NLLB model and tokenizer.
Args:
model (transformers.AutoModelForSeq2SeqLM): The NLLB model to use for translation.
Example: model, tokenizer = nllb()
tokenizer (transformers.AutoTokenizer): The tokenizer to use with the NLLB model.
Example: model, tokenizer = nllb()
article (str): The article text to be translated.
Example: "This is a sample article."
language (str): The target language for translation. Must be either 'spanish' or 'english'.
Example: "spanish"
Returns:
str: The translated text.
Example: "Este es un artículo de muestra."
"""
try:
# Tokenize the text
inputs = tokenizer(article, return_tensors="pt")
# Move the tokenized inputs to the same device as the model
inputs = {k: v.to(model.device) for k, v in inputs.items()}
if language == "es":
translated_tokens = model.generate(
**inputs, forced_bos_token_id=tokenizer.lang_code_to_id["spa_Latn"], max_length=30
)
elif language == "en":
translated_tokens = model.generate(
**inputs, forced_bos_token_id=tokenizer.lang_code_to_id["eng_Latn"], max_length=30
)
else:
raise ValueError("Unsupported language. Use 'es' or 'en'.")
# Decode the translation
text = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
return text
except Exception as e:
print(f"Error during translation: {e}")
return "Translation failed"