import os import gradio as gr from datetime import datetime import pandas as pd from api import doc_generate, summary, title_generate, material_retrieval, daoci_generate with gr.Blocks() as app: with gr.Tab('文章生成'): with gr.Column() as page1: with gr.Row(): # 输入 with gr.Column(): # 时间、地点 with gr.Row(): page1_input_date = gr.Textbox(show_label=False, placeholder='时间(格式示例为:2023年09月08日)') page1_input_site = gr.Textbox(show_label=False, placeholder='地点') # 主要人物 page1_input_persons = gr.Dataframe(label='主要人物', headers=['姓名', '单位', '职务'], interactive=True, col_count=(3, 'fixed')) # 记录、草稿 page1_input_records = gr.TextArea(show_label=False, placeholder='记录&草稿(建议对记录进行编号)') # 文章类型、模型、温度、惩罚系数 with gr.Column(): with gr.Row(): page1_input_type = gr.Dropdown(label='文章类型', choices=['会议纪要', '领导讲话', '体会心得', '工作汇报'], value='会议纪要') page1_input_model = gr.Dropdown(label='模型', choices=['ErnieBot', 'ErnieBot-turbo', 'Bloomz-7B'], value='ErnieBot') with gr.Row(): page1_input_temperature = gr.Slider(label='温度', info='温度参数越大,生成文章的多样性越大', maximum=1.0, value=0.95) page1_input_penalty_score = gr.Slider(label='(重复)惩罚系数', info='惩罚系数越大,文章整体重复性越低', minimum=1.0, maximum=2.0, value=1.0) # 输出 with gr.Column(): page1_output = gr.TextArea(show_label=False, placeholder='文章内容', show_copy_button=True, interactive=False) with gr.Row(): # 按钮 page1_button_commit = gr.Button(value='提交') page1_button_clear = gr.ClearButton(value="清除", components=[ page1_input_date, page1_input_site, page1_input_persons, page1_input_records, page1_output ]) pass page1_example_persons = pd.DataFrame([['张三', '北京大学', '教授'], ['李四', '清华大学', '副教授']], columns=['姓名', '单位', '职务']) page1_example_recoreds = '1. 张教授讲解大模型时代的发展现状和应用前景\n2. 李教授讲解大模型与知识图谱融合领域的现状\n3. 张教授和李教授围绕“大模型应用”展开对话' page1_example = gr.Examples( examples=[['2023年6月18日', '北京大学', page1_example_persons, page1_example_recoreds]], inputs=[page1_input_date, page1_input_site, page1_input_persons, page1_input_records]) # 触发事件 def check_date(date): try: datetime.strptime(date, "%Y年%m月%d日") return True except ValueError: return False def page1_button_commit_click(data): date = data[page1_input_date].strip() site = data[page1_input_site] persons = data[page1_input_persons] records = data[page1_input_records] doc_type = data[page1_input_type] model = data[page1_input_model] temperature = data[page1_input_temperature] penalty_score = data[page1_input_penalty_score] if not check_date(date): raise gr.Error('时间格式错误!') success, result = doc_generate(date, site, persons, records, doc_type, model, temperature, penalty_score) if success: return result else: gr.Error(result) page1_button_commit.click(fn=page1_button_commit_click, inputs={ page1_input_date, page1_input_site, page1_input_persons, page1_input_records, page1_input_type, page1_input_model, page1_input_temperature, page1_input_penalty_score }, outputs=page1_output) def model_change(model): if model in ["ErnieBot"]: return ( gr.Slider.update(visible=True), gr.Slider.update(visible=True) ) elif model in ["ErnieBot-turbo", 'Bloomz-7B']: return ( gr.Slider.update(visible=False), gr.Slider.update(visible=False) ) page1_input_model.change(fn=model_change, inputs=page1_input_model, outputs=[page1_input_temperature, page1_input_penalty_score]) with gr.Tab('文章摘要'): with gr.Column() as page2: # I/O with gr.Row(): # 输入 with gr.Column(): page2_input_title = gr.Textbox(show_label=False, placeholder='标题(可选)') page2_input_content = gr.TextArea(show_label=False, placeholder='文章内容(必填)') page2_input_maxlen = gr.Slider(label='摘要最大长度', minimum=50, maximum=1000, step=50, value=500) # 输出 with gr.Column(): page2_output = gr.TextArea(label='摘要', placeholder='输入标题和文章内容,点击提交按钮,摘要将在此处输出。', show_copy_button=True, interactive=False) # 按钮 with gr.Row(): page2_button_commit = gr.Button(value="提交") page2_button_clear = gr.ClearButton(value="清除", components=[ page2_input_title, page2_input_content, page2_output ]) # 触发事件 def page2_button_commit_click(data): title = data[page2_input_title] content = data[page2_input_content] maxlen = data[page2_input_maxlen] if title is None or title.strip() == '' or len(title) < 200: title = None success, result = summary(content, maxlen, title) if success: return result else: raise gr.Error(result) page2_button_commit.click(fn=page2_button_commit_click, inputs={page2_input_title, page2_input_content, page2_input_maxlen}, outputs=page2_output, scroll_to_output=True, show_progress='minimal') with gr.Tab('标题生成'): with gr.Column() as page3: # I/O with gr.Row(): # 输入 with gr.Column(): page3_input_doc = gr.TextArea(show_label=False, placeholder='文章内容(必填)') # 输出 with gr.Column(): page3_output = gr.Label(label='标题') # 按钮 with gr.Row(): page3_button_commit = gr.Button(value="提交") page3_button_clear = gr.ClearButton(value="清除", components=[ page3_input_doc, page3_output ]) # 触发事件 def page3_button_commit_click(data): doc = data[page3_input_doc] success, result = title_generate(doc) if success: return result else: raise gr.Error(result) page3_button_commit.click(fn=page3_button_commit_click, inputs={page3_input_doc}, outputs=page3_output, scroll_to_output=True, show_progress='minimal') with gr.Tab('素材检索'): with gr.Column() as page4: with gr.Row(): # 输入 with gr.Column(): page4_input_doc = gr.TextArea(show_label=False, placeholder='输入一段文字,根据文字提取关键信息,并搜索相关素材') page4_input_kw_num = gr.Slider(label='关键词最大个数', minimum=1, maximum=10, step=1, value=5) page4_input_m_num = gr.Slider(label='每个关键词检索的素材数量', minimum=1, maximum=10, step=1, value=3) # 输出 with gr.Column(): page4_output_kw = gr.Textbox(label='关键词', show_copy_button=True) page4_output_m = gr.Markdown(show_label=False, scroll_to_output=True, show_progress=True, show_copy_button=True) pass with gr.Row(): # 按钮 page4_button_commit = gr.Button(value="提交") page4_button_clear = gr.ClearButton(value="清除", components=[ page4_input_doc, page4_output_kw, page4_output_m ]) pass page4_example_doc = '今年是习近平总书记提出共建“一带一路”倡议9周年。9' \ '年来,共建“一带一路”坚持共商共建共享原则,秉持开放、绿色、廉洁理念,努力实现高标准、可持续、惠民生目标,取得了实打实、沉甸甸的成就。' \ '其中,中欧班列作为共建“一带一路”的旗舰项目和标志性品牌,自开行以来得到了国际社会的广泛赞誉和积极参与,成为了广受欢迎的国际公共产品。' \ '为总结发展成效、阐释重要贡献、展望发展愿景,推进“一带一路”建设工作领导小组办公室组织编写了《中欧班列发展报告(2021)》,并于2022年8月18日正式发布。' page4_example = gr.Examples(examples=[[page4_example_doc]], inputs=[page4_input_doc]) # 触发事件 def page4_button_commit_click(data): doc = data[page4_input_doc] kw_num = data[page4_input_kw_num] m_num = data[page4_input_m_num] success, result = material_retrieval(doc, kw_num, m_num) if success: return result else: raise gr.Error(result) page4_button_commit.click(fn=page4_button_commit_click, inputs={page4_input_doc, page4_input_kw_num, page4_input_m_num}, outputs=[page4_output_kw, page4_output_m], scroll_to_output=True, show_progress='minimal') with gr.Tab('悼词生成'): with gr.Column() as page5: with gr.Row(): # 输入 with gr.Column(): # 姓名、性别 with gr.Row(): page5_input_name = gr.Textbox(label='姓名') page5_input_gender = gr.Dropdown(label='性别', choices=['男', '女'], value='男') # 年龄、政治面貌 with gr.Row(): page5_input_age = gr.Textbox(label='年龄', placeholder='输入数字') page5_input_political = gr.Dropdown(label='政治面貌', choices=['中共党员', '中共预备党员', '共青团员', '民革党员', '民盟盟员', '民建会员', '民进会员', '农工党党员', '致公党党员', '九三学社社员', '台盟盟员', '无党派人士', '群众'], value='群众') # 去世时间、去世原因 with gr.Row(): page5_input_dead_time = gr.Textbox(label='去世时间', placeholder='格式:XXXX年XX月XX日') page5_input_dead_reason = gr.Textbox(label='去世原因') # 生平经历 page5_input_experience = gr.Dataframe(label='生平经历', info='时间格式为:XXXX年XX月', headers=['起始时间', '结束时间', '所属单位', '所任职务', '工作职责或主要事迹'], interactive=True, col_count=(5, 'fixed')) # 模型、温度、惩罚系数 with gr.Column(): with gr.Row(): page5_input_model = gr.Dropdown(label='模型', choices=['ErnieBot', 'ErnieBot-turbo', 'Bloomz-7B'], value='ErnieBot') page5_input_temperature = gr.Slider(label='温度', info='温度参数越大,生成文章的多样性越大', maximum=1.0, value=0.95) page5_input_penalty_score = gr.Slider(label='(重复)惩罚系数', info='惩罚系数越大,文章整体重复性越低', minimum=1.0, maximum=2.0, value=1.0) # 输出 with gr.Column(): page5_output = gr.TextArea(show_label=False, placeholder='悼词', show_copy_button=True, interactive=False) with gr.Row(): # 按钮 page5_button_commit = gr.Button(value='提交') page5_button_clear = gr.ClearButton(value="清除", components=[ page5_input_name, page5_input_gender, page5_input_age, page5_input_dead_time, page5_input_dead_reason, page5_input_experience, page5_output ]) pass page5_example_experience = pd.DataFrame( [['2012年9月', '2017年6月', '阿里巴巴', '算法工程师', '参与研发工作'], ['2018年2月', '2023年3月', '北京大学', '辅导员', '负责学生的思想政治教育、学生日常管理、就业指导、心理健康以及学生党团建设']], columns=['起始时间', '结束时间', '所属单位', '所任职务', '工作职责或主要事迹']) page5_example = gr.Examples( examples=[['张三', '男', '36', '共青团员', '2023年4月16日', '车祸', page5_example_experience]], inputs=[page5_input_name, page5_input_gender, page5_input_age, page5_input_political, page5_input_dead_time, page5_input_dead_reason, page5_input_experience]) # 触发事件 def page5_check_date(date, format_str): try: datetime.strptime(date, format_str) return True except ValueError: return False def page5_button_commit_click(data): name = data[page5_input_name].strip() gender = data[page5_input_gender] age = data[page5_input_age].strip() political = data[page5_input_political] dead_time = data[page5_input_dead_time] dead_reason = data[page5_input_dead_reason] experience = data[page5_input_experience] model = data[page5_input_model] temperature = data[page5_input_temperature] penalty_score = data[page5_input_penalty_score] if not age.isdigit(): raise gr.Error('年龄必须是数字!') if not page5_check_date(dead_time, "%Y年%m月%d日"): raise gr.Error('去世时间格式错误!') for i in range(len(experience)): if not page5_check_date(experience['起始时间'][i], "%Y年%m月"): raise gr.Error('起始时间格式错误!') if not page5_check_date(experience['结束时间'][i], "%Y年%m月"): raise gr.Error('结束时间格式错误!') success, result = daoci_generate(name, gender, age, political, dead_time, dead_reason, experience, model, temperature, penalty_score) if success: return result else: gr.Error(result) page5_button_commit.click(fn=page5_button_commit_click, inputs={ page5_input_name, page5_input_gender, page5_input_age, page5_input_political, page5_input_dead_time, page5_input_dead_reason, page5_input_experience, page5_input_model, page5_input_temperature, page5_input_penalty_score }, outputs=page5_output) def model_change(model): if model in ["ErnieBot"]: return ( gr.Slider.update(visible=True), gr.Slider.update(visible=True) ) elif model in ["ErnieBot-turbo", 'Bloomz-7B']: return ( gr.Slider.update(visible=False), gr.Slider.update(visible=False) ) page5_input_model.change(fn=model_change, inputs=page5_input_model, outputs=[page5_input_temperature, page5_input_penalty_score]) with gr.Tab('设置') as setting: pass with gr.Tab('关于') as setting: pass if __name__ == '__main__': username = os.getenv('username') password = os.getenv('password') app.launch(auth=(username, password))