v2 / src /model_io.py
JulianHJR's picture
Upload folder using huggingface_hub
e53f10b verified
"""
Model loading and prompt construction for Qwen3-30B-A3B-Thinking.
Provides:
- load_model_and_tokenizer(): robust loader with dtype/device auto-handling
- build_thinking_prompt(problem, enable_thinking): chat-template wrapper
- generate(): a simple generation helper
"""
from pathlib import Path
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from configs.model import MODEL_CONFIG, GEN_CONFIG
def load_model_and_tokenizer(
model_dir: Path = None,
dtype: str = None,
device_map: str = "auto",
verbose: bool = True,
):
"""
Load Qwen3-30B-A3B-Thinking-2507 from local dir.
Args:
model_dir: override MODEL_CONFIG["local_dir"]
dtype: "bfloat16" | "float16" | "auto"; overrides MODEL_CONFIG
device_map: HuggingFace device_map, default "auto"
"""
mdir = Path(model_dir) if model_dir else Path(MODEL_CONFIG["local_dir"])
dt = dtype or MODEL_CONFIG["load_dtype"]
torch_dtype = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
"auto": "auto",
}.get(dt, torch.bfloat16)
if verbose:
print(f"[model_io] Loading tokenizer: {mdir}")
tokenizer = AutoTokenizer.from_pretrained(
mdir, trust_remote_code=MODEL_CONFIG["trust_remote_code"]
)
if verbose:
print(f"[model_io] Loading model dtype={dt} device_map={device_map}")
model = AutoModelForCausalLM.from_pretrained(
mdir,
torch_dtype=torch_dtype,
device_map=device_map,
trust_remote_code=MODEL_CONFIG["trust_remote_code"],
)
model.eval()
# Validate architecture matches config
cfg = model.config
assert cfg.num_hidden_layers == MODEL_CONFIG["num_layers"], \
f"num_layers mismatch: model has {cfg.num_hidden_layers}, config says {MODEL_CONFIG['num_layers']}"
if verbose:
ne = getattr(cfg, "num_experts", None) or getattr(cfg, "n_routed_experts", None)
nk = getattr(cfg, "num_experts_per_tok", None) or getattr(cfg, "top_k", None)
print(f"[model_io] layers={cfg.num_hidden_layers}, experts={ne}, top_k={nk}")
return model, tokenizer
def build_thinking_prompt(
tokenizer,
problem: str,
system_prompt: str = None,
enable_thinking: bool = True,
) -> str:
"""
Construct chat-template prompt. Qwen3-Thinking uses enable_thinking
to insert the <think> channel.
"""
sys_msg = system_prompt or MODEL_CONFIG["default_system_prompt"]
messages = [
{"role": "system", "content": sys_msg},
{"role": "user", "content": f"Problem: {problem}\n\nSolve step by step."},
]
# Some Qwen3 chat templates accept enable_thinking kwarg
try:
return tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
enable_thinking=enable_thinking,
)
except TypeError:
# Fallback: plain chat template (no thinking switch)
return tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
)
def generate(
model, tokenizer, prompt: str,
max_new_tokens: int = None,
temperature: float = None,
top_p: float = None,
do_sample: bool = None,
) -> str:
"""Generate a completion. Returns only newly generated text (no prompt)."""
max_new = max_new_tokens or GEN_CONFIG["max_new_tokens"]
t = temperature if temperature is not None else GEN_CONFIG["temperature"]
p = top_p if top_p is not None else GEN_CONFIG["top_p"]
ds = do_sample if do_sample is not None else GEN_CONFIG["do_sample"]
enc = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
out = model.generate(
**enc,
max_new_tokens=max_new,
temperature=t,
top_p=p,
do_sample=ds,
pad_token_id=tokenizer.eos_token_id,
)
gen_ids = out[0, enc["input_ids"].shape[1]:]
return tokenizer.decode(gen_ids, skip_special_tokens=True)