DAT-Byte-Demo / test-trad.py
hudsongouge's picture
Update space
adf0368
from inference.inference import generate_text, list_checkpoints, load_model
import argparse
import torch
from inference.model import ByteTokenizer
def main():
parser = argparse.ArgumentParser(
description="Text generation with DiffAttention LLM"
)
parser.add_argument("--checkpoint", type=str, help="Path to the checkpoint file")
parser.add_argument(
"--prompt",
type=str,
default="""<|im_start|>system\nYou are a helpful chatbot<|im_end|>\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n""",
)
parser.add_argument(
"--max_tokens",
type=int,
default=500,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temperature", type=float, default=0.7, help="Sampling temperature"
)
parser.add_argument(
"--top_k", type=int, default=1, help="Top-k sampling parameter (0 to disable)"
)
parser.add_argument(
"--top_p",
type=float,
default=0.9,
help="Top-p (nucleus) sampling parameter (0 to disable)",
)
parser.add_argument(
"--repetition_penalty",
type=float,
default=1.0,
help="Repetition penalty (1.0 for no penalty)",
)
parser.add_argument(
"--list_checkpoints",
action="store_true",
help="List available checkpoints and exit",
)
args = parser.parse_args()
# List checkpoints if requested
if args.list_checkpoints:
print("Available checkpoints:")
checkpoints = list_checkpoints()
for i, ckpt in enumerate(checkpoints):
print(f"{i+1}. {ckpt}")
return
# If no checkpoint specified, use the latest one
if not args.checkpoint:
checkpoints = list_checkpoints()
if not checkpoints:
print("No checkpoints found. Please train the model first.")
return
# Find the latest epoch_end checkpoint
end_checkpoints = [ckpt for ckpt in checkpoints if "end.pt" in ckpt]
if end_checkpoints:
latest_checkpoint = max(end_checkpoints)
else:
latest_checkpoint = max(checkpoints)
checkpoint_path = os.path.join("checkpoints", latest_checkpoint)
else:
checkpoint_path = args.checkpoint
# Set device
device = torch.device(
"cuda" if torch.cuda.is_available() and not force_CPU else "cpu"
)
print(f"Using device: {device}")
# Initialize tokenizer
tokenizer = ByteTokenizer()
# Load model
model = load_model(checkpoint_path, device)
# Generate text
print(f"\nGenerating text with prompt: '{args.prompt}'")
print(
f"Parameters: temperature={args.temperature}, top_k={args.top_k}, top_p={args.top_p}, repetition_penalty={args.repetition_penalty}"
)
print("\nGenerating...")
generated_text, full_text = generate_text(
model=model,
tokenizer=tokenizer,
prompt=args.prompt,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
device=device,
)
print("\n\nGenerated completion only:")
print("-" * 40)
print(generated_text)
print("-" * 40)
print("\nFull generated text (prompt + completion):")
print("-" * 40)
print(full_text)
print("-" * 40)
if __name__ == "__main__":
import argparse
main()