DAT-Byte-Demo / test.py
hudsongouge's picture
Update space
adf0368
from inference.onnx_inference import generate_text
import argparse
import onnxruntime as ort
from inference.model import ByteTokenizer
sequence_breaker_strings = ["\n", ":", '"', "*", "<", ">", "|"]
def main():
parser = argparse.ArgumentParser(
description="Inference with ONNX DiffTransformerLLM"
)
parser.add_argument(
"--onnx_path", type=str, default="models/small.onnx", help="Path to ONNX model"
)
parser.add_argument(
"--prompt",
type=str,
default="<|im_start|>system\nYou are a helpful chatbot<|im_end|>\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n",
help="Prompt for the model",
)
parser.add_argument("--max_tokens", type=int, default=100, help="Max new tokens")
parser.add_argument(
"--temperature", type=float, default=0.7, help="Temperature for sampling"
)
parser.add_argument("--top_k", type=int, default=1, help="Top-k for sampling")
parser.add_argument(
"--stop_sequence", type=str, action="append", help="Stop sequence(s)"
)
# DRY sampling args
parser.add_argument(
"--dry_range", type=int, default=1024, help="Range for DRY sampling"
)
parser.add_argument(
"--dry_allowed_length",
type=int,
default=17,
help="Allowed repeat length for DRY sampling",
)
parser.add_argument(
"--dry_base", type=float, default=1.1, help="Base for DRY penalty"
)
parser.add_argument(
"--dry_multiplier", type=float, default=0.0, help="Multiplier for DRY penalty"
)
args = parser.parse_args()
print(f"Loading ONNX model from {args.onnx_path}")
session = ort.InferenceSession(args.onnx_path, providers=["CPUExecutionProvider"])
tokenizer = ByteTokenizer()
sequence_breaker_ids = {tokenizer.im_start_id, tokenizer.im_end_id}
for s in sequence_breaker_strings:
# These are single-byte tokens, so encode will return a list with one ID
sequence_breaker_ids.add(tokenizer.encode(s.encode("utf-8"))[0])
print(f"Prompt: {args.prompt}")
print("--- Output ---")
generated_text, tps = generate_text(
session,
tokenizer,
args.prompt,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
stop_sequences=["<|im_end|>".encode("utf-8")],
dry_sequence_breakers=sequence_breaker_ids,
dry_range=args.dry_range,
dry_allowed_length=args.dry_allowed_length,
dry_base=args.dry_base,
dry_multiplier=args.dry_multiplier,
)
print(generated_text)
print(generated_text.decode("utf-8", "ignore"))
print("--------------")
print(f"\nPerformance: {tps:.2f} tokens/second")
if __name__ == "__main__":
main()