Extracting Label

#1
by Rogerx98 - opened

Hi!

I'm trying to classify the sentences of a specific column into three labels with Bart Large MNLI. The problem is that the output of the model is "sentence + the three labels + the scores for each label. Example:

{'sequence': 'Growing special event set production/fabrication company is seeking a full-time accountant with experience in entertainment accounting. This position is located in a fast-paced production office located near downtown Los Angeles.Responsibilities:• Payroll management for 12+ employees, including processing new employee paperwork.', 'labels': ['senior', 'middle', 'junior'], 'scores': [0.5461998581886292, 0.327671617269516, 0.12612852454185486]}

What I need is to get a single column with only the label with the highest score, in this case "senior".

Any feedback which can help me to do it? Right now my code looks like:

df_test = df.sample(frac = 0.0025)
#print(df_test)

classifier = pipeline("zero-shot-classification",
                      model="facebook/bart-large-mnli")

sequence_to_classify = df_test["full_description"]
candidate_labels = ['senior', 'middle', 'junior']
#result = classifier(sequence_to_classify, candidate_labels)


df_test["seniority_label"] = df_test.apply(lambda x: classifier(x.full_description, candidate_labels, multi_label=True,), axis=1)
print(df_test)

df_test.to_csv("Seniority_Classified_SampleTest.csv")

And the code I've followed comes from this web, where they do receive a column with labels as an output idk how: https://practicaldatascience.co.uk/machine-learning/how-to-classify-customer-service-emails-with-bart-mnli

( @Rogerx98 took the liberty of formatting your code blocks)

Sorry I don't fully understand here @Rogerx98 - If the output is in the format as you described above can't you just always just take the label corresponding of the highest probability?

index = torch.argmax(classifier_output["scores"])
label = classifier_output["labels"][index]

I assume I'm missing something here?

Thanks @patrickvonplaten , that worked, with a for loop I could extract the labels and add them into a new column!

Sign up or log in to comment