cointegrated commited on
Commit
f3b1f58
1 Parent(s): aa3b56b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +51 -2
README.md CHANGED
@@ -4,10 +4,59 @@ pipeline_tag: zero-shot-classification
4
  tags:
5
  - rubert
6
  - russian
 
 
 
7
  widget:
8
  - text: "Я хочу поехать в Австралию"
9
  candidate_labels: "спорт,путешествия,музыка,кино,книги,наука,политика"
10
  hypothesis_template: "Тема текста - {}."
11
  ---
12
- # RuBERT base model (cased) for NLI
13
- The model was trained on a series of NLI datasets translated to Russian from English [from this repo](https://github.com/felipessalvatore/NLI_datasets).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  tags:
5
  - rubert
6
  - russian
7
+ - nli
8
+ - rte
9
+ - zero-shot-classification
10
  widget:
11
  - text: "Я хочу поехать в Австралию"
12
  candidate_labels: "спорт,путешествия,музыка,кино,книги,наука,политика"
13
  hypothesis_template: "Тема текста - {}."
14
  ---
15
+ # RuBERT base model (cased) fine-tuned for NLI (natural language inference)
16
+ The model has been trained on a series of NLI datasets automatically translated to Russian from English [from this repo](https://github.com/felipessalvatore/NLI_datasets).
17
+
18
+ It predicts the logical relationship between two short texts: entailment, contradiction, or neutral.
19
+
20
+
21
+ How to run the model for NLI:
22
+ ```python
23
+ # !pip install transformers sentencepiece --quiet
24
+ import torch
25
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
26
+
27
+ model_checkpoint = 'cointegrated/rubert-base-cased-nli-threeway'
28
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
29
+ model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
30
+ if torch.cuda.is_available():
31
+ model.cuda()
32
+
33
+ text1 = 'Сократ - человек, а все люди смертны.'
34
+ text2 = 'Сократ никогда не умрёт.'
35
+ with torch.inference_mode():
36
+ out = model(**tokenizer(text1, text2, return_tensors='pt').to(model.device))
37
+ proba = torch.softmax(out.logits, -1).cpu().numpy()[0]
38
+ print({v: proba[k] for k, v in model.config.id2label.items()})
39
+ # {'entailment': 0.009525929, 'contradiction': 0.9332064, 'neutral': 0.05726764}
40
+ ```
41
+
42
+ You can also use this model for zero-shot short text classification (by labels only), e.g. for sentiment analysis:
43
+
44
+ ```python
45
+ def predict_zero_shot(text, label_texts, model, tokenizer, label='entailment', normalize=True):
46
+ label_texts
47
+ tokens = tokenizer([text] * len(label_texts), label_texts, truncation=True, return_tensors='pt', padding=True)
48
+ with torch.inference_mode():
49
+ result = torch.softmax(model(**tokens.to(model.device)).logits, -1)
50
+ proba = result[:, model.config.label2id[label]].cpu().numpy()
51
+ if normalize:
52
+ proba /= sum(proba)
53
+ return proba
54
+
55
+ classes = ['Я доволен', 'Я недоволен']
56
+ predict_zero_shot('Какая гадость эта ваша заливная рыба!', classes, model, tokenizer)
57
+ # array([0.05609814, 0.9439019 ], dtype=float32)
58
+ predict_zero_shot('Какая вкусная эта ваша заливная рыба!', classes, model, tokenizer)
59
+ # array([0.9059292 , 0.09407079], dtype=float32)
60
+ ```
61
+
62
+ Alternatively, you can use [Huggingface pipelines](https://huggingface.co/transformers/main_classes/pipelines.html) for inference.