malusama commited on
Commit
cfb8cb0
·
verified ·
1 Parent(s): 1e07cbe

Add runnable ONNX example script

Browse files
Files changed (2) hide show
  1. README.md +19 -0
  2. examples/run_onnx_inference.py +116 -0
README.md CHANGED
@@ -132,6 +132,25 @@ image_embeds = image_session.run(
132
  )[0]
133
  ```
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  ## Upload
136
 
137
  Option 1:
 
132
  )[0]
133
  ```
134
 
135
+ Standalone script:
136
+
137
+ `examples/run_onnx_inference.py`
138
+
139
+ ```bash
140
+ python examples/run_onnx_inference.py \
141
+ --image pokemon.jpeg \
142
+ --text 杰尼龟 妙蛙种子 小火龙 皮卡丘
143
+ ```
144
+
145
+ You can also download from the Hub first:
146
+
147
+ ```bash
148
+ python examples/run_onnx_inference.py \
149
+ --repo-id malusama/M2-Encoder-0.4B \
150
+ --image pokemon.jpeg \
151
+ --text 杰尼龟 妙蛙种子 小火龙 皮卡丘
152
+ ```
153
+
154
  ## Upload
155
 
156
  Option 1:
examples/run_onnx_inference.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import importlib
3
+ import json
4
+ import os
5
+ import sys
6
+
7
+ import numpy as np
8
+ import onnxruntime as ort
9
+ from huggingface_hub import snapshot_download
10
+ from PIL import Image
11
+
12
+
13
+ def resolve_model_dir(args):
14
+ if args.model_dir:
15
+ return os.path.abspath(args.model_dir)
16
+ if args.repo_id:
17
+ return snapshot_download(repo_id=args.repo_id)
18
+ return os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
19
+
20
+
21
+ def load_processors(model_dir):
22
+ sys.path.insert(0, model_dir)
23
+ tokenizer_config_path = os.path.join(model_dir, "tokenizer_config.json")
24
+ with open(tokenizer_config_path, "r", encoding="utf-8") as f:
25
+ tokenizer_config = json.load(f)
26
+
27
+ GLMChineseTokenizer = importlib.import_module("tokenization_glm").GLMChineseTokenizer
28
+ M2EncoderImageProcessor = importlib.import_module("image_processing_m2_encoder").M2EncoderImageProcessor
29
+
30
+ tokenizer = GLMChineseTokenizer(
31
+ vocab_file=os.path.join(model_dir, "sp.model"),
32
+ eos_token=tokenizer_config.get("eos_token"),
33
+ pad_token=tokenizer_config.get("pad_token"),
34
+ cls_token=tokenizer_config.get("cls_token"),
35
+ mask_token=tokenizer_config.get("mask_token"),
36
+ unk_token=tokenizer_config.get("unk_token"),
37
+ )
38
+ image_processor = M2EncoderImageProcessor.from_pretrained(model_dir)
39
+ return tokenizer, image_processor
40
+
41
+
42
+ def softmax(x):
43
+ x = x - np.max(x, axis=-1, keepdims=True)
44
+ exp_x = np.exp(x)
45
+ return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
46
+
47
+
48
+ def main():
49
+ parser = argparse.ArgumentParser(description="Run M2-Encoder ONNX inference.")
50
+ parser.add_argument("--repo-id", help="Hugging Face repo id to download.")
51
+ parser.add_argument("--model-dir", help="Local model directory. Defaults to this repo root.")
52
+ parser.add_argument("--image", required=True, help="Local image path.")
53
+ parser.add_argument(
54
+ "--text",
55
+ nargs="+",
56
+ required=True,
57
+ help="Candidate text labels. Example: --text 杰尼龟 妙蛙种子 小火龙 皮卡丘",
58
+ )
59
+ args = parser.parse_args()
60
+
61
+ model_dir = resolve_model_dir(args)
62
+ tokenizer, image_processor = load_processors(model_dir)
63
+
64
+ text_inputs = tokenizer(
65
+ args.text,
66
+ padding="max_length",
67
+ truncation=True,
68
+ max_length=52,
69
+ return_special_tokens_mask=True,
70
+ return_tensors="np",
71
+ )
72
+ image_inputs = image_processor(
73
+ Image.open(args.image).convert("RGB"),
74
+ return_tensors="np",
75
+ )
76
+
77
+ text_session = ort.InferenceSession(
78
+ os.path.join(model_dir, "onnx", "text_encoder.onnx"),
79
+ providers=["CPUExecutionProvider"],
80
+ )
81
+ image_session = ort.InferenceSession(
82
+ os.path.join(model_dir, "onnx", "image_encoder.onnx"),
83
+ providers=["CPUExecutionProvider"],
84
+ )
85
+
86
+ text_embeds = text_session.run(
87
+ None,
88
+ {
89
+ "input_ids": text_inputs["input_ids"],
90
+ "attention_mask": text_inputs["attention_mask"],
91
+ },
92
+ )[0]
93
+ image_embeds = image_session.run(
94
+ None,
95
+ {"pixel_values": image_inputs["pixel_values"]},
96
+ )[0]
97
+
98
+ scores = image_embeds @ text_embeds.T
99
+ probs = softmax(scores)
100
+ ranked = [
101
+ {
102
+ "label": label,
103
+ "score": float(score),
104
+ "prob": float(prob),
105
+ }
106
+ for label, score, prob in sorted(
107
+ zip(args.text, scores[0].tolist(), probs[0].tolist()),
108
+ key=lambda item: item[2],
109
+ reverse=True,
110
+ )
111
+ ]
112
+ print(json.dumps({"ranked_results": ranked}, ensure_ascii=False, indent=2))
113
+
114
+
115
+ if __name__ == "__main__":
116
+ main()