File size: 3,489 Bytes
adf0368
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from inference.inference import generate_text, list_checkpoints, load_model
import argparse
import torch
from inference.model import ByteTokenizer


def main():
    parser = argparse.ArgumentParser(
        description="Text generation with DiffAttention LLM"
    )
    parser.add_argument("--checkpoint", type=str, help="Path to the checkpoint file")
    parser.add_argument(
        "--prompt",
        type=str,
        default="""<|im_start|>system\nYou are a helpful chatbot<|im_end|>\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n""",
    )
    parser.add_argument(
        "--max_tokens",
        type=int,
        default=500,
        help="Maximum number of tokens to generate",
    )
    parser.add_argument(
        "--temperature", type=float, default=0.7, help="Sampling temperature"
    )
    parser.add_argument(
        "--top_k", type=int, default=1, help="Top-k sampling parameter (0 to disable)"
    )
    parser.add_argument(
        "--top_p",
        type=float,
        default=0.9,
        help="Top-p (nucleus) sampling parameter (0 to disable)",
    )
    parser.add_argument(
        "--repetition_penalty",
        type=float,
        default=1.0,
        help="Repetition penalty (1.0 for no penalty)",
    )
    parser.add_argument(
        "--list_checkpoints",
        action="store_true",
        help="List available checkpoints and exit",
    )
    args = parser.parse_args()

    # List checkpoints if requested
    if args.list_checkpoints:
        print("Available checkpoints:")
        checkpoints = list_checkpoints()
        for i, ckpt in enumerate(checkpoints):
            print(f"{i+1}. {ckpt}")
        return

    # If no checkpoint specified, use the latest one
    if not args.checkpoint:
        checkpoints = list_checkpoints()
        if not checkpoints:
            print("No checkpoints found. Please train the model first.")
            return

        # Find the latest epoch_end checkpoint
        end_checkpoints = [ckpt for ckpt in checkpoints if "end.pt" in ckpt]
        if end_checkpoints:
            latest_checkpoint = max(end_checkpoints)
        else:
            latest_checkpoint = max(checkpoints)

        checkpoint_path = os.path.join("checkpoints", latest_checkpoint)
    else:
        checkpoint_path = args.checkpoint

    # Set device
    device = torch.device(
        "cuda" if torch.cuda.is_available() and not force_CPU else "cpu"
    )
    print(f"Using device: {device}")

    # Initialize tokenizer
    tokenizer = ByteTokenizer()

    # Load model
    model = load_model(checkpoint_path, device)

    # Generate text
    print(f"\nGenerating text with prompt: '{args.prompt}'")
    print(
        f"Parameters: temperature={args.temperature}, top_k={args.top_k}, top_p={args.top_p}, repetition_penalty={args.repetition_penalty}"
    )
    print("\nGenerating...")

    generated_text, full_text = generate_text(
        model=model,
        tokenizer=tokenizer,
        prompt=args.prompt,
        max_new_tokens=args.max_tokens,
        temperature=args.temperature,
        top_k=args.top_k,
        top_p=args.top_p,
        repetition_penalty=args.repetition_penalty,
        device=device,
    )

    print("\n\nGenerated completion only:")
    print("-" * 40)
    print(generated_text)
    print("-" * 40)

    print("\nFull generated text (prompt + completion):")
    print("-" * 40)
    print(full_text)
    print("-" * 40)


if __name__ == "__main__":
    import argparse

    main()