--- language: ru pipeline_tag: zero-shot-classification tags: - rubert - russian - nli - rte - zero-shot-classification widget: - text: "Я хочу поехать в Австралию" candidate_labels: "спорт,путешествия,музыка,кино,книги,наука,политика" hypothesis_template: "Тема текста - {}." --- # RuBERT base model (cased) fine-tuned for NLI (natural language inference) The model has been trained on a series of NLI datasets automatically translated to Russian from English [from this repo](https://github.com/felipessalvatore/NLI_datasets). It predicts the logical relationship between two short texts: entailment, contradiction, or neutral. How to run the model for NLI: ```python # !pip install transformers sentencepiece --quiet import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification model_checkpoint = 'cointegrated/rubert-base-cased-nli-threeway' tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint) if torch.cuda.is_available(): model.cuda() text1 = 'Сократ - человек, а все люди смертны.' text2 = 'Сократ никогда не умрёт.' with torch.inference_mode(): out = model(**tokenizer(text1, text2, return_tensors='pt').to(model.device)) proba = torch.softmax(out.logits, -1).cpu().numpy()[0] print({v: proba[k] for k, v in model.config.id2label.items()}) # {'entailment': 0.009525929, 'contradiction': 0.9332064, 'neutral': 0.05726764} ``` You can also use this model for zero-shot short text classification (by labels only), e.g. for sentiment analysis: ```python def predict_zero_shot(text, label_texts, model, tokenizer, label='entailment', normalize=True): label_texts tokens = tokenizer([text] * len(label_texts), label_texts, truncation=True, return_tensors='pt', padding=True) with torch.inference_mode(): result = torch.softmax(model(**tokens.to(model.device)).logits, -1) proba = result[:, model.config.label2id[label]].cpu().numpy() if normalize: proba /= sum(proba) return proba classes = ['Я доволен', 'Я недоволен'] predict_zero_shot('Какая гадость эта ваша заливная рыба!', classes, model, tokenizer) # array([0.05609814, 0.9439019 ], dtype=float32) predict_zero_shot('Какая вкусная эта ваша заливная рыба!', classes, model, tokenizer) # array([0.9059292 , 0.09407079], dtype=float32) ``` Alternatively, you can use [Huggingface pipelines](https://huggingface.co/transformers/main_classes/pipelines.html) for inference.