ITACA_Insurance_Core_v4 / detect_language.py
dperales's picture
Upload 2 files
65ffad7
raw
history blame
1.13 kB
from transformers import AutoTokenizer, AutoModelForSequenceClassification
class LanguageDetector:
def __init__(self):
# Download the model file
#model_path = hf_hub_download("facebook/fasttext-language-identification", "model.bin")
# Load the FastText model
#self.model = fasttext.load_model(model_path)
self.tokenizer = AutoTokenizer.from_pretrained("papluca/xlm-roberta-base-language-detection")
self.model = AutoModelForSequenceClassification.from_pretrained("papluca/xlm-roberta-base-language-detection")
# Function to predict the language of a text
def predict_language(self, text):
# Tokenize the input text
inputs = self.tokenizer(text, return_tensors="pt")
# Get the model's predictions
outputs = self.model(**inputs)
# Find the index of the highest score
prediction_idx = outputs.logits.argmax(dim=-1).item()
# Convert the index to the corresponding language code using the model's config.id2label
language_code = self.model.config.id2label[prediction_idx]
return language_code