Edit model card

Fb_improved_zeroshot

Zero-Shot Model designed to classify academic search logs in German and English. Developed by students at ETH Zรผrich.

This model was trained using the bart-large-mnli checkpoint provided by Meta on Huggingface. It was then fine-tuned to suit the needs of this project.

NLI-based Zero-Shot Text Classification

This method is based on Natural Language Inference (NLI), see Yin et al.. The following tutorials are taken from the model card of bart-large-mnli.

With the zero-shot classification pipeline

The model can be loaded with the zero-shot-classification pipeline like so:

from transformers import pipeline
classifier = pipeline("zero-shot-classification",
                      model="oigele/Fb_improved_zeroshot")

You can then use this pipeline to classify sequences into any of the class names you specify.

sequence_to_classify = "natural language processing"
candidate_labels = ['Location & Address', 'Employment', 'Organizational', 'Name', 'Service', 'Studies', 'Science']
classifier(sequence_to_classify, candidate_labels)

If more than one candidate label can be correct, pass multi_class=True to calculate each class independently:

candidate_labels = ['Location & Address', 'Employment', 'Organizational', 'Name', 'Service', 'Studies', 'Science']
classifier(sequence_to_classify, candidate_labels, multi_class=True)

With manual PyTorch

# pose sequence as a NLI premise and label as a hypothesis
from transformers import AutoModelForSequenceClassification, AutoTokenizer
nli_model = AutoModelForSequenceClassification.from_pretrained('oigele/Fb_improved_zeroshot/')
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')
premise = sequence
hypothesis = f'This is {label}.'
# run through model pre-trained on MNLI
x = tokenizer.encode(premise, hypothesis, return_tensors='pt',
                     truncation_strategy='only_first')
logits = nli_model(x.to(device))[0]
# we throw away "neutral" (dim 1) and take the probability of
# "entailment" (2) as the probability of the label being true 
entail_contradiction_logits = logits[:,[0,2]]
probs = entail_contradiction_logits.softmax(dim=1)
prob_label_is_true = probs[:,1]
Downloads last month
14
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Dataset used to train oigele/Fb_improved_zeroshot

Spaces using oigele/Fb_improved_zeroshot 4