jhansss commited on
Commit
6bf86b0
·
1 Parent(s): 157f247

Add MiniMax-Text-01 support

Browse files
config/interface/options.yaml CHANGED
@@ -15,12 +15,14 @@ llm_models:
15
  name: Gemini 2.5 Flash
16
  - id: google/gemma-2-2b
17
  name: Gemma 2 2B
18
- - id: MiniMaxAI/MiniMax-M1-80k
19
- name: MiniMax M1 80k
20
  - id: meta-llama/Llama-3.1-8B-Instruct
21
  name: Llama 3.1 8B Instruct
22
  - id: meta-llama/Llama-3.2-3B-Instruct
23
  name: Llama 3.2 3B Instruct
 
 
24
 
25
  svs_models:
26
  - id: mandarin-espnet/mixdata_svs_visinger2_spkemb_lang_pretrained
 
15
  name: Gemini 2.5 Flash
16
  - id: google/gemma-2-2b
17
  name: Gemma 2 2B
18
+ - id: MiniMaxAI/MiniMax-Text-01
19
+ name: MiniMax Text 01
20
  - id: meta-llama/Llama-3.1-8B-Instruct
21
  name: Llama 3.1 8B Instruct
22
  - id: meta-llama/Llama-3.2-3B-Instruct
23
  name: Llama 3.2 3B Instruct
24
+ - id: Qwen/Qwen3-8B
25
+ name: Qwen3 8B
26
 
27
  svs_models:
28
  - id: mandarin-espnet/mixdata_svs_visinger2_spkemb_lang_pretrained
modules/llm/minimax.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ref: https://github.com/MiniMax-AI/MiniMax-01
2
+
3
+ from transformers import (
4
+ AutoConfig,
5
+ AutoModelForCausalLM,
6
+ AutoTokenizer,
7
+ GenerationConfig,
8
+ QuantoConfig,
9
+ )
10
+
11
+ from .base import AbstractLLMModel
12
+ from .registry import register_llm_model
13
+
14
+
15
+ @register_llm_model("MiniMaxAI/MiniMax-Text-01")
16
+ class MiniMaxLLM(AbstractLLMModel):
17
+ def __init__(
18
+ self, model_id: str, device: str = "cpu", cache_dir: str = "cache", **kwargs
19
+ ):
20
+ super().__init__(model_id, device, cache_dir, **kwargs)
21
+ assert device == "cuda", "MiniMax model only supports CUDA device"
22
+
23
+ # load hf config
24
+ hf_config = AutoConfig.from_pretrained(
25
+ "MiniMaxAI/MiniMax-Text-01", trust_remote_code=True, cache_dir=cache_dir
26
+ )
27
+
28
+ # quantization config, int8 is recommended
29
+ quantization_config = QuantoConfig(
30
+ weights="int8",
31
+ modules_to_not_convert=[
32
+ "lm_head",
33
+ "embed_tokens",
34
+ ]
35
+ + [
36
+ f"model.layers.{i}.coefficient"
37
+ for i in range(hf_config.num_hidden_layers)
38
+ ]
39
+ + [
40
+ f"model.layers.{i}.block_sparse_moe.gate"
41
+ for i in range(hf_config.num_hidden_layers)
42
+ ],
43
+ )
44
+
45
+ # assume 8 GPUs
46
+ world_size = 8
47
+ layers_per_device = hf_config.num_hidden_layers // world_size
48
+ # set device map
49
+ device_map = {
50
+ "model.embed_tokens": "cuda:0",
51
+ "model.norm": f"cuda:{world_size - 1}",
52
+ "lm_head": f"cuda:{world_size - 1}",
53
+ }
54
+ for i in range(world_size):
55
+ for j in range(layers_per_device):
56
+ device_map[f"model.layers.{i * layers_per_device + j}"] = f"cuda:{i}"
57
+
58
+ # load tokenizer
59
+ self.tokenizer = AutoTokenizer.from_pretrained(
60
+ "MiniMaxAI/MiniMax-Text-01", cache_dir=cache_dir
61
+ )
62
+
63
+ # load bfloat16 model, move to device, and apply quantization
64
+ self.quantized_model = AutoModelForCausalLM.from_pretrained(
65
+ "MiniMaxAI/MiniMax-Text-01",
66
+ torch_dtype="bfloat16",
67
+ device_map=device_map,
68
+ quantization_config=quantization_config,
69
+ trust_remote_code=True,
70
+ offload_buffers=True,
71
+ cache_dir=cache_dir,
72
+ )
73
+
74
+ def generate(
75
+ self,
76
+ prompt: str,
77
+ system_prompt: str = "You are a helpful assistant created by MiniMax based on MiniMax-Text-01 model.",
78
+ max_new_tokens: int = 20,
79
+ **kwargs,
80
+ ) -> str:
81
+ messages = [
82
+ {
83
+ "role": "system",
84
+ "content": [
85
+ {
86
+ "type": "text",
87
+ "text": system_prompt,
88
+ }
89
+ ],
90
+ },
91
+ {"role": "user", "content": [{"type": "text", "text": prompt}]},
92
+ ]
93
+ text = self.tokenizer.apply_chat_template(
94
+ messages, tokenize=False, add_generation_prompt=True
95
+ )
96
+ # tokenize and move to device
97
+ model_inputs = self.tokenizer(text, return_tensors="pt").to("cuda")
98
+ generation_config = GenerationConfig(
99
+ max_new_tokens=max_new_tokens,
100
+ eos_token_id=200020,
101
+ use_cache=True,
102
+ )
103
+ generated_ids = self.quantized_model.generate(
104
+ **model_inputs, generation_config=generation_config
105
+ )
106
+ generated_ids = [
107
+ output_ids[len(input_ids) :]
108
+ for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
109
+ ]
110
+ response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
111
+ return response