silk-road's picture
Update app.py
7cab2f9
raw
history blame
3.62 kB
import gradio as gr
from scipy.spatial.distance import cosine
from transformers import AutoModel, AutoTokenizer
from argparse import Namespace
import torch
from tsne import TSNE_Plot
tokenizer = AutoTokenizer.from_pretrained("silk-road/luotuo-bert")
model_args = Namespace(do_mlm=None,
pooler_type="cls",
temp=0.05,
mlp_only_train=False,
init_embeddings_model=None)
model = AutoModel.from_pretrained("silk-road/luotuo-bert",
trust_remote_code=True,
model_args=model_args)
def divide_str(s, sep=['\n', '.', '。']):
mid_len = len(s) // 2 # 中心点位置
best_sep_pos = len(s) + 1 # 最接近中心点的分隔符位置
best_sep = None # 最接近中心点的分隔符
for curr_sep in sep:
sep_pos = s.rfind(curr_sep, 0, mid_len) # 从中心点往左找分隔符
if sep_pos > 0 and abs(sep_pos - mid_len) < abs(best_sep_pos - mid_len):
best_sep_pos = sep_pos
best_sep = curr_sep
if not best_sep: # 没有找到分隔符
return s, ''
return s[:best_sep_pos + 1], s[best_sep_pos + 1:]
def strong_divide( s ):
left, right = divide_str(s)
if right != '':
return left, right
whole_sep = ['\n', '.', ',', '、', ';', ',', ';',\
':', '!', '?', '(', ')', '”', '“', \
'’', '‘', '[', ']', '{', '}', '<', '>', \
'/', '''\''', '|', '-', '=', '+', '*', '%', \
'$', '''#''', '@', '&', '^', '_', '`', '~',\
'·', '…']
left, right = divide_str(s, sep = whole_sep )
if right != '':
return left, right
mid_len = len(s) // 2
return s[:mid_len], s[mid_len:]
def generate_image(text_input):
# 将输入的文本按行分割并保存到列表中
text_input = text_input.split('\n')
label = []
for idx, i in enumerate(text_input):
if '#' in i:
label.append(i[i.find('#') + 1:])
text_input[idx] = i[:i.find('#')]
else:
label.append('No.{}'.format(idx))
divided_text = [strong_divide(i) for i in text_input]
text_left, text_right = [i[0] for i in divided_text], [i[1] for i in divided_text]
inputs = tokenizer(text_left, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
embeddings_left = model(**inputs, output_hidden_states=True, return_dict=True, sent_emb=True).pooler_output
inputs = tokenizer(text_right, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
embeddings_right = model(**inputs, output_hidden_states=True, return_dict=True, sent_emb=True).pooler_output
merged_list = text_left + text_right
merged_embed = torch.cat((embeddings_left, embeddings_right), dim=0)
tsne_plot = TSNE_Plot(merged_list, merged_embed, label=label * 2, n_annotation_positions=len(merged_list))
fig = tsne_plot.tsne_plot(n_sentence=len(merged_list), return_fig=True)
return fig
with gr.Blocks() as demo:
name = gr.inputs.Textbox(lines=20,
placeholder='在此输入歌词,每一行为一个输入,如果需要输入歌词对应的歌名,请用#隔开\n例如:听雷声 滚滚 他默默 闭紧嘴唇 停止吟唱暮色与想念 他此刻沉痛而危险 听雷声 滚滚 他渐渐 感到胸闷 乌云阻拦明月涌河湾 他起身独立向荒原#河北墨麒麟')
output = gr.Plot()
btn = gr.Button("Generate")
btn.click(fn=generate_image, inputs=name, outputs=output, api_name="generate-image")
demo.launch(debug=True)