Spaces:
Paused
Paused
| import re, os, gradio as gr, torch | |
| from transformers import AutoModel, AutoTokenizer # 关键:AutoModel(不是 AutoModelForCausalLM) | |
| # 直接使用 Qwen 模型 | |
| MODEL_ID = "kudui/qwen-s1k-overfit10" | |
| REV = None | |
| print(f"[INFO] Using MODEL_ID={MODEL_ID}") | |
| VER = "v4-dream" | |
| SYS = ( | |
| "只输出最终答案且必须是一个数字,不要任何解释或单位。" | |
| "禁止在末尾多打一位。若题目中出现小数,请按题面中最长的小数位数输出(截断/补零,不四舍五入)。" | |
| ) | |
| def collapse_digit_seps(s: str) -> str: | |
| s = s.replace(",", ",").replace("。", ".").replace(":", ":").replace("(", "(").replace(")", ")") | |
| s = re.sub(r'(?<=\d)[,\s\u00A0\u202F\u2009\u2007\u2060]+(?=\d)', '', s) | |
| return re.sub(r'\s+', ' ', s).strip() | |
| def max_decimals_in_text(s: str) -> int: | |
| decs = [len(m) for m in re.findall(r'\d+\.(\d+)', s)] | |
| return max(decs) if decs else 0 | |
| def extract_number(s: str): | |
| m = re.search(r'[-+]?\d*\.?\d+(?:e[-+]?\d+)?', s, re.I) | |
| return m.group(0) if m else None | |
| print(f"Loading tokenizer from {MODEL_ID}...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, revision=REV) | |
| print(f"Loading Dream model from {MODEL_ID}...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = AutoModel.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| revision=REV, | |
| torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32, | |
| ).to(device).eval() | |
| def predict(q: str): | |
| if not q.strip(): | |
| return "" | |
| qn = collapse_digit_seps(q) | |
| want_dec = max_decimals_in_text(qn) | |
| messages = [{"role":"system","content":SYS},{"role":"user","content":qn}] | |
| inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True) | |
| input_ids = inputs.input_ids.to(device) | |
| attention_mask = inputs.attention_mask.to(device) | |
| with torch.no_grad(): | |
| out = model.diffusion_generate( # 关键:diffusion_generate(不是 generate) | |
| input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=24, # 更稳更省 | |
| steps=32, # 32/64/128 越大越稳 | |
| temperature=0.0, | |
| top_p=None, | |
| return_dict_in_generate=True | |
| ) | |
| raw = tokenizer.decode(out.sequences[0][input_ids.shape[1]:].tolist(), skip_special_tokens=True).strip() | |
| raw = collapse_digit_seps(raw) | |
| num = extract_number(raw) | |
| if num is None: | |
| return raw | |
| if 'e' not in num.lower(): | |
| if want_dec == 0: | |
| try: | |
| num = str(int(float(num))) | |
| except Exception: | |
| pass | |
| else: | |
| sgn = '-' if num.startswith('-') else '' | |
| body = num.lstrip('+-') | |
| if '.' in body: | |
| intp, frac = body.split('.', 1) | |
| frac = (frac + '0'*want_dec)[:want_dec] | |
| else: | |
| intp, frac = body, '0'*want_dec | |
| num = sgn + intp + ('.'+frac if want_dec>0 else '') | |
| return num | |
| with gr.Blocks() as demo: | |
| gr.Markdown(f"# Dream S1K Demo \n**Model:** `{MODEL_ID}` • **Version:** `{VER}`") | |
| inp = gr.Textbox(label="Question", lines=4, placeholder="例如:Compute: 330.76 + 0.00") | |
| btn = gr.Button("Run") | |
| out = gr.Textbox(label="Answer") | |
| btn.click(fn=predict, inputs=inp, outputs=out) | |
| if __name__ == "__main__": | |
| demo.launch() | |