HF Deploy Script
Initial deployment: diffusion-chatbot
a919dff
import os
import math
import copy
import json
import torch
import torch.nn.functional as F
from flask import Flask, request, jsonify, Response
from transformers import AutoTokenizer, AutoModelForMaskedLM
app = Flask(__name__)
model = None
tokenizer = None
device = None
def add_gumbel_noise(logits, temperature):
if temperature == 0:
return logits
logits = logits.to(torch.float64)
noise = torch.rand_like(logits, dtype=torch.float64)
g = (-torch.log(noise)) ** temperature
return logits.exp() / g
def get_num_transfer_tokens(mask_index, steps):
mask_num = mask_index.sum(dim=1, keepdim=True)
base = mask_num // steps
rem = mask_num % steps
out = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.long) + base
for i in range(mask_num.size(0)):
out[i, : rem[i]] += 1
return out
def build_staircase_attention_mask(x, block_size, pad_id):
B, T = x.shape
device = x.device
valid = x != pad_id
pos_raw = torch.cumsum(valid.long(), dim=-1)
position_ids = torch.where(valid, pos_raw - 1, torch.zeros_like(pos_raw)).long()
col = torch.arange(T, device=device)
block_ids = (col // block_size).view(1, T).expand(B, T)
block_ids = torch.where(valid, block_ids, torch.full_like(block_ids, -1))
q = block_ids.view(B, 1, T, 1)
k = block_ids.view(B, 1, 1, T)
attn = (k <= q) & (q >= 0) & (k >= 0)
return attn, position_ids
def diffusion_step_block(logits, x_block, mask_block, num_transfer, temperature, remasking):
B, L, _ = logits.shape
if not mask_block.any():
return x_block
noisy = add_gumbel_noise(logits, temperature)
x0 = noisy.argmax(dim=-1)
if remasking == "low_confidence":
p = F.softmax(logits, dim=-1)
conf = p.gather(-1, x0.unsqueeze(-1)).squeeze(-1)
elif remasking == "random":
conf = torch.rand((B, L), device=logits.device)
else:
raise ValueError(remasking)
x0 = torch.where(mask_block, x0, x_block)
neg_inf = torch.full_like(conf, -float("inf"))
conf = torch.where(mask_block, conf, neg_inf)
commit = torch.zeros_like(x_block, dtype=torch.bool)
for i in range(B):
k = int(num_transfer[i].item())
if k > 0:
valid = (conf[i] > -float("inf")).sum().item()
k = min(k, valid)
_, idx = torch.topk(conf[i], k)
commit[i, idx] = True
out = x_block.clone()
out[commit] = x0[commit]
return out
@torch.no_grad()
def generate(
model,
tokenizer,
prompt,
steps=128,
max_new_tokens=128,
block_size=32,
temperature=0.0,
cfg_scale=0.0,
remasking="low_confidence",
capture_interval=0,
):
device = model.device
mask_id = tokenizer.mask_token_id
pad_id = tokenizer.pad_token_id
if pad_id is None:
pad_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.mask_token_id
if isinstance(prompt, torch.Tensor):
x = prompt.to(device).long()
else:
if isinstance(prompt[0], (list, tuple)):
max_len = max(len(p) for p in prompt)
x = torch.full((len(prompt), max_len), pad_id, device=device, dtype=torch.long)
for i, p in enumerate(prompt):
x[i, : len(p)] = torch.tensor(p, device=device)
else:
x = torch.tensor(prompt, device=device).long()
if x.dim() == 1:
x = x.unsqueeze(0)
B = x.size(0)
finished = torch.zeros(B, dtype=torch.bool, device=device)
num_blocks = math.ceil(max_new_tokens / block_size)
steps_per_block = math.ceil(steps / num_blocks)
generated = 0
intermediates = []
total_step = 0
while generated < max_new_tokens:
if finished.all():
break
T_prefix = x.size(1)
offset = T_prefix % block_size
room = block_size if offset == 0 else block_size - offset
cur_len = min(room, max_new_tokens - generated)
if cur_len <= 0:
break
attn_pfx, pos_pfx = build_staircase_attention_mask(x, block_size, pad_id)
out = model(x, attention_mask=attn_pfx, position_ids=pos_pfx, use_cache=True)
cond_past = out.past_key_values
if cfg_scale > 0:
un_x = x.clone()
un_x[:] = mask_id
out_un = model(un_x, attention_mask=attn_pfx, position_ids=pos_pfx, use_cache=True)
uncond_past = out_un.past_key_values
else:
uncond_past = None
block = torch.full((B, cur_len), mask_id, device=device, dtype=torch.long)
block[finished] = pad_id
x = torch.cat([x, block], dim=1)
T_total = x.size(1)
block_mask = x[:, -cur_len:] == mask_id
num_transfer = get_num_transfer_tokens(block_mask, steps_per_block)
eff_steps = num_transfer.size(1)
full_attn, full_pos = build_staircase_attention_mask(x, block_size, pad_id)
attn_blk = full_attn[:, :, T_prefix:T_total, :]
pos_blk = full_pos[:, T_prefix:T_total]
for t in range(eff_steps):
x_blk = x[:, T_prefix:T_total]
m_blk = x_blk == mask_id
cond_logits = model(
x_blk, attention_mask=attn_blk, position_ids=pos_blk,
past_key_values=copy.deepcopy(cond_past), use_cache=False
).logits
logits = cond_logits
if cfg_scale > 0:
un_logits = model(
x_blk, attention_mask=attn_blk, position_ids=pos_blk,
past_key_values=copy.deepcopy(uncond_past), use_cache=False
).logits
logits = un_logits + (cfg_scale + 1.0) * (cond_logits - un_logits)
x_blk_new = diffusion_step_block(
logits, x_blk, m_blk, num_transfer[:, t], temperature, remasking
)
x[:, T_prefix:T_total] = x_blk_new
if capture_interval > 0 and total_step % capture_interval == 0:
intermediates.append(x.clone())
total_step += 1
if tokenizer.eos_token_id is not None:
finished |= (x_blk_new == tokenizer.eos_token_id).any(dim=1)
if finished.all():
break
generated += cur_len
if finished.all():
break
if capture_interval > 0:
return x, intermediates
return x
@torch.no_grad()
def generate_stream(
model,
tokenizer,
prompt,
steps=128,
max_new_tokens=128,
block_size=32,
temperature=0.0,
cfg_scale=0.0,
remasking="low_confidence",
capture_interval=10,
):
device = model.device
mask_id = tokenizer.mask_token_id
pad_id = tokenizer.pad_token_id
if pad_id is None:
pad_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.mask_token_id
if isinstance(prompt, torch.Tensor):
x = prompt.to(device).long()
else:
if isinstance(prompt[0], (list, tuple)):
max_len = max(len(p) for p in prompt)
x = torch.full((len(prompt), max_len), pad_id, device=device, dtype=torch.long)
for i, p in enumerate(prompt):
x[i, : len(p)] = torch.tensor(p, device=device)
else:
x = torch.tensor(prompt, device=device).long()
if x.dim() == 1:
x = x.unsqueeze(0)
B = x.size(0)
finished = torch.zeros(B, dtype=torch.bool, device=device)
num_blocks = math.ceil(max_new_tokens / block_size)
steps_per_block = math.ceil(steps / num_blocks)
generated = 0
total_step = 0
prompt_len = x.size(1)
while generated < max_new_tokens:
if finished.all():
break
T_prefix = x.size(1)
offset = T_prefix % block_size
room = block_size if offset == 0 else block_size - offset
cur_len = min(room, max_new_tokens - generated)
if cur_len <= 0:
break
attn_pfx, pos_pfx = build_staircase_attention_mask(x, block_size, pad_id)
out = model(x, attention_mask=attn_pfx, position_ids=pos_pfx, use_cache=True)
cond_past = out.past_key_values
if cfg_scale > 0:
un_x = x.clone()
un_x[:] = mask_id
out_un = model(un_x, attention_mask=attn_pfx, position_ids=pos_pfx, use_cache=True)
uncond_past = out_un.past_key_values
else:
uncond_past = None
block = torch.full((B, cur_len), mask_id, device=device, dtype=torch.long)
block[finished] = pad_id
x = torch.cat([x, block], dim=1)
T_total = x.size(1)
block_mask = x[:, -cur_len:] == mask_id
num_transfer = get_num_transfer_tokens(block_mask, steps_per_block)
eff_steps = num_transfer.size(1)
full_attn, full_pos = build_staircase_attention_mask(x, block_size, pad_id)
attn_blk = full_attn[:, :, T_prefix:T_total, :]
pos_blk = full_pos[:, T_prefix:T_total]
for t in range(eff_steps):
x_blk = x[:, T_prefix:T_total]
m_blk = x_blk == mask_id
cond_logits = model(
x_blk, attention_mask=attn_blk, position_ids=pos_blk,
past_key_values=copy.deepcopy(cond_past), use_cache=False
).logits
logits = cond_logits
if cfg_scale > 0:
un_logits = model(
x_blk, attention_mask=attn_blk, position_ids=pos_blk,
past_key_values=copy.deepcopy(uncond_past), use_cache=False
).logits
logits = un_logits + (cfg_scale + 1.0) * (cond_logits - un_logits)
x_blk_new = diffusion_step_block(
logits, x_blk, m_blk, num_transfer[:, t], temperature, remasking
)
x[:, T_prefix:T_total] = x_blk_new
if total_step % capture_interval == 0:
new_tokens = x[0, prompt_len:prompt_len + max_new_tokens].tolist()
text = tokenizer.decode(new_tokens, skip_special_tokens=True)
yield {
"type": "intermediate",
"step": total_step,
"text": text,
"total_steps": steps
}
total_step += 1
if tokenizer.eos_token_id is not None:
finished |= (x_blk_new == tokenizer.eos_token_id).any(dim=1)
if finished.all():
break
generated += cur_len
if finished.all():
break
new_tokens = x[0, prompt_len:prompt_len + max_new_tokens].tolist()
final_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
yield {
"type": "final",
"text": final_text,
"total_steps": total_step
}
def load_model():
global model, tokenizer, device
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = os.getenv("MODEL_NAME", "dllm-hub/Qwen3-0.6B-diffusion-bd3lm-v0.1")
print(f"Loading model {model_name} on {device}...")
model = AutoModelForMaskedLM.from_pretrained(
model_name,
dtype=torch.bfloat16,
trust_remote_code=True
).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
print("Model loaded successfully!")
@app.route('/health', methods=['GET'])
def health():
return jsonify({"status": "healthy", "model_loaded": model is not None})
@app.route('/generate', methods=['POST'])
def generate_text():
if model is None or tokenizer is None:
return jsonify({"error": "Model not loaded"}), 503
data = request.get_json()
if not data or 'prompt' not in data:
return jsonify({"error": "Missing 'prompt' field"}), 400
prompt = data['prompt']
steps = data.get('steps', 256)
max_new_tokens = data.get('max_new_tokens', 256)
block_size = data.get('block_size', 32)
temperature = data.get('temperature', 0.0)
cfg_scale = data.get('cfg_scale', 0.0)
remasking = data.get('remasking', 'low_confidence')
system_prompt = data.get('system_prompt', 'You are a helpful AI assistant.')
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
]
encoded = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
enable_thinking=False
)
input_ids = torch.tensor([encoded], dtype=torch.long, device=device)
output = generate(
model,
tokenizer,
input_ids,
steps=steps,
max_new_tokens=max_new_tokens,
block_size=block_size,
temperature=temperature,
cfg_scale=cfg_scale,
remasking=remasking,
)
prompt_len = len(encoded)
new_tokens = output[0, prompt_len:prompt_len + max_new_tokens].tolist()
generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
return jsonify({
"prompt": prompt,
"generated_text": generated_text,
"parameters": {
"steps": steps,
"max_new_tokens": max_new_tokens,
"block_size": block_size,
"temperature": temperature,
"cfg_scale": cfg_scale,
"remasking": remasking
}
})
@app.route('/generate_stream', methods=['POST'])
def generate_text_stream():
if model is None or tokenizer is None:
return jsonify({"error": "Model not loaded"}), 503
data = request.get_json()
if not data or 'prompt' not in data:
return jsonify({"error": "Missing 'prompt' field"}), 400
prompt = data['prompt']
steps = data.get('steps', 256)
max_new_tokens = data.get('max_new_tokens', 256)
block_size = data.get('block_size', 32)
temperature = data.get('temperature', 0.0)
cfg_scale = data.get('cfg_scale', 0.0)
remasking = data.get('remasking', 'low_confidence')
system_prompt = data.get('system_prompt', 'You are a helpful AI assistant.')
capture_interval = data.get('capture_interval', 10)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
]
encoded = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
enable_thinking=False
)
input_ids = torch.tensor([encoded], dtype=torch.long, device=device)
output, intermediates = generate(
model,
tokenizer,
input_ids,
steps=steps,
max_new_tokens=max_new_tokens,
block_size=block_size,
temperature=temperature,
cfg_scale=cfg_scale,
remasking=remasking,
capture_interval=capture_interval,
)
prompt_len = len(encoded)
intermediate_states = []
for i, intermediate in enumerate(intermediates):
new_tokens = intermediate[0, prompt_len:prompt_len + max_new_tokens].tolist()
text = tokenizer.decode(new_tokens, skip_special_tokens=True)
intermediate_states.append({
"step": i * capture_interval,
"text": text
})
new_tokens = output[0, prompt_len:prompt_len + max_new_tokens].tolist()
generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
return jsonify({
"prompt": prompt,
"generated_text": generated_text,
"intermediate_states": intermediate_states,
"parameters": {
"steps": steps,
"max_new_tokens": max_new_tokens,
"block_size": block_size,
"temperature": temperature,
"cfg_scale": cfg_scale,
"remasking": remasking,
"capture_interval": capture_interval
}
})
@app.route('/generate_sse', methods=['POST'])
def generate_text_sse():
if model is None or tokenizer is None:
return jsonify({"error": "Model not loaded"}), 503
data = request.get_json()
if not data or 'prompt' not in data:
return jsonify({"error": "Missing 'prompt' field"}), 400
prompt = data['prompt']
steps = data.get('steps', 256)
max_new_tokens = data.get('max_new_tokens', 256)
block_size = data.get('block_size', 32)
temperature = data.get('temperature', 0.0)
cfg_scale = data.get('cfg_scale', 0.0)
remasking = data.get('remasking', 'low_confidence')
system_prompt = data.get('system_prompt', 'You are a helpful AI assistant.')
capture_interval = data.get('capture_interval', 10)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
]
encoded = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
enable_thinking=False
)
input_ids = torch.tensor([encoded], dtype=torch.long, device=device)
def stream():
for state in generate_stream(
model,
tokenizer,
input_ids,
steps=steps,
max_new_tokens=max_new_tokens,
block_size=block_size,
temperature=temperature,
cfg_scale=cfg_scale,
remasking=remasking,
capture_interval=capture_interval,
):
yield f"data: {json.dumps(state)}\n\n"
return Response(
stream(),
mimetype='text/event-stream',
headers={
'Cache-Control': 'no-cache',
'X-Accel-Buffering': 'no',
}
)
if __name__ == '__main__':
load_model()
app.run(host='0.0.0.0', port=int(os.getenv('PORT', 5000)))