|
--- |
|
license: apache-2.0 |
|
datasets: |
|
- common_language |
|
language: |
|
- ar |
|
- eu |
|
- br |
|
- ca |
|
- zh |
|
- cv |
|
- cs |
|
- nl |
|
- en |
|
- eo |
|
- et |
|
- fr |
|
- ka |
|
- de |
|
- el |
|
- id |
|
- ia |
|
- it |
|
- ja |
|
- rw |
|
- ky |
|
- lv |
|
- mt |
|
- mn |
|
- fa |
|
- pl |
|
- pt |
|
- ro |
|
- rm |
|
- ru |
|
- sl |
|
- es |
|
- sv |
|
- ta |
|
- tt |
|
- tr |
|
- uk |
|
- cy |
|
metrics: |
|
- accuracy |
|
- precision |
|
- recall |
|
- f1 |
|
tags: |
|
- language-detection |
|
- Frisian |
|
- Dhivehi |
|
- Hakha_Chin |
|
- Kabyle |
|
- Sakha |
|
--- |
|
|
|
|
|
### Overview |
|
This model supports the detection of **45** languages, and it's fine-tuned using **multilingual-e5-base** model on the **common-language** dataset.<br> |
|
The overall accuracy is **98.37%**, and more evaluation results are shown the below. |
|
|
|
### Download the model |
|
```python |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
tokenizer = AutoTokenizer.from_pretrained('Mike0307/multilingual-e5-language-detection') |
|
model = AutoModelForSequenceClassification.from_pretrained('Mike0307/multilingual-e5-language-detection', num_labels=45) |
|
``` |
|
|
|
### Example of language detection |
|
```python |
|
import torch |
|
|
|
languages = [ |
|
"Arabic", "Basque", "Breton", "Catalan", "Chinese_China", "Chinese_Hongkong", |
|
"Chinese_Taiwan", "Chuvash", "Czech", "Dhivehi", "Dutch", "English", |
|
"Esperanto", "Estonian", "French", "Frisian", "Georgian", "German", "Greek", |
|
"Hakha_Chin", "Indonesian", "Interlingua", "Italian", "Japanese", "Kabyle", |
|
"Kinyarwanda", "Kyrgyz", "Latvian", "Maltese", "Mongolian", "Persian", "Polish", |
|
"Portuguese", "Romanian", "Romansh_Sursilvan", "Russian", "Sakha", "Slovenian", |
|
"Spanish", "Swedish", "Tamil", "Tatar", "Turkish", "Ukranian", "Welsh" |
|
] |
|
|
|
def predict(text, model, tokenizer, device = torch.device('cpu')): |
|
model.to(device) |
|
model.eval() |
|
tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors="pt") |
|
input_ids = tokenized['input_ids'] |
|
attention_mask = tokenized['attention_mask'] |
|
with torch.no_grad(): |
|
input_ids = input_ids.to(device) |
|
attention_mask = attention_mask.to(device) |
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
|
logits = outputs.logits |
|
probabilities = torch.nn.functional.softmax(logits, dim=1) |
|
return probabilities |
|
|
|
def get_topk(probabilities, languages, k=3): |
|
topk_prob, topk_indices = torch.topk(probabilities, k) |
|
topk_prob = topk_prob.cpu().numpy()[0].tolist() |
|
topk_indices = topk_indices.cpu().numpy()[0].tolist() |
|
topk_labels = [languages[index] for index in topk_indices] |
|
return topk_prob, topk_labels |
|
|
|
text = "你的測試句子" |
|
probabilities = predict(text, model, tokenizer) |
|
topk_prob, topk_labels = get_topk(probabilities, languages) |
|
print(topk_prob, topk_labels) |
|
|
|
# [0.999620258808, 0.00025940246996469, 2.7690215574693e-05] |
|
# ['Chinese_Taiwan', 'Chinese_Hongkong', 'Chinese_China'] |
|
``` |
|
|
|
### Evaluation Results |
|
The test datasets refers to the **common_language** test datasets. |
|
|
|
|language | precision | recall | f1-score | support | |
|
| --- | --- | ---| --- | --- | |
|
|Arabic|1.00|1.00|1.00|151| |
|
| Basque | 0.99 | 1.00 | 1.00 | 111| |
|
| Breton | 1.00 | 0.90 | 0.95 | 252| |
|
| Catalan | 0.96 | 0.99 | 0.97 | 96| |
|
| Chinese_China | 0.98 | 1.00 | 0.99 | 100| |
|
| Chinese_Hongkong | 0.97 | 0.87 | 0.92 | 115| |
|
| Chinese_Taiwan | 0.92 | 0.98 | 0.95 | 170| |
|
| Chuvash | 0.98 | 1.00 | 0.99 | 137| |
|
| Czech | 0.98 | 1.00 | 0.99 | 128| |
|
| Dhivehi | 1.00 | 1.00 | 1.00 | 111| |
|
| Dutch | 0.99 | 1.00 | 0.99 | 144| |
|
| English | 0.96 | 1.00 | 0.98 | 98| |
|
| Esperanto | 0.98 | 0.98 | 0.98 | 107| |
|
| Estonian | 1.00 | 0.99 | 0.99 | 93| |
|
| French | 0.95 | 1.00 | 0.98 | 106| |
|
| Frisian | 1.00 | 0.98 | 0.99 | 117| |
|
| Georgian | 1.00 | 1.00 | 1.00 | 110| |
|
| German | 1.00 | 1.00 | 1.00 | 101| |
|
| Greek | 1.00 | 1.00 | 1.00 | 153| |
|
| Hakha_Chin | 0.99 | 1.00 | 0.99 | 202| |
|
| Indonesian | 0.99 | 0.99 | 0.99 | 150| |
|
| Interlingua | 0.96 | 0.97 | 0.96 | 182| |
|
| Italian | 0.99 | 0.94 | 0.96 | 100| |
|
| Japanese | 1.00 | 1.00 | 1.00 | 144| |
|
| Kabyle | 1.00 | 0.96 | 0.98 | 156| |
|
| Kinyarwanda | 0.97 | 1.00 | 0.99 | 103| |
|
| Kyrgyz | 0.98 | 1.00 | 0.99 | 129| |
|
| Latvian | 0.98 | 0.98 | 0.98 | 171| |
|
| Maltese | 0.99 | 0.98 | 0.98 | 152| |
|
| Mongolian | 1.00 | 1.00 | 1.00 | 112| |
|
| Persian | 1.00 | 1.00 | 1.00 | 123| |
|
| Polish | 0.91 | 0.99 | 0.95 | 128| |
|
| Portuguese | 0.94 | 0.99 | 0.96 | 124| |
|
| Romanian | 1.00 | 1.00 | 1.00 | 152| |
|
|Romansh_Sursilvan | 0.99 | 0.95 | 0.97 | 106| |
|
| Russian | 0.99 | 0.99 | 0.99 | 100| |
|
| Sakha | 0.99 | 1.00 | 1.00 | 105| |
|
| Slovenian | 0.99 | 1.00 | 1.00 | 166| |
|
| Spanish | 0.96 | 0.95 | 0.95 | 94| |
|
| Swedish | 0.99 | 1.00 | 0.99 | 190| |
|
| Tamil | 1.00 | 1.00 | 1.00 | 135| |
|
| Tatar | 1.00 | 0.96 | 0.98 | 173| |
|
| Turkish | 1.00 | 1.00 | 1.00 | 137| |
|
| Ukranian | 0.99 | 1.00 | 1.00 | 126| |
|
| Welsh | 0.98 | 1.00 | 0.99 | 103| |
|
|| |
|
| *macro avg* | 0.98 | 0.99 | 0.98 | 5963| |
|
| *weighted avg* | 0.98 | 0.98 | 0.98 | 5963| |
|
|| |
|
| *overall accuracy* | | | 0.9837 | 5963| |
|
|
|
|
|
|
|
|