from inference.onnx_inference import generate_text import argparse import onnxruntime as ort from inference.model import ByteTokenizer sequence_breaker_strings = ["\n", ":", '"', "*", "<", ">", "|"] def main(): parser = argparse.ArgumentParser( description="Inference with ONNX DiffTransformerLLM" ) parser.add_argument( "--onnx_path", type=str, default="models/small.onnx", help="Path to ONNX model" ) parser.add_argument( "--prompt", type=str, default="<|im_start|>system\nYou are a helpful chatbot<|im_end|>\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n", help="Prompt for the model", ) parser.add_argument("--max_tokens", type=int, default=100, help="Max new tokens") parser.add_argument( "--temperature", type=float, default=0.7, help="Temperature for sampling" ) parser.add_argument("--top_k", type=int, default=1, help="Top-k for sampling") parser.add_argument( "--stop_sequence", type=str, action="append", help="Stop sequence(s)" ) # DRY sampling args parser.add_argument( "--dry_range", type=int, default=1024, help="Range for DRY sampling" ) parser.add_argument( "--dry_allowed_length", type=int, default=17, help="Allowed repeat length for DRY sampling", ) parser.add_argument( "--dry_base", type=float, default=1.1, help="Base for DRY penalty" ) parser.add_argument( "--dry_multiplier", type=float, default=0.0, help="Multiplier for DRY penalty" ) args = parser.parse_args() print(f"Loading ONNX model from {args.onnx_path}") session = ort.InferenceSession(args.onnx_path, providers=["CPUExecutionProvider"]) tokenizer = ByteTokenizer() sequence_breaker_ids = {tokenizer.im_start_id, tokenizer.im_end_id} for s in sequence_breaker_strings: # These are single-byte tokens, so encode will return a list with one ID sequence_breaker_ids.add(tokenizer.encode(s.encode("utf-8"))[0]) print(f"Prompt: {args.prompt}") print("--- Output ---") generated_text, tps = generate_text( session, tokenizer, args.prompt, max_new_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k, stop_sequences=["<|im_end|>".encode("utf-8")], dry_sequence_breakers=sequence_breaker_ids, dry_range=args.dry_range, dry_allowed_length=args.dry_allowed_length, dry_base=args.dry_base, dry_multiplier=args.dry_multiplier, ) print(generated_text) print(generated_text.decode("utf-8", "ignore")) print("--------------") print(f"\nPerformance: {tps:.2f} tokens/second") if __name__ == "__main__": main()