| | import torch |
| |
|
| | from e1_fastplms.modeling_e1 import E1BatchPreparer |
| |
|
| |
|
| | def analyze_batch_kwargs(batch_kwargs: dict, preparer: E1BatchPreparer, sequences: list[str]) -> None: |
| | print("==== Batch kwargs analysis ====") |
| |
|
| | input_ids = batch_kwargs["input_ids"] |
| | within_seq_position_ids = batch_kwargs["within_seq_position_ids"] |
| | global_position_ids = batch_kwargs["global_position_ids"] |
| | sequence_ids = batch_kwargs["sequence_ids"] |
| | labels = batch_kwargs["labels"] |
| | context = batch_kwargs["context"] |
| | context_len = batch_kwargs["context_len"] |
| |
|
| | pad_token_id = preparer.pad_token_id |
| | def _shortened_list(values: list[int], max_items: int = 8) -> str: |
| | if len(values) <= max_items: |
| | return str(values) |
| | return str(values[:max_items] + [f"... (+{len(values) - max_items} more)"]) |
| |
|
| | assert input_ids.shape == within_seq_position_ids.shape == global_position_ids.shape == sequence_ids.shape == labels.shape |
| | batch_size, max_len = input_ids.shape |
| | assert len(context) == batch_size == len(context_len) == len(sequences) |
| |
|
| | print(f"batch_size: {batch_size}") |
| | print(f"max_length: {max_len}") |
| | print(f"pad_token_id: {pad_token_id}") |
| | print(f"kwargs keys: {list(batch_kwargs.keys())}") |
| |
|
| | for name, tensor in ( |
| | ("input_ids", input_ids), |
| | ("within_seq_position_ids", within_seq_position_ids), |
| | ("global_position_ids", global_position_ids), |
| | ("sequence_ids", sequence_ids), |
| | ("labels", labels), |
| | ): |
| | assert isinstance(tensor, torch.Tensor) |
| | non_pad = (tensor != -1).sum().item() |
| | if tensor.numel() > 0 and tensor.dtype.is_floating_point: |
| | value_stats = f"min={tensor.min().item():.4f}, max={tensor.max().item():.4f}" |
| | else: |
| | value_stats = f"min={tensor.min().item()}, max={tensor.max().item()}" |
| | print() |
| | print(f"{name}:") |
| | print(f" shape={tuple(tensor.shape)} dtype={tensor.dtype} device={tensor.device}") |
| | first_index = tuple([0] * tensor.ndim) |
| | print(f" first_element={tensor[first_index].item()}") |
| | first_row = tensor[0, : min(8, tensor.shape[1])].tolist() |
| | print(f" first_row_prefix={_shortened_list([int(x) for x in first_row], max_items=8)}") |
| | print(f" non_padding_count={non_pad} / total={tensor.numel()} ({non_pad / tensor.numel() * 100:.2f}%)") |
| | print(f" {value_stats}") |
| |
|
| | print() |
| | print("context tokens (metadata):") |
| | print(f" first_context: '{str(context[0])[:50]}'") |
| | print(f" first_context_len: {context_len[0]}") |
| | print(f" first_sequence: '{sequences[0]}'") |
| | for i, (raw_sequence, decoded_context, ctx_len, raw_ids) in enumerate( |
| | zip(sequences, context, context_len, sequence_ids) |
| | ): |
| | valid_len = int((raw_ids != -1).sum().item()) |
| | ctx_len = int(ctx_len) |
| | print(f" sample[{i}] raw sequence input: {raw_sequence}") |
| | print(f" valid_length={valid_len}, context_len={ctx_len}, context='{decoded_context}'") |
| |
|
| | row_input_ids = input_ids[i, :valid_len] |
| | row_sequence_ids = raw_ids[:valid_len] |
| | row_within = within_seq_position_ids[i, :valid_len] |
| | row_global = global_position_ids[i, :valid_len] |
| | row_labels = labels[i, :valid_len] |
| |
|
| | print(f" decoded_input_ids: {preparer.tokenizer.decode(row_input_ids.tolist(), skip_special_tokens=False)}") |
| |
|
| | print(f" input_id_pads: {int((row_input_ids == pad_token_id).sum().item())}") |
| | print(f" sequence_id_tail: {row_sequence_ids[-5:].tolist()}") |
| |
|
| | assert torch.equal(row_sequence_ids[torch.where(row_sequence_ids != -1)[0][0] : torch.where(row_sequence_ids != -1)[0][-1] + 1], row_sequence_ids[row_sequence_ids != -1]) |
| | unique_sequence_ids = torch.unique(row_sequence_ids[row_sequence_ids != -1]).tolist() |
| | print(f" unique sequence_ids: {unique_sequence_ids}") |
| |
|
| | seq_boundaries = torch.where(row_sequence_ids[1:] != row_sequence_ids[:-1])[0] + 1 |
| | seq_breaks = seq_boundaries.tolist() + [valid_len] |
| | seq_lens = [] |
| | start = 0 |
| | for end in seq_breaks: |
| | seq_lens.append(end - start) |
| | start = end |
| | print(f" per-subsequence token counts (from concatenated encoding): {seq_lens}") |
| |
|
| | context_mask = torch.arange(valid_len) < ctx_len |
| | context_masked = int((row_labels[context_mask] == pad_token_id).sum().item()) |
| | target_mask = torch.arange(valid_len) >= ctx_len |
| | target_tokens = int((row_labels[target_mask] != pad_token_id).sum().item()) |
| | print(f" context tokens masked in labels: {context_masked} / {ctx_len}") |
| | print(f" non-context target tokens kept: {target_tokens}") |
| |
|
| | |
| | print(f" within_seq_position_ids unique: {torch.unique(row_within).tolist()}") |
| | print(f" global_position_ids max: {int(row_global.max().item())}, min: {int(row_global.min().item())}") |
| | print() |
| |
|
| |
|
| | def main() -> None: |
| | |
| | sequences = [ |
| | "ACDEFGHIKLMNPQRSTVWY", |
| | "MKTFFLILV,LKQMN", |
| | ] |
| |
|
| | preparer = E1BatchPreparer() |
| | batch_kwargs = preparer.get_batch_kwargs(sequences, device=torch.device("cpu")) |
| |
|
| | analyze_batch_kwargs(batch_kwargs, preparer, sequences) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|