from fastapi import APIRouter from datetime import datetime from datasets import load_dataset from sklearn.metrics import accuracy_score from .data.data_loaders import TextDataLoader from .models.text_classifiers import BaselineModel from .utils.evaluation import TextEvaluationRequest from .utils.emissions import get_tracker, clean_emissions_data, get_space_info, EmissionsData # define models from .models.text_classifiers import ModelFactory embedding_ml_model = ModelFactory.create_model({"model_type": "embeddingML"}) distilbert_model = ModelFactory.create_model({"model_type": "distilbert-pretrained", "model_name": "2025-01-27_17-00-47_DistilBERT_Model_fined-tuned_from_distilbert-base-uncased" }) model_to_evaluate = distilbert_model # define router router = APIRouter() DESCRIPTION = model_to_evaluate.description ROUTE = "/text" @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION) async def evaluate_text(request: TextEvaluationRequest, track_emissions: bool = True, model = distilbert_model, light_dataset: bool = False) -> dict: """ Evaluate text classification for climate disinformation detection. Parameters: ----------- request: TextEvaluationRequest The request object containing the dataset configuration. track_emissions: bool Whether to track emissions or not. model: TextClassifier The model to use for inference. light_dataset: bool Whether to use a light dataset or not. Returns: -------- dict A dictionary containing the evaluation results. """ # Get space info username, space_url = get_space_info() # Load the dataset test_dataset = TextDataLoader(request, light=light_dataset).get_test_dataset() # Start tracking emissions if track_emissions: tracker = get_tracker() tracker.start() tracker.start_task("inference") # model inference predictions = [model.predict(quote) for quote in test_dataset["quote"]] # Stop tracking emissions if track_emissions: emissions_data = tracker.stop_task() else: emissions_data = EmissionsData(0, 0) # Calculate accuracy true_labels = test_dataset["label"] accuracy = accuracy_score(true_labels, predictions) # Prepare results dictionary results = { "username": username, "space_url": space_url, "submission_timestamp": datetime.now().isoformat(), "model_description": DESCRIPTION, "accuracy": float(accuracy), "energy_consumed_wh": emissions_data.energy_consumed * 1000, "emissions_gco2eq": emissions_data.emissions * 1000, "emissions_data": clean_emissions_data(emissions_data), "api_route": ROUTE, "dataset_config": { "dataset_name": request.dataset_name, "test_size": request.test_size, "test_seed": request.test_seed } } return results