| """Tests for the style fingerprinting module.""" |
|
|
| import pytest |
| import torch |
| from src.style.fingerprinter import StyleFingerprinter, StyleProjectionMLP |
| from src.style.style_vector import cosine_similarity, average_style_vectors |
|
|
|
|
| @pytest.fixture |
| def fingerprinter(tmp_path): |
| awl = tmp_path / "awl.txt" |
| awl.write_text("analysis\nconsider\nestablish\nsignificant\n") |
| return StyleFingerprinter(spacy_model="en_core_web_sm", awl_path=str(awl)) |
|
|
|
|
| def test_style_vector_shape(fingerprinter): |
| """Test that style vectors have correct dimensionality.""" |
| vec = fingerprinter.extract_vector("This is a test sentence for analysis.") |
| assert vec.shape == (512,) |
|
|
|
|
| def test_style_vector_different_texts(fingerprinter): |
| """Test that different writing styles produce different vectors.""" |
| formal = "The analysis demonstrates significant correlations between variables." |
| informal = "yo this stuff is like totally awesome and cool" |
| v1 = fingerprinter.extract_vector(formal) |
| v2 = fingerprinter.extract_vector(informal) |
| sim = cosine_similarity(v1, v2) |
| assert sim < 0.99 |
|
|
|
|
| def test_style_blend(fingerprinter): |
| """Test that blended vectors have unit norm.""" |
| v1 = fingerprinter.extract_vector("Academic formal text with analysis.") |
| v2 = fingerprinter.extract_vector("Casual informal text with stuff.") |
| blended = fingerprinter.blend_vectors(v1, v2, alpha=0.6) |
| norm = torch.norm(blended).item() |
| assert abs(norm - 1.0) < 0.01 |
|
|
|
|
| def test_raw_features_keys(fingerprinter): |
| """Test that raw features contain expected keys.""" |
| features = fingerprinter.extract_raw_features("The quick brown fox jumps over the lazy dog.") |
| assert "sentence_length_mean" in features |
| assert "type_token_ratio" in features |
| assert "passive_voice_ratio" in features |
| assert "lexical_density" in features |
|
|