Spaces:
Runtime error
Runtime error
import markdown | |
import re | |
import os | |
import math | |
from textwrap import dedent | |
from functools import lru_cache | |
from pymdownx.superfences import fence_div_format, fence_code_format | |
from latex2mathml.converter import convert as tex2mathml | |
from shared_utils.config_loader import get_conf as get_conf | |
pj = os.path.join | |
default_user_name = 'default_user' | |
markdown_extension_configs = { | |
'mdx_math': { | |
'enable_dollar_delimiter': True, | |
'use_gitlab_delimiters': False, | |
}, | |
} | |
code_highlight_configs = { | |
"pymdownx.superfences": { | |
'css_class': 'codehilite', | |
"custom_fences": [ | |
{ | |
'name': 'mermaid', | |
'class': 'mermaid', | |
'format': fence_code_format | |
} | |
] | |
}, | |
"pymdownx.highlight": { | |
'css_class': 'codehilite', | |
'guess_lang': True, | |
# 'auto_title': True, | |
# 'linenums': True | |
} | |
} | |
def text_divide_paragraph(text): | |
""" | |
将文本按照段落分隔符分割开,生成带有段落标签的HTML代码。 | |
""" | |
pre = '<div class="markdown-body">' | |
suf = '</div>' | |
if text.startswith(pre) and text.endswith(suf): | |
return text | |
if '```' in text: | |
# careful input | |
return text | |
elif '</div>' in text: | |
# careful input | |
return text | |
else: | |
# whatever input | |
lines = text.split("\n") | |
for i, line in enumerate(lines): | |
lines[i] = lines[i].replace(" ", " ") | |
text = "</br>".join(lines) | |
return pre + text + suf | |
def tex2mathml_catch_exception(content, *args, **kwargs): | |
try: | |
content = tex2mathml(content, *args, **kwargs) | |
except: | |
content = content | |
return content | |
def replace_math_no_render(match): | |
content = match.group(1) | |
if 'mode=display' in match.group(0): | |
content = content.replace('\n', '</br>') | |
return f"<font color=\"#00FF00\">$$</font><font color=\"#FF00FF\">{content}</font><font color=\"#00FF00\">$$</font>" | |
else: | |
return f"<font color=\"#00FF00\">$</font><font color=\"#FF00FF\">{content}</font><font color=\"#00FF00\">$</font>" | |
def replace_math_render(match): | |
content = match.group(1) | |
if 'mode=display' in match.group(0): | |
if '\\begin{aligned}' in content: | |
content = content.replace('\\begin{aligned}', '\\begin{array}') | |
content = content.replace('\\end{aligned}', '\\end{array}') | |
content = content.replace('&', ' ') | |
content = tex2mathml_catch_exception(content, display="block") | |
return content | |
else: | |
return tex2mathml_catch_exception(content) | |
def markdown_bug_hunt(content): | |
""" | |
解决一个mdx_math的bug(单$包裹begin命令时多余<script>) | |
""" | |
content = content.replace('<script type="math/tex">\n<script type="math/tex; mode=display">', | |
'<script type="math/tex; mode=display">') | |
content = content.replace('</script>\n</script>', '</script>') | |
return content | |
def is_equation(txt): | |
""" | |
判定是否为公式 | 测试1 写出洛伦兹定律,使用tex格式公式 测试2 给出柯西不等式,使用latex格式 测试3 写出麦克斯韦方程组 | |
""" | |
if '```' in txt and '```reference' not in txt: return False | |
if '$' not in txt and '\\[' not in txt: return False | |
mathpatterns = { | |
r'(?<!\\|\$)(\$)([^\$]+)(\$)': {'allow_multi_lines': False}, # $...$ | |
r'(?<!\\)(\$\$)([^\$]+)(\$\$)': {'allow_multi_lines': True}, # $$...$$ | |
r'(?<!\\)(\\\[)(.+?)(\\\])': {'allow_multi_lines': False}, # \[...\] | |
# r'(?<!\\)(\\\()(.+?)(\\\))': {'allow_multi_lines': False}, # \(...\) | |
# r'(?<!\\)(\\begin{([a-z]+?\*?)})(.+?)(\\end{\2})': {'allow_multi_lines': True}, # \begin...\end | |
# r'(?<!\\)(\$`)([^`]+)(`\$)': {'allow_multi_lines': False}, # $`...`$ | |
} | |
matches = [] | |
for pattern, property in mathpatterns.items(): | |
flags = re.ASCII | re.DOTALL if property['allow_multi_lines'] else re.ASCII | |
matches.extend(re.findall(pattern, txt, flags)) | |
if len(matches) == 0: return False | |
contain_any_eq = False | |
illegal_pattern = re.compile(r'[^\x00-\x7F]|echo') | |
for match in matches: | |
if len(match) != 3: return False | |
eq_canidate = match[1] | |
if illegal_pattern.search(eq_canidate): | |
return False | |
else: | |
contain_any_eq = True | |
return contain_any_eq | |
def fix_markdown_indent(txt): | |
# fix markdown indent | |
if (' - ' not in txt) or ('. ' not in txt): | |
# do not need to fix, fast escape | |
return txt | |
# walk through the lines and fix non-standard indentation | |
lines = txt.split("\n") | |
pattern = re.compile(r'^\s+-') | |
activated = False | |
for i, line in enumerate(lines): | |
if line.startswith('- ') or line.startswith('1. '): | |
activated = True | |
if activated and pattern.match(line): | |
stripped_string = line.lstrip() | |
num_spaces = len(line) - len(stripped_string) | |
if (num_spaces % 4) == 3: | |
num_spaces_should_be = math.ceil(num_spaces / 4) * 4 | |
lines[i] = ' ' * num_spaces_should_be + stripped_string | |
return '\n'.join(lines) | |
FENCED_BLOCK_RE = re.compile( | |
dedent(r''' | |
(?P<fence>^[ \t]*(?:~{3,}|`{3,}))[ ]* # opening fence | |
((\{(?P<attrs>[^\}\n]*)\})| # (optional {attrs} or | |
(\.?(?P<lang>[\w#.+-]*)[ ]*)? # optional (.)lang | |
(hl_lines=(?P<quot>"|')(?P<hl_lines>.*?)(?P=quot)[ ]*)?) # optional hl_lines) | |
\n # newline (end of opening fence) | |
(?P<code>.*?)(?<=\n) # the code block | |
(?P=fence)[ ]*$ # closing fence | |
'''), | |
re.MULTILINE | re.DOTALL | re.VERBOSE | |
) | |
def get_line_range(re_match_obj, txt): | |
start_pos, end_pos = re_match_obj.regs[0] | |
num_newlines_before = txt[:start_pos+1].count('\n') | |
line_start = num_newlines_before | |
line_end = num_newlines_before + txt[start_pos:end_pos].count('\n')+1 | |
return line_start, line_end | |
def fix_code_segment_indent(txt): | |
lines = [] | |
change_any = False | |
txt_tmp = txt | |
while True: | |
re_match_obj = FENCED_BLOCK_RE.search(txt_tmp) | |
if not re_match_obj: break | |
if len(lines) == 0: lines = txt.split("\n") | |
# 清空 txt_tmp 对应的位置方便下次搜索 | |
start_pos, end_pos = re_match_obj.regs[0] | |
txt_tmp = txt_tmp[:start_pos] + ' '*(end_pos-start_pos) + txt_tmp[end_pos:] | |
line_start, line_end = get_line_range(re_match_obj, txt) | |
# 获取公共缩进 | |
shared_indent_cnt = 1e5 | |
for i in range(line_start, line_end): | |
stripped_string = lines[i].lstrip() | |
num_spaces = len(lines[i]) - len(stripped_string) | |
if num_spaces < shared_indent_cnt: | |
shared_indent_cnt = num_spaces | |
# 修复缩进 | |
if (shared_indent_cnt < 1e5) and (shared_indent_cnt % 4) == 3: | |
num_spaces_should_be = math.ceil(shared_indent_cnt / 4) * 4 | |
for i in range(line_start, line_end): | |
add_n = num_spaces_should_be - shared_indent_cnt | |
lines[i] = ' ' * add_n + lines[i] | |
if not change_any: # 遇到第一个 | |
change_any = True | |
if change_any: | |
return '\n'.join(lines) | |
else: | |
return txt | |
# 使用 lru缓存 加快转换速度 | |
def markdown_convertion(txt): | |
""" | |
将Markdown格式的文本转换为HTML格式。如果包含数学公式,则先将公式转换为HTML格式。 | |
""" | |
pre = '<div class="markdown-body">' | |
suf = '</div>' | |
if txt.startswith(pre) and txt.endswith(suf): | |
# print('警告,输入了已经经过转化的字符串,二次转化可能出问题') | |
return txt # 已经被转化过,不需要再次转化 | |
find_equation_pattern = r'<script type="math/tex(?:.*?)>(.*?)</script>' | |
txt = fix_markdown_indent(txt) | |
# txt = fix_code_segment_indent(txt) | |
if is_equation(txt): # 有$标识的公式符号,且没有代码段```的标识 | |
# convert everything to html format | |
split = markdown.markdown(text='---') | |
convert_stage_1 = markdown.markdown(text=txt, extensions=['sane_lists', 'tables', 'mdx_math', 'pymdownx.superfences', 'pymdownx.highlight'], | |
extension_configs={**markdown_extension_configs, **code_highlight_configs}) | |
convert_stage_1 = markdown_bug_hunt(convert_stage_1) | |
# 1. convert to easy-to-copy tex (do not render math) | |
convert_stage_2_1, n = re.subn(find_equation_pattern, replace_math_no_render, convert_stage_1, flags=re.DOTALL) | |
# 2. convert to rendered equation | |
convert_stage_2_2, n = re.subn(find_equation_pattern, replace_math_render, convert_stage_1, flags=re.DOTALL) | |
# cat them together | |
return pre + convert_stage_2_1 + f'{split}' + convert_stage_2_2 + suf | |
else: | |
return pre + markdown.markdown(txt, extensions=['sane_lists', 'tables', 'pymdownx.superfences', 'pymdownx.highlight'], extension_configs=code_highlight_configs) + suf | |
def close_up_code_segment_during_stream(gpt_reply): | |
""" | |
在gpt输出代码的中途(输出了前面的```,但还没输出完后面的```),补上后面的``` | |
Args: | |
gpt_reply (str): GPT模型返回的回复字符串。 | |
Returns: | |
str: 返回一个新的字符串,将输出代码片段的“后面的```”补上。 | |
""" | |
if '```' not in gpt_reply: | |
return gpt_reply | |
if gpt_reply.endswith('```'): | |
return gpt_reply | |
# 排除了以上两个情况,我们 | |
segments = gpt_reply.split('```') | |
n_mark = len(segments) - 1 | |
if n_mark % 2 == 1: | |
return gpt_reply + '\n```' # 输出代码片段中! | |
else: | |
return gpt_reply | |
def format_io(self, y): | |
""" | |
将输入和输出解析为HTML格式。将y中最后一项的输入部分段落化,并将输出部分的Markdown和数学公式转换为HTML格式。 | |
""" | |
if y is None or y == []: | |
return [] | |
i_ask, gpt_reply = y[-1] | |
# 输入部分太自由,预处理一波 | |
if i_ask is not None: i_ask = text_divide_paragraph(i_ask) | |
# 当代码输出半截的时候,试着补上后个``` | |
if gpt_reply is not None: gpt_reply = close_up_code_segment_during_stream(gpt_reply) | |
# process | |
y[-1] = ( | |
None if i_ask is None else markdown.markdown(i_ask, extensions=['pymdownx.superfences', 'tables', 'pymdownx.highlight'], extension_configs=code_highlight_configs), | |
None if gpt_reply is None else markdown_convertion(gpt_reply) | |
) | |
return y | |