M2-Encoder-1B / examples /run_onnx_inference.py
malusama's picture
Add runnable ONNX example script
8b09a83 verified
import argparse
import importlib
import json
import os
import sys
import numpy as np
import onnxruntime as ort
from huggingface_hub import snapshot_download
from PIL import Image
def resolve_model_dir(args):
if args.model_dir:
return os.path.abspath(args.model_dir)
if args.repo_id:
return snapshot_download(repo_id=args.repo_id)
return os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
def load_processors(model_dir):
sys.path.insert(0, model_dir)
tokenizer_config_path = os.path.join(model_dir, "tokenizer_config.json")
with open(tokenizer_config_path, "r", encoding="utf-8") as f:
tokenizer_config = json.load(f)
GLMChineseTokenizer = importlib.import_module("tokenization_glm").GLMChineseTokenizer
M2EncoderImageProcessor = importlib.import_module("image_processing_m2_encoder").M2EncoderImageProcessor
tokenizer = GLMChineseTokenizer(
vocab_file=os.path.join(model_dir, "sp.model"),
eos_token=tokenizer_config.get("eos_token"),
pad_token=tokenizer_config.get("pad_token"),
cls_token=tokenizer_config.get("cls_token"),
mask_token=tokenizer_config.get("mask_token"),
unk_token=tokenizer_config.get("unk_token"),
)
image_processor = M2EncoderImageProcessor.from_pretrained(model_dir)
return tokenizer, image_processor
def softmax(x):
x = x - np.max(x, axis=-1, keepdims=True)
exp_x = np.exp(x)
return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
def main():
parser = argparse.ArgumentParser(description="Run M2-Encoder ONNX inference.")
parser.add_argument("--repo-id", help="Hugging Face repo id to download.")
parser.add_argument("--model-dir", help="Local model directory. Defaults to this repo root.")
parser.add_argument("--image", required=True, help="Local image path.")
parser.add_argument(
"--text",
nargs="+",
required=True,
help="Candidate text labels. Example: --text 杰尼龟 妙蛙种子 小火龙 皮卡丘",
)
args = parser.parse_args()
model_dir = resolve_model_dir(args)
tokenizer, image_processor = load_processors(model_dir)
text_inputs = tokenizer(
args.text,
padding="max_length",
truncation=True,
max_length=52,
return_special_tokens_mask=True,
return_tensors="np",
)
image_inputs = image_processor(
Image.open(args.image).convert("RGB"),
return_tensors="np",
)
text_session = ort.InferenceSession(
os.path.join(model_dir, "onnx", "text_encoder.onnx"),
providers=["CPUExecutionProvider"],
)
image_session = ort.InferenceSession(
os.path.join(model_dir, "onnx", "image_encoder.onnx"),
providers=["CPUExecutionProvider"],
)
text_embeds = text_session.run(
None,
{
"input_ids": text_inputs["input_ids"],
"attention_mask": text_inputs["attention_mask"],
},
)[0]
image_embeds = image_session.run(
None,
{"pixel_values": image_inputs["pixel_values"]},
)[0]
scores = image_embeds @ text_embeds.T
probs = softmax(scores)
ranked = [
{
"label": label,
"score": float(score),
"prob": float(prob),
}
for label, score, prob in sorted(
zip(args.text, scores[0].tolist(), probs[0].tolist()),
key=lambda item: item[2],
reverse=True,
)
]
print(json.dumps({"ranked_results": ranked}, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()