cointegrated's picture
Update README.md
fc1c751
---
language: ru
pipeline_tag: zero-shot-classification
tags:
- rubert
- russian
- nli
- rte
- zero-shot-classification
widget:
- text: "Я хочу поехать в Австралию"
candidate_labels: "спорт,путешествия,музыка,кино,книги,наука,политика"
hypothesis_template: "Тема текста - {}."
---
# RuBERT for NLI (natural language inference)
This is the [DeepPavlov/rubert-base-cased](https://huggingface.co/DeepPavlov/rubert-base-cased) fine-tuned to predict the logical relationship between two short texts: entailment, contradiction, or neutral.
## Usage
How to run the model for NLI:
```python
# !pip install transformers sentencepiece --quiet
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
model_checkpoint = 'cointegrated/rubert-base-cased-nli-threeway'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
if torch.cuda.is_available():
model.cuda()
text1 = 'Сократ - человек, а все люди смертны.'
text2 = 'Сократ никогда не умрёт.'
with torch.inference_mode():
out = model(**tokenizer(text1, text2, return_tensors='pt').to(model.device))
proba = torch.softmax(out.logits, -1).cpu().numpy()[0]
print({v: proba[k] for k, v in model.config.id2label.items()})
# {'entailment': 0.009525929, 'contradiction': 0.9332064, 'neutral': 0.05726764}
```
You can also use this model for zero-shot short text classification (by labels only), e.g. for sentiment analysis:
```python
def predict_zero_shot(text, label_texts, model, tokenizer, label='entailment', normalize=True):
label_texts
tokens = tokenizer([text] * len(label_texts), label_texts, truncation=True, return_tensors='pt', padding=True)
with torch.inference_mode():
result = torch.softmax(model(**tokens.to(model.device)).logits, -1)
proba = result[:, model.config.label2id[label]].cpu().numpy()
if normalize:
proba /= sum(proba)
return proba
classes = ['Я доволен', 'Я недоволен']
predict_zero_shot('Какая гадость эта ваша заливная рыба!', classes, model, tokenizer)
# array([0.05609814, 0.9439019 ], dtype=float32)
predict_zero_shot('Какая вкусная эта ваша заливная рыба!', classes, model, tokenizer)
# array([0.9059292 , 0.09407079], dtype=float32)
```
Alternatively, you can use [Huggingface pipelines](https://huggingface.co/transformers/main_classes/pipelines.html) for inference.
## Sources
The model has been trained on a series of NLI datasets automatically translated to Russian from English.
Most datasets were taken [from the repo of Felipe Salvatore](https://github.com/felipessalvatore/NLI_datasets):
[JOCI](https://github.com/sheng-z/JOCI),
[MNLI](https://cims.nyu.edu/~sbowman/multinli/),
[MPE](https://aclanthology.org/I17-1011/),
[SICK](http://www.lrec-conf.org/proceedings/lrec2014/pdf/363_Paper.pdf),
[SNLI](https://nlp.stanford.edu/projects/snli/).
Some datasets obtained from the original sources:
[ANLI](https://github.com/facebookresearch/anli),
[NLI-style FEVER](https://github.com/easonnie/combine-FEVER-NSMN/blob/master/other_resources/nli_fever.md),
[IMPPRES](https://github.com/facebookresearch/Imppres).
## Performance
The table below shows ROC AUC for three models on small samples of the DEV sets:
- [tiny](https://huggingface.co/cointegrated/rubert-tiny-bilingual-nli): a small BERT predicting entailment vs not_entailment
- [twoway](https://huggingface.co/cointegrated/rubert-base-cased-nli-twoway): a base-sized BERT predicting entailment vs not_entailment
- [threeway](https://huggingface.co/cointegrated/rubert-base-cased-nli-threeway) (**this model**): a base-sized BERT predicting entailment vs contradiction vs neutral
|model |tiny/entailment|twoway/entailment|threeway/entailment|threeway/contradiction|threeway/neutral|
|-----------|---------------|-----------------|-------------------|-------------------------|-------------------|
|add_one_rte|0.82 |0.90 |0.92 | | |
|anli_r1 |0.50 |0.68 |0.66 |0.70 |0.75 |
|anli_r2 |0.55 |0.62 |0.62 |0.62 |0.69 |
|anli_r3 |0.50 |0.63 |0.59 |0.62 |0.64 |
|copa |0.55 |0.60 |0.62 | | |
|fever |0.88 |0.94 |0.94 |0.91 |0.92 |
|help |0.74 |0.87 |0.46 | | |
|iie |0.79 |0.85 |0.54 | | |
|imppres |0.94 |0.99 |0.99 |0.99 |0.99 |
|joci |0.87 |0.93 |0.93 |0.85 |0.80 |
|mnli |0.87 |0.92 |0.93 |0.89 |0.86 |
|monli |0.94 |1.00 |0.67 | | |
|mpe |0.82 |0.90 |0.90 |0.91 |0.80 |
|scitail |0.80 |0.96 |0.85 | | |
|sick |0.97 |0.99 |0.99 |0.98 |0.96 |
|snli |0.95 |0.98 |0.98 |0.99 |0.97 |
|terra |0.73 |0.93 |0.93 | | |