CodonTranslator / src /dataset.py
alegendaryfish's picture
Public CodonTranslator model and training code release
2d8da02 verified
# src/dataset.py
"""
Production-ready dataset + dataloader utilities.
Rules (because we're adults):
- Data drives design. Inputs are rows with columns: ["cds_DNA", "protein_seq", "Taxon", (optional) "RefseqID"].
- Output per sample is a tiny dict the model actually needs. Nothing else.
- We stream Parquet by row groups, CSV by chunks. No full-file pandas nonsense on big data.
- We shard by (FSDP rank × dataloader worker). No DistributedSampler needed.
- We do a simple streaming shuffle buffer for train. Good enough. No fancy "epoch managers".
Fields emitted per sample (for collate_fn and trainer):
{
"species_name": str,
"species_id": int,
"protein_seq": str, # raw AA (ESM tokenized later)
"aa_len": int,
"codon_ids": List[int], # tokenized 3-mer ids + EOS at the end
"refseq_id": str,
"protein_refseq_id": str,
"control_mode": "fixed",
"meta": {"src": "parquet|csv", "file": basename, "row": int}
}
Invariants:
- cds_DNA length divisible by 3 after trimming to match protein length.
- DNA uses only ACGT (uppercase). If not, we skip the row. We don't "helpfully fix" broken data.
- We truncate both DNA and protein to the same min length (codon count).
- EOS appended to codon_ids; PAD is handled at collate time, not here.
Dependencies:
- pyarrow only if you read parquet. If it isn't installed and you pass parquet files, we fail loudly.
"""
from __future__ import annotations
import os
import json
import glob
import random
import logging
import heapq
from typing import Dict, List, Any, Optional, Iterable, Tuple
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from torch.utils.data import IterableDataset, Dataset, DataLoader, get_worker_info
try:
from tqdm.auto import tqdm as _tqdm
except Exception: # pragma: no cover - tqdm might be unavailable in minimal envs
_tqdm = None
logger = logging.getLogger(__name__)
# ------------------------------
# Species Embedding Store (kept simple and stable)
# ------------------------------
class SpeciesEmbeddingStore:
def __init__(self, embeddings_dir: str, dtype: str = "float32", pin_memory: bool = False, pooling: str = "last"):
self.embeddings_dir = Path(embeddings_dir)
self.pin_memory = bool(pin_memory)
self.is_legacy = False
self.pooling = pooling
vocab_path = self.embeddings_dir / "species_vocab.json"
if not vocab_path.exists():
raise FileNotFoundError(f"Species vocabulary not found at {vocab_path}")
with open(vocab_path, "r") as f:
self.vocab: Dict[str, int] = json.load(f)
meta_path = self.embeddings_dir / "species_metadata.json"
new_emb_path = self.embeddings_dir / "species_embeddings.bin"
legacy_index = self.embeddings_dir / "species_index.json"
legacy_emb = self.embeddings_dir / "species_tok_emb.bin"
if self.pooling == "sequence" and legacy_index.exists() and legacy_emb.exists():
self.is_legacy = True
self._load_legacy_format(dtype)
return
if meta_path.exists() and new_emb_path.exists():
with open(meta_path, "r") as f:
meta = json.load(f)
self.num_species = int(meta["num_species"])
self._ds = int(meta["embedding_dim"])
self.embedding_type = str(meta.get("embedding_type", "fixed_size"))
np_dtype = np.float16 if dtype == "float16" else np.float32
self.embeddings = np.memmap(new_emb_path, dtype=np_dtype, mode="r", shape=(self.num_species, self._ds))
self._np_dtype = np_dtype
print(f"Loaded fixed-size species embeddings: {len(self.vocab)} species, Ds={self._ds}, dtype={self._np_dtype}")
else:
self.is_legacy = True
self._load_legacy_format(dtype)
def _load_legacy_format(self, dtype: str):
index_path = self.embeddings_dir / "species_index.json"
if not index_path.exists():
raise FileNotFoundError(f"Species index not found at {index_path}")
with open(index_path, "r") as f:
raw_index = json.load(f)
self.index: Dict[str, Dict[str, int]] = {str(k): v for k, v in raw_index.items()}
meta_path = self.embeddings_dir / "metadata.json"
file_dtype = dtype
if meta_path.exists():
with open(meta_path, "r") as f:
meta = json.load(f)
self._ds = int(meta.get("embedding_dim", 1024))
file_dtype = str(meta.get("dtype", dtype)).lower()
else:
self._ds = 1024
emb_path = self.embeddings_dir / "species_tok_emb.bin"
if not emb_path.exists():
raise FileNotFoundError(f"Species embeddings not found at {emb_path}")
np_dtype = np.float16 if file_dtype == "float16" else np.float32
itemsize = np.dtype(np_dtype).itemsize
file_bytes = os.path.getsize(emb_path)
if file_bytes % (self._ds * itemsize) != 0:
raise ValueError(f"Emb file size {file_bytes} not divisible by Ds*itemsize ({self._ds}*{itemsize})")
total_tokens = file_bytes // (self._ds * itemsize)
self.embeddings = np.memmap(emb_path, dtype=np_dtype, mode="r", shape=(total_tokens, self._ds))
self._np_dtype = np_dtype
self.num_species = len(self.vocab)
print(f"[LEGACY] variable-length embeddings: {len(self.vocab)} species, {total_tokens} tokens total, Ds={self._ds}.")
def load_vocab(self) -> Dict[str, int]:
return self.vocab.copy()
def _deterministic_stub(self, length: int = None) -> torch.FloatTensor:
if self.is_legacy and length:
t = torch.zeros(1, length, self._ds, dtype=torch.float32)
else:
t = torch.zeros(1, self._ds, dtype=torch.float32)
return t
def get(self, species_id: int) -> torch.FloatTensor:
if not self.is_legacy:
if species_id < 0 or species_id >= getattr(self, "num_species", 0):
return self._deterministic_stub()
emb = self.embeddings[species_id]
tensor = torch.from_numpy(np.asarray(emb).copy()).float().unsqueeze(0)
return tensor
else:
sid = str(species_id)
entry = self.index.get(sid)
if entry is None:
return self._deterministic_stub(length=8)
offset = int(entry["offset"]); length = int(entry["length"])
view = self.embeddings[offset: offset + length]
tensor = torch.from_numpy(np.asarray(view).copy()).float().unsqueeze(0)
return tensor
def batch_get(self, species_ids: List[int]) -> Any:
if torch.is_tensor(species_ids):
species_ids = species_ids.detach().cpu().tolist()
else:
species_ids = [int(x) for x in species_ids]
B = len(species_ids)
if not self.is_legacy:
batch_emb = torch.zeros(B, self._ds, dtype=torch.float32)
for i, sid in enumerate(species_ids):
batch_emb[i] = self.get(sid).squeeze(0)
return batch_emb
else:
tensors = [self.get(sid) for sid in species_ids]
lengths = torch.tensor([t.shape[1] for t in tensors], dtype=torch.long)
Ls_max = int(lengths.max().item()) if lengths.numel() > 0 else 0
padded = torch.zeros(B, Ls_max, self._ds, dtype=torch.float32)
for i, t in enumerate(tensors):
L = t.shape[1]; padded[i, :L] = t.squeeze(0)
return padded, lengths
def Ds(self) -> int:
return self._ds
def _is_parquet(path: str) -> bool:
lower = path.lower()
return lower.endswith(".parquet") or lower.endswith(".parq")
def _is_csv(path: str) -> bool:
lower = path.lower()
return (
lower.endswith(".csv")
or lower.endswith(".tsv")
or lower.endswith(".csv.gz")
or lower.endswith(".tsv.gz")
)
def _expand_paths(maybe_path_or_glob: str | List[str]) -> List[str]:
"""
Expand a path/glob or list of them into a sorted, de-duplicated list of files.
We prioritize parquet, then csv/tsv.
"""
paths: List[str] = []
if isinstance(maybe_path_or_glob, str):
p = Path(maybe_path_or_glob)
if p.is_dir():
# Scan directory for parquet first, then csv/tsv
paths.extend(sorted(str(x) for x in p.rglob("*.parquet")))
paths.extend(sorted(str(x) for x in p.rglob("*.parq")))
paths.extend(sorted(str(x) for x in p.rglob("*.csv")))
paths.extend(sorted(str(x) for x in p.rglob("*.tsv")))
paths.extend(sorted(str(x) for x in p.rglob("*.csv.gz")))
paths.extend(sorted(str(x) for x in p.rglob("*.tsv.gz")))
else:
paths = sorted(glob.glob(str(p)))
else:
for it in maybe_path_or_glob:
paths.extend(_expand_paths(it))
# Dedup while preserving order
seen = set()
out = []
for x in paths:
if x not in seen:
out.append(x)
seen.add(x)
if not out:
raise FileNotFoundError(f"No input files found for: {maybe_path_or_glob}")
return out
def _dist_info() -> Tuple[int, int]:
"""
Returns (num_global_workers, global_worker_id)
where global_worker_id = rank * num_workers + worker_id.
"""
world_size = 1
rank = 0
try:
import torch.distributed as dist
if dist.is_available() and dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
except Exception:
pass
wi = get_worker_info()
nw = wi.num_workers if wi else 1
wid = wi.id if wi else 0
return world_size * nw, rank * nw + wid
class _ResumeSkipProgress:
"""Lightweight progress helper for resume skips."""
def __init__(self, total: int, label: str):
self.total = int(max(0, total))
self.label = label
self.count = 0
self._bar = None
if self.total <= 0:
return
if _tqdm is not None:
self._bar = _tqdm(total=self.total, desc=label, unit="sample", dynamic_ncols=True, leave=False)
else:
logger.info("%s: skipping %d samples to reach resume cursor", label, self.total)
def update(self, n: int = 1):
if self.total <= 0:
return
self.count += int(n)
if self._bar is not None:
self._bar.update(n)
else:
if self.count == self.total or self.count % 10000 == 0:
logger.info("%s: skipped %d / %d", self.label, self.count, self.total)
def close(self):
if self.total <= 0:
return
if self._bar is not None:
self._bar.close()
logger.info("%s: resume skip finished (%d samples)", self.label, self.count)
class StreamSeqDataset(IterableDataset):
"""
Streaming dataset with **non-overlapping Parquet row-group sharding**.
- Accepts list of files (parquet and/or csv/tsv).
- **Parquet**: we enumerate (file, row_group) tasks and stride them across
the *global* worker id to avoid duplicates and to keep all ranks busy even
with few files.
- **CSV/TSV**: assigned at file granularity (one worker reads a file).
If you have only a few CSV files and many ranks, some ranks may get no CSV work.
(Parquet is the recommended format at scale.)
- CSV is read with pandas chunksize to keep memory usage sane.
- Each Parquet task reads exactly **one row group** into pandas.
Minimal resume support:
- set_resume_skip(N) skips N yielded samples across the worker's assigned tasks.
(Use a **per-rank** skip value in your trainer so multi-node resumes stay in lockstep.)
Output sample schema:
{
"species_name": str,
"species_id": int,
"protein_seq": str, # raw AA (ESM tokenized later)
"aa_len": int,
"codon_ids": List[int], # tokenized 3-mer ids + EOS at the end
"refseq_id": str,
"protein_refseq_id": str,
"control_mode": "fixed",
"meta": {"src": "parquet|csv", "file": basename, "row": int}
}
"""
# Canonical required columns. We also accept common aliases (e.g., 'taxon').
REQUIRED = ["cds_DNA", "protein_seq", "Taxon"]
def __init__(
self,
files: List[str],
tokenizer,
species_vocab_path: str,
unknown_species_id: int = 0,
csv_chunksize: int = 200_000,
shuffle_buffer: int = 0,
seed: int = 1234,
shard_across_ranks: bool = True,
):
super().__init__()
self.files = files
self.tok = tokenizer
with open(species_vocab_path, "r") as f:
self.species_vocab: Dict[str, int] = json.load(f)
self.unknown_species_id = int(unknown_species_id)
self.csv_chunksize = int(max(1, csv_chunksize))
self.shuffle_buffer = int(max(0, shuffle_buffer))
self.seed = int(seed)
# When False, every rank iterates over the full task list instead of
# taking a disjoint shard. This keeps FSDP collectives aligned during
# evaluation even if the validation dataset is smaller than WORLD_SIZE.
self.shard_across_ranks = bool(shard_across_ranks)
# Minimal resume cursor
self._resume_skip_n: int = 0
self._offset_start: int = 0
self._emitted: int = 0
# ---- resume cursor (minimal) ----
def set_resume_skip(self, n: int) -> None:
n = int(max(0, n))
self._resume_skip_n = n
self._offset_start = n
self._emitted = 0
def get_stream_position(self) -> int:
# Total yielded so far since dataset creation, including initial skip offset
return int(self._offset_start + self._emitted)
# ---- core row-wise iterator on a pandas DataFrame ----
def _iter_df(self, df: pd.DataFrame, src: str, file: str) -> Iterable[Dict[str, Any]]:
# Normalize common column aliases before validating.
# Some shards use lowercase `taxon` instead of `Taxon`.
if "Taxon" not in df.columns and "taxon" in df.columns:
df = df.rename(columns={"taxon": "Taxon"})
# Hard fail if required missing
for c in self.REQUIRED:
if c not in df.columns:
raise ValueError(f"Input missing required column '{c}' in {file}")
# Normalize & clean
df = df[self.REQUIRED + ([c for c in ["RefseqID"] if c in df.columns])]
df["Taxon"] = df["Taxon"].astype(str).str.strip()
df["protein_seq"] = df["protein_seq"].astype(str).str.strip().str.upper()
df["cds_DNA"] = df["cds_DNA"].astype(str).str.strip().str.upper()
# Filter DNA: ACGT only and length > 0
ok_mask = (df["cds_DNA"].str.len() > 0) & df["cds_DNA"].str.fullmatch(r"[ACGT]+", na=False)
df = df[ok_mask]
if df.empty:
return
# Trim protein/DNA to shared min length (in codons)
cds_codons = (df["cds_DNA"].str.len() // 3).astype(int)
prot_len = df["protein_seq"].str.len().astype(int)
min_len = np.minimum(cds_codons.values, prot_len.values)
df = df.assign(__min_len=min_len)
df = df[df["__min_len"] > 0]
if df.empty:
return
# Species id map
def map_species(x: str) -> int:
try:
return int(self.species_vocab.get(x, self.unknown_species_id))
except Exception:
return self.unknown_species_id
species_ids = [map_species(x) for x in df["Taxon"].tolist()]
refseq_col = "RefseqID" if "RefseqID" in df.columns else None
for i, (row_idx, row) in enumerate(df.iterrows()):
ml = int(row["__min_len"])
cds = row["cds_DNA"][: ml * 3]
prot = row["protein_seq"][: ml]
if (len(cds) // 3) != len(prot):
continue
# Tokenize DNA → 3-mer ids; append EOS
codon_ids = self.tok.encode_codon_seq(cds, validate=False)
codon_ids.append(
self.tok.special_ids.eos if hasattr(self.tok, "special_ids") else self.tok._special_ids.eos
)
species_id = species_ids[i]
ref_id = row[refseq_col] if refseq_col else f"{Path(file).stem}:{int(row_idx)}"
yield {
"species_name": row["Taxon"],
"species_id": int(species_id),
"protein_seq": prot,
"aa_len": len(prot),
"codon_ids": codon_ids,
"refseq_id": ref_id,
"protein_refseq_id": ref_id,
"control_mode": "fixed",
"meta": {"src": src, "file": os.path.basename(file), "row": int(row_idx)},
}
# ---- Parquet helpers: enumerate row-group tasks & read one row group ----
def _enumerate_tasks(self, files: List[str]) -> List[Tuple[str, str, Optional[int], int]]:
"""
Return a task list of tuples:
("parquet", path, row_group_idx, weight) for each row group in each Parquet file
("csv", path, None, weight) for each CSV/TSV file
"""
tasks: List[Tuple[str, str, Optional[int], int]] = []
parquet_files = [f for f in files if _is_parquet(f)]
csv_files = [f for f in files if _is_csv(f)]
if parquet_files:
try:
import pyarrow.parquet as pq # type: ignore
except Exception as e:
raise ImportError("pyarrow is required to read parquet files") from e
for fp in parquet_files:
pf = pq.ParquetFile(fp)
nrg = int(pf.num_row_groups or 0)
if nrg <= 0:
# Treat as single task if row groups unavailable (unusual)
total_rows = pf.metadata.num_rows if pf.metadata and pf.metadata.num_rows is not None else 1
tasks.append(("parquet", fp, 0, max(1, int(total_rows))))
else:
for rg in range(nrg):
if pf.metadata is not None:
rg_meta = pf.metadata.row_group(rg)
num_rows = rg_meta.num_rows if rg_meta.num_rows is not None else 0
else:
num_rows = 0
tasks.append(("parquet", fp, rg, max(1, int(num_rows))))
# CSV/TSV files remain file-level tasks
for fp in csv_files:
file_size = os.path.getsize(fp)
# Assume ~256 bytes per record when estimating CSV row counts (empirical default)
est_rows = max(1, int(file_size // 256))
tasks.append(("csv", fp, None, est_rows))
# Keep a deterministic order
# (files are already sorted by _expand_paths)
return tasks
@staticmethod
def _balanced_partition(tasks: List[Tuple[str, str, Optional[int], int]], groups: int) -> List[List[Tuple[str, str, Optional[int], int]]]:
if groups <= 1:
return [tasks]
if not tasks:
return [[] for _ in range(groups)]
# Greedy load balancing: assign heavier tasks first to the lightest bucket.
indexed = [(idx, kind, path, rg, weight) for idx, (kind, path, rg, weight) in enumerate(tasks)]
tasks_sorted = sorted(
indexed,
key=lambda entry: (entry[4], -entry[0]),
reverse=True,
)
heap: List[Tuple[int, int]] = [(0, bucket_idx) for bucket_idx in range(groups)]
heapq.heapify(heap)
buckets: List[List[Tuple[int, str, str, Optional[int], int]]] = [[] for _ in range(groups)]
for original_index, kind, path, rg, weight in tasks_sorted:
load, bucket_idx = heapq.heappop(heap)
buckets[bucket_idx].append((original_index, kind, path, rg, weight))
heapq.heappush(heap, (load + weight, bucket_idx))
partitions: List[List[Tuple[str, str, Optional[int], int]]] = []
for bucket in buckets:
bucket.sort(key=lambda entry: entry[0])
partitions.append([(kind, path, rg, weight) for (_idx, kind, path, rg, weight) in bucket])
return partitions
def _parquet_rowgroup_iter(
self, file: str, row_group_idx: int, cols_cache: Dict[str, List[str]]
) -> Iterable[Dict[str, Any]]:
import pyarrow.parquet as pq # safe: checked in _enumerate_tasks
pf = pq.ParquetFile(file)
# Cache the column subset per file so we don't recompute
if file not in cols_cache:
names = set(pf.schema.names)
cols: List[str] = []
# Required columns, with alias support (notably Taxon vs taxon).
for c in self.REQUIRED:
if c in names:
cols.append(c)
continue
if c == "Taxon" and "taxon" in names:
cols.append("taxon")
continue
# Optional debug id
if "RefseqID" in names:
cols.append("RefseqID")
cols_cache[file] = cols
cols = cols_cache[file]
table = pf.read_row_group(row_group_idx, columns=cols)
df = table.to_pandas(types_mapper=None)
yield from self._iter_df(df, "parquet", file)
def _csv_file_iter(self, file: str) -> Iterable[Dict[str, Any]]:
# One worker owns this file (non-overlapping assignment)
for chunk in pd.read_csv(file, chunksize=self.csv_chunksize, dtype=str, keep_default_na=False):
yield from self._iter_df(chunk, "csv", file)
# ---- main iterator ----
def __iter__(self):
wi = get_worker_info()
num_workers = wi.num_workers if wi else 1
worker_id = wi.id if wi else 0
num_global, gid = _dist_info()
if not self.shard_across_ranks:
num_global = max(1, num_workers)
gid = worker_id
workers_per_rank = max(1, num_workers)
rank = gid // workers_per_rank if self.shard_across_ranks else 0
world = max(1, num_global // workers_per_rank)
# Each rank may have a non-zero per-rank resume skip. Split evenly across local
# dataloader workers so the sum equals the per-rank target, then apply a fast
# task-level skip to avoid row-by-row scans for huge cursors.
per_rank_skip = int(self._resume_skip_n)
base = per_rank_skip // max(1, workers_per_rank)
rem = per_rank_skip % max(1, workers_per_rank)
local_skip_target = base + (1 if worker_id < rem else 0)
progress: Optional[_ResumeSkipProgress] = None
# Build the global task list (parquet row groups + csv files) and shard by gid
tasks = self._enumerate_tasks(self.files)
if tasks:
partitions = self._balanced_partition(tasks, max(1, num_global))
my_tasks_full = partitions[gid] if gid < len(partitions) else []
else:
my_tasks_full = []
if local_skip_target > 0 and worker_id == 0:
label = (
"resume skip" if world == 1 else f"resume skip (rank {rank}/{world})"
)
progress = _ResumeSkipProgress(local_skip_target, label)
# Fast task-level skip: consume whole tasks when their weight is <= remaining skip
# and only fall back to row-level skipping for the first partial task.
skip_remaining = int(local_skip_target)
start_idx = 0
partial_task_idx = None
partial_task_kind = None
partial_task_path = None
partial_task_rg = None
if skip_remaining > 0 and my_tasks_full:
for idx, (kind, path, rg, weight) in enumerate(my_tasks_full):
w = int(weight) if weight is not None else 0
if w <= 0:
continue
if skip_remaining >= w:
skip_remaining -= w
start_idx = idx + 1
if progress is not None:
progress.update(w)
else:
partial_task_idx = idx
partial_task_kind = kind
partial_task_path = path
partial_task_rg = rg
break
# Slice my task list to start after any fully-skipped tasks
my_tasks = [(kind, path, rg) for (kind, path, rg, _w) in my_tasks_full[start_idx:]]
rng = random.Random(self.seed + gid)
buffer: List[Dict[str, Any]] = []
bufN = self.shuffle_buffer
def _drain_buffer():
if not buffer:
return
if bufN > 0:
rng.shuffle(buffer)
for it in buffer:
yield it
buffer.clear()
# Skip counter for resume cursor (row-level remainder after task skips)
skipped = int(local_skip_target - skip_remaining)
# Cache for per-file Parquet column selection
cols_cache: Dict[str, List[str]] = {}
try:
# If we split a task, handle its partial row-level skip first
if partial_task_idx is not None and skip_remaining > 0:
kind = partial_task_kind
path = partial_task_path
rg = partial_task_rg
if kind == "parquet":
assert rg is not None
row_iter = self._parquet_rowgroup_iter(path, int(rg), cols_cache)
elif kind == "csv":
row_iter = self._csv_file_iter(path)
else:
raise ValueError(f"Unknown task kind: {kind}")
for sample in row_iter:
if skip_remaining > 0:
skip_remaining -= 1
skipped += 1
if progress is not None:
progress.update(1)
if skip_remaining == 0 and progress is not None:
progress.close()
progress = None
continue
# past the partial skip remainder, fall through to normal buffering/yield
if bufN <= 0:
self._emitted += 1
yield sample
else:
buffer.append(sample)
if len(buffer) >= bufN:
j = rng.randrange(len(buffer))
buffer[j], buffer[-1] = buffer[-1], buffer[j]
self._emitted += 1
yield buffer.pop()
for (kind, path, rg) in my_tasks:
if kind == "parquet":
assert rg is not None
row_iter = self._parquet_rowgroup_iter(path, int(rg), cols_cache)
elif kind == "csv":
row_iter = self._csv_file_iter(path)
else:
raise ValueError(f"Unknown task kind: {kind}")
for sample in row_iter:
# Apply any remaining resume skip across the flattened stream
if skip_remaining > 0:
skip_remaining -= 1
skipped += 1
if progress is not None:
progress.update(1)
if skip_remaining == 0 and progress is not None:
# Finish the progress bar once we've consumed the target
progress.close()
progress = None
continue
if bufN <= 0:
self._emitted += 1
yield sample
else:
buffer.append(sample)
if len(buffer) >= bufN:
j = rng.randrange(len(buffer))
buffer[j], buffer[-1] = buffer[-1], buffer[j]
self._emitted += 1
yield buffer.pop()
# Flush leftovers
for it in _drain_buffer():
self._emitted += 1
yield it
finally:
if progress is not None:
progress.close()
if local_skip_target > 0:
# Persist any remaining leftover skip (including partial progress) per worker copy
self._resume_skip_n = max(local_skip_target - skipped, 0)
# ------------------------------
# Simple collate: end-only pad for codon stream, pass-through everything else
# ------------------------------
def stage_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
B = len(batch)
if B == 0:
return {}
# species ids
species_ids = torch.tensor([int(x.get("species_id", 0)) for x in batch], dtype=torch.long)
# raw protein sequences stay as list[str] (ESM handles tokenization)
protein_seqs = [str(x.get("protein_seq", "M")) for x in batch]
# Build padded codon ids (right padding). Keep EOS inside the sequence (already appended in dataset).
codon_lists = [x.get("codon_ids", []) for x in batch]
max_len = max(len(c) for c in codon_lists)
pad_id = 0 # tokenizer.pad_token_id is 0 in our tokenizer.
codon_ids = torch.full((B, max_len), pad_id, dtype=torch.long)
for i, row in enumerate(codon_lists):
if len(row) > 0:
codon_ids[i, : len(row)] = torch.tensor(row, dtype=torch.long)
out: Dict[str, Any] = {
"species_ids": species_ids,
"protein_seqs": protein_seqs,
"codon_ids": codon_ids,
"control_mode": batch[0].get("control_mode", "fixed"),
}
# Optional passthroughs
if "refseq_id" in batch[0]:
out["refseq_id"] = [x.get("refseq_id") for x in batch]
if "protein_refseq_id" in batch[0]:
out["protein_refseq_id"] = [x.get("protein_refseq_id") for x in batch]
return out
def _build_dataset(
path_or_paths: str | List[str],
tokenizer,
species_vocab_path: str,
shuffle_buffer: int,
csv_chunksize: int,
shard_across_ranks: bool = True,
) -> StreamSeqDataset:
files = _expand_paths(path_or_paths)
return StreamSeqDataset(
files=files,
tokenizer=tokenizer,
species_vocab_path=species_vocab_path,
unknown_species_id=0,
csv_chunksize=csv_chunksize,
shuffle_buffer=shuffle_buffer,
seed=1234,
shard_across_ranks=shard_across_ranks,
)
def create_precomputed_dataloaders(
train_path: str | List[str],
val_path: Optional[str | List[str]],
embeddings_dir: str,
tokenizer,
batch_size: int,
num_workers: int = 4,
species_pooling: str = "sequence",
csv_chunksize: int = 200_000,
train_shuffle_buffer: int = 8192,
val_shuffle_buffer: int = 0,
) -> Tuple[DataLoader, Optional[DataLoader], SpeciesEmbeddingStore]:
"""
Returns:
- train_loader, val_loader (optional), and the SpeciesEmbeddingStore
"""
species_store = SpeciesEmbeddingStore(embeddings_dir, pin_memory=True, pooling=species_pooling)
species_vocab_path = os.path.join(embeddings_dir, "species_vocab.json")
num_workers = int(max(0, num_workers))
train_ds = _build_dataset(
path_or_paths=train_path,
tokenizer=tokenizer,
species_vocab_path=species_vocab_path,
shuffle_buffer=int(train_shuffle_buffer),
csv_chunksize=int(csv_chunksize),
)
val_ds = None
if val_path:
val_ds = _build_dataset(
path_or_paths=val_path,
tokenizer=tokenizer,
species_vocab_path=species_vocab_path,
shuffle_buffer=int(val_shuffle_buffer),
csv_chunksize=int(csv_chunksize),
)
# NOTE: IterableDataset can't be shuffled by DataLoader. We already "shuffle" inside the dataset.
kwargs_common = dict(
num_workers=num_workers,
collate_fn=stage_collate_fn,
pin_memory=True,
persistent_workers=(num_workers > 0),
)
if num_workers > 0:
kwargs_common["prefetch_factor"] = 4
# Drop last for train to keep batch shapes stable under FSDP.
train_loader = DataLoader(
train_ds,
batch_size=batch_size,
shuffle=False,
drop_last=True,
**kwargs_common,
)
val_loader = None
if val_ds is not None:
val_loader = DataLoader(
val_ds,
batch_size=batch_size,
shuffle=False,
drop_last=False,
**kwargs_common,
)
return train_loader, val_loader, species_store