cointegrated's picture
Update README.md
3000577
|
raw
history blame
6.07 kB
metadata
language: ru
pipeline_tag: zero-shot-classification
tags:
  - rubert
  - russian
  - nli
  - rte
  - zero-shot-classification
widget:
  - text: Я хочу поехать в Австралию
    candidate_labels: спорт,путешествия,музыка,кино,книги,наука,политика
    hypothesis_template: Тема текста - {}.

RuBERT for NLI (natural language inference)

This is the DeepPavlov/rubert-base-cased fine-tuned to predict the logical relationship between two short texts: entailment, contradiction, or neutral.

Usage

How to run the model for NLI:

# !pip install transformers sentencepiece --quiet
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

model_checkpoint = 'cointegrated/rubert-base-cased-nli-threeway'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
if torch.cuda.is_available():
    model.cuda()

text1 = 'Сократ - человек, а все люди смертны.'
text2 = 'Сократ никогда не умрёт.'
with torch.inference_mode():
    out = model(**tokenizer(text1, text2, return_tensors='pt').to(model.device))
    proba = torch.softmax(out.logits, -1).cpu().numpy()[0]
print({v: proba[k] for k, v in model.config.id2label.items()})
# {'entailment': 0.009525929, 'contradiction': 0.9332064, 'neutral': 0.05726764} 

You can also use this model for zero-shot short text classification (by labels only), e.g. for sentiment analysis:

def predict_zero_shot(text, label_texts, model, tokenizer, label='entailment', normalize=True):
    label_texts
    tokens = tokenizer([text] * len(label_texts), label_texts, truncation=True, return_tensors='pt', padding=True)
    with torch.inference_mode():
        result = torch.softmax(model(**tokens.to(model.device)).logits, -1)
    proba = result[:, model.config.label2id[label]].cpu().numpy()
    if normalize:
        proba /= sum(proba)
    return proba

classes = ['Я доволен', 'Я недоволен']
predict_zero_shot('Какая гадость эта ваша заливная рыба!', classes, model, tokenizer)
# array([0.05609814, 0.9439019 ], dtype=float32)
predict_zero_shot('Какая вкусная эта ваша заливная рыба!', classes, model, tokenizer)
# array([0.9059292 , 0.09407079], dtype=float32)

Alternatively, you can use Huggingface pipelines for inference.

Sources

The model has been trained on a series of NLI datasets automatically translated to Russian from English.

Most datasets were taken from the repo of Felipe Salvatore: JOCI, MNLI, MPE, SICK, SNLI.

Some datasets obtained from the original sources: ANLI, NLI-style FEVER, IMPPRES.

Performance

The table below shows ROC AUC for three models on small samples of the DEV sets:

  • tiny: a small BERT predicting entailment vs not_entailment
  • twoway: a base-sized BERT predicting entailment vs not_entailment
  • threeway (this model): a base-sized BERT predicting entailment vs contradiction vs neutral
model tiny/entailment twoway/entailment threeway/entailment threeway[3]/contradiction threeway[3]/neutral
add_one_rte 0.82 0.90 0.92
anli_r1 0.50 0.68 0.66 0.70 0.75
anli_r2 0.55 0.62 0.62 0.62 0.69
anli_r3 0.50 0.63 0.59 0.62 0.64
copa 0.55 0.60 0.62
fever 0.88 0.94 0.94 0.91 0.92
help 0.74 0.87 0.46
iie 0.79 0.85 0.54
imppres 0.94 0.99 0.99 0.99 0.99
joci 0.87 0.93 0.93 0.85 0.80
mnli 0.87 0.92 0.93 0.89 0.86
monli 0.94 1.00 0.67
mpe 0.82 0.90 0.90 0.91 0.80
scitail 0.80 0.96 0.85
sick 0.97 0.99 0.99 0.98 0.96
snli 0.95 0.98 0.98 0.99 0.97
terra 0.73 0.93 0.93