llama-python-streamingllm / mods /btn_status_bar.py
Limour's picture
Upload 2 files
b9cb0bd verified
import numpy as np
def init(cfg):
chat_template = cfg['chat_template']
model = cfg['model']
s_info = cfg['s_info']
lock = cfg['session_lock']
# ========== 预处理 key、desc ==========
def str_tokenize(s):
s = model.tokenize((chat_template.nl + s).encode('utf-8'), add_bos=False, special=False)
if s[0] in chat_template.onenl:
return s[1:]
else:
return s
text_format = cfg['text_format']
for x in cfg['btn_status_bar_list']:
x['key'] = text_format(x['key'],
char=cfg['role_char'].value,
user=cfg['role_usr'].value)
x['key_t'] = str_tokenize(x['key'])
x['desc'] = text_format(x['desc'],
char=cfg['role_char'].value,
user=cfg['role_usr'].value)
if x['desc']:
x['desc_t'] = str_tokenize(x['desc'])
# ========== 预处理 构造函数 mask ==========
def btn_status_bar_fn_mask():
_shape1d = model.scores.shape[-1]
mask = np.full((_shape1d,), -np.inf, dtype=np.single)
return mask
# ========== 预处理 构造函数 数字 ==========
def btn_status_bar_fn_int(unit: str):
t_int = str_tokenize('0123456789')
assert len(t_int) == 10
fn_int_mask = btn_status_bar_fn_mask()
fn_int_mask[chat_template.eos] = 0
fn_int_mask[t_int] = 0
if unit:
unit_t = str_tokenize(unit)
fn_int_mask[unit_t[0]] = 0
def logits_processor(_input_ids, logits):
return logits + fn_int_mask
def inner(eval_t, sample_t):
retn = []
while True:
token = sample_t(logits_processor)
# ========== 不是数字就结束 ==========
if token in chat_template.eos:
break
if unit and token == unit_t[0]:
break
# ========== 是数字就继续 ==========
retn.append(token)
eval_t([token])
if unit:
eval_t(unit_t) # 添加单位
retn.extend(unit_t)
return model.str_detokenize(retn)
return inner
# ========== 预处理 构造函数 集合 ==========
def btn_status_bar_fn_set(value):
value_t = {_x[0][0]: _x for _x in ((str_tokenize(_y), _y) for _y in value)}
fn_set_mask = btn_status_bar_fn_mask()
fn_set_mask[list(value_t.keys())] = 0
def logits_processor(_input_ids, logits):
return logits + fn_set_mask
def inner(eval_t, sample_t):
token = sample_t(logits_processor)
eval_t(value_t[token][0])
return value_t[token][1]
return inner
# ========== 预处理 构造函数 字符串 ==========
def btn_status_bar_fn_str():
def inner(eval_t, sample_t):
retn = []
tmp = ''
while True:
token = sample_t(None)
if token in chat_template.eos:
break
retn.append(token)
tmp = model.str_detokenize(retn)
if tmp.endswith('\n') or tmp.endswith('\r'):
break
# ========== 继续 ==========
eval_t([token])
return tmp.strip()
return inner
# ========== 预处理 value ==========
for x in cfg['btn_status_bar_list']:
for y in x['combine']:
if y['prefix']:
y['prefix_t'] = str_tokenize(y['prefix'])
if y['type'] == 'int':
y['fn'] = btn_status_bar_fn_int(y['unit'])
elif y['type'] == 'set':
y['fn'] = btn_status_bar_fn_set(y['value'])
elif y['type'] == 'str':
y['fn'] = btn_status_bar_fn_str()
else:
pass
# ========== 添加分隔标记 ==========
for i, x in enumerate(cfg['btn_status_bar_list']):
if i == 0: # 跳过第一个
continue
x['key_t'] = chat_template.im_end_nl[-1:] + x['key_t']
del x # 避免干扰
del y
# print(cfg['btn_status_bar_list'])
# ========== 输出状态栏 ==========
def btn_status_bar(_n_keep, _n_discard,
_temperature, _repeat_penalty, _frequency_penalty,
_presence_penalty, _repeat_last_n, _top_k,
_top_p, _min_p, _typical_p,
_tfs_z, _mirostat_mode, _mirostat_eta,
_mirostat_tau, _usr, _char,
_rag, _max_tokens):
with lock:
if not cfg['session_active']:
raise RuntimeError
if cfg['btn_stop_status']:
yield [], model.venv_info
return
# ========== 临时的eval和sample ==========
def eval_t(tokens):
return model.eval_t(
tokens=tokens,
n_keep=_n_keep,
n_discard=_n_discard,
im_start=chat_template.im_start_token
)
def sample_t(logits_processor):
return model.sample_t(
top_k=_top_k,
top_p=_top_p,
min_p=_min_p,
typical_p=_typical_p,
temp=_temperature,
repeat_penalty=_repeat_penalty,
repeat_last_n=_repeat_last_n,
frequency_penalty=_frequency_penalty,
presence_penalty=_presence_penalty,
tfs_z=_tfs_z,
mirostat_mode=_mirostat_mode,
mirostat_tau=_mirostat_tau,
mirostat_eta=_mirostat_eta,
logits_processor=logits_processor
)
# ========== 初始化输出模版 ==========
model.venv_create('status') # 创建隔离环境
eval_t(chat_template('状态')) # 开始标记
# ========== 流式输出 ==========
df = [] # 清空
for _x in cfg['btn_status_bar_list']:
# ========== 属性 ==========
df.append([_x['key'], ''])
eval_t(_x['key_t'])
if _x['desc']:
eval_t(_x['desc_t'])
yield df, model.venv_info
# ========== 值 ==========
for _y in _x['combine']:
if _y['prefix']:
if df[-1][-1]:
df[-1][-1] += _y['prefix']
else:
df[-1][-1] += _y['prefix'].lstrip(':')
eval_t(_y['prefix_t'])
df[-1][-1] += _y['fn'](eval_t, sample_t)
yield df, model.venv_info
eval_t(chat_template.im_end_nl) # 结束标记
# ========== 清理上一次生成的状态栏 ==========
model.venv_remove('status', keep_last=1)
yield df, model.venv_info
cfg['btn_status_bar_fn'] = {
'fn': btn_status_bar,
'inputs': cfg['setting'],
'outputs': [cfg['status_bar'], s_info]
}
cfg['btn_status_bar_fn'].update(cfg['btn_concurrency'])
cfg['btn_status_bar'].click(
**cfg['btn_start']
).success(
**cfg['btn_status_bar_fn']
).success(
**cfg['btn_finish']
)