Spaces:
Running
Running
import unittest | |
from unittest.mock import patch | |
import pandas as pd | |
import src.backend.evaluate_model as evaluate_model | |
import src.envs as envs | |
class TestEvaluator(unittest.TestCase): | |
def setUp(self): | |
self.model_name = 'test_model' | |
self.revision = 'test_revision' | |
self.precision = 'test_precision' | |
self.batch_size = 10 | |
self.device = 'test_device' | |
self.no_cache = False | |
self.limit = 10 | |
def test_evaluator_initialization(self, mock_eval_model, mock_summary_generator): | |
evaluator = evaluate_model.Evaluator(self.model_name, self.revision, | |
self.precision, self.batch_size, | |
self.device, self.no_cache, self.limit) | |
mock_summary_generator.assert_called_once_with(self.model_name, self.revision) | |
mock_eval_model.assert_called_once_with(envs.HEM_PATH) | |
self.assertEqual(evaluator.model, self.model_name) | |
def test_evaluator_initialization_error(self, mock_summary_generator, mock_eval_model): | |
mock_eval_model.side_effect = Exception('test_exception') | |
with self.assertRaises(Exception): | |
evaluate_model.Evaluator(self.model_name, self.revision, | |
self.precision, self.batch_size, | |
self.device, self.no_cache, self.limit) | |
def test_evaluate_method(self, mock_format_results, mock_read_csv, mock_eval_model, | |
mock_summary_generator): | |
evaluator = evaluate_model.Evaluator(self.model_name, self.revision, | |
self.precision, self.batch_size, | |
self.device, self.no_cache, self.limit) | |
# Mock setup | |
mock_format_results.return_value = {'test': 'result'} | |
mock_read_csv.return_value = pd.DataFrame({'column1': ['data1', 'data2']}) | |
mock_summary_generator.return_value.generate_summaries.return_value = pd.DataFrame({'column1': ['summary1', 'summary2']}) | |
mock_summary_generator.return_value.avg_length = 100 | |
mock_summary_generator.return_value.answer_rate = 1.0 | |
mock_summary_generator.return_value.error_rate = 0.0 | |
mock_eval_model.return_value.compute_accuracy.return_value = 1.0 | |
mock_eval_model.return_value.hallucination_rate = 0.0 | |
mock_eval_model.return_value.evaluate_hallucination.return_value = [0.5] | |
# Method call and assertions | |
results = evaluator.evaluate() | |
mock_format_results.assert_called_once_with(model_name=self.model_name, | |
revision=self.revision, | |
precision=self.precision, | |
accuracy=1.0, hallucination_rate=0.0, | |
answer_rate=1.0, avg_summary_len=100, | |
error_rate=0.0) | |
mock_read_csv.assert_called_once_with(envs.SOURCE_PATH) | |
def test_evaluate_with_file_not_found(self, mock_read_csv, mock_eval_model, | |
mock_summary_generator): | |
mock_read_csv.side_effect = FileNotFoundError('test_exception') | |
evaluator = evaluate_model.Evaluator(self.model_name, self.revision, | |
self.precision, self.batch_size, | |
self.device, self.no_cache, self.limit) | |
with self.assertRaises(FileNotFoundError): | |
evaluator.evaluate() | |
if __name__ == '__main__': | |
unittest.main() | |