hahahafofo's picture
init
79ec61a
raw
history blame contribute delete
No virus
6.05 kB
import re
import gradio as gr
from typing import List
from models import models
from loguru import logger
import re
PROMPT_TEMPLATE = """\
使用中文{query_str}:
{context_str}
"""
def get_text_lines(input_txt: str) -> List[str]:
lines = input_txt.splitlines()
lines = [line.strip() for line in lines if line.strip()]
return lines
stop_chars_set = {
'.', '!', '?', '。', '!', '?', '…', ';', ';', ':', ':',
'”', '’', ')', '】', '》', '」', '』', '〕', '〉',
'》', '〗', '〞', '〟', '»', '"', "'", ')', ']', '}'
}
def split_in_line(input_txt: str, limit_length: int) -> List[str]:
new_text = ''
contents = []
outputs = []
for text in input_txt:
new_text += text
if text in stop_chars_set:
contents.append(new_text)
# logger.debug(f"{new_text}")
new_text = ''
# logger.debug(f"{input_txt[-1]} {input_txt[-1] not in stop_chars_set} {new_text}")
if input_txt[-1] not in stop_chars_set:
contents.append(new_text)
text = ""
text_length = 0
for idx, content in enumerate(contents):
text += content
text_length += len(content)
if text_length >= limit_length:
outputs.append(text)
text = ""
text_length = 0
if text_length < limit_length:
outputs.append(text)
return outputs
def get_text_limit_length(input_txt: str, max_length: int = 2048) -> List[str]:
lines = get_text_lines(input_txt)
output: List[str] = []
for line in lines:
if len(line) <= max_length:
output.append(line)
else:
text_lines = split_in_line(line, max_length)
logger.debug(f"split in line: {len(text_lines)}")
# logger.debug(f"{line} ==> {text_lines}")
output.extend(text_lines)
return output
def split_input_text(input_txt, strip_input_lines=0, max_length=2048):
if strip_input_lines > 0:
pattern = r'[\r\n]{' + str(strip_input_lines) + r',}'
re.compile(pattern=pattern)
logger.debug(f"strip input txt: {pattern}")
input_txt = re.sub(pattern, '', input_txt)
lines = get_text_limit_length(input_txt, max_length)
logger.debug(f"split input txt: {len(lines)}")
return "\n\n\n".join(lines)
def gen_keyword_summary(input_txt, keyword_prompt, summary_prompt, max_length=2048):
lines = input_txt.split("\n\n\n")
keywords_output = []
for line in lines:
keywords = models.llm_model.generate_answer(
keyword_prompt,
line,
history=None,
max_length=max_length,
prompt_template=PROMPT_TEMPLATE
)[0]
logger.debug(f"text len: {len(line)} ==> {keywords}")
keywords_output.extend(keywords.split())
keywords_output = [keyword.strip() for keyword in keywords_output if keyword.strip() != ""]
keywords_output = list(set(keywords_output))
return f"保留关键词:{' '.join(keywords_output)},{summary_prompt}"
def gen_summary(input_txt, summary_prompt, max_length=2048):
lines = input_txt.split("\n\n\n")
output_summary = []
summary = ""
for idx, line in enumerate(lines):
if idx == 1:
summary = models.llm_model.generate_answer(
summary_prompt,
line,
history=None,
max_length=max_length,
prompt_template=PROMPT_TEMPLATE
)[0]
logger.debug(f"text len: {len(line)} ==> {summary}")
else:
summary = models.llm_model.generate_answer(
summary_prompt,
f"{summary}{line}",
history=None,
max_length=max_length,
prompt_template=PROMPT_TEMPLATE
)[0]
logger.debug(f"summary: {len(summary)} + text: {len(line)} ==> {summary}")
output_summary.append(summary)
return "\n\n\n".join(output_summary)
def summary_ui():
with gr.Row():
with gr.Column(scale=1):
line_max_length = gr.Slider(minimum=512, maximum=4096, step=1, value=1024, label="每行最大长度")
strip_input_lines = gr.Slider(
label="去除输入文本连续的空行(0:不除去)",
minimum=1,
maximum=10,
step=1,
value=0
)
with gr.Column(scale=4):
keyword_prompt = gr.Textbox(
lines=1,
label="抽取关键词",
value="抽取以下内容的人物和地点:",
placeholder="请输入抽取关键词的Prompt"
)
summary_prompt = gr.Textbox(
lines=2,
label="生成摘要",
value="生成以下内容的摘要:",
placeholder="请输入生成摘要的Prompt"
)
keyword_summary_prompt = gr.Textbox(lines=4, label="关键词+摘要", placeholder="请输入关键词+摘要的Prompt")
with gr.Row():
input_text = gr.Textbox(lines=20, max_lines=60, label="输入文本", placeholder="请输入文本")
split_text = gr.Textbox(lines=20, max_lines=60, label="分段文本", placeholder="请输入分段文本")
summary = gr.Textbox(lines=20, max_lines=60, label="生成摘要", placeholder="请输入生成摘要的Prompt")
with gr.Row():
btn_split = gr.Button("分段")
btn_keyword = gr.Button("提取关键词")
btn_summary = gr.Button("生成摘要")
btn_split.click(
split_input_text,
inputs=[input_text, strip_input_lines, line_max_length],
outputs=[split_text]
)
btn_summary.click(
gen_summary,
inputs=[split_text, keyword_summary_prompt, line_max_length],
outputs=[summary]
)
btn_keyword.click(
gen_keyword_summary,
inputs=[split_text, keyword_prompt, summary_prompt, line_max_length],
outputs=[keyword_summary_prompt]
)