Thomas Müller
Init commit-
aa9810a
metadata
language:
  - ar
  - bg
  - de
  - el
  - en
  - es
  - fr
  - ru
  - th
  - tr
  - ur
  - vn
  - zh
datasets:
  - SNLI
  - MNLI
  - ANLI
  - XNLI
tags:
  - zero-shot-classification

A cross attention NLI model trained for zero-shot and few-shot text classification.

The base model is xlm-roberta-base, trained with the code from here. on SNLI and MNLI, ANLI and XNLI.

Usage:

from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
import numpy as np

model = AutoModelForSequenceClassification.from_pretrained("symanto/xlm-roberta-base-snli-mnli-anli-xnli")
tokenizer = AutoTokenizer.from_pretrained("symanto/xlm-roberta-base-snli-mnli-anli-xnli")

input_pairs = [("I like this pizza.", "The sentence is positive."), ("I like this pizza.", "The sentence is negative.")]
inputs = tokenizer(["</s></s>".join(input_pair) for input_pair in input_pairs], return_tensors="pt")
logits = model(**inputs).logits
probs =  torch.softmax(logits, dim=1).tolist()
print("probs", probs)
np.testing.assert_almost_equal(probs, [[0.86, 0.14, 0.00], [0.16, 0.15, 0.69]], decimal=2)