Stereotype_Detection / stereotype_detector /stereotype_detector.py
Zekun Wu
update
bfea8bb
raw
history blame contribute delete
No virus
2.63 kB
from typing import List
from transformers import pipeline, AutoTokenizer
import os
class Detector:
"""
A class for detecting various forms of bias in text using pre-trained models.
"""
def __init__(self, classifier):
"""
Initializes the detector with a specific model.
Args:
classifier (str): The type of classifier to use.
"""
self.classifier = classifier
self.classifier_model_mapping = {
"Token": "wu981526092/Token-Level-Stereotype-Detector",
"Sentence": "wu981526092/Sentence-Level-Stereotype-Detector"
}
if classifier not in self.classifier_model_mapping:
raise ValueError(f"Invalid classifier. Expected one of: {list(self.classifier_model_mapping.keys())}")
self.model_path = self.classifier_model_mapping[classifier]
API_token = os.getenv("BIAS_DETECTOR_API_KEY")
# Using pipeline for inference
if classifier == 'Token':
task_type = "ner"
self.model = pipeline(task_type, model=self.model_path, tokenizer=self.model_path,token=API_token)
else:
task_type = "text-classification"
self.model = pipeline(task_type, model=self.model_path, tokenizer=self.model_path,token=API_token,return_all_scores= True)
def predict(self, texts: List[str]):
"""
Predicts the bias of the given text or list of texts.
Args:
texts (List[str]): A list of strings to analyze.
Returns:
A list of dictionaries. Each dictionary contains the 'label' and 'score' for each text.
"""
if not all(isinstance(text, str) for text in texts):
raise ValueError("All elements in 'texts' should be of str type")
results = []
predictions = self.model(texts)
for text, prediction in zip(texts, predictions):
result = {}
if self.classifier == 'Token':
for item in prediction:
result[item['word']] = {item['entity']: item['score']}
elif self.classifier == 'Sentence':
result = {item['label'].split('__')[-1]: item['score'] for item in prediction}
results.append({text: result})
return results
if __name__ == '__main__':
detector_test = Detector("Sentence")
test_data = ["The girl performed poorly at reading in school.","Sir is an address showing respect for a person. It usually refers to a male."]
result = detector_test.predict(test_data)
print(result)
print(result[1][test_data[1]])