--- datasets: - snli - anli - multi_nli - multi_nli_mismatch - fever license: mit --- This is a strong pre-trained RoBERTa-Large NLI model. The training data is a combination of well-known NLI datasets: [`SNLI`](https://nlp.stanford.edu/projects/snli/), [`MNLI`](https://cims.nyu.edu/~sbowman/multinli/), [`FEVER-NLI`](https://github.com/easonnie/combine-FEVER-NSMN/blob/master/other_resources/nli_fever.md), [`ANLI (R1, R2, R3)`](https://github.com/facebookresearch/anli). Other pre-trained NLI models including `RoBERTa`, `ALBert`, `BART`, `ELECTRA`, `XLNet` are also available. Trained by [Yixin Nie](https://easonnie.github.io), [original source](https://github.com/facebookresearch/anli). Try the code snippet below. ``` from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch if __name__ == '__main__': max_length = 256 premise = "Two women are embracing while holding to go packages." hypothesis = "The men are fighting outside a deli." hg_model_hub_name = "ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli" # hg_model_hub_name = "ynie/albert-xxlarge-v2-snli_mnli_fever_anli_R1_R2_R3-nli" # hg_model_hub_name = "ynie/bart-large-snli_mnli_fever_anli_R1_R2_R3-nli" # hg_model_hub_name = "ynie/electra-large-discriminator-snli_mnli_fever_anli_R1_R2_R3-nli" # hg_model_hub_name = "ynie/xlnet-large-cased-snli_mnli_fever_anli_R1_R2_R3-nli" tokenizer = AutoTokenizer.from_pretrained(hg_model_hub_name) model = AutoModelForSequenceClassification.from_pretrained(hg_model_hub_name) tokenized_input_seq_pair = tokenizer.encode_plus(premise, hypothesis, max_length=max_length, return_token_type_ids=True, truncation=True) input_ids = torch.Tensor(tokenized_input_seq_pair['input_ids']).long().unsqueeze(0) # remember bart doesn't have 'token_type_ids', remove the line below if you are using bart. token_type_ids = torch.Tensor(tokenized_input_seq_pair['token_type_ids']).long().unsqueeze(0) attention_mask = torch.Tensor(tokenized_input_seq_pair['attention_mask']).long().unsqueeze(0) outputs = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=None) # Note: # "id2label": { # "0": "entailment", # "1": "neutral", # "2": "contradiction" # }, predicted_probability = torch.softmax(outputs[0], dim=1)[0].tolist() # batch_size only one print("Premise:", premise) print("Hypothesis:", hypothesis) print("Entailment:", predicted_probability[0]) print("Neutral:", predicted_probability[1]) print("Contradiction:", predicted_probability[2]) ``` More in [here](https://github.com/facebookresearch/anli/blob/master/src/hg_api/interactive_eval.py). Citation: ``` @inproceedings{nie-etal-2020-adversarial, title = "Adversarial {NLI}: A New Benchmark for Natural Language Understanding", author = "Nie, Yixin and Williams, Adina and Dinan, Emily and Bansal, Mohit and Weston, Jason and Kiela, Douwe", booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics", year = "2020", publisher = "Association for Computational Linguistics", } ```