| | import entrypoint_setup |
| |
|
| | import os |
| | import torch |
| | import warnings |
| | import sqlite3 |
| | import gzip |
| | from torch.utils.data import DataLoader |
| | from tqdm.auto import tqdm |
| | from dataclasses import dataclass |
| | from typing import Optional, Callable, List |
| | from huggingface_hub import hf_hub_download |
| |
|
| | try: |
| | from seed_utils import seed_worker, dataloader_generator, get_global_seed |
| | from data.dataset_classes import SimpleProteinDataset |
| | from base_models.get_base_models import get_base_model |
| | from pooler import Pooler |
| | from utils import torch_load, print_message, maybe_compile |
| | except ImportError: |
| | from .seed_utils import seed_worker, dataloader_generator, get_global_seed |
| | from .data.dataset_classes import SimpleProteinDataset |
| | from .base_models.get_base_models import get_base_model |
| | from .pooler import Pooler |
| | from .utils import torch_load, print_message, maybe_compile |
| |
|
| |
|
| | def build_collator(tokenizer) -> Callable[[List[str]], tuple[torch.Tensor, torch.Tensor]]: |
| | def _collate_fn(sequences: List[str]) -> tuple[torch.Tensor, torch.Tensor]: |
| | """Collate function for batching sequences.""" |
| | return tokenizer(sequences, return_tensors="pt", padding='longest', pad_to_multiple_of=8) |
| | return _collate_fn |
| |
|
| |
|
| | def get_embedding_filename(model_name: str, matrix_embed: bool, pooling_types: List[str], extension: str = 'pth') -> str: |
| | """ |
| | Generate embedding filename with pooling types for vector embeddings. |
| | |
| | Args: |
| | model_name: Name of the model |
| | matrix_embed: Whether embeddings are matrices (True) or vectors (False) |
| | pooling_types: List of pooling types used (only relevant for vector embeddings) |
| | extension: File extension ('pth' or 'db') |
| | |
| | Returns: |
| | Filename string in format: {model_name}_{matrix_embed}[_{pooling_types}].{extension} |
| | """ |
| | base_name = f'{model_name}_{matrix_embed}' |
| | if not matrix_embed and pooling_types: |
| | |
| | pooling_str = '_'.join(sorted(pooling_types)) |
| | base_name = f'{base_name}_{pooling_str}' |
| | return f'{base_name}.{extension}' |
| |
|
| |
|
| | @dataclass |
| | class EmbeddingArguments: |
| | def __init__( |
| | self, |
| | embedding_batch_size: int = 4, |
| | embedding_num_workers: int = 0, |
| | download_embeddings: bool = False, |
| | download_dir: str = 'Synthyra/vector_embeddings', |
| | matrix_embed: bool = False, |
| | embedding_pooling_types: List[str] = ['mean'], |
| | save_embeddings: bool = False, |
| | embed_dtype: torch.dtype = torch.float32, |
| | model_dtype: torch.dtype = None, |
| | sql: bool = False, |
| | embedding_save_dir: str = 'embeddings', |
| | **kwargs |
| | ): |
| | self.batch_size = embedding_batch_size |
| | self.num_workers = embedding_num_workers |
| | self.download_embeddings = download_embeddings |
| | self.download_dir = download_dir |
| | self.matrix_embed = matrix_embed |
| | self.pooling_types = embedding_pooling_types |
| | self.save_embeddings = save_embeddings |
| | self.embed_dtype = embed_dtype |
| | self.model_dtype = model_dtype |
| | self.sql = sql |
| | self.embedding_save_dir = embedding_save_dir |
| |
|
| |
|
| | class Embedder: |
| | def __init__(self, args: EmbeddingArguments, all_seqs: List[str]): |
| | self.args = args |
| | self.all_seqs = all_seqs |
| | self.batch_size = args.batch_size |
| | self.num_workers = args.num_workers |
| | self.matrix_embed = args.matrix_embed |
| | self.pooling_types = args.pooling_types |
| | self.download_embeddings = args.download_embeddings |
| | self.download_dir = args.download_dir |
| | self.save_embeddings = args.save_embeddings |
| | self.embed_dtype = args.embed_dtype |
| | self.model_dtype = args.model_dtype |
| | self.sql = args.sql |
| | self.embedding_save_dir = args.embedding_save_dir |
| |
|
| | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | print_message(f'Device {self.device} found') |
| |
|
| | def _download_embeddings(self, model_name: str): |
| | |
| | |
| | |
| | filename = get_embedding_filename(model_name, self.matrix_embed, self.pooling_types, 'pth') |
| | try: |
| | local_path = hf_hub_download( |
| | repo_id=self.download_dir, |
| | filename=f'embeddings/{filename}.gz', |
| | repo_type='dataset' |
| | ) |
| | except: |
| | print(f'No embeddings found for {model_name} in {self.download_dir}') |
| | return |
| |
|
| | |
| | print_message(f'Unzipping {local_path}') |
| | with gzip.open(local_path, 'rb') as f_in: |
| | with open(local_path.replace('.gz', ''), 'wb') as f_out: |
| | f_out.write(f_in.read()) |
| | |
| | unzipped_path = local_path.replace('.gz', '') |
| | final_path = os.path.join(self.embedding_save_dir, filename) |
| | |
| | if os.path.exists(final_path): |
| | print_message(f'Found existing embeddings in {final_path}') |
| | |
| | downloaded_embeddings = torch_load(unzipped_path) |
| | existing_embeddings = torch_load(final_path) |
| |
|
| | download_dtype = torch.float16 |
| | if self.embed_dtype != download_dtype: |
| | print_message(f"Warning:\nDownloaded embeddings are {download_dtype} but the current setting is {self.embed_dtype}\nWhen combining with existing embeddings, this could result in unintended biases or reductions in performance") |
| |
|
| | |
| | print_message('Combining and casting') |
| | downloaded_embeddings.update(existing_embeddings) |
| |
|
| | |
| | for seq in downloaded_embeddings: |
| | downloaded_embeddings[seq] = downloaded_embeddings[seq].to(self.embed_dtype) |
| |
|
| | |
| | print_message(f'Saving combined embeddings to {final_path}') |
| | torch.save(downloaded_embeddings, final_path) |
| | else: |
| | print_message(f'Downloading embeddings from {self.download_dir}, no previous embeddings found') |
| | downloaded_embeddings = torch.load(unzipped_path) |
| | torch.save(downloaded_embeddings, final_path) |
| | return final_path |
| |
|
| | def _read_sequences_from_db(self, db_path: str) -> set[str]: |
| | """Read sequences from SQLite database.""" |
| | import sqlite3 |
| | sequences = [] |
| | with sqlite3.connect(db_path) as conn: |
| | c = conn.cursor() |
| | c.execute("SELECT sequence FROM embeddings") |
| | while True: |
| | row = c.fetchone() |
| | if row is None: |
| | break |
| | sequences.append(row[0]) |
| | return set(sequences) |
| |
|
| | def _read_embeddings_from_disk(self, model_name: str): |
| | if self.sql: |
| | filename = get_embedding_filename(model_name, self.matrix_embed, self.pooling_types, 'db') |
| | save_path = os.path.join(self.embedding_save_dir, filename) |
| | if os.path.exists(save_path): |
| | conn = sqlite3.connect(save_path) |
| | c = conn.cursor() |
| | c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)') |
| | already_embedded = self._read_sequences_from_db(save_path) |
| | to_embed = [seq for seq in self.all_seqs if seq not in already_embedded] |
| | print_message(f"Loaded {len(already_embedded)} already embedded sequences from {save_path}\nEmbedding {len(to_embed)} new sequences") |
| | return to_embed, save_path, {} |
| | else: |
| | print_message(f"No embeddings found in {save_path}") |
| | return self.all_seqs, save_path, {} |
| |
|
| | else: |
| | embeddings_dict = {} |
| | filename = get_embedding_filename(model_name, self.matrix_embed, self.pooling_types, 'pth') |
| | save_path = os.path.join(self.embedding_save_dir, filename) |
| | if os.path.exists(save_path): |
| | print_message(f"Loading embeddings from {save_path}") |
| | embeddings_dict = torch_load(save_path) |
| | print_message(f"Loaded {len(embeddings_dict)} embeddings from {save_path}") |
| | |
| | |
| | |
| | to_embed = [seq for seq in self.all_seqs if seq not in embeddings_dict] |
| | return to_embed, save_path, embeddings_dict |
| | else: |
| | print_message(f"No embeddings found in {save_path}") |
| | return self.all_seqs, save_path, {} |
| |
|
| | @torch.inference_mode() |
| | def _embed_sequences( |
| | self, |
| | to_embed: List[str], |
| | save_path: str, |
| | embedding_model: any, |
| | tokenizer: any, |
| | embeddings_dict: dict[str, torch.Tensor]) -> Optional[dict[str, torch.Tensor]]: |
| | os.makedirs(self.embedding_save_dir, exist_ok=True) |
| | model = embedding_model.to(self.device).eval() |
| | model = maybe_compile(model) |
| | device = self.device |
| | collate_fn = build_collator(tokenizer) |
| | print_message(f'Pooling types: {self.pooling_types}') |
| | if self.matrix_embed: |
| | pooler = None |
| | else: |
| | pooler = Pooler(self.pooling_types) |
| |
|
| | def _get_embeddings( |
| | residue_embeddings: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | attentions: Optional[torch.Tensor] = None |
| | ) -> torch.Tensor: |
| | if residue_embeddings.ndim == 2 or self.matrix_embed: |
| | return residue_embeddings |
| | else: |
| | return pooler(emb=residue_embeddings, attention_mask=attention_mask, attentions=attentions) |
| |
|
| | dataset = SimpleProteinDataset(to_embed) |
| | dataloader = DataLoader( |
| | dataset, |
| | batch_size=self.batch_size, |
| | num_workers=self.num_workers, |
| | prefetch_factor=2 if self.num_workers > 0 else None, |
| | collate_fn=collate_fn, |
| | shuffle=False, |
| | pin_memory=True, |
| | worker_init_fn=seed_worker, |
| | generator=dataloader_generator(get_global_seed()) |
| | ) |
| |
|
| | if self.sql: |
| | conn = sqlite3.connect(save_path) |
| | c = conn.cursor() |
| | c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)') |
| |
|
| | for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'): |
| | seqs = to_embed[i * self.batch_size:(i + 1) * self.batch_size] |
| | batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)} |
| | if 'attention_mask' in batch: |
| | attention_mask = batch['attention_mask'] |
| | elif 'sequence_ids' in batch: |
| | attention_mask = (batch['sequence_ids'] != -1).long().to(device) |
| | else: |
| | attention_mask = torch.ones_like(batch['input_ids'], device=device) |
| |
|
| | if 'parti' in self.pooling_types: |
| | try: |
| | residue_embeddings, attentions = model(**batch, output_attentions=True) |
| | embeddings = _get_embeddings(residue_embeddings, attention_mask=attention_mask, attentions=attentions).cpu() |
| | except Exception as e: |
| | print_message(f"Error in parti pooling: {e}\nDefaulting to mean pooling") |
| | self.pooling_types = ['mean'] |
| | pooler = Pooler(self.pooling_types) |
| | residue_embeddings = model(**batch) |
| | embeddings = _get_embeddings(residue_embeddings, attention_mask=attention_mask).cpu() |
| | else: |
| | residue_embeddings = model(**batch) |
| | embeddings = _get_embeddings(residue_embeddings, attention_mask=attention_mask).cpu() |
| |
|
| | for seq, emb, mask in zip(seqs, embeddings, attention_mask.cpu()): |
| | if self.matrix_embed: |
| | emb = emb[mask.bool()] |
| | |
| | if self.sql: |
| | c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)", |
| | (seq, emb.numpy().tobytes())) |
| | else: |
| | embeddings_dict[seq] = emb.to(self.embed_dtype) |
| | |
| | if (i + 1) % 100 == 0 and self.sql: |
| | conn.commit() |
| |
|
| | if self.sql: |
| | conn.commit() |
| | conn.close() |
| | return embeddings_dict |
| | |
| | if self.save_embeddings: |
| | print_message(f"Saving embeddings to {save_path}") |
| | torch.save(embeddings_dict, save_path) |
| | |
| | return embeddings_dict |
| |
|
| | def __call__(self, model_name: str, model_type: str = None, model_path: str = None): |
| | if self.download_embeddings: |
| | self._download_embeddings(model_name) |
| |
|
| | if self.device == 'cpu': |
| | warnings.warn("Downloading embeddings is recommended for CPU usage - Embedding on CPU will be extremely slow!") |
| | to_embed, save_path, embeddings_dict = self._read_embeddings_from_disk(model_name) |
| | |
| | if len(to_embed) > 0: |
| | print_message(f"Embedding {len(to_embed)} sequences with {model_name}") |
| | dispatch_name = model_type or model_name |
| | model, tokenizer = get_base_model(dispatch_name, dtype=self.model_dtype, model_path=model_path) |
| |
|
| | return self._embed_sequences(to_embed, save_path, model, tokenizer, embeddings_dict) |
| | else: |
| | print_message(f"No sequences to embed with {model_name}") |
| | return embeddings_dict |
| |
|
| |
|
| | if __name__ == '__main__': |
| | |
| | |
| | import argparse |
| | from huggingface_hub import upload_file, login |
| | from data.supported_datasets import vector_benchmark |
| | from data.data_mixin import DataArguments, DataMixin |
| | from base_models.get_base_models import BaseModelArguments, get_base_model |
| | from seed_utils import set_global_seed |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--token', default=None, help='Huggingface token') |
| | parser.add_argument('--batch_size', type=int, default=16) |
| | parser.add_argument('--num_workers', type=int, default=4) |
| | parser.add_argument('--embed_dtype', type=str, default='float16') |
| | parser.add_argument('--model_names', nargs='+', default=['standard']) |
| | parser.add_argument('--models_to_skip', nargs='+', default=[], help='When checking for existing embeddings, skip these models.') |
| | parser.add_argument('--embedding_save_dir', type=str, default='embeddings') |
| | parser.add_argument('--download_dir', type=str, default='Synthyra/vector_embeddings') |
| | parser.add_argument('--embedding_pooling_types', nargs='+', default=['mean', 'var'], help='Pooling types for embeddings.') |
| | args = parser.parse_args() |
| |
|
| | chosen_seed = set_global_seed() |
| |
|
| | if args.token is not None: |
| | login(args.token) |
| |
|
| | if args.embed_dtype == 'float16': |
| | dtype = torch.float16 |
| | elif args.embed_dtype == 'bfloat16': |
| | dtype = torch.bfloat16 |
| | elif args.embed_dtype == 'float32': |
| | dtype = torch.float32 |
| | else: |
| | raise ValueError(f"Invalid embedding dtype: {args.embed_dtype}") |
| |
|
| | |
| | data_args = DataArguments( |
| | data_names=vector_benchmark, |
| | max_length=1024, |
| | trim=False |
| | ) |
| | all_seqs = DataMixin(data_args).get_data()[1] |
| |
|
| | |
| | model_args = BaseModelArguments(model_names=args.model_names) |
| | for model_name in model_args.model_names: |
| |
|
| | embedder_args = EmbeddingArguments( |
| | batch_size=args.batch_size, |
| | num_workers=args.num_workers, |
| | download_embeddings=model_name not in args.models_to_skip, |
| | matrix_embed=False, |
| | embedding_pooling_types=args.embedding_pooling_types, |
| | save_embeddings=True, |
| | embed_dtype=dtype, |
| | sql=False, |
| | embedding_save_dir='embeddings' |
| | ) |
| | embedder = Embedder(embedder_args, all_seqs) |
| |
|
| | _ = embedder(model_name) |
| | filename = get_embedding_filename(model_name, False, embedder_args.pooling_types, 'pth') |
| | save_path = os.path.join(args.embedding_save_dir, filename) |
| | |
| | compressed_path = f"{save_path}.gz" |
| | print(f"Compressing {save_path} to {compressed_path}") |
| | with open(save_path, 'rb') as f_in: |
| | with gzip.open(compressed_path, 'wb') as f_out: |
| | f_out.write(f_in.read()) |
| | upload_path = compressed_path |
| | path_in_repo = f'embeddings/{filename}.gz' |
| | |
| | upload_file( |
| | path_or_fileobj=upload_path, |
| | path_in_repo=path_in_repo, |
| | repo_id=args.download_dir, |
| | repo_type='dataset' |
| | ) |
| |
|
| | print('Done') |