zeroshotGPU / tests /test_logging.py
Arjunvir Singh
Initial commit: zeroshotGPU MVP with full eval surface
db06ffa
"""Tests for the logging configuration and structured log emission."""
from __future__ import annotations
import io
import json
import logging
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
from zsgdp.logging_config import configure_logging, get_logger
from zsgdp.pipeline import parse_document
class ConfigureLoggingTests(unittest.TestCase):
def setUp(self) -> None:
# Reset between tests so each one configures cleanly.
root = logging.getLogger("zsgdp")
for handler in list(root.handlers):
root.removeHandler(handler)
def test_idempotent_configuration(self):
stream = io.StringIO()
configure_logging(level="INFO", json_format=False, stream=stream)
configure_logging(level="INFO", json_format=False, stream=stream)
root = logging.getLogger("zsgdp")
# Idempotent: still exactly one handler attached.
self.assertEqual(len(root.handlers), 1)
def test_text_format_emits_human_readable(self):
stream = io.StringIO()
configure_logging(level="INFO", json_format=False, stream=stream)
get_logger("zsgdp.test").info("hello", extra={"doc_id": "d1"})
output = stream.getvalue()
self.assertIn("INFO", output)
self.assertIn("zsgdp.test", output)
self.assertIn("hello", output)
def test_json_format_emits_one_line_records(self):
stream = io.StringIO()
configure_logging(level="INFO", json_format=True, stream=stream)
get_logger("zsgdp.test").info("event", extra={"doc_id": "abc", "count": 3})
output = stream.getvalue().strip()
record = json.loads(output)
self.assertEqual(record["level"], "INFO")
self.assertEqual(record["logger"], "zsgdp.test")
self.assertEqual(record["message"], "event")
self.assertEqual(record["doc_id"], "abc")
self.assertEqual(record["count"], 3)
def test_default_level_is_warning(self):
stream = io.StringIO()
with patch.dict("os.environ", {"ZSGDP_LOG_LEVEL": "", "ZSGDP_LOG_JSON": ""}, clear=False):
configure_logging(stream=stream)
get_logger("zsgdp.test").info("hidden_at_default_level")
self.assertNotIn("hidden_at_default_level", stream.getvalue())
get_logger("zsgdp.test").warning("visible_at_default_level")
self.assertIn("visible_at_default_level", stream.getvalue())
def test_get_logger_namespacing(self):
self.assertEqual(get_logger().name, "zsgdp")
self.assertEqual(get_logger("zsgdp.foo").name, "zsgdp.foo")
# Bare names get prefixed.
self.assertEqual(get_logger("foo").name, "zsgdp.foo")
class PipelineLogEmissionTests(unittest.TestCase):
def test_parse_emits_start_and_end_records(self):
# Reset handlers so assertLogs works against the named logger.
root = logging.getLogger("zsgdp")
for handler in list(root.handlers):
root.removeHandler(handler)
root.setLevel(logging.DEBUG)
root.propagate = True
with tempfile.TemporaryDirectory() as tmp:
input_path = Path(tmp) / "doc.md"
input_path.write_text("# Doc\n\nHello.\n", encoding="utf-8")
with self.assertLogs("zsgdp.pipeline", level="INFO") as captured:
parse_document(input_path, Path(tmp) / "out")
messages = [record.message for record in captured.records]
self.assertIn("parse_start", messages)
self.assertIn("parse_end", messages)
# Find a parse_end record and assert structured fields are present.
parse_end = next(record for record in captured.records if record.message == "parse_end")
self.assertTrue(hasattr(parse_end, "doc_id"))
self.assertTrue(hasattr(parse_end, "elapsed_seconds"))
self.assertTrue(hasattr(parse_end, "quality_score"))
self.assertTrue(hasattr(parse_end, "element_count"))
class RepairLogEmissionTests(unittest.TestCase):
def test_repair_emits_iteration_record(self):
root = logging.getLogger("zsgdp")
for handler in list(root.handlers):
root.removeHandler(handler)
root.setLevel(logging.DEBUG)
root.propagate = True
with tempfile.TemporaryDirectory() as tmp:
input_path = Path(tmp) / "report.md"
# Malformed table forces a repair iteration.
input_path.write_text(
"# Report\n\n| A | B |\n| --- | --- |\n| 1 | 2 | 3 |\n",
encoding="utf-8",
)
with self.assertLogs("zsgdp.repair.controller", level="INFO") as captured:
parse_document(input_path, Path(tmp) / "out")
repair_records = [r for r in captured.records if r.message == "repair_iteration"]
self.assertGreaterEqual(len(repair_records), 1)
# Each record carries the iteration index.
for record in repair_records:
self.assertTrue(hasattr(record, "iteration"))
self.assertTrue(hasattr(record, "status"))
if __name__ == "__main__":
unittest.main()