cyberfly commited on
Commit
c1e1bf3
·
1 Parent(s): fe65777

Add LFM2-350M MNN model files and inference script

Browse files
Files changed (8) hide show
  1. config.json +22 -0
  2. export_args.json +42 -0
  3. inference.py +188 -0
  4. llm.mnn +3 -0
  5. llm.mnn.json +0 -0
  6. llm.mnn.weight +3 -0
  7. llm_config.json +20 -0
  8. tokenizer.mtok +3 -0
config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "llm_model": "llm.mnn",
3
+ "llm_weight": "llm.mnn.weight",
4
+ "backend_type": "cpu",
5
+ "thread_num": 4,
6
+ "precision": "low",
7
+ "memory": "low",
8
+ "sampler_type": "mixed",
9
+ "temperature": 0.8,
10
+ "top_k": 40,
11
+ "top_p": 0.9,
12
+ "min_p": 0.05,
13
+ "tfs_z": 1.0,
14
+ "typical": 0.95,
15
+ "repetition_penalty": 1.0,
16
+ "presence_penalty": 0.0,
17
+ "frequency_penalty": 0.0,
18
+ "penalty_window": 0,
19
+ "n_gram": 8,
20
+ "ngram_factor": 1.0,
21
+ "tokenizer_file": "tokenizer.mtok"
22
+ }
export_args.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "path": "../../../../",
3
+ "type": null,
4
+ "tokenizer_path": "../../../../",
5
+ "eagle_path": null,
6
+ "lora_path": null,
7
+ "gptq_path": null,
8
+ "dst_path": "./model",
9
+ "verbose": false,
10
+ "test": null,
11
+ "export": "mnn",
12
+ "onnx_slim": false,
13
+ "quant_bit": 4,
14
+ "quant_block": 64,
15
+ "visual_quant_bit": null,
16
+ "visual_quant_block": null,
17
+ "lm_quant_bit": 4,
18
+ "lm_quant_block": 64,
19
+ "mnnconvert": "../../../build/MNNConvert",
20
+ "ppl": false,
21
+ "awq": false,
22
+ "hqq": true,
23
+ "omni": false,
24
+ "transformer_fuse": false,
25
+ "group_conv_native": false,
26
+ "smooth": false,
27
+ "sym": false,
28
+ "visual_sym": false,
29
+ "seperate_embed": false,
30
+ "lora_split": false,
31
+ "calib_data": null,
32
+ "act_bit": 16,
33
+ "embed_bit": 16,
34
+ "act_sym": false,
35
+ "quant_config": null,
36
+ "generate_for_npu": false,
37
+ "skip_weight": false,
38
+ "omni_epochs": 20,
39
+ "omni_lr": 0.005,
40
+ "omni_wd": 0.0001,
41
+ "tie_word_embeddings": true
42
+ }
inference.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """MNN LLM Inference & Benchmark script for LFM2-350M model."""
3
+
4
+ import sys
5
+ import os
6
+ import time
7
+ import argparse
8
+
9
+ import MNN.llm as llm
10
+
11
+
12
+ def run_inference(model, prompt, stream=False):
13
+ """Run a single inference and return the response + timing context."""
14
+ model.reset()
15
+ response = model.response(prompt, stream)
16
+ if stream:
17
+ output = ""
18
+ for chunk in response:
19
+ print(chunk, end="", flush=True)
20
+ output += chunk
21
+ print()
22
+ return output
23
+ return response
24
+
25
+
26
+ def benchmark(model, prompts, warmup=1, runs=3):
27
+ """Benchmark prefill and decode performance across multiple prompts."""
28
+ print("=" * 60)
29
+ print("BENCHMARK")
30
+ print("=" * 60)
31
+
32
+ # Warmup
33
+ print(f"\nWarmup ({warmup} run(s))...")
34
+ for i in range(warmup):
35
+ model.reset()
36
+ model.response(prompts[0], False)
37
+
38
+ results = []
39
+ for idx, prompt in enumerate(prompts):
40
+ prompt_results = []
41
+ for run in range(runs):
42
+ model.reset()
43
+ t0 = time.perf_counter()
44
+ response = model.response(prompt, False)
45
+ t1 = time.perf_counter()
46
+ wall_time = t1 - t0
47
+
48
+ ctx = model.context
49
+ ctx.refresh()
50
+
51
+ prompt_tokens = ctx.prompt_len
52
+ gen_tokens = ctx.gen_seq_len
53
+ prefill_us = ctx.prefill_us
54
+ decode_us = ctx.decode_us
55
+
56
+ prefill_s = prefill_us / 1e6 if prefill_us else 0
57
+ decode_s = decode_us / 1e6 if decode_us else 0
58
+
59
+ prefill_tps = prompt_tokens / prefill_s if prefill_s > 0 else 0
60
+ decode_tps = gen_tokens / decode_s if decode_s > 0 else 0
61
+
62
+ prompt_results.append({
63
+ "prompt_tokens": prompt_tokens,
64
+ "gen_tokens": gen_tokens,
65
+ "wall_time": wall_time,
66
+ "prefill_s": prefill_s,
67
+ "decode_s": decode_s,
68
+ "prefill_tps": prefill_tps,
69
+ "decode_tps": decode_tps,
70
+ "response": response,
71
+ })
72
+
73
+ results.append(prompt_results)
74
+
75
+ # Print per-prompt summary
76
+ avg_prefill_tps = sum(r["prefill_tps"] for r in prompt_results) / runs
77
+ avg_decode_tps = sum(r["decode_tps"] for r in prompt_results) / runs
78
+ avg_wall = sum(r["wall_time"] for r in prompt_results) / runs
79
+ prompt_tokens = prompt_results[0]["prompt_tokens"]
80
+ avg_gen = sum(r["gen_tokens"] for r in prompt_results) / runs
81
+
82
+ print(f"\nPrompt {idx + 1}: \"{prompt[:60]}{'...' if len(prompt) > 60 else ''}\"")
83
+ print(f" Prompt tokens : {prompt_tokens}")
84
+ print(f" Avg gen tokens : {avg_gen:.1f}")
85
+ print(f" Avg wall time : {avg_wall:.3f} s")
86
+ print(f" Avg prefill : {avg_prefill_tps:.1f} tok/s")
87
+ print(f" Avg decode : {avg_decode_tps:.1f} tok/s")
88
+
89
+ # Overall summary
90
+ all_runs = [r for pr in results for r in pr]
91
+ overall_prefill = sum(r["prefill_tps"] for r in all_runs) / len(all_runs)
92
+ overall_decode = sum(r["decode_tps"] for r in all_runs) / len(all_runs)
93
+ print("\n" + "=" * 60)
94
+ print(f"Overall avg prefill : {overall_prefill:.1f} tok/s")
95
+ print(f"Overall avg decode : {overall_decode:.1f} tok/s")
96
+ print("=" * 60)
97
+
98
+ return results
99
+
100
+
101
+ def main():
102
+ parser = argparse.ArgumentParser(description="MNN LLM Inference & Benchmark")
103
+ parser.add_argument("--config", default="config.json",
104
+ help="Path to MNN config.json (default: config.json)")
105
+ parser.add_argument("--prompt", default=None,
106
+ help="Single prompt for inference")
107
+ parser.add_argument("--stream", action="store_true",
108
+ help="Stream output tokens")
109
+ parser.add_argument("--benchmark", action="store_true",
110
+ help="Run benchmark suite")
111
+ parser.add_argument("--warmup", type=int, default=1,
112
+ help="Warmup runs for benchmark (default: 1)")
113
+ parser.add_argument("--runs", type=int, default=3,
114
+ help="Benchmark runs per prompt (default: 3)")
115
+ parser.add_argument("--backend", default=None,
116
+ choices=["cpu", "metal"],
117
+ help="Override backend type")
118
+ parser.add_argument("--threads", type=int, default=None,
119
+ help="Override thread count")
120
+ parser.add_argument("--max-tokens", type=int, default=128,
121
+ help="Max tokens to generate (default: 128)")
122
+ args = parser.parse_args()
123
+
124
+ model_dir = os.path.dirname(os.path.abspath(args.config))
125
+ config_path = os.path.abspath(args.config)
126
+
127
+ print(f"Loading model from: {config_path}")
128
+ model = llm.create(config_path)
129
+
130
+ if args.backend:
131
+ model.set_config({"backend_type": args.backend})
132
+ if args.threads:
133
+ model.set_config({"thread_num": args.threads})
134
+ model.set_config({"max_new_tokens": args.max_tokens})
135
+
136
+ model.load()
137
+ print("Model loaded.\n")
138
+
139
+ if args.benchmark:
140
+ bench_prompts = [
141
+ "Hello!",
142
+ "What is the capital of France?",
143
+ "Explain quantum computing in simple terms.",
144
+ "Write a short poem about the ocean.",
145
+ "List 5 programming languages and their main use cases.",
146
+ ]
147
+ benchmark(model, bench_prompts, warmup=args.warmup, runs=args.runs)
148
+ elif args.prompt:
149
+ print(f"Prompt: {args.prompt}\n")
150
+ response = run_inference(model, args.prompt, stream=args.stream)
151
+ if not args.stream:
152
+ print(f"Response:\n{response}")
153
+
154
+ ctx = model.context
155
+ ctx.refresh()
156
+ print(f"\n--- Stats ---")
157
+ print(f"Prompt tokens : {ctx.prompt_len}")
158
+ print(f"Gen tokens : {ctx.gen_seq_len}")
159
+ prefill_s = ctx.prefill_us / 1e6 if ctx.prefill_us else 0
160
+ decode_s = ctx.decode_us / 1e6 if ctx.decode_us else 0
161
+ if prefill_s > 0:
162
+ print(f"Prefill : {ctx.prompt_len / prefill_s:.1f} tok/s ({prefill_s:.3f}s)")
163
+ if decode_s > 0:
164
+ print(f"Decode : {ctx.gen_seq_len / decode_s:.1f} tok/s ({decode_s:.3f}s)")
165
+ else:
166
+ # Interactive mode
167
+ print("Interactive mode (type 'quit' to exit)\n")
168
+ while True:
169
+ try:
170
+ user_input = input("You: ").strip()
171
+ except (EOFError, KeyboardInterrupt):
172
+ print("\nBye!")
173
+ break
174
+ if user_input.lower() in ("quit", "exit"):
175
+ break
176
+ if not user_input:
177
+ continue
178
+ response = run_inference(model, user_input, stream=True)
179
+ ctx = model.context
180
+ ctx.refresh()
181
+ prefill_s = ctx.prefill_us / 1e6 if ctx.prefill_us else 0
182
+ decode_s = ctx.decode_us / 1e6 if ctx.decode_us else 0
183
+ if decode_s > 0:
184
+ print(f" [{ctx.gen_seq_len} tokens, {ctx.gen_seq_len / decode_s:.1f} tok/s]")
185
+
186
+
187
+ if __name__ == "__main__":
188
+ main()
llm.mnn ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:783e5af16be0b4a7bb58e626ff36cb8d496d209d1f484c4ef414619dc1bff749
3
+ size 329296
llm.mnn.json ADDED
The diff for this file is too large to render. See raw diff
 
