Spaces:
Sleeping
Sleeping
from inference.onnx_inference import generate_text | |
import argparse | |
import onnxruntime as ort | |
from inference.model import ByteTokenizer | |
sequence_breaker_strings = ["\n", ":", '"', "*", "<", ">", "|"] | |
def main(): | |
parser = argparse.ArgumentParser( | |
description="Inference with ONNX DiffTransformerLLM" | |
) | |
parser.add_argument( | |
"--onnx_path", type=str, default="models/small.onnx", help="Path to ONNX model" | |
) | |
parser.add_argument( | |
"--prompt", | |
type=str, | |
default="<|im_start|>system\nYou are a helpful chatbot<|im_end|>\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n", | |
help="Prompt for the model", | |
) | |
parser.add_argument("--max_tokens", type=int, default=100, help="Max new tokens") | |
parser.add_argument( | |
"--temperature", type=float, default=0.7, help="Temperature for sampling" | |
) | |
parser.add_argument("--top_k", type=int, default=1, help="Top-k for sampling") | |
parser.add_argument( | |
"--stop_sequence", type=str, action="append", help="Stop sequence(s)" | |
) | |
# DRY sampling args | |
parser.add_argument( | |
"--dry_range", type=int, default=1024, help="Range for DRY sampling" | |
) | |
parser.add_argument( | |
"--dry_allowed_length", | |
type=int, | |
default=17, | |
help="Allowed repeat length for DRY sampling", | |
) | |
parser.add_argument( | |
"--dry_base", type=float, default=1.1, help="Base for DRY penalty" | |
) | |
parser.add_argument( | |
"--dry_multiplier", type=float, default=0.0, help="Multiplier for DRY penalty" | |
) | |
args = parser.parse_args() | |
print(f"Loading ONNX model from {args.onnx_path}") | |
session = ort.InferenceSession(args.onnx_path, providers=["CPUExecutionProvider"]) | |
tokenizer = ByteTokenizer() | |
sequence_breaker_ids = {tokenizer.im_start_id, tokenizer.im_end_id} | |
for s in sequence_breaker_strings: | |
# These are single-byte tokens, so encode will return a list with one ID | |
sequence_breaker_ids.add(tokenizer.encode(s.encode("utf-8"))[0]) | |
print(f"Prompt: {args.prompt}") | |
print("--- Output ---") | |
generated_text, tps = generate_text( | |
session, | |
tokenizer, | |
args.prompt, | |
max_new_tokens=args.max_tokens, | |
temperature=args.temperature, | |
top_k=args.top_k, | |
stop_sequences=["<|im_end|>".encode("utf-8")], | |
dry_sequence_breakers=sequence_breaker_ids, | |
dry_range=args.dry_range, | |
dry_allowed_length=args.dry_allowed_length, | |
dry_base=args.dry_base, | |
dry_multiplier=args.dry_multiplier, | |
) | |
print(generated_text) | |
print(generated_text.decode("utf-8", "ignore")) | |
print("--------------") | |
print(f"\nPerformance: {tps:.2f} tokens/second") | |
if __name__ == "__main__": | |
main() | |