| import os |
| import torch |
| import torch.distributed as dist |
| from tokenizers import Tokenizer |
| from torch.nn.utils.rnn import pad_sequence |
| from typing import Any |
| from dataclasses import dataclass |
|
|
|
|
| PAD_TOKEN_ID = 0 |
|
|
|
|
| def get_tokenizer() -> Tokenizer: |
| fname = os.path.join(os.path.dirname(__file__), "e1_tokenizer.json") |
| tokenizer: Tokenizer = Tokenizer.from_file(fname) |
| assert tokenizer.padding["pad_id"] == PAD_TOKEN_ID, ( |
| f"Padding token id must be {PAD_TOKEN_ID}, but got {tokenizer.padding['pad_id']}" |
| ) |
|
|
| return tokenizer |
|
|
|
|
| def is_dist_initialized() -> bool: |
| return dist.is_available() and dist.is_initialized() |
|
|
|
|
| def get_world_size(group: Any = None) -> int: |
| if os.environ.get("RANK", -1) == -1 or not is_dist_initialized(): |
| return 1 |
| return dist.get_world_size(group=group) |
|
|
|
|
| def get_rank(group: Any = None) -> int: |
| if os.environ.get("RANK", -1) == -1 or not is_dist_initialized(): |
| return 0 |
| return dist.get_rank(group=group) |
|
|
|
|
| def get_device() -> torch.device: |
| if torch.cuda.is_available(): |
| return torch.device("cuda", torch.cuda.current_device()) |
| return torch.device("cpu") |
|
|
|
|
| def get_local_rank() -> int: |
| return int(os.environ.get("LOCAL_RANK", 0)) if is_dist_initialized() else 0 |
|
|
|
|
| def setup_dist() -> None: |
| rank = int(os.environ.get("RANK", -1)) |
| if dist.is_available() and torch.cuda.is_available() and rank != -1: |
| torch.distributed.init_process_group(backend="nccl") |
| torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) |
|
|
|
|
| def destroy_process_group() -> None: |
| if is_dist_initialized(): |
| dist.destroy_process_group() |
|
|
|
|
| def barrier() -> None: |
| if is_dist_initialized(): |
| dist.barrier() |
|
|
|
|
| @dataclass |
| class DataPrepConfig: |
| max_num_sequences: int = 512 |
| max_num_positions_within_seq: int = 8192 |
| remove_X_tokens: bool = False |
|
|
|
|
| def get_context(sequence: str) -> str | None: |
| if "," in sequence: |
| return sequence.rsplit(",", 1)[0] |
| return None |
|
|
|
|
| class E1BatchPreparer: |
| def __init__( |
| self, |
| data_prep_config: DataPrepConfig | None = None, |
| tokenizer: Tokenizer | None = None, |
| preserve_context_labels: bool = False, |
| device: torch.device | None = None, |
| ): |
| self.tokenizer = tokenizer or get_tokenizer() |
| self.data_prep_config = data_prep_config or DataPrepConfig() |
| self.pad_token_id = self.tokenizer.token_to_id("<pad>") |
| self.preserve_context_labels = preserve_context_labels |
| self.boundary_token_ids = torch.tensor( |
| [self.tokenizer.token_to_id(token) for token in ["<bos>", "<eos>", "1", "2", "<pad>"]], |
| device=(device or get_device()) |
| ).long() |
| self.mask_token = "?" |
| self.mask_token_id = self.tokenizer.token_to_id(self.mask_token) |
| self.X_token_id = self.tokenizer.token_to_id("X") |
| self.vocab = self.tokenizer.get_vocab() |
|
|
| def get_batch_kwargs( |
| self, sequences: list[str], device: torch.device = torch.device("cpu"), non_blocking: bool = False |
| ) -> dict[str, torch.Tensor | list[str] | list[int]]: |
| sequence_encodings = [self.prepare_multiseq(sequence) for sequence in sequences] |
| return self.pad_encodings(sequence_encodings, device, non_blocking) |
|
|
| def pad_encodings( |
| self, |
| sequence_encodings: list[dict[str, torch.Tensor]], |
| device: torch.device = torch.device("cpu"), |
| non_blocking: bool = False, |
| ) -> dict[str, torch.Tensor | list[str] | list[int]]: |
| non_blocking = non_blocking and device.type == "cuda" |
| padded_encodings = {} |
| |
| |
| |
| for key, padding_value in { |
| "input_ids": self.pad_token_id, |
| "sequence_ids": -1, |
| "within_seq_position_ids": -1, |
| "global_position_ids": -1, |
| "labels": self.pad_token_id, |
| }.items(): |
| padded_encodings[key] = pad_sequence( |
| [enc[key] for enc in sequence_encodings], batch_first=True, padding_value=padding_value |
| ).to(device=device, dtype=torch.long, non_blocking=non_blocking) |
|
|
| padded_encodings["context"] = [enc["context"] for enc in sequence_encodings] |
| padded_encodings["context_len"] = [enc["context_len"] for enc in sequence_encodings] |
|
|
| return padded_encodings |
|
|
| def prepare_multiseq(self, sequence: str) -> dict[str, torch.Tensor | str | int]: |
| single_sequences = sequence.split(",") |
| if len(single_sequences) > self.data_prep_config.max_num_sequences: |
| raise ValueError( |
| f"Number of sequences {len(single_sequences)} exceeds max number of sequences {self.data_prep_config.max_num_sequences}" |
| " in the provided multi-sequence instance. Please remove some homologous sequences before trying again." |
| ) |
|
|
| single_sequence_encodings = [self.prepare_singleseq(sequence) for sequence in single_sequences] |
|
|
| num_tokens = [len(x["input_ids"]) for x in single_sequence_encodings] |
| input_ids = torch.cat([x["input_ids"] for x in single_sequence_encodings]) |
| labels = torch.cat([x["labels"] for x in single_sequence_encodings]) |
|
|
| within_seq_position_ids = torch.cat([encoding["position_ids"] for encoding in single_sequence_encodings]) |
| global_position_ids, ctx_len = [], 0 |
| for encoding in single_sequence_encodings: |
| global_position_ids.append(encoding["position_ids"] + ctx_len) |
| ctx_len = max(ctx_len, encoding["position_ids"].max().item() + ctx_len + 1) |
| global_position_ids = torch.cat(global_position_ids) |
|
|
| sequence_ids = torch.repeat_interleave(torch.tensor(num_tokens)) |
|
|
| |
| context_len = sum(num_tokens[:-1]) |
| context = self.tokenizer.decode(input_ids[:context_len].tolist(), skip_special_tokens=False) |
| if not self.preserve_context_labels: |
| labels[:context_len] = self.pad_token_id |
|
|
| assert ( |
| input_ids.shape |
| == sequence_ids.shape |
| == within_seq_position_ids.shape |
| == global_position_ids.shape |
| == labels.shape |
| ), "Input ids, sequence ids, within seq position ids, global position ids, and labels must have the same shape" |
|
|
| assert input_ids.shape[0] >= context_len, "Input ids must have at least as many tokens as the context length" |
|
|
| return { |
| "input_ids": input_ids, |
| "sequence_ids": sequence_ids, |
| "within_seq_position_ids": within_seq_position_ids, |
| "global_position_ids": global_position_ids, |
| "labels": labels, |
| "context": context, |
| "context_len": context_len, |
| } |
|
|
| def prepare_singleseq(self, sequence: str) -> dict[str, torch.Tensor]: |
| if not self.validate_sequence(sequence): |
| raise ValueError(f"Invalid sequence: {sequence}; Input sequence should contain [A-Z] or ? characters only") |
|
|
| if len(sequence) > self.data_prep_config.max_num_positions_within_seq: |
| raise ValueError( |
| f"Sequence length {len(sequence)} exceeds max length {self.data_prep_config.max_num_positions_within_seq}" |
| ) |
|
|
| |
| |
| tokens = torch.tensor([self.vocab[token] for token in ["<bos>", "1", *sequence, "2", "<eos>"]]) |
| position_ids = torch.arange(len(tokens)) |
|
|
| if self.data_prep_config.remove_X_tokens: |
| X_positions = torch.where(tokens != self.X_token_id)[0] |
| tokens = tokens[X_positions] |
| position_ids = position_ids[X_positions] |
|
|
| return {"input_ids": tokens, "labels": tokens, "position_ids": position_ids} |
|
|
| def get_boundary_token_mask(self, tokens: torch.Tensor) -> torch.BoolTensor: |
| return torch.isin(tokens, self.boundary_token_ids.to(tokens.device)) |
|
|
| def get_mask_positions_mask(self, tokens: torch.Tensor) -> torch.BoolTensor: |
| return tokens == self.mask_token_id |
|
|
| def validate_sequence(self, sequence: str) -> bool: |
| assert isinstance(sequence, str), "Sequence must be a string" |
| sequence = sequence.replace(self.mask_token, "") |
| return sequence.isalpha() and sequence.isupper() |
|
|