Spaces:
Sleeping
Sleeping
from inference.inference import generate_text, list_checkpoints, load_model | |
import argparse | |
import torch | |
from inference.model import ByteTokenizer | |
def main(): | |
parser = argparse.ArgumentParser( | |
description="Text generation with DiffAttention LLM" | |
) | |
parser.add_argument("--checkpoint", type=str, help="Path to the checkpoint file") | |
parser.add_argument( | |
"--prompt", | |
type=str, | |
default="""<|im_start|>system\nYou are a helpful chatbot<|im_end|>\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n""", | |
) | |
parser.add_argument( | |
"--max_tokens", | |
type=int, | |
default=500, | |
help="Maximum number of tokens to generate", | |
) | |
parser.add_argument( | |
"--temperature", type=float, default=0.7, help="Sampling temperature" | |
) | |
parser.add_argument( | |
"--top_k", type=int, default=1, help="Top-k sampling parameter (0 to disable)" | |
) | |
parser.add_argument( | |
"--top_p", | |
type=float, | |
default=0.9, | |
help="Top-p (nucleus) sampling parameter (0 to disable)", | |
) | |
parser.add_argument( | |
"--repetition_penalty", | |
type=float, | |
default=1.0, | |
help="Repetition penalty (1.0 for no penalty)", | |
) | |
parser.add_argument( | |
"--list_checkpoints", | |
action="store_true", | |
help="List available checkpoints and exit", | |
) | |
args = parser.parse_args() | |
# List checkpoints if requested | |
if args.list_checkpoints: | |
print("Available checkpoints:") | |
checkpoints = list_checkpoints() | |
for i, ckpt in enumerate(checkpoints): | |
print(f"{i+1}. {ckpt}") | |
return | |
# If no checkpoint specified, use the latest one | |
if not args.checkpoint: | |
checkpoints = list_checkpoints() | |
if not checkpoints: | |
print("No checkpoints found. Please train the model first.") | |
return | |
# Find the latest epoch_end checkpoint | |
end_checkpoints = [ckpt for ckpt in checkpoints if "end.pt" in ckpt] | |
if end_checkpoints: | |
latest_checkpoint = max(end_checkpoints) | |
else: | |
latest_checkpoint = max(checkpoints) | |
checkpoint_path = os.path.join("checkpoints", latest_checkpoint) | |
else: | |
checkpoint_path = args.checkpoint | |
# Set device | |
device = torch.device( | |
"cuda" if torch.cuda.is_available() and not force_CPU else "cpu" | |
) | |
print(f"Using device: {device}") | |
# Initialize tokenizer | |
tokenizer = ByteTokenizer() | |
# Load model | |
model = load_model(checkpoint_path, device) | |
# Generate text | |
print(f"\nGenerating text with prompt: '{args.prompt}'") | |
print( | |
f"Parameters: temperature={args.temperature}, top_k={args.top_k}, top_p={args.top_p}, repetition_penalty={args.repetition_penalty}" | |
) | |
print("\nGenerating...") | |
generated_text, full_text = generate_text( | |
model=model, | |
tokenizer=tokenizer, | |
prompt=args.prompt, | |
max_new_tokens=args.max_tokens, | |
temperature=args.temperature, | |
top_k=args.top_k, | |
top_p=args.top_p, | |
repetition_penalty=args.repetition_penalty, | |
device=device, | |
) | |
print("\n\nGenerated completion only:") | |
print("-" * 40) | |
print(generated_text) | |
print("-" * 40) | |
print("\nFull generated text (prompt + completion):") | |
print("-" * 40) | |
print(full_text) | |
print("-" * 40) | |
if __name__ == "__main__": | |
import argparse | |
main() | |