"""Script to evaluate a model on a validation set. Based on scripts/generate.py from open_lm repo. """ import argparse import json import re import torch from open_lm.evaluate import evaluate_loop from open_lm.data import get_data from open_lm.model import create_model from open_lm.distributed import init_distributed_device from open_lm.params import parse_args from scripts.generate_without_hf import Generator, GenerationArgs def generate_model_jsonl(params): params_to_width_depth_dict = {5: (96, 3), 7: (128, 4), 9: (160, 5), 15: (224, 6), 22: (288, 8), 28: (320, 9), 37: (384, 10), 57: (480, 12), 84: (576, 14), 108: (640, 15), 149: (704, 18), 220: (832, 21), 347: (1024, 23), 455: (1120, 26), 611: (1312, 26), 901: (1504, 30) } width, depth = params_to_width_depth_dict[params] filepath = f"layers={depth}_hidden-dim={width}.json" data = { "hidden_dim": width, "n_layers": depth, "n_heads": 4, "seq_len": 2048, "vocab_size": 50432, "post_embed_norm": False, "weight_tying": False, "qk_norm": True } with open(filepath, 'w') as file: file.write(json.dumps(data) + '\n') return filepath class ModelArgs: def __init__(self, params, val_data, val_data_key): default_params = vars(parse_args("")) for k, v in default_params.items(): setattr(self, k, v) self.model = generate_model_jsonl(params) self.val_data = [val_data] self.val_data_key = [val_data_key] self.per_gpu_val_batch_size = 16 self.vocab_size = 50432 self.seq_len = 2048 self.wandb = False def main(): parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", default="path/to/checkpoint") parser.add_argument("--val-data", default="", help="Path to validation data. If empty, generate text.") parser.add_argument("--val-data-key", default="json.gz") parser.add_argument("--input-text", default="", type=str, help="Input text to generate from. If empty, evaluate on validation data.") parser.add_argument("--max-gen-len", default=200, type=int) parser.add_argument("--temperature", default=0.8, type=float) parser.add_argument("--top-p", default=0.95, type=float) args = parser.parse_args() params = int(re.search(r"params=(\d+)", args.checkpoint).group(1)) checkpoint = torch.load(args.checkpoint) state_dict = checkpoint["state_dict"] state_dict = {x.replace("module.", ""): y for x, y in state_dict.items()} model_args = ModelArgs(params=params, val_data=args.val_data, val_data_key=args.val_data_key) device = init_distributed_device(model_args) model_args.device = device model = create_model(model_args) model.load_state_dict(state_dict) model.eval().cuda() if args.val_data != "": data = get_data( model_args, skip_train=True, ) metrics = evaluate_loop(model, data["val_list"], 0, model_args, None) print(metrics) elif args.input_text != "": model = model.half() generator = Generator(model) input_text = [ args.input_text, ] output = generator.generate( input_text, GenerationArgs(args.max_gen_len, args.temperature, args.top_p), ) print("".join(output)) else: print("Please provide either --val-data or --input-text") if __name__ == "__main__": main()