dream-s1k-demo / app.py
况兑
Remove Dream guard logic, force Qwen model
8d0f2c1
import re, os, gradio as gr, torch
from transformers import AutoModel, AutoTokenizer # 关键:AutoModel(不是 AutoModelForCausalLM)
# 直接使用 Qwen 模型
MODEL_ID = "kudui/qwen-s1k-overfit10"
REV = None
print(f"[INFO] Using MODEL_ID={MODEL_ID}")
VER = "v4-dream"
SYS = (
"只输出最终答案且必须是一个数字,不要任何解释或单位。"
"禁止在末尾多打一位。若题目中出现小数,请按题面中最长的小数位数输出(截断/补零,不四舍五入)。"
)
def collapse_digit_seps(s: str) -> str:
s = s.replace(",", ",").replace("。", ".").replace(":", ":").replace("(", "(").replace(")", ")")
s = re.sub(r'(?<=\d)[,\s\u00A0\u202F\u2009\u2007\u2060]+(?=\d)', '', s)
return re.sub(r'\s+', ' ', s).strip()
def max_decimals_in_text(s: str) -> int:
decs = [len(m) for m in re.findall(r'\d+\.(\d+)', s)]
return max(decs) if decs else 0
def extract_number(s: str):
m = re.search(r'[-+]?\d*\.?\d+(?:e[-+]?\d+)?', s, re.I)
return m.group(0) if m else None
print(f"Loading tokenizer from {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, revision=REV)
print(f"Loading Dream model from {MODEL_ID}...")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModel.from_pretrained(
MODEL_ID,
trust_remote_code=True,
revision=REV,
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
).to(device).eval()
def predict(q: str):
if not q.strip():
return ""
qn = collapse_digit_seps(q)
want_dec = max_decimals_in_text(qn)
messages = [{"role":"system","content":SYS},{"role":"user","content":qn}]
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True)
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)
with torch.no_grad():
out = model.diffusion_generate( # 关键:diffusion_generate(不是 generate)
input_ids,
attention_mask=attention_mask,
max_new_tokens=24, # 更稳更省
steps=32, # 32/64/128 越大越稳
temperature=0.0,
top_p=None,
return_dict_in_generate=True
)
raw = tokenizer.decode(out.sequences[0][input_ids.shape[1]:].tolist(), skip_special_tokens=True).strip()
raw = collapse_digit_seps(raw)
num = extract_number(raw)
if num is None:
return raw
if 'e' not in num.lower():
if want_dec == 0:
try:
num = str(int(float(num)))
except Exception:
pass
else:
sgn = '-' if num.startswith('-') else ''
body = num.lstrip('+-')
if '.' in body:
intp, frac = body.split('.', 1)
frac = (frac + '0'*want_dec)[:want_dec]
else:
intp, frac = body, '0'*want_dec
num = sgn + intp + ('.'+frac if want_dec>0 else '')
return num
with gr.Blocks() as demo:
gr.Markdown(f"# Dream S1K Demo \n**Model:** `{MODEL_ID}` • **Version:** `{VER}`")
inp = gr.Textbox(label="Question", lines=4, placeholder="例如:Compute: 330.76 + 0.00")
btn = gr.Button("Run")
out = gr.Textbox(label="Answer")
btn.click(fn=predict, inputs=inp, outputs=out)
if __name__ == "__main__":
demo.launch()