Number of classes

#32
by tdiaschmidt - opened

I would like to run this model with more than 10 classes. I read in another discussion that this is possible if I run the model locally. How can I do this?

I have same problem

I am having the same problem

This can be done as follows. Let's use a dummy dataset for demonstration purposes.

Let's say we want to zero-shot classify all the texts present in the stanfordnlp/imdb dataset.

The way models like "facebook/bart-large-mnli" work is by simply binary classifying the text with each possible candidate label (hence we have as many text-label pairs as we have candidate labels). The model then predicts whether the candidate label is entailed, neutral or contradicted with respect to the text.

from datasets import load_dataset
from transformers import AutoTokenizer

model_name = "facebook/bart-large-mnli"
tokenizer = AutoTokenizer.from_pretrained(model_name)

dataset = load_dataset("stanfordnlp/imdb")

# let's say we have 3 candidate labels
candidate_labels = ["sports", "science", "politics"]

def expand_dataset(example):
    
    all_texts = []
    all_candidate_labels = []

    for text in example["text"]:
        for label in candidate_labels:
            all_texts.append(text)
            all_candidate_labels.append(label)
            
    inputs = tokenizer(all_texts, all_candidate_labels, truncation=True, padding="max_length", return_tensors="pt")

    return inputs

# expand the dataset by converting each text into a set of (text, candidate label) pairs for the model
dataset = dataset["train"].map(expand_dataset, batched=True, remove_columns=dataset["train"].column_names)

Once we have prepared the data for the model, we can run a forward pass in batches (set the batch size as high as possible on your given hardware):

import torch
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(model_name)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)



@torch
	.no_grad()
def predict_labels(batch):

   # move batch to device
   batch = {k: v.to(device) for k,v in batch.items()}

    # forward pass
    outputs = model(**batch)

    # get the logits
    logits = outputs.logits

    # add predicted label (either "entailment", "neutral" or "contradicted" to dataset
    predicted_labels = outputs.logits.argmax(-1).tolist()
    predicted_labels = [model.config.id2label[id] for id in predicted_labels]

    batch["labels"] = predicted_labels

    return batch

dataset = dataset.map(predict_labels, batched=True, batch_size=4)

Sign up or log in to comment