RuBERT for NLI (natural language inference)

This is the DeepPavlov/rubert-base-cased fine-tuned to predict the logical relationship between two short texts: entailment, contradiction, or neutral.

Usage

How to run the model for NLI:

# !pip install transformers sentencepiece --quiet
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

model_checkpoint = 'cointegrated/rubert-base-cased-nli-threeway'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
if torch.cuda.is_available():
    model.cuda()

text1 = 'Сократ - человек, а все люди смертны.'
text2 = 'Сократ никогда не умрёт.'
with torch.inference_mode():
    out = model(**tokenizer(text1, text2, return_tensors='pt').to(model.device))
    proba = torch.softmax(out.logits, -1).cpu().numpy()[0]
print({v: proba[k] for k, v in model.config.id2label.items()})
# {'entailment': 0.009525929, 'contradiction': 0.9332064, 'neutral': 0.05726764} 

You can also use this model for zero-shot short text classification (by labels only), e.g. for sentiment analysis:

def predict_zero_shot(text, label_texts, model, tokenizer, label='entailment', normalize=True):
    label_texts
    tokens = tokenizer([text] * len(label_texts), label_texts, truncation=True, return_tensors='pt', padding=True)
    with torch.inference_mode():
        result = torch.softmax(model(**tokens.to(model.device)).logits, -1)
    proba = result[:, model.config.label2id[label]].cpu().numpy()
    if normalize:
        proba /= sum(proba)
    return proba

classes = ['Я доволен', 'Я недоволен']
predict_zero_shot('Какая гадость эта ваша заливная рыба!', classes, model, tokenizer)
# array([0.05609814, 0.9439019 ], dtype=float32)
predict_zero_shot('Какая вкусная эта ваша заливная рыба!', classes, model, tokenizer)
# array([0.9059292 , 0.09407079], dtype=float32)

Alternatively, you can use Huggingface pipelines for inference.

Sources

The model has been trained on a series of NLI datasets automatically translated to Russian from English.

Most datasets were taken from the repo of Felipe Salvatore: JOCI, MNLI, MPE, SICK, SNLI.

Some datasets obtained from the original sources: ANLI, NLI-style FEVER, IMPPRES.

Performance

The table below shows ROC AUC (one class vs rest) for five models on the corresponding dev sets:

  • tiny: a small BERT predicting entailment vs not_entailment
  • twoway: a base-sized BERT predicting entailment vs not_entailment
  • threeway (this model): a base-sized BERT predicting entailment vs contradiction vs neutral
  • vicgalle-xlm: a large multilingual NLI model
  • facebook-bart: a large multilingual NLI model
model add_one_rte anli_r1 anli_r2 anli_r3 copa fever help iie imppres joci mnli monli mpe scitail sick snli terra total
n_observations 387 1000 1000 1200 200 20474 3355 31232 7661 939 19647 269 1000 2126 500 9831 307 101128
tiny/entailment 0.77 0.59 0.52 0.53 0.53 0.90 0.81 0.78 0.93 0.81 0.82 0.91 0.81 0.78 0.93 0.95 0.67 0.77
twoway/entailment 0.89 0.73 0.61 0.62 0.58 0.96 0.92 0.87 0.99 0.90 0.90 0.99 0.91 0.96 0.97 0.97 0.87 0.86
threeway/entailment 0.91 0.75 0.61 0.61 0.57 0.96 0.56 0.61 0.99 0.90 0.91 0.67 0.92 0.84 0.98 0.98 0.90 0.80
vicgalle-xlm/entailment 0.88 0.79 0.63 0.66 0.57 0.93 0.56 0.62 0.77 0.80 0.90 0.70 0.83 0.84 0.91 0.93 0.93 0.78
facebook-bart/entailment 0.51 0.41 0.43 0.47 0.50 0.74 0.55 0.57 0.60 0.63 0.70 0.52 0.56 0.68 0.67 0.72 0.64 0.58
threeway/contradiction 0.71 0.64 0.61 0.97 1.00 0.77 0.92 0.89 0.99 0.98 0.85
threeway/neutral 0.79 0.70 0.62 0.91 0.99 0.68 0.86 0.79 0.96 0.96 0.83

For evaluation (and for training of the tiny and twoway models), some extra datasets were used: Add-one RTE, CoPA, IIE, and SCITAIL taken from the repo of Felipe Salvatore and translatted, HELP and MoNLI taken from the original sources and translated, and Russian TERRa.

Downloads last month
242
Hosted inference API
Zero-Shot Classification
This model can be loaded on the Inference API on-demand.