Spaces:
Sleeping
Sleeping
| """ | |
| Rigorous Tests for Encryption Service. | |
| Tests cover: | |
| 1. Private key loading (env, file, caching) | |
| 2. Direct RSA-OAEP decryption | |
| 3. Hybrid RSA+AES-GCM decryption | |
| 4. Main decrypt_data entry point | |
| 5. Multiple block decryption | |
| 6. Error handling and edge cases | |
| Uses real cryptographic operations with test keypairs. | |
| """ | |
| import pytest | |
| import base64 | |
| import json | |
| import os | |
| import tempfile | |
| from unittest.mock import patch, MagicMock | |
| from cryptography.hazmat.primitives import serialization, hashes | |
| from cryptography.hazmat.primitives.asymmetric import rsa, padding | |
| from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes | |
| from cryptography.hazmat.backends import default_backend | |
| # ============================================================================= | |
| # Test Fixtures - Generate RSA keypair for testing | |
| # ============================================================================= | |
| def test_keypair(): | |
| """Generate RSA keypair for testing.""" | |
| private_key = rsa.generate_private_key( | |
| public_exponent=65537, | |
| key_size=2048, | |
| backend=default_backend() | |
| ) | |
| public_key = private_key.public_key() | |
| # Get PEM encoded private key | |
| private_pem = private_key.private_bytes( | |
| encoding=serialization.Encoding.PEM, | |
| format=serialization.PrivateFormat.TraditionalOpenSSL, | |
| encryption_algorithm=serialization.NoEncryption() | |
| ).decode('utf-8') | |
| return { | |
| "private_key": private_key, | |
| "public_key": public_key, | |
| "private_pem": private_pem | |
| } | |
| def encrypt_direct(public_key, plaintext: str) -> str: | |
| """Encrypt data using RSA-OAEP (for testing).""" | |
| encrypted = public_key.encrypt( | |
| plaintext.encode('utf-8'), | |
| padding.OAEP( | |
| mgf=padding.MGF1(algorithm=hashes.SHA256()), | |
| algorithm=hashes.SHA256(), | |
| label=None | |
| ) | |
| ) | |
| payload = { | |
| "type": "direct", | |
| "data": base64.b64encode(encrypted).decode('utf-8') | |
| } | |
| return base64.b64encode(json.dumps(payload).encode('utf-8')).decode('utf-8') | |
| def encrypt_hybrid(public_key, plaintext: str) -> str: | |
| """Encrypt data using hybrid RSA+AES-GCM (for testing).""" | |
| # Generate random AES key and IV | |
| aes_key = os.urandom(32) # 256-bit AES key | |
| iv = os.urandom(12) # 96-bit IV for GCM | |
| # Encrypt plaintext with AES-GCM | |
| cipher = Cipher( | |
| algorithms.AES(aes_key), | |
| modes.GCM(iv), | |
| backend=default_backend() | |
| ) | |
| encryptor = cipher.encryptor() | |
| ciphertext = encryptor.update(plaintext.encode('utf-8')) + encryptor.finalize() | |
| # Append auth tag to ciphertext | |
| encrypted_data = ciphertext + encryptor.tag | |
| # Encrypt AES key with RSA-OAEP | |
| encrypted_aes_key = public_key.encrypt( | |
| aes_key, | |
| padding.OAEP( | |
| mgf=padding.MGF1(algorithm=hashes.SHA256()), | |
| algorithm=hashes.SHA256(), | |
| label=None | |
| ) | |
| ) | |
| payload = { | |
| "type": "hybrid", | |
| "key": base64.b64encode(encrypted_aes_key).decode('utf-8'), | |
| "iv": base64.b64encode(iv).decode('utf-8'), | |
| "data": base64.b64encode(encrypted_data).decode('utf-8') | |
| } | |
| return base64.b64encode(json.dumps(payload).encode('utf-8')).decode('utf-8') | |
| # ============================================================================= | |
| # 1. Private Key Loading Tests | |
| # ============================================================================= | |
| class TestPrivateKeyLoading: | |
| """Test load_private_key function.""" | |
| def test_load_key_from_env_variable(self, test_keypair): | |
| """Load private key from PRIVATE_KEY env variable.""" | |
| import services.encryption_service as es | |
| es._private_key = None # Reset cache | |
| with patch.dict(os.environ, {"PRIVATE_KEY": test_keypair["private_pem"]}): | |
| key = es.load_private_key() | |
| assert key is not None | |
| def test_load_key_from_file(self, test_keypair): | |
| """Load private key from file when env var missing.""" | |
| import services.encryption_service as es | |
| es._private_key = None # Reset cache | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.pem', delete=False) as f: | |
| f.write(test_keypair["private_pem"]) | |
| temp_path = f.name | |
| try: | |
| with patch.dict(os.environ, {}, clear=True): | |
| os.environ.pop("PRIVATE_KEY", None) | |
| with patch.object(es, 'PRIVATE_KEY_PATH', temp_path): | |
| es._private_key = None | |
| key = es.load_private_key() | |
| assert key is not None | |
| finally: | |
| os.unlink(temp_path) | |
| def test_returns_none_when_no_key(self): | |
| """Return None when both env and file are missing.""" | |
| import services.encryption_service as es | |
| es._private_key = None # Reset cache | |
| with patch.dict(os.environ, {}, clear=True): | |
| os.environ.pop("PRIVATE_KEY", None) | |
| with patch.object(es, 'PRIVATE_KEY_PATH', '/nonexistent/path.pem'): | |
| es._private_key = None | |
| key = es.load_private_key() | |
| assert key is None | |
| def test_key_is_cached(self, test_keypair): | |
| """Key is cached after first load.""" | |
| import services.encryption_service as es | |
| es._private_key = None # Reset cache | |
| with patch.dict(os.environ, {"PRIVATE_KEY": test_keypair["private_pem"]}): | |
| key1 = es.load_private_key() | |
| key2 = es.load_private_key() | |
| assert key1 is key2 | |
| def test_invalid_pem_handling(self): | |
| """Invalid PEM content falls back to file.""" | |
| import services.encryption_service as es | |
| es._private_key = None # Reset cache | |
| with patch.dict(os.environ, {"PRIVATE_KEY": "not-valid-pem"}): | |
| with patch.object(es, 'PRIVATE_KEY_PATH', '/nonexistent/path.pem'): | |
| es._private_key = None | |
| key = es.load_private_key() | |
| assert key is None # Falls through to None | |
| # ============================================================================= | |
| # 2. Direct RSA Decryption Tests | |
| # ============================================================================= | |
| class TestDirectDecryption: | |
| """Test decrypt_direct function.""" | |
| def test_decrypt_valid_rsa_data(self, test_keypair): | |
| """Decrypt valid RSA-OAEP encrypted data.""" | |
| from services.encryption_service import decrypt_direct | |
| plaintext = "Hello, World!" | |
| encrypted = test_keypair["public_key"].encrypt( | |
| plaintext.encode('utf-8'), | |
| padding.OAEP( | |
| mgf=padding.MGF1(algorithm=hashes.SHA256()), | |
| algorithm=hashes.SHA256(), | |
| label=None | |
| ) | |
| ) | |
| payload = {"data": base64.b64encode(encrypted).decode('utf-8')} | |
| result = decrypt_direct(payload, test_keypair["private_key"]) | |
| assert result == plaintext | |
| def test_invalid_base64(self, test_keypair): | |
| """Handle invalid base64 input.""" | |
| from services.encryption_service import decrypt_direct | |
| payload = {"data": "not-valid-base64!!!"} | |
| with pytest.raises(Exception): | |
| decrypt_direct(payload, test_keypair["private_key"]) | |
| def test_corrupted_encrypted_data(self, test_keypair): | |
| """Handle corrupted encrypted data.""" | |
| from services.encryption_service import decrypt_direct | |
| # Random bytes that aren't valid RSA ciphertext | |
| payload = {"data": base64.b64encode(os.urandom(256)).decode('utf-8')} | |
| with pytest.raises(Exception): | |
| decrypt_direct(payload, test_keypair["private_key"]) | |
| # ============================================================================= | |
| # 3. Hybrid RSA+AES-GCM Decryption Tests | |
| # ============================================================================= | |
| class TestHybridDecryption: | |
| """Test decrypt_hybrid function.""" | |
| def test_decrypt_valid_hybrid_data(self, test_keypair): | |
| """Decrypt valid hybrid RSA+AES-GCM data.""" | |
| from services.encryption_service import decrypt_hybrid | |
| plaintext = "This is a longer message that exceeds 190 bytes and needs hybrid encryption!" | |
| # Encrypt with test helper | |
| aes_key = os.urandom(32) | |
| iv = os.urandom(12) | |
| cipher = Cipher(algorithms.AES(aes_key), modes.GCM(iv), backend=default_backend()) | |
| encryptor = cipher.encryptor() | |
| ciphertext = encryptor.update(plaintext.encode('utf-8')) + encryptor.finalize() | |
| encrypted_data = ciphertext + encryptor.tag | |
| encrypted_aes_key = test_keypair["public_key"].encrypt( | |
| aes_key, | |
| padding.OAEP( | |
| mgf=padding.MGF1(algorithm=hashes.SHA256()), | |
| algorithm=hashes.SHA256(), | |
| label=None | |
| ) | |
| ) | |
| payload = { | |
| "key": base64.b64encode(encrypted_aes_key).decode('utf-8'), | |
| "iv": base64.b64encode(iv).decode('utf-8'), | |
| "data": base64.b64encode(encrypted_data).decode('utf-8') | |
| } | |
| result = decrypt_hybrid(payload, test_keypair["private_key"]) | |
| assert result == plaintext | |
| def test_tampered_ciphertext_fails(self, test_keypair): | |
| """Tampered ciphertext fails GCM authentication.""" | |
| from services.encryption_service import decrypt_hybrid | |
| plaintext = "Original message" | |
| aes_key = os.urandom(32) | |
| iv = os.urandom(12) | |
| cipher = Cipher(algorithms.AES(aes_key), modes.GCM(iv), backend=default_backend()) | |
| encryptor = cipher.encryptor() | |
| ciphertext = encryptor.update(plaintext.encode('utf-8')) + encryptor.finalize() | |
| encrypted_data = ciphertext + encryptor.tag | |
| # Tamper with ciphertext | |
| tampered_data = bytearray(encrypted_data) | |
| tampered_data[0] ^= 0xFF # Flip bits | |
| encrypted_aes_key = test_keypair["public_key"].encrypt( | |
| aes_key, | |
| padding.OAEP( | |
| mgf=padding.MGF1(algorithm=hashes.SHA256()), | |
| algorithm=hashes.SHA256(), | |
| label=None | |
| ) | |
| ) | |
| payload = { | |
| "key": base64.b64encode(encrypted_aes_key).decode('utf-8'), | |
| "iv": base64.b64encode(iv).decode('utf-8'), | |
| "data": base64.b64encode(bytes(tampered_data)).decode('utf-8') | |
| } | |
| with pytest.raises(Exception): # GCM auth failure | |
| decrypt_hybrid(payload, test_keypair["private_key"]) | |
| def test_invalid_aes_key(self, test_keypair): | |
| """Handle corrupted/invalid AES key.""" | |
| from services.encryption_service import decrypt_hybrid | |
| payload = { | |
| "key": base64.b64encode(os.urandom(256)).decode('utf-8'), # Random, not RSA encrypted | |
| "iv": base64.b64encode(os.urandom(12)).decode('utf-8'), | |
| "data": base64.b64encode(os.urandom(100)).decode('utf-8') | |
| } | |
| with pytest.raises(Exception): | |
| decrypt_hybrid(payload, test_keypair["private_key"]) | |
| # ============================================================================= | |
| # 4. Main decrypt_data Entry Point Tests | |
| # ============================================================================= | |
| class TestDecryptData: | |
| """Test decrypt_data main entry function.""" | |
| def test_no_key_returns_status(self): | |
| """Return no_key_available when no private key.""" | |
| import services.encryption_service as es | |
| es._private_key = None | |
| with patch.object(es, 'load_private_key', return_value=None): | |
| result = es.decrypt_data("some-encrypted-data") | |
| assert result["decryption_status"] == "no_key_available" | |
| assert "encrypted_data" in result | |
| def test_decrypt_direct_type(self, test_keypair): | |
| """Decrypt data with type='direct'.""" | |
| import services.encryption_service as es | |
| es._private_key = test_keypair["private_key"] | |
| plaintext = '{"message": "hello"}' | |
| encrypted = encrypt_direct(test_keypair["public_key"], plaintext) | |
| result = es.decrypt_data(encrypted) | |
| assert result["message"] == "hello" | |
| def test_decrypt_hybrid_type(self, test_keypair): | |
| """Decrypt data with type='hybrid'.""" | |
| import services.encryption_service as es | |
| es._private_key = test_keypair["private_key"] | |
| plaintext = '{"data": "long message here"}' | |
| encrypted = encrypt_hybrid(test_keypair["public_key"], plaintext) | |
| result = es.decrypt_data(encrypted) | |
| assert result["data"] == "long message here" | |
| def test_unknown_type_returns_error(self, test_keypair): | |
| """Unknown encryption type returns error.""" | |
| import services.encryption_service as es | |
| es._private_key = test_keypair["private_key"] | |
| payload = {"type": "unknown_type", "data": "something"} | |
| encrypted = base64.b64encode(json.dumps(payload).encode()).decode() | |
| result = es.decrypt_data(encrypted) | |
| assert "decryption_error" in result | |
| assert "unknown" in result["decryption_error"].lower() | |
| def test_invalid_outer_base64(self, test_keypair): | |
| """Invalid outer base64 returns error.""" | |
| import services.encryption_service as es | |
| es._private_key = test_keypair["private_key"] | |
| result = es.decrypt_data("not-valid-base64!!!") | |
| assert "decryption_error" in result | |
| def test_invalid_json_payload(self, test_keypair): | |
| """Invalid JSON payload returns error.""" | |
| import services.encryption_service as es | |
| es._private_key = test_keypair["private_key"] | |
| # Valid base64 but not JSON | |
| encrypted = base64.b64encode(b"not json content").decode() | |
| result = es.decrypt_data(encrypted) | |
| assert "decryption_error" in result | |
| def test_non_json_decrypted_returns_raw(self, test_keypair): | |
| """Non-JSON decrypted content returns raw_data.""" | |
| import services.encryption_service as es | |
| es._private_key = test_keypair["private_key"] | |
| plaintext = "just plain text, not JSON" | |
| encrypted = encrypt_direct(test_keypair["public_key"], plaintext) | |
| result = es.decrypt_data(encrypted) | |
| assert result["raw_data"] == plaintext | |
| # ============================================================================= | |
| # 5. Multiple Blocks Tests | |
| # ============================================================================= | |
| class TestMultipleBlocks: | |
| """Test decrypt_multiple_blocks function.""" | |
| def test_decrypt_multiple_valid_blocks(self, test_keypair): | |
| """Decrypt multiple valid encrypted blocks.""" | |
| import services.encryption_service as es | |
| es._private_key = test_keypair["private_key"] | |
| plaintext1 = '{"id": 1}' | |
| plaintext2 = '{"id": 2}' | |
| encrypted1 = encrypt_direct(test_keypair["public_key"], plaintext1) | |
| encrypted2 = encrypt_direct(test_keypair["public_key"], plaintext2) | |
| combined = f"{encrypted1},{encrypted2}" | |
| results = es.decrypt_multiple_blocks(combined) | |
| assert len(results) == 2 | |
| assert results[0]["id"] == 1 | |
| assert results[1]["id"] == 2 | |
| def test_empty_input_returns_empty_list(self): | |
| """Empty input returns empty list.""" | |
| import services.encryption_service as es | |
| results = es.decrypt_multiple_blocks("") | |
| assert results == [] | |
| def test_handles_whitespace(self, test_keypair): | |
| """Handle extra whitespace in input.""" | |
| import services.encryption_service as es | |
| es._private_key = test_keypair["private_key"] | |
| plaintext = '{"id": 1}' | |
| encrypted = encrypt_direct(test_keypair["public_key"], plaintext) | |
| # Add whitespace | |
| combined = f" {encrypted} , {encrypted} " | |
| results = es.decrypt_multiple_blocks(combined) | |
| assert len(results) == 2 | |
| def test_mixed_valid_invalid_blocks(self, test_keypair): | |
| """Handle mixed valid and invalid blocks.""" | |
| import services.encryption_service as es | |
| es._private_key = test_keypair["private_key"] | |
| valid = encrypt_direct(test_keypair["public_key"], '{"valid": true}') | |
| invalid = "not-valid-encrypted-data" | |
| combined = f"{valid},{invalid}" | |
| results = es.decrypt_multiple_blocks(combined) | |
| assert len(results) == 2 | |
| assert results[0]["valid"] == True | |
| assert "decryption_error" in results[1] | |
| # ============================================================================= | |
| # 6. Edge Cases and Security Tests | |
| # ============================================================================= | |
| class TestEdgeCases: | |
| """Test edge cases and security scenarios.""" | |
| def test_empty_plaintext(self, test_keypair): | |
| """Handle empty plaintext.""" | |
| import services.encryption_service as es | |
| es._private_key = test_keypair["private_key"] | |
| plaintext = "" | |
| encrypted = encrypt_direct(test_keypair["public_key"], plaintext) | |
| result = es.decrypt_data(encrypted) | |
| assert result["raw_data"] == "" | |
| def test_unicode_plaintext(self, test_keypair): | |
| """Handle unicode plaintext.""" | |
| import services.encryption_service as es | |
| es._private_key = test_keypair["private_key"] | |
| plaintext = '{"emoji": "ππ", "chinese": "ε ε―"}' | |
| encrypted = encrypt_direct(test_keypair["public_key"], plaintext) | |
| result = es.decrypt_data(encrypted) | |
| assert result["emoji"] == "ππ" | |
| assert result["chinese"] == "ε ε―" | |
| def test_large_payload_hybrid(self, test_keypair): | |
| """Handle large payload with hybrid encryption.""" | |
| import services.encryption_service as es | |
| es._private_key = test_keypair["private_key"] | |
| # Create large payload (> 190 bytes which requires hybrid) | |
| large_data = {"data": "x" * 1000} | |
| plaintext = json.dumps(large_data) | |
| encrypted = encrypt_hybrid(test_keypair["public_key"], plaintext) | |
| result = es.decrypt_data(encrypted) | |
| assert len(result["data"]) == 1000 | |
| def test_payload_at_rsa_limit(self, test_keypair): | |
| """Handle payload near RSA size limit.""" | |
| import services.encryption_service as es | |
| es._private_key = test_keypair["private_key"] | |
| # RSA-OAEP with SHA-256 and 2048-bit key: max ~190 bytes | |
| # Test with something just under | |
| plaintext = '{"d":"' + 'x' * 150 + '"}' | |
| encrypted = encrypt_direct(test_keypair["public_key"], plaintext) | |
| result = es.decrypt_data(encrypted) | |
| assert len(result["d"]) == 150 | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v"]) | |