""" Define equivariance testing task. """ from __future__ import annotations from collections.abc import Sequence from pathlib import Path import numpy as np from ase import Atoms from prefect import task from scipy.spatial.transform import Rotation as R from tqdm import tqdm def generate_random_unit_vector(): """Generate a random unit vector.""" vec = np.random.normal(0, 1, 3) return vec / np.linalg.norm(vec) def rotate_molecule_arbitrary( atoms: Atoms, angle: float, axis: np.ndarray ) -> tuple[Atoms, np.ndarray]: """Rotate molecule around arbitrary axis.""" rotated_atoms = atoms.copy() positions = rotated_atoms.get_positions() rot = R.from_rotvec(np.radians(angle) * axis) rotation_mat = rot.as_matrix() rotated_positions = rot.apply(positions) rotated_atoms.set_positions(rotated_positions) cell = atoms.get_cell() rotated_cell = rot.apply(cell) rotated_atoms.set_cell(rotated_cell) return rotated_atoms, rotation_mat def compare_forces( original_forces: np.ndarray, rotated_forces: np.ndarray, rotation_mat: np.ndarray, zero_threshold: float = 1e-10, ) -> tuple[float, np.ndarray, np.ndarray, np.ndarray]: """ Compare forces before and after rotation, with handling of 0 force case. Args: original_forces: Forces before rotation (N x 3 array) rotated_forces: Forces after rotation (N x 3 array) rotation_mat: 3 x 3 rotation matrix zero_threshold: Threshold below which forces are considered zero Returns: tuple containing: - mae: Mean absolute error between forces - cosine_similarity: Cosine similarity between force vectors """ rotated_original_forces = np.dot(original_forces, rotation_mat.T) force_diff = rotated_original_forces - rotated_forces mae = np.mean(np.abs(force_diff)) original_magnitudes = np.linalg.norm(rotated_original_forces, axis=1) rotated_magnitudes = np.linalg.norm(rotated_forces, axis=1) zero_original = original_magnitudes < zero_threshold zero_rotated = rotated_magnitudes < zero_threshold both_zero = zero_original & zero_rotated either_zero = zero_original | zero_rotated one_zero = either_zero & ~both_zero cosine_similarity = np.zeros(len(original_forces)) valid_forces = ~either_zero if np.any(valid_forces): norms_product = np.linalg.norm( rotated_original_forces[valid_forces], axis=1 ) * np.linalg.norm(rotated_forces[valid_forces], axis=1) dot_products = np.sum( rotated_original_forces[valid_forces] * rotated_forces[valid_forces], axis=1 ) cosine_similarity[valid_forces] = dot_products / norms_product # If both forces are 0, cosine similarity should be 1. If one is 0, we take the conservative -1. cosine_similarity[both_zero] = 1.0 cosine_similarity[one_zero] = -1.0 return mae, cosine_similarity def save_molecule_results( aggregate_results: dict, idx_list: np.ndarray, save_path: str | Path ) -> None: """ Save all molecule results from equivariance testing to .npy files. Save the index list of the atoms for further analysis. Args: aggregate_results: Dictionary containing the aggregated results from run() idx_list: List of the indices of the atoms in the original dataset save_path: Path to save the .npy files """ save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) all_molecule_results = aggregate_results["molecule_results"] rotation_angles = list(all_molecule_results[0]["results_by_angle"].keys()) num_molecules = len(all_molecule_results) num_angles = len(rotation_angles) num_random_axes = len( all_molecule_results[0]["results_by_angle"][rotation_angles[0]]["maes"] ) num_atoms = len( all_molecule_results[0]["results_by_angle"][rotation_angles[0]][ "cosine_similarities" ][0] ) maes = np.zeros((num_molecules, num_angles, num_random_axes)) cosine_similarities = np.zeros((num_molecules, num_angles, num_random_axes)) for mol_idx, molecule in enumerate(all_molecule_results): for angle_idx, angle in enumerate(rotation_angles): angle_results = molecule["results_by_angle"][angle] maes[mol_idx, angle_idx, :] = angle_results["maes"] cosine_similarities[mol_idx, angle_idx, :] = np.mean( angle_results["cosine_similarities"], axis=-1 ) np.save(save_path.with_name(f"{save_path.stem}_maes.npy"), maes) np.save( save_path.with_name(f"{save_path.stem}_cosine_similarities.npy"), cosine_similarities, ) np.save(save_path.with_name(f"{save_path.stem}_idx_list.npy"), idx_list) @task( name="Equivariance testing", task_run_name=_generate_task_run_name, cache_policy=TASK_SOURCE + INPUTS, ) def run( atoms_list: Sequence[Atoms], idx_list: np.ndarray, calculator: BaseCalculator, save_path: str | Path | None = None, rotation_angles: list[float] | np.ndarray = None, num_random_axes: int = 100, threshold: float = 1e-3, seed: int | None = None, ) -> dict: """ Test equivariance of force predictions under rotations for multiple structures. Args: atoms_list: List of input atomic structures idx_list: List of the indices of the atoms in the original dataset calculator: Calculator to use num_rotations: Number of random rotations to test rotation_angle: Angle of rotation in degrees threshold: Threshold for considering forces equivariant seed: Random seed Returns: Dictionary containing test results """ if seed is not None: np.random.seed(seed) if rotation_angles is None: rotation_angles = np.arange(30, 361, 30) rotation_angles = np.array(rotation_angles) all_results = [] cross_molecule_cosine_sims = {angle: [] for angle in rotation_angles} cross_molecule_mae = {angle: [] for angle in rotation_angles} rotation_axes = [generate_random_unit_vector() for _ in range(num_random_axes)] total_tests = len(atoms_list) * len(rotation_angles) * num_random_axes pbar = tqdm(total=total_tests, desc="Testing rotations") for atom_idx, atoms in enumerate(atoms_list): atoms = atoms.copy() atoms.calc = calculator original_forces = atoms.get_forces() results_by_angle = { angle: { "mae": [], "cosine_similarities": [], "passed_tests": 0, "passed_mae": 0, "passed_cosine_similarity": 0, } for angle in rotation_angles } # Test each angle with multiple random axes for angle in rotation_angles: for axis in rotation_axes: rotated_atoms, rotation_mat = rotate_molecule_arbitrary( atoms, angle, axis ) rotated_atoms.calc = calculator rotated_forces = rotated_atoms.get_forces() mae, cosine_similarity = compare_forces( original_forces, rotated_forces, rotation_mat ) results_by_angle[angle]["mae"].append(mae) results_by_angle[angle]["cosine_similarities"].append(cosine_similarity) cross_molecule_cosine_sims[angle].append( float(np.mean(cosine_similarity)) ) cross_molecule_mae[angle].append(float(np.mean(mae))) mae_check = mae < threshold cosine_check = all(cosine_similarity > (1 - threshold)) results_by_angle[angle]["passed_tests"] += int( mae_check and cosine_check ) results_by_angle[angle]["passed_mae"] += int(mae_check) results_by_angle[angle]["passed_cosine_similarity"] += int(cosine_check) pbar.update(1) # Compute summary statistics for angle in rotation_angles: results = results_by_angle[angle] results["mean_cosine_similarity"] = float( np.mean(results["cosine_similarities"]) ) results["avg_mae"] = float(np.mean(results["mae"])) results["equivariant_ratio"] = results["passed_tests"] / num_random_axes results["mae_passed_ratio"] = results["passed_mae"] / num_random_axes results["cosine_passed_ratio"] = ( results["passed_cosine_similarity"] / num_random_axes ) results["passed"] = results["passed_tests"] == num_random_axes results["passed_mae"] = results["passed_mae"] == num_random_axes results["passed_cosine_similarity"] = ( results["passed_cosine_similarity"] == num_random_axes ) results["maes"] = [float(x) for x in results["mae"]] results["cosine_similarities"] = [ [float(y) for y in x] for x in results["cosine_similarities"] ] molecule_results = { "mol_idx": idx_list[atom_idx], "results_by_angle": results_by_angle, "all_passed": all( results_by_angle[angle]["passed"] for angle in rotation_angles ), "avg_cosine_similarity_by_molecule": float( np.mean( [ results_by_angle[angle]["mean_cosine_similarity"] for angle in rotation_angles ] ) ), "avg_mae_by_molecule": float( np.mean( [results_by_angle[angle]["avg_mae"] for angle in rotation_angles] ) ), "overall_equivariant_ratio": float( np.mean( [ results_by_angle[angle]["equivariant_ratio"] for angle in rotation_angles ] ) ), } all_results.append(molecule_results) pbar.close() aggregate_results = { "num_molecules": len(atoms_list), "all_molecules_passed": all(result["all_passed"] for result in all_results), "average_equivariant_ratio": float( np.mean([result["overall_equivariant_ratio"] for result in all_results]) ), "average_cosine_similarity_by_angle": { angle: float(np.mean(sims)) for angle, sims in cross_molecule_cosine_sims.items() }, "average_mae_by_angle": { angle: float(np.mean(diffs)) for angle, diffs in cross_molecule_mae.items() }, "molecule_results": all_results, } if save_path: save_molecule_results(aggregate_results, idx_list, save_path) np.save( str(save_path.with_name(f"{save_path.stem}_molecule_results.npy")), all_results, ) return aggregate_results