Spaces:
Runtime error
Runtime error
| """ | |
| A number of functions that help with evaluating a base model. | |
| """ | |
| import math | |
| import torch | |
| import torch.distributed as dist | |
| def evaluate_bpb(model, batches, steps, token_bytes): | |
| """ | |
| Instead of the naive 'mean loss', this function returns the bits per byte (bpb), | |
| which is a tokenization vocab size-indepedent metric, meaning you are still comparing | |
| apples:apples if you change the vocab size. The way this works is that instead of just | |
| calculating the average loss as usual, you calculate the sum loss, and indepependently | |
| also the sum bytes (of all the target tokens), and divide. This normalizes the loss by | |
| the number of bytes that the target tokens represent. | |
| The added complexity is so that: | |
| 1) All "normal" tokens are normalized by the length of the token in bytes | |
| 2) No special tokens (e.g. <|bos|>) are included in the metric - they are masked out. | |
| 3) No actively masked tokens (using ignore_index of e.g. -1) are included in the metric. | |
| In addition to evaluate_loss, we need the token_bytes tensor: | |
| It is a 1D tensor of shape (vocab_size,), indicating the number of bytes for | |
| each token id, or 0 if the token is to not be counted (e.g. special tokens). | |
| """ | |
| # record the losses | |
| total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device()) | |
| total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device()) | |
| batch_iter = iter(batches) | |
| for _ in range(steps): | |
| x, y = next(batch_iter) | |
| loss2d = model(x, y, loss_reduction='none') # (B, T) | |
| loss2d = loss2d.view(-1) # flatten | |
| y = y.view(-1) # flatten | |
| if (y < 0).any(): | |
| # slightly more complex code path if some target tokens are ignore_index (e.g. -1) | |
| # any target token < 0 is to be ignored: do NOT index token_bytes with negatives | |
| valid = y >= 0 | |
| y_safe = torch.where(valid, y, torch.zeros_like(y)) | |
| # map valid targets to their byte length; ignored targets contribute 0 bytes | |
| num_bytes2d = torch.where( | |
| valid, | |
| token_bytes[y_safe], | |
| torch.zeros_like(y, dtype=token_bytes.dtype) | |
| ) | |
| total_nats += (loss2d * (num_bytes2d > 0)).sum() | |
| total_bytes += num_bytes2d.sum() | |
| else: | |
| # fast path: no ignored targets, safe to index directly | |
| num_bytes2d = token_bytes[y] | |
| total_nats += (loss2d * (num_bytes2d > 0)).sum() | |
| total_bytes += num_bytes2d.sum() | |
| # sum reduce across all ranks | |
| world_size = dist.get_world_size() if dist.is_initialized() else 1 | |
| if world_size > 1: | |
| dist.all_reduce(total_nats, op=dist.ReduceOp.SUM) | |
| dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM) | |
| # move both to cpu, calculate bpb and return | |
| total_nats = total_nats.item() | |
| total_bytes = total_bytes.item() | |
| bpb = total_nats / (math.log(2) * total_bytes) | |
| return bpb | |