|
import torch |
|
from ..hellaswag import render_example, iterate_examples, get_most_likely_row |
|
import torch.distributed as dist |
|
from torch.distributed import init_process_group, destroy_process_group |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
import os |
|
from ..ModelGPT2 import GPT,log_file |
|
|
|
ddp = int(os.environ.get('RANK', -1)) != -1 |
|
if ddp: |
|
assert torch.cuda.is_available() |
|
init_process_group(backend='nccl') |
|
ddp_rank = int(os.environ['RANK']) |
|
ddp_local_rank = int(os.environ['LOCAL_RANK']) |
|
ddp_world_size = int(os.environ['WORLD_SIZE']) |
|
device = f"cuda:{ddp_local_rank}" |
|
torch.cuda.set_device(device) |
|
master_process = ddp_rank == 0 |
|
else: |
|
ddp_rank = 0 |
|
ddp_local_rank = 0 |
|
ddp_world_size = 1 |
|
master_process = True |
|
|
|
device = 'cpu' |
|
if torch.cuda.is_available(): |
|
device = 'cuda' |
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
|
device = "mps" |
|
print(f"Using device: {device}") |
|
|
|
|
|
device_type = "cuda" if device.startswith("cuda") else "cpu" |
|
|
|
torch.manual_seed(1337) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed(1337) |
|
|
|
|
|
|
|
checkpoint_path = '../log/model_final.pt' |
|
if master_process: |
|
print(f"Loading checkpoint from {checkpoint_path}") |
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
|
|
|
model_config = checkpoint['config'] |
|
model_config.vocab_size = 50304 |
|
model = GPT(model_config) |
|
|
|
model.load_state_dict(checkpoint['model']) |
|
model = DDP(model, device_ids=[ddp_local_rank]) |
|
model.to(device) |
|
|
|
|
|
def evaluate_hellaswag(model, device, device_type, ddp, ddp_rank, ddp_world_size, log_file, master_process): |
|
|
|
num_correct_norm = 0 |
|
num_total = 0 |
|
|
|
for i, example in enumerate(iterate_examples("val")): |
|
|
|
if ddp: |
|
if i % ddp_world_size != ddp_rank: |
|
continue |
|
|
|
_, tokens, mask, label = render_example(example) |
|
tokens = tokens.to(device) |
|
mask = mask.to(device) |
|
|
|
with torch.no_grad(): |
|
with torch.autocast(device_type=device_type, dtype=torch.bfloat16): |
|
logits, loss = model(tokens) |
|
pred_norm = get_most_likely_row(tokens, mask, logits) |
|
num_total += 1 |
|
num_correct_norm += int(pred_norm == label) |
|
|
|
if ddp: |
|
num_total = torch.tensor(num_total, dtype=torch.long, device=device) |
|
num_correct_norm = torch.tensor(num_correct_norm, dtype=torch.long, device=device) |
|
dist.all_reduce(num_total, op=dist.ReduceOp.SUM) |
|
dist.all_reduce(num_correct_norm, op=dist.ReduceOp.SUM) |
|
num_total = num_total.item() |
|
num_correct_norm = num_correct_norm.item() |
|
acc_norm = num_correct_norm / num_total |
|
if master_process: |
|
print(f"HellaSwag accuracy: {num_correct_norm}/{num_total}={acc_norm:.4f}") |
|
with open(log_file, "a") as f: |
|
f.write(f"Final Hellaswag accuracy: {acc_norm:.4f}\n") |
|
|
|
evaluate_hellaswag(model, device, device_type, ddp, ddp_rank, ddp_world_size, log_file, master_process) |
|
if ddp: |
|
destroy_process_group() |