metadata
language: ru
pipeline_tag: zero-shot-classification
tags:
- rubert
- russian
- nli
- rte
- zero-shot-classification
widget:
- text: Я хочу поехать в Австралию
candidate_labels: спорт,путешествия,музыка,кино,книги,наука,политика
hypothesis_template: Тема текста - {}.
RuBERT for NLI (natural language inference)
This is the DeepPavlov/rubert-base-cased fine-tuned to predict the logical relationship between two short texts: entailment, contradiction, or neutral.
Usage
How to run the model for NLI:
# !pip install transformers sentencepiece --quiet
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
model_checkpoint = 'cointegrated/rubert-base-cased-nli-threeway'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
if torch.cuda.is_available():
model.cuda()
text1 = 'Сократ - человек, а все люди смертны.'
text2 = 'Сократ никогда не умрёт.'
with torch.inference_mode():
out = model(**tokenizer(text1, text2, return_tensors='pt').to(model.device))
proba = torch.softmax(out.logits, -1).cpu().numpy()[0]
print({v: proba[k] for k, v in model.config.id2label.items()})
# {'entailment': 0.009525929, 'contradiction': 0.9332064, 'neutral': 0.05726764}
You can also use this model for zero-shot short text classification (by labels only), e.g. for sentiment analysis:
def predict_zero_shot(text, label_texts, model, tokenizer, label='entailment', normalize=True):
label_texts
tokens = tokenizer([text] * len(label_texts), label_texts, truncation=True, return_tensors='pt', padding=True)
with torch.inference_mode():
result = torch.softmax(model(**tokens.to(model.device)).logits, -1)
proba = result[:, model.config.label2id[label]].cpu().numpy()
if normalize:
proba /= sum(proba)
return proba
classes = ['Я доволен', 'Я недоволен']
predict_zero_shot('Какая гадость эта ваша заливная рыба!', classes, model, tokenizer)
# array([0.05609814, 0.9439019 ], dtype=float32)
predict_zero_shot('Какая вкусная эта ваша заливная рыба!', classes, model, tokenizer)
# array([0.9059292 , 0.09407079], dtype=float32)
Alternatively, you can use Huggingface pipelines for inference.
Sources
The model has been trained on a series of NLI datasets automatically translated to Russian from English.
Most datasets were taken from the repo of Felipe Salvatore: JOCI, MNLI, MPE, SICK, SNLI.
Some datasets obtained from the original sources: ANLI, NLI-style FEVER, IMPPRES.
Performance
The table below shows ROC AUC for three models on small samples of the DEV sets:
- tiny: a small BERT predicting entailment vs not_entailment
- twoway: a base-sized BERT predicting entailment vs not_entailment
- threeway (this model): a base-sized BERT predicting entailment vs contradiction vs neutral
model | tiny/entailment | twoway/entailment | threeway/entailment | threeway[3]/contradiction | threeway[3]/neutral |
---|---|---|---|---|---|
add_one_rte | 0.82 | 0.90 | 0.92 | ||
anli_r1 | 0.50 | 0.68 | 0.66 | 0.70 | 0.75 |
anli_r2 | 0.55 | 0.62 | 0.62 | 0.62 | 0.69 |
anli_r3 | 0.50 | 0.63 | 0.59 | 0.62 | 0.64 |
copa | 0.55 | 0.60 | 0.62 | ||
fever | 0.88 | 0.94 | 0.94 | 0.91 | 0.92 |
help | 0.74 | 0.87 | 0.46 | ||
iie | 0.79 | 0.85 | 0.54 | ||
imppres | 0.94 | 0.99 | 0.99 | 0.99 | 0.99 |
joci | 0.87 | 0.93 | 0.93 | 0.85 | 0.80 |
mnli | 0.87 | 0.92 | 0.93 | 0.89 | 0.86 |
monli | 0.94 | 1.00 | 0.67 | ||
mpe | 0.82 | 0.90 | 0.90 | 0.91 | 0.80 |
scitail | 0.80 | 0.96 | 0.85 | ||
sick | 0.97 | 0.99 | 0.99 | 0.98 | 0.96 |
snli | 0.95 | 0.98 | 0.98 | 0.99 | 0.97 |
terra | 0.73 | 0.93 | 0.93 |