llm.mnn.weight ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:efaf91b109d51dab379ebc37be735f5b7173babb37cb9af65268805eb825cce9
3
+ size 221790210
llm_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "lfm2",
3
+ "hidden_size": 1024,
4
+ "layer_nums": 16,
5
+ "attention_mask": "float",
6
+ "attention_type": "full",
7
+ "is_mrope": false,
8
+ "jinja": {
9
+ "chat_template": "{{- bos_token -}}{%- set system_prompt = \"\" -%}{%- set ns = namespace(system_prompt=\"\") -%}{%- if messages[0][\"role\"] == \"system\" -%} {%- set ns.system_prompt = messages[0][\"content\"] -%} {%- set messages = messages[1:] -%}{%- endif -%}{%- if tools -%} {%- set ns.system_prompt = ns.system_prompt + (\"\n\" if ns.system_prompt else \"\") + \"List of tools: <|tool_list_start|>[\" -%} {%- for tool in tools -%} {%- if tool is not string -%} {%- set tool = tool | tojson -%} {%- endif -%} {%- set ns.system_prompt = ns.system_prompt + tool -%} {%- if not loop.last -%} {%- set ns.system_prompt = ns.system_prompt + \", \" -%} {%- endif -%} {%- endfor -%} {%- set ns.system_prompt = ns.system_prompt + \"]<|tool_list_end|>\" -%}{%- endif -%}{%- if ns.system_prompt -%} {{- \"<|im_start|>system\n\" + ns.system_prompt + \"<|im_end|>\n\" -}}{%- endif -%}{%- for message in messages -%} {{- \"<|im_start|>\" + message[\"role\"] + \"\n\" -}} {%- set content = message[\"content\"] -%} {%- if content is not string -%} {%- set content = content | tojson -%} {%- endif -%} {%- if message[\"role\"] == \"tool\" -%} {%- set content = \"<|tool_response_start|>\" + content + \"<|tool_response_end|>\" -%} {%- endif -%} {{- content + \"<|im_end|>\n\" -}}{%- endfor -%}{%- if add_generation_prompt -%} {{- \"<|im_start|>assistant\n\" -}}{%- endif -%}",
10
+ "bos": "<|startoftext|>",
11
+ "eos": "<|im_end|>"
12
+ },
13
+ "tie_embeddings": [
14
+ 179847170,
15
+ 213401602,
16
+ 8388608,
17
+ 4,
18
+ 64
19
+ ]
20
+ }
tokenizer.mtok ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:780f1e05d303486716c28bb13c7f2897383f59a3d4aa6bd02f15dda4301a0389
3
+ size 1670119