Spaces:
Runtime error
Runtime error
from typing import List | |
from transformers import pipeline, AutoTokenizer | |
import os | |
class Detector: | |
""" | |
A class for detecting various forms of bias in text using pre-trained models. | |
""" | |
def __init__(self, classifier): | |
""" | |
Initializes the detector with a specific model. | |
Args: | |
classifier (str): The type of classifier to use. | |
""" | |
self.classifier = classifier | |
self.classifier_model_mapping = { | |
"Token": "wu981526092/Token-Level-Stereotype-Detector", | |
"Sentence": "wu981526092/Sentence-Level-Stereotype-Detector" | |
} | |
if classifier not in self.classifier_model_mapping: | |
raise ValueError(f"Invalid classifier. Expected one of: {list(self.classifier_model_mapping.keys())}") | |
self.model_path = self.classifier_model_mapping[classifier] | |
API_token = os.getenv("BIAS_DETECTOR_API_KEY") | |
# Using pipeline for inference | |
if classifier == 'Token': | |
task_type = "ner" | |
self.model = pipeline(task_type, model=self.model_path, tokenizer=self.model_path,token=API_token) | |
else: | |
task_type = "text-classification" | |
self.model = pipeline(task_type, model=self.model_path, tokenizer=self.model_path,token=API_token,return_all_scores= True) | |
def predict(self, texts: List[str]): | |
""" | |
Predicts the bias of the given text or list of texts. | |
Args: | |
texts (List[str]): A list of strings to analyze. | |
Returns: | |
A list of dictionaries. Each dictionary contains the 'label' and 'score' for each text. | |
""" | |
if not all(isinstance(text, str) for text in texts): | |
raise ValueError("All elements in 'texts' should be of str type") | |
results = [] | |
predictions = self.model(texts) | |
for text, prediction in zip(texts, predictions): | |
result = {} | |
if self.classifier == 'Token': | |
for item in prediction: | |
result[item['word']] = {item['entity']: item['score']} | |
elif self.classifier == 'Sentence': | |
result = {item['label'].split('__')[-1]: item['score'] for item in prediction} | |
results.append({text: result}) | |
return results | |
if __name__ == '__main__': | |
detector_test = Detector("Sentence") | |
test_data = ["The girl performed poorly at reading in school.","Sir is an address showing respect for a person. It usually refers to a male."] | |
result = detector_test.predict(test_data) | |
print(result) | |
print(result[1][test_data[1]]) | |