File size: 1,955 Bytes
e53b6a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
```
# Load model directly
from transformers import AutoTokenizer, AutoModelForTokenClassification
tokenizer = AutoTokenizer.from_pretrained("newsmediabias/UnBIAS-Roberta-NER")
model = AutoModelForTokenClassification.from_pretrained("newsmediabias/UnBIAS-Roberta-NER")
# Example batch of sentences
sentences = [
"The corrupt politician embezzled funds.",
"Immigrants are causing a surge in crime.",
"The movie star is an idiot for their political views.",
"Only a fool would believe in climate change.",
"The new policy will destroy the economy."
]
# Tokenize the batch
encoding = tokenizer(sentences, return_tensors='pt', padding=True, truncation=True)
# Get model predictions
outputs = model(**encoding)
# Apply softmax to the output logits to get probabilities
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Get the highest probability labels for each token
predicted_labels = torch.argmax(predictions, dim=-1)
# Define a mapping for the labels
label_mapping = {
0: "O", # No bias
1: "B-BIAS", # Beginning of a biased sequence
2: "I-BIAS" # Inside a biased sequence
}
# Convert predicted labels to their corresponding label names using the mapping
labels = [[label_mapping[label_id.item()] for label_id in sentence_labels] for sentence_labels in predicted_labels]
# Align labels with the words in the sentences
aligned_labels = []
for i, sentence_labels in enumerate(labels):
# Get the tokens from the original sentence
tokens = tokenizer.convert_ids_to_tokens(encoding['input_ids'][i])
# Only consider labels for tokens that are not special tokens
sentence_labels = [label for token, label in zip(tokens, sentence_labels) if token not in tokenizer.all_special_tokens]
aligned_labels.append(sentence_labels)
# Print the aligned labels for each sentence
for sentence, labels in zip(sentences, aligned_labels):
print(f"Sentence: {sentence}\nLabels: {labels}\n")
|