| | """ |
| | Embedding Test Suite CLI |
| | |
| | Tests embedding quality by sampling sequences from EC dataset (default), |
| | embedding them with various pooling methods, and reporting statistics |
| | on distribution, NaNs, and sparsity. |
| | """ |
| |
|
| | import json |
| | import argparse |
| | import math |
| | import random |
| | import numpy as np |
| | import torch |
| | from typing import Dict, List, Optional |
| |
|
| | try: |
| | from data.data_mixin import DataMixin, DataArguments |
| | from embedder import Embedder, EmbeddingArguments |
| | from base_models.get_base_models import standard_models |
| | from seed_utils import set_global_seed, get_global_seed |
| | from utils import print_message |
| | except ImportError: |
| | from ..data.data_mixin import DataMixin, DataArguments |
| | from ..embedder import Embedder, EmbeddingArguments |
| | from ..base_models.get_base_models import standard_models |
| | from ..seed_utils import set_global_seed, get_global_seed |
| | from ..utils import print_message |
| |
|
| |
|
| | |
| | DEFAULT_TEST_DATASETS = [ |
| | 'EC', |
| | ] |
| |
|
| |
|
| | seed = get_global_seed() |
| | if seed is not None: |
| | random.seed(seed) |
| | np.random.seed(seed) |
| |
|
| | def load_and_sample_sequences( |
| | dataset_names: List[str], |
| | sample_frac: float = 0.1, |
| | max_length: int = 1024, |
| | trim: bool = False |
| | ) -> Dict[str, List[str]]: |
| | """ |
| | Load datasets and sample sequences from them. |
| | |
| | Args: |
| | dataset_names: List of dataset names to load |
| | sample_frac: Fraction of sequences to sample (default 0.1 = 10%) |
| | max_length: Maximum sequence length |
| | trim: Whether to trim sequences to max_length |
| | |
| | Returns: |
| | Dictionary mapping dataset names to lists of sampled sequences |
| | """ |
| | dataset_seqs = {} |
| | |
| | for dataset_name in dataset_names: |
| | print_message(f"Loading dataset: {dataset_name}") |
| | |
| | try: |
| | |
| | data_args = DataArguments( |
| | data_names=[dataset_name], |
| | max_length=max_length, |
| | trim=trim |
| | ) |
| | data_mixin = DataMixin(data_args) |
| | datasets, all_seqs = data_mixin.get_data() |
| | |
| | |
| | sequences = [] |
| | if dataset_name in datasets: |
| | train_set, valid_set, test_set, _, _, ppi = datasets[dataset_name] |
| | |
| | if ppi: |
| | |
| | sequences.extend(list(train_set['SeqA'])) |
| | sequences.extend(list(train_set['SeqB'])) |
| | sequences.extend(list(valid_set['SeqA'])) |
| | sequences.extend(list(valid_set['SeqB'])) |
| | sequences.extend(list(test_set['SeqA'])) |
| | sequences.extend(list(test_set['SeqB'])) |
| | else: |
| | sequences.extend(list(train_set['seqs'])) |
| | sequences.extend(list(valid_set['seqs'])) |
| | sequences.extend(list(test_set['seqs'])) |
| | else: |
| | |
| | sequences = list(all_seqs) |
| | |
| | |
| | sequences = list(set(sequences)) |
| | n_samples = max(1, math.ceil(len(sequences) * sample_frac)) |
| | sampled = random.sample(sequences, min(n_samples, len(sequences))) |
| | dataset_seqs[dataset_name] = sampled |
| | |
| | print_message(f"Sampled {len(sampled)} sequences from {len(sequences)} total") |
| | |
| | except Exception as e: |
| | print_message(f"Error loading dataset {dataset_name}: {e}") |
| | continue |
| | |
| | return dataset_seqs |
| |
|
| |
|
| | def compute_diagnostics(embeddings: torch.Tensor, zero_eps: float = 1e-8) -> Dict[str, float]: |
| | emb = embeddings.detach().float().cpu().numpy() |
| | flat = emb.ravel() |
| |
|
| | is_nan = np.isnan(flat) |
| | is_inf = np.isinf(flat) |
| | is_finite = np.isfinite(flat) |
| |
|
| | finite = flat[is_finite] |
| | if finite.size == 0: |
| | |
| | return { |
| | "n_samples": int(emb.shape[0]), |
| | "embedding_dim": int(emb.shape[1]), |
| | "finite_count": 0, |
| | "nan_count": int(is_nan.sum()), |
| | "inf_count": int(is_inf.sum()), |
| | } |
| |
|
| | near_zero = np.abs(finite) < zero_eps |
| |
|
| | sample_l2 = np.linalg.norm(emb, axis=1) |
| |
|
| | return { |
| | "n_samples": int(emb.shape[0]), |
| | "embedding_dim": int(emb.shape[1]), |
| |
|
| | "finite_count": int(finite.size), |
| | "finite_fraction": float(finite.size / flat.size), |
| |
|
| | "nan_count": int(is_nan.sum()), |
| | "nan_fraction": float(is_nan.mean()), |
| |
|
| | "inf_count": int(is_inf.sum()), |
| | "inf_fraction": float(is_inf.mean()), |
| |
|
| | "zero_eps": float(zero_eps), |
| | "near_zero_count": int(near_zero.sum()), |
| | "near_zero_fraction": float(near_zero.mean()), |
| |
|
| | "mean": float(np.mean(finite)), |
| | "std": float(np.std(finite)), |
| | "min": float(np.min(finite)), |
| | "max": float(np.max(finite)), |
| | "p25": float(np.percentile(finite, 25)), |
| | "p50": float(np.percentile(finite, 50)), |
| | "p75": float(np.percentile(finite, 75)), |
| | "p95": float(np.percentile(finite, 95)), |
| | "p99": float(np.percentile(finite, 99)), |
| |
|
| | "mean_l2": float(np.mean(sample_l2)), |
| | "std_l2": float(np.std(sample_l2)), |
| | "p95_l2": float(np.percentile(sample_l2, 95)), |
| | } |
| |
|
| |
|
| | def embed_and_diagnose( |
| | sequences: List[str], |
| | model_name: str, |
| | pooling_types: List[str], |
| | batch_size: int = 16, |
| | num_workers: int = 0 |
| | ) -> Dict[str, Dict[str, float]]: |
| | """ |
| | Embed sequences and compute diagnostics for each pooling type. |
| | |
| | Args: |
| | sequences: List of sequences to embed |
| | model_name: Name of the model to use |
| | pooling_types: List of pooling types to test |
| | batch_size: Batch size for embedding |
| | num_workers: Number of workers for data loading |
| | |
| | Returns: |
| | Dictionary mapping pooling types to their diagnostics |
| | """ |
| | print_message(f"Embedding {len(sequences)} sequences with {model_name}") |
| | |
| | |
| | pooling_list = {} |
| | for pool_type in pooling_types: |
| | |
| | if ',' in pool_type: |
| | |
| | pool_list = [p.strip() for p in pool_type.split(',')] |
| | pooling_list[pool_type] = pool_list |
| | else: |
| | |
| | pooling_list[pool_type] = [pool_type] |
| | |
| | results = {} |
| | |
| | |
| | print_message(f"Loading model: {model_name}") |
| | from base_models.get_base_models import get_base_model |
| | model, tokenizer = get_base_model(model_name) |
| | |
| | for pool_type, pool_list in pooling_list.items(): |
| | print_message(f"Testing pooling: {pool_type} (types: {pool_list})") |
| | |
| | |
| | embedder_args = EmbeddingArguments( |
| | embedding_batch_size=batch_size, |
| | embedding_num_workers=num_workers, |
| | download_embeddings=False, |
| | matrix_embed=False, |
| | embedding_pooling_types=pool_list, |
| | save_embeddings=False, |
| | embed_dtype=torch.float32, |
| | sql=False, |
| | embedding_save_dir='embeddings' |
| | ) |
| | |
| | embedder = Embedder(embedder_args, sequences) |
| | |
| | try: |
| | |
| | to_embed, save_path, embeddings_dict = embedder._read_embeddings_from_disk(model_name) |
| | |
| | if len(to_embed) > 0: |
| | result = embedder._embed_sequences( |
| | to_embed, save_path, model, tokenizer, embeddings_dict |
| | ) |
| | if result is not None: |
| | embeddings_dict = result |
| | |
| | if embeddings_dict is None or len(embeddings_dict) == 0: |
| | print_message(f"Warning: No embeddings returned for {model_name} with {pool_type}") |
| | continue |
| | |
| | embedding_tensors = [] |
| | for seq in sequences: |
| | if seq in embeddings_dict: |
| | embedding_tensors.append(embeddings_dict[seq]) |
| | |
| | if len(embedding_tensors) == 0: |
| | print_message(f"Error: No embeddings found for {pool_type}") |
| | continue |
| | |
| | embeddings = torch.stack(embedding_tensors) |
| | |
| | diagnostics = compute_diagnostics(embeddings) |
| | results[pool_type] = diagnostics |
| | |
| | except Exception as e: |
| | print_message(f"Error embedding with {model_name} using {pool_type}: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | continue |
| | |
| | return results |
| |
|
| |
|
| | def run_test_suite( |
| | dataset_names: Optional[List[str]] = None, |
| | model_names: Optional[List[str]] = None, |
| | pooling_methods: List[str] = ['cls', 'mean,var'], |
| | sample_frac: float = 0.1, |
| | batch_size: int = 16, |
| | num_workers: int = 0 |
| | ) -> Dict: |
| | """ |
| | Run the embedding test suite. |
| | """ |
| | if dataset_names is None: |
| | dataset_names = DEFAULT_TEST_DATASETS |
| | |
| | if model_names is None: |
| | model_names = standard_models |
| | |
| | print_message(f"Running embedding test suite") |
| | print_message(f"Datasets: {dataset_names}") |
| | print_message(f"Models: {model_names}") |
| | print_message(f"Pooling methods: {pooling_methods}") |
| | print_message(f"Sample fraction: {sample_frac}") |
| | |
| | dataset_seqs = load_and_sample_sequences(dataset_names, sample_frac=sample_frac) |
| | |
| | if len(dataset_seqs) == 0: |
| | print_message("Error: No sequences loaded") |
| | return {} |
| | |
| | all_results = {} |
| | |
| | for dataset_name, sequences in dataset_seqs.items(): |
| | print_message(f"\nProcessing dataset: {dataset_name}") |
| | all_results[dataset_name] = {} |
| | |
| | for model_name in model_names: |
| | print_message(f"Model: {model_name}") |
| | model_results = embed_and_diagnose( |
| | sequences, |
| | model_name, |
| | pooling_methods, |
| | batch_size=batch_size, |
| | num_workers=num_workers |
| | ) |
| | |
| | if model_results: |
| | all_results[dataset_name][model_name] = model_results |
| | |
| | print_table_results(all_results) |
| | print_json_results(all_results) |
| | |
| | return all_results |
| |
|
| |
|
| | def print_table_results(results: Dict): |
| | """Print results in table format.""" |
| | print("\n" + "="*100) |
| | print("EMBEDDING TEST SUITE RESULTS") |
| | print("="*100) |
| | |
| | for dataset_name, dataset_results in results.items(): |
| | print(f"\nDataset: {dataset_name}") |
| | print("-" * 100) |
| | |
| | for model_name, model_results in dataset_results.items(): |
| | print(f"\n Model: {model_name}") |
| | |
| | for pool_type, diagnostics in model_results.items(): |
| | print(f"\nPooling: {pool_type}") |
| | print(f"Samples: {diagnostics['n_samples']}, Dim: {diagnostics['embedding_dim']}") |
| | print(f"Mean: {diagnostics['mean']:.6f}, Std: {diagnostics['std']:.6f}") |
| | print(f"Min: {diagnostics['min']:.6f}, Max: {diagnostics['max']:.6f}") |
| | print(f"Percentiles: P25={diagnostics['p25']:.6f}, P50={diagnostics['p50']:.6f}, " |
| | f"P75={diagnostics['p75']:.6f}, P95={diagnostics['p95']:.6f}, P99={diagnostics['p99']:.6f}") |
| | print(f"NaN: {diagnostics['nan_count']} ({diagnostics['nan_fraction']*100:.2f}%)") |
| | if 'near_zero_count' in diagnostics: |
| | print(f"Near zeros: {diagnostics['near_zero_count']} ({diagnostics['near_zero_fraction']*100:.2f}%)") |
| | print(f"Inf: {diagnostics['inf_count']} ({diagnostics['inf_fraction']*100:.2f}%)") |
| | |
| | |
| | anomalies = [] |
| | if diagnostics['nan_fraction'] > 0: |
| | anomalies.append(f"NaNs detected ({diagnostics['nan_fraction']*100:.2f}%)") |
| | if 'near_zero_fraction' in diagnostics and diagnostics['near_zero_fraction'] > 0.2: |
| | anomalies.append(f"High sparsity ({diagnostics['near_zero_fraction']*100:.2f}%)") |
| | if diagnostics['inf_fraction'] > 0: |
| | anomalies.append(f"Infs detected ({diagnostics['inf_fraction']*100:.2f}%)") |
| | if abs(diagnostics['mean']) > 100: |
| | anomalies.append(f"Extreme mean ({diagnostics['mean']:.2f})") |
| | if diagnostics['std'] > 100: |
| | anomalies.append(f"Extreme std ({diagnostics['std']:.2f})") |
| | |
| | if anomalies: |
| | print(f"Anomalies: {', '.join(anomalies)}") |
| | else: |
| | print(f"No anomalies detected") |
| |
|
| |
|
| | def print_json_results(results: Dict): |
| | """Print results in JSON format.""" |
| | print("\n" + "="*50) |
| | print("JSON RESULTS") |
| | print("="*50) |
| | print(json.dumps(results, indent=2)) |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser( |
| | description='Embedding Test Suite - Test embedding quality across datasets and models' |
| | ) |
| | |
| | parser.add_argument( |
| | '--datasets', |
| | nargs='+', |
| | default=None, |
| | help=f'List of dataset names to test (default: EC)' |
| | ) |
| | |
| | parser.add_argument( |
| | '--model_names', |
| | nargs='+', |
| | default=None, |
| | help='List of model names to test (default: all currently_supported_models)' |
| | ) |
| | |
| | parser.add_argument( |
| | '--pooling_methods', |
| | nargs='+', |
| | default=['cls', 'mean,var'], |
| | help='List of pooling methods to test (default: mean, var, cls, parti, mean,var)' |
| | ) |
| | |
| | parser.add_argument( |
| | '--sample_frac', |
| | type=float, |
| | default=0.1, |
| | help='Fraction of sequences to sample from each dataset (default: 0.1)' |
| | ) |
| | |
| | parser.add_argument( |
| | '--batch_size', |
| | type=int, |
| | default=16, |
| | help='Batch size for embedding (default: 16)' |
| | ) |
| | |
| | parser.add_argument( |
| | '--num_workers', |
| | type=int, |
| | default=0, |
| | help='Number of workers for data loading (default: 0)' |
| | ) |
| | |
| | parser.add_argument( |
| | '--seed', |
| | type=int, |
| | default=None, |
| | help='Random seed for reproducibility' |
| | ) |
| | |
| | args = parser.parse_args() |
| | |
| | |
| | if args.seed is not None: |
| | set_global_seed(args.seed) |
| | |
| | |
| | results = run_test_suite( |
| | dataset_names=args.datasets, |
| | model_names=args.model_names, |
| | pooling_methods=args.pooling_methods, |
| | sample_frac=args.sample_frac, |
| | batch_size=args.batch_size, |
| | num_workers=args.num_workers |
| | ) |
| | |
| | return results |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|
| |
|