File size: 6,414 Bytes
f3ea335
 
 
 
 
 
f3b1f58
 
 
83aa602
 
 
 
a762532
 
f3ea335
3000577
f3b1f58
3000577
f3b1f58
3000577
f3b1f58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3000577
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
097e5df
3000577
 
 
097e5df
 
3000577
097e5df
f4a5a35
872779e
097e5df
872779e
 
 
 
 
 
 
 
097e5df
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
---
language: ru
pipeline_tag: zero-shot-classification
tags:
- rubert
- russian
- nli
- rte
- zero-shot-classification
widget:
- text: "Я хочу поехать в Австралию"
  candidate_labels: "спорт,путешествия,музыка,кино,книги,наука,политика"
  hypothesis_template: "Тема текста - {}." 
datasets:
- cointegrated/nli-rus-translated-v2021
---
# RuBERT for NLI (natural language inference)

This is the [DeepPavlov/rubert-base-cased](https://huggingface.co/DeepPavlov/rubert-base-cased) fine-tuned to predict the logical relationship between two short texts: entailment, contradiction, or neutral.

## Usage
How to run the model for NLI:
```python
# !pip install transformers sentencepiece --quiet
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

model_checkpoint = 'cointegrated/rubert-base-cased-nli-threeway'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
if torch.cuda.is_available():
    model.cuda()

text1 = 'Сократ - человек, а все люди смертны.'
text2 = 'Сократ никогда не умрёт.'
with torch.inference_mode():
    out = model(**tokenizer(text1, text2, return_tensors='pt').to(model.device))
    proba = torch.softmax(out.logits, -1).cpu().numpy()[0]
print({v: proba[k] for k, v in model.config.id2label.items()})
# {'entailment': 0.009525929, 'contradiction': 0.9332064, 'neutral': 0.05726764} 
```

You can also use this model for zero-shot short text classification (by labels only), e.g. for sentiment analysis:

```python
def predict_zero_shot(text, label_texts, model, tokenizer, label='entailment', normalize=True):
    label_texts
    tokens = tokenizer([text] * len(label_texts), label_texts, truncation=True, return_tensors='pt', padding=True)
    with torch.inference_mode():
        result = torch.softmax(model(**tokens.to(model.device)).logits, -1)
    proba = result[:, model.config.label2id[label]].cpu().numpy()
    if normalize:
        proba /= sum(proba)
    return proba

classes = ['Я доволен', 'Я недоволен']
predict_zero_shot('Какая гадость эта ваша заливная рыба!', classes, model, tokenizer)
# array([0.05609814, 0.9439019 ], dtype=float32)
predict_zero_shot('Какая вкусная эта ваша заливная рыба!', classes, model, tokenizer)
# array([0.9059292 , 0.09407079], dtype=float32)
```

Alternatively, you can use [Huggingface pipelines](https://huggingface.co/transformers/main_classes/pipelines.html) for inference.

## Sources
The model has been trained on a series of NLI datasets automatically translated to Russian from English.

Most datasets were taken [from the repo of Felipe Salvatore](https://github.com/felipessalvatore/NLI_datasets):
[JOCI](https://github.com/sheng-z/JOCI), 
[MNLI](https://cims.nyu.edu/~sbowman/multinli/), 
[MPE](https://aclanthology.org/I17-1011/), 
[SICK](http://www.lrec-conf.org/proceedings/lrec2014/pdf/363_Paper.pdf), 
[SNLI](https://nlp.stanford.edu/projects/snli/).

Some datasets obtained from the original sources:
[ANLI](https://github.com/facebookresearch/anli), 
[NLI-style FEVER](https://github.com/easonnie/combine-FEVER-NSMN/blob/master/other_resources/nli_fever.md),
[IMPPRES](https://github.com/facebookresearch/Imppres).

## Performance

The table below shows ROC AUC (one class vs rest) for five models on the corresponding *dev* sets:
- [tiny](https://huggingface.co/cointegrated/rubert-tiny-bilingual-nli): a small BERT predicting entailment vs not_entailment
- [twoway](https://huggingface.co/cointegrated/rubert-base-cased-nli-twoway): a base-sized BERT predicting entailment vs not_entailment
- [threeway](https://huggingface.co/cointegrated/rubert-base-cased-nli-threeway) (**this model**): a base-sized BERT predicting entailment vs contradiction vs neutral
- [vicgalle-xlm](https://huggingface.co/vicgalle/xlm-roberta-large-xnli-anli): a large multilingual NLI model
- [facebook-bart](https://huggingface.co/facebook/bart-large-mnli): a large multilingual NLI model


|model                   |add_one_rte|anli_r1|anli_r2|anli_r3|copa|fever|help|iie  |imppres|joci|mnli |monli|mpe |scitail|sick|snli|terra|total |
|------------------------|-----------|-------|-------|-------|----|-----|----|-----|-------|----|-----|-----|----|-------|----|----|-----|------|
|n_observations          |387        |1000   |1000   |1200   |200 |20474|3355|31232|7661   |939 |19647|269  |1000|2126   |500 |9831|307  |101128|
|tiny/entailment         |0.77       |0.59   |0.52   |0.53   |0.53|0.90 |0.81|0.78 |0.93   |0.81|0.82 |0.91 |0.81|0.78   |0.93|0.95|0.67 |0.77  |
|twoway/entailment       |0.89       |0.73   |0.61   |0.62   |0.58|0.96 |0.92|0.87 |0.99   |0.90|0.90 |0.99 |0.91|0.96   |0.97|0.97|0.87 |0.86  |
|threeway/entailment     |0.91       |0.75   |0.61   |0.61   |0.57|0.96 |0.56|0.61 |0.99   |0.90|0.91 |0.67 |0.92|0.84   |0.98|0.98|0.90 |0.80  |
|vicgalle-xlm/entailment |0.88       |0.79   |0.63   |0.66   |0.57|0.93 |0.56|0.62 |0.77   |0.80|0.90 |0.70 |0.83|0.84   |0.91|0.93|0.93 |0.78  |
|facebook-bart/entailment|0.51       |0.41   |0.43   |0.47   |0.50|0.74 |0.55|0.57 |0.60   |0.63|0.70 |0.52 |0.56|0.68   |0.67|0.72|0.64 |0.58  |
|threeway/contradiction  |           |0.71   |0.64   |0.61   |    |0.97 |    |     |1.00   |0.77|0.92 |     |0.89|       |0.99|0.98|     |0.85  |
|threeway/neutral        |           |0.79   |0.70   |0.62   |    |0.91 |    |     |0.99   |0.68|0.86 |     |0.79|       |0.96|0.96|     |0.83  |

For evaluation (and for training of the [tiny](https://huggingface.co/cointegrated/rubert-tiny-bilingual-nli) and [twoway](https://huggingface.co/cointegrated/rubert-base-cased-nli-twoway) models), some extra datasets were used: 
[Add-one RTE](https://cs.brown.edu/people/epavlick/papers/ans.pdf), 
[CoPA](https://people.ict.usc.edu/~gordon/copa.html), 
[IIE](https://aclanthology.org/I17-1100), and
[SCITAIL](https://allenai.org/data/scitail) taken from [the repo of Felipe Salvatore](https://github.com/felipessalvatore/NLI_datasets) and translatted,
[HELP](https://github.com/verypluming/HELP) and [MoNLI](https://github.com/atticusg/MoNLI) taken from the original sources and translated, 
and Russian [TERRa](https://russiansuperglue.com/ru/tasks/task_info/TERRa).