cointegrated's picture
Update README.md
872779e
metadata
language: ru
pipeline_tag: zero-shot-classification
tags:
  - rubert
  - russian
  - nli
  - rte
  - zero-shot-classification
widget:
  - text: Я хочу поехать в Австралию
    candidate_labels: спорт,путешествия,музыка,кино,книги,наука,политика
    hypothesis_template: Тема текста - {}.

RuBERT for NLI (natural language inference)

This is the DeepPavlov/rubert-base-cased fine-tuned to predict the logical relationship between two short texts: entailment, contradiction, or neutral.

Usage

How to run the model for NLI:

# !pip install transformers sentencepiece --quiet
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

model_checkpoint = 'cointegrated/rubert-base-cased-nli-threeway'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
if torch.cuda.is_available():
    model.cuda()

text1 = 'Сократ - человек, а все люди смертны.'
text2 = 'Сократ никогда не умрёт.'
with torch.inference_mode():
    out = model(**tokenizer(text1, text2, return_tensors='pt').to(model.device))
    proba = torch.softmax(out.logits, -1).cpu().numpy()[0]
print({v: proba[k] for k, v in model.config.id2label.items()})
# {'entailment': 0.009525929, 'contradiction': 0.9332064, 'neutral': 0.05726764} 

You can also use this model for zero-shot short text classification (by labels only), e.g. for sentiment analysis:

def predict_zero_shot(text, label_texts, model, tokenizer, label='entailment', normalize=True):
    label_texts
    tokens = tokenizer([text] * len(label_texts), label_texts, truncation=True, return_tensors='pt', padding=True)
    with torch.inference_mode():
        result = torch.softmax(model(**tokens.to(model.device)).logits, -1)
    proba = result[:, model.config.label2id[label]].cpu().numpy()
    if normalize:
        proba /= sum(proba)
    return proba

classes = ['Я доволен', 'Я недоволен']
predict_zero_shot('Какая гадость эта ваша заливная рыба!', classes, model, tokenizer)
# array([0.05609814, 0.9439019 ], dtype=float32)
predict_zero_shot('Какая вкусная эта ваша заливная рыба!', classes, model, tokenizer)
# array([0.9059292 , 0.09407079], dtype=float32)

Alternatively, you can use Huggingface pipelines for inference.

Sources

The model has been trained on a series of NLI datasets automatically translated to Russian from English.

Most datasets were taken from the repo of Felipe Salvatore: JOCI, MNLI, MPE, SICK, SNLI.

Some datasets obtained from the original sources: ANLI, NLI-style FEVER, IMPPRES.

Performance

The table below shows ROC AUC for three models on small samples of the DEV sets:

  • tiny: a small BERT predicting entailment vs not_entailment
  • twoway: a base-sized BERT predicting entailment vs not_entailment
  • threeway (this model): a base-sized BERT predicting entailment vs contradiction vs neutral
model tiny/entailment twoway/entailment threeway/entailment threeway/contradiction threeway/neutral
add_one_rte 0.82 0.90 0.92
anli_r1 0.50 0.68 0.66 0.70 0.75
anli_r2 0.55 0.62 0.62 0.62 0.69
anli_r3 0.50 0.63 0.59 0.62 0.64
copa 0.55 0.60 0.62
fever 0.88 0.94 0.94 0.91 0.92
help 0.74 0.87 0.46
iie 0.79 0.85 0.54
imppres 0.94 0.99 0.99 0.99 0.99
joci 0.87 0.93 0.93 0.85 0.80
mnli 0.87 0.92 0.93 0.89 0.86
monli 0.94 1.00 0.67
mpe 0.82 0.90 0.90 0.91 0.80
scitail 0.80 0.96 0.85
sick 0.97 0.99 0.99 0.98 0.96
snli 0.95 0.98 0.98 0.99 0.97
terra 0.73 0.93 0.93
m add_one_rte anli_r1 anli_r2 anli_r3 copa fever help iie imppres joci mnli monli mpe scitail sick snli terra mean
n 387 1000 1000 1200 200 20474 3355 31232 7661 939 19647 269 1000 2126 500 9831 307 101128
------------------------ ----------- ------- ------- ------- ---- ----- ---- ----- ------- ---- ----- ----- ---- ------- ---- ---- ----- ------
tiny/entailment 0.77 0.59 0.52 0.53 0.53 0.90 0.81 0.78 0.93 0.81 0.82 0.91 0.81 0.78 0.93 0.95 0.67 0.77
twoway/entailment 0.89 0.73 0.61 0.62 0.58 0.96 0.92 0.87 0.99 0.90 0.90 0.99 0.91 0.96 0.97 0.97 0.87 0.86
threeway/entailment 0.91 0.75 0.61 0.61 0.57 0.96 0.56 0.61 0.99 0.90 0.91 0.67 0.92 0.84 0.98 0.98 0.90 0.80
vicgalle-xlm/entailment 0.88 0.79 0.63 0.66 0.57 0.93 0.56 0.62 0.77 0.80 0.90 0.70 0.83 0.84 0.91 0.93 0.93 0.78
facebook-bart/entailment 0.51 0.41 0.43 0.47 0.50 0.74 0.55 0.57 0.60 0.63 0.70 0.52 0.56 0.68 0.67 0.72 0.64 0.58
threeway/contradiction 0.71 0.64 0.61 0.97 1.00 0.77 0.92 0.89 0.99 0.98 0.85
threeway/neutral 0.79 0.70 0.62 0.91 0.99 0.68 0.86 0.79 0.96 0.96 0.83