Spaces:
Running
Running
| import numpy as np | |
| import os | |
| import re | |
| import datetime | |
| import arxiv | |
| import openai, tenacity | |
| import base64, requests | |
| import argparse | |
| import configparser | |
| import fitz, io, os | |
| from PIL import Image | |
| import gradio | |
| import markdown | |
| class Paper: | |
| def __init__(self, path, title='', url='', abs='', authers=[], sl=[]): | |
| # 初始化函数,根据pdf路径初始化Paper对象 | |
| self.url = url # 文章链接 | |
| self.path = path # pdf路径 | |
| self.sl = sl | |
| self.section_names = [] # 段落标题 | |
| self.section_texts = {} # 段落内容 | |
| if title == '': | |
| self.pdf = fitz.open(self.path) # pdf文档 | |
| self.title = self.get_title() | |
| self.parse_pdf() | |
| else: | |
| self.title = title | |
| self.authers = authers | |
| self.abs = abs | |
| self.roman_num = ["I", "II", 'III', "IV", "V", "VI", "VII", "VIII", "IIX", "IX", "X"] | |
| self.digit_num = [str(d+1) for d in range(10)] | |
| self.first_image = '' | |
| def parse_pdf(self): | |
| self.pdf = fitz.open(self.path) # pdf文档 | |
| self.text_list = [page.get_text() for page in self.pdf] | |
| self.all_text = ' '.join(self.text_list) | |
| self.section_page_dict = self._get_all_page_index() # 段落与页码的对应字典 | |
| print("section_page_dict", self.section_page_dict) | |
| self.section_text_dict = self._get_all_page() # 段落与内容的对应字典 | |
| self.section_text_dict.update({"title": self.title}) | |
| self.pdf.close() | |
| def get_image_path(self, image_path=''): | |
| """ | |
| 将PDF中的第一张图保存到image.png里面,存到本地目录,返回文件名称,供gitee读取 | |
| :param filename: 图片所在路径,"C:\\Users\\Administrator\\Desktop\\nwd.pdf" | |
| :param image_path: 图片提取后的保存路径 | |
| :return: | |
| """ | |
| # open file | |
| max_size = 0 | |
| image_list = [] | |
| with fitz.Document(self.path) as my_pdf_file: | |
| # 遍历所有页面 | |
| for page_number in range(1, len(my_pdf_file) + 1): | |
| # 查看独立页面 | |
| page = my_pdf_file[page_number - 1] | |
| # 查看当前页所有图片 | |
| images = page.get_images() | |
| # 遍历当前页面所有图片 | |
| for image_number, image in enumerate(page.get_images(), start=1): | |
| # 访问图片xref | |
| xref_value = image[0] | |
| # 提取图片信息 | |
| base_image = my_pdf_file.extract_image(xref_value) | |
| # 访问图片 | |
| image_bytes = base_image["image"] | |
| # 获取图片扩展名 | |
| ext = base_image["ext"] | |
| # 加载图片 | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| image_size = image.size[0] * image.size[1] | |
| if image_size > max_size: | |
| max_size = image_size | |
| image_list.append(image) | |
| for image in image_list: | |
| image_size = image.size[0] * image.size[1] | |
| if image_size == max_size: | |
| image_name = f"image.{ext}" | |
| im_path = os.path.join(image_path, image_name) | |
| print("im_path:", im_path) | |
| max_pix = 480 | |
| origin_min_pix = min(image.size[0], image.size[1]) | |
| if image.size[0] > image.size[1]: | |
| min_pix = int(image.size[1] * (max_pix/image.size[0])) | |
| newsize = (max_pix, min_pix) | |
| else: | |
| min_pix = int(image.size[0] * (max_pix/image.size[1])) | |
| newsize = (min_pix, max_pix) | |
| image = image.resize(newsize) | |
| image.save(open(im_path, "wb")) | |
| return im_path, ext | |
| return None, None | |
| # 定义一个函数,根据字体的大小,识别每个章节名称,并返回一个列表 | |
| def get_chapter_names(self,): | |
| # # 打开一个pdf文件 | |
| doc = fitz.open(self.path) # pdf文档 | |
| text_list = [page.get_text() for page in doc] | |
| all_text = '' | |
| for text in text_list: | |
| all_text += text | |
| # # 创建一个空列表,用于存储章节名称 | |
| chapter_names = [] | |
| for line in all_text.split('\n'): | |
| line_list = line.split(' ') | |
| if '.' in line: | |
| point_split_list = line.split('.') | |
| space_split_list = line.split(' ') | |
| if 1 < len(space_split_list) < 5: | |
| if 1 < len(point_split_list) < 5 and (point_split_list[0] in self.roman_num or point_split_list[0] in self.digit_num): | |
| print("line:", line) | |
| chapter_names.append(line) | |
| return chapter_names | |
| def get_title(self): | |
| doc = self.pdf # 打开pdf文件 | |
| max_font_size = 0 # 初始化最大字体大小为0 | |
| max_string = "" # 初始化最大字体大小对应的字符串为空 | |
| max_font_sizes = [0] | |
| for page in doc: # 遍历每一页 | |
| text = page.get_text("dict") # 获取页面上的文本信息 | |
| blocks = text["blocks"] # 获取文本块列表 | |
| for block in blocks: # 遍历每个文本块 | |
| if block["type"] == 0: # 如果是文字类型 | |
| font_size = block["lines"][0]["spans"][0]["size"] # 获取第一行第一段文字的字体大小 | |
| max_font_sizes.append(font_size) | |
| if font_size > max_font_size: # 如果字体大小大于当前最大值 | |
| max_font_size = font_size # 更新最大值 | |
| max_string = block["lines"][0]["spans"][0]["text"] # 更新最大值对应的字符串 | |
| max_font_sizes.sort() | |
| print("max_font_sizes", max_font_sizes[-10:]) | |
| cur_title = '' | |
| for page in doc: # 遍历每一页 | |
| text = page.get_text("dict") # 获取页面上的文本信息 | |
| blocks = text["blocks"] # 获取文本块列表 | |
| for block in blocks: # 遍历每个文本块 | |
| if block["type"] == 0: # 如果是文字类型 | |
| cur_string = block["lines"][0]["spans"][0]["text"] # 更新最大值对应的字符串 | |
| font_flags = block["lines"][0]["spans"][0]["flags"] # 获取第一行第一段文字的字体特征 | |
| font_size = block["lines"][0]["spans"][0]["size"] # 获取第一行第一段文字的字体大小 | |
| # print(font_size) | |
| if abs(font_size - max_font_sizes[-1]) < 0.3 or abs(font_size - max_font_sizes[-2]) < 0.3: | |
| # print("The string is bold.", max_string, "font_size:", font_size, "font_flags:", font_flags) | |
| if len(cur_string) > 4 and "arXiv" not in cur_string: | |
| # print("The string is bold.", max_string, "font_size:", font_size, "font_flags:", font_flags) | |
| if cur_title == '' : | |
| cur_title += cur_string | |
| else: | |
| cur_title += ' ' + cur_string | |
| # break | |
| title = cur_title.replace('\n', ' ') | |
| return title | |
| def _get_all_page_index(self): | |
| # 定义需要寻找的章节名称列表 | |
| section_list = self.sl | |
| # 初始化一个字典来存储找到的章节和它们在文档中出现的页码 | |
| section_page_dict = {} | |
| # 遍历每一页文档 | |
| for page_index, page in enumerate(self.pdf): | |
| # 获取当前页面的文本内容 | |
| cur_text = page.get_text() | |
| # 遍历需要寻找的章节名称列表 | |
| for section_name in section_list: | |
| # 将章节名称转换成大写形式 | |
| section_name_upper = section_name.upper() | |
| # 如果当前页面包含"Abstract"这个关键词 | |
| if "Abstract" == section_name and section_name in cur_text: | |
| # 将"Abstract"和它所在的页码加入字典中 | |
| section_page_dict[section_name] = page_index | |
| # 如果当前页面包含章节名称,则将章节名称和它所在的页码加入字典中 | |
| else: | |
| if section_name + '\n' in cur_text: | |
| section_page_dict[section_name] = page_index | |
| elif section_name_upper + '\n' in cur_text: | |
| section_page_dict[section_name] = page_index | |
| # 返回所有找到的章节名称及它们在文档中出现的页码 | |
| return section_page_dict | |
| def _get_all_page(self): | |
| """ | |
| 获取PDF文件中每个页面的文本信息,并将文本信息按照章节组织成字典返回。 | |
| Returns: | |
| section_dict (dict): 每个章节的文本信息字典,key为章节名,value为章节文本。 | |
| """ | |
| text = '' | |
| text_list = [] | |
| section_dict = {} | |
| # # 先处理Abstract章节 | |
| # for page_index, page in enumerate(self.pdf): | |
| # cur_text = page.get_text() | |
| # # 如果该页面是Abstract章节所在页面 | |
| # if page_index == list(self.section_page_dict.values())[0]: | |
| # abs_str = "Abstract" | |
| # # 获取Abstract章节的起始位置 | |
| # first_index = cur_text.find(abs_str) | |
| # # 查找下一个章节的关键词,这里是Introduction | |
| # intro_str = "Introduction" | |
| # if intro_str in cur_text: | |
| # second_index = cur_text.find(intro_str) | |
| # elif intro_str.upper() in cur_text: | |
| # second_index = cur_text.find(intro_str.upper()) | |
| # # 将Abstract章节内容加入字典中 | |
| # section_dict[abs_str] = cur_text[first_index+len(abs_str)+1:second_index].replace('-\n', | |
| # '').replace('\n', ' ').split('I.')[0].split("II.")[0] | |
| # 再处理其他章节: | |
| text_list = [page.get_text() for page in self.pdf] | |
| for sec_index, sec_name in enumerate(self.section_page_dict): | |
| print(sec_index, sec_name, self.section_page_dict[sec_name]) | |
| if sec_index <= 0: | |
| continue | |
| else: | |
| # 直接考虑后面的内容: | |
| start_page = self.section_page_dict[sec_name] | |
| if sec_index < len(list(self.section_page_dict.keys()))-1: | |
| end_page = self.section_page_dict[list(self.section_page_dict.keys())[sec_index+1]] | |
| else: | |
| end_page = len(text_list) | |
| print("start_page, end_page:", start_page, end_page) | |
| cur_sec_text = '' | |
| if end_page - start_page == 0: | |
| if sec_index < len(list(self.section_page_dict.keys()))-1: | |
| next_sec = list(self.section_page_dict.keys())[sec_index+1] | |
| if text_list[start_page].find(sec_name) == -1: | |
| start_i = text_list[start_page].find(sec_name.upper()) | |
| else: | |
| start_i = text_list[start_page].find(sec_name) | |
| if text_list[start_page].find(next_sec) == -1: | |
| end_i = text_list[start_page].find(next_sec.upper()) | |
| else: | |
| end_i = text_list[start_page].find(next_sec) | |
| cur_sec_text += text_list[start_page][start_i:end_i] | |
| else: | |
| for page_i in range(start_page, end_page): | |
| # print("page_i:", page_i) | |
| if page_i == start_page: | |
| if text_list[start_page].find(sec_name) == -1: | |
| start_i = text_list[start_page].find(sec_name.upper()) | |
| else: | |
| start_i = text_list[start_page].find(sec_name) | |
| cur_sec_text += text_list[page_i][start_i:] | |
| elif page_i < end_page: | |
| cur_sec_text += text_list[page_i] | |
| elif page_i == end_page: | |
| if sec_index < len(list(self.section_page_dict.keys()))-1: | |
| next_sec = list(self.section_page_dict.keys())[sec_index+1] | |
| if text_list[start_page].find(next_sec) == -1: | |
| end_i = text_list[start_page].find(next_sec.upper()) | |
| else: | |
| end_i = text_list[start_page].find(next_sec) | |
| cur_sec_text += text_list[page_i][:end_i] | |
| section_dict[sec_name] = cur_sec_text.replace('-\n', '').replace('\n', ' ') | |
| return section_dict | |
| # 定义Reader类 | |
| class Reader: | |
| # 初始化方法,设置属性 | |
| def __init__(self, key_word='', query='', filter_keys='', | |
| root_path='./', | |
| gitee_key='', | |
| sort=arxiv.SortCriterion.SubmittedDate, user_name='defualt', language='cn', key=''): | |
| self.key = str(key) # OpenAI key | |
| self.user_name = user_name # 读者姓名 | |
| self.key_word = key_word # 读者感兴趣的关键词 | |
| self.query = query # 读者输入的搜索查询 | |
| self.sort = sort # 读者选择的排序方式 | |
| self.language = language # 读者选择的语言 | |
| self.filter_keys = filter_keys # 用于在摘要中筛选的关键词 | |
| self.root_path = root_path | |
| self.file_format = 'md' # or 'txt',如果为图片,则必须为'md' | |
| self.save_image = False | |
| if self.save_image: | |
| self.gitee_key = self.config.get('Gitee', 'api') | |
| else: | |
| self.gitee_key = '' | |
| def get_arxiv(self, max_results=30): | |
| search = arxiv.Search(query=self.query, | |
| max_results=max_results, | |
| sort_by=self.sort, | |
| sort_order=arxiv.SortOrder.Descending, | |
| ) | |
| return search | |
| def filter_arxiv(self, max_results=30): | |
| search = self.get_arxiv(max_results=max_results) | |
| print("all search:") | |
| for index, result in enumerate(search.results()): | |
| print(index, result.title, result.updated) | |
| filter_results = [] | |
| filter_keys = self.filter_keys | |
| print("filter_keys:", self.filter_keys) | |
| # 确保每个关键词都能在摘要中找到,才算是目标论文 | |
| for index, result in enumerate(search.results()): | |
| abs_text = result.summary.replace('-\n', '-').replace('\n', ' ') | |
| meet_num = 0 | |
| for f_key in filter_keys.split(" "): | |
| if f_key.lower() in abs_text.lower(): | |
| meet_num += 1 | |
| if meet_num == len(filter_keys.split(" ")): | |
| filter_results.append(result) | |
| # break | |
| print("filter_results:", len(filter_results)) | |
| print("filter_papers:") | |
| for index, result in enumerate(filter_results): | |
| print(index, result.title, result.updated) | |
| return filter_results | |
| def validateTitle(self, title): | |
| # 将论文的乱七八糟的路径格式修正 | |
| rstr = r"[\/\\\:\*\?\"\<\>\|]" # '/ \ : * ? " < > |' | |
| new_title = re.sub(rstr, "_", title) # 替换为下划线 | |
| return new_title | |
| def download_pdf(self, filter_results): | |
| # 先创建文件夹 | |
| date_str = str(datetime.datetime.now())[:13].replace(' ', '-') | |
| key_word = str(self.key_word.replace(':', ' ')) | |
| path = self.root_path + 'pdf_files/' + self.query.replace('au: ', '').replace('title: ', '').replace('ti: ', '').replace(':', ' ')[:25] + '-' + date_str | |
| try: | |
| os.makedirs(path) | |
| except: | |
| pass | |
| print("All_paper:", len(filter_results)) | |
| # 开始下载: | |
| paper_list = [] | |
| for r_index, result in enumerate(filter_results): | |
| try: | |
| title_str = self.validateTitle(result.title) | |
| pdf_name = title_str+'.pdf' | |
| # result.download_pdf(path, filename=pdf_name) | |
| self.try_download_pdf(result, path, pdf_name) | |
| paper_path = os.path.join(path, pdf_name) | |
| print("paper_path:", paper_path) | |
| paper = Paper(path=paper_path, | |
| url=result.entry_id, | |
| title=result.title, | |
| abs=result.summary.replace('-\n', '-').replace('\n', ' '), | |
| authers=[str(aut) for aut in result.authors], | |
| ) | |
| # 下载完毕,开始解析: | |
| paper.parse_pdf() | |
| paper_list.append(paper) | |
| except Exception as e: | |
| print("download_error:", e) | |
| pass | |
| return paper_list | |
| def try_download_pdf(self, result, path, pdf_name): | |
| result.download_pdf(path, filename=pdf_name) | |
| def upload_gitee(self, image_path, image_name='', ext='png'): | |
| """ | |
| 上传到码云 | |
| :return: | |
| """ | |
| with open(image_path, 'rb') as f: | |
| base64_data = base64.b64encode(f.read()) | |
| base64_content = base64_data.decode() | |
| date_str = str(datetime.datetime.now())[:19].replace(':', '-').replace(' ', '-') + '.' + ext | |
| path = image_name+ '-' +date_str | |
| payload = { | |
| "access_token": self.gitee_key, | |
| "owner": self.config.get('Gitee', 'owner'), | |
| "repo": self.config.get('Gitee', 'repo'), | |
| "path": self.config.get('Gitee', 'path'), | |
| "content": base64_content, | |
| "message": "upload image" | |
| } | |
| # 这里需要修改成你的gitee的账户和仓库名,以及文件夹的名字: | |
| url = f'https://gitee.com/api/v5/repos/'+self.config.get('Gitee', 'owner')+'/'+self.config.get('Gitee', 'repo')+'/contents/'+self.config.get('Gitee', 'path')+'/'+path | |
| rep = requests.post(url, json=payload).json() | |
| print("rep:", rep) | |
| if 'content' in rep.keys(): | |
| image_url = rep['content']['download_url'] | |
| else: | |
| image_url = r"https://gitee.com/api/v5/repos/"+self.config.get('Gitee', 'owner')+'/'+self.config.get('Gitee', 'repo')+'/contents/'+self.config.get('Gitee', 'path')+'/' + path | |
| return image_url | |
| def summary_with_chat(self, paper_list, key): | |
| htmls = [] | |
| for paper_index, paper in enumerate(paper_list): | |
| # 第一步先用title,abs,和introduction进行总结。 | |
| text = '' | |
| text += 'Title:' + paper.title | |
| text += 'Url:' + paper.url | |
| text += 'Abstrat:' + paper.abs | |
| # intro | |
| text += list(paper.section_text_dict.values())[0] | |
| max_token = 2500 * 4 | |
| text = text[:max_token] | |
| chat_summary_text = self.chat_summary(text=text, key=str(key)) | |
| htmls.append(chat_summary_text) | |
| # TODO 往md文档中插入论文里的像素最大的一张图片,这个方案可以弄的更加智能一些: | |
| first_image, ext = paper.get_image_path() | |
| if first_image is None or self.gitee_key == '': | |
| pass | |
| else: | |
| image_title = self.validateTitle(paper.title) | |
| image_url = self.upload_gitee(image_path=first_image, image_name=image_title, ext=ext) | |
| htmls.append("\n") | |
| htmls.append("") | |
| htmls.append("\n") | |
| # 第二步总结方法: | |
| # TODO,由于有些文章的方法章节名是算法名,所以简单的通过关键词来筛选,很难获取,后面需要用其他的方案去优化。 | |
| method_key = '' | |
| for parse_key in paper.section_text_dict.keys(): | |
| if 'method' in parse_key.lower() or 'approach' in parse_key.lower(): | |
| method_key = parse_key | |
| break | |
| if method_key != '': | |
| text = '' | |
| method_text = '' | |
| summary_text = '' | |
| summary_text += "<summary>" + chat_summary_text | |
| # methods | |
| method_text += paper.section_text_dict[method_key] | |
| # TODO 把这个变成tenacity的自动判别! | |
| max_token = 2500 * 4 | |
| text = summary_text + "\n <Methods>:\n" + method_text | |
| text = text[:max_token] | |
| chat_method_text = self.chat_method(text=text, key=str(key)) | |
| htmls.append(chat_method_text) | |
| else: | |
| chat_method_text = '' | |
| htmls.append("\n") | |
| # 第三步总结全文,并打分: | |
| conclusion_key = '' | |
| for parse_key in paper.section_text_dict.keys(): | |
| if 'conclu' in parse_key.lower(): | |
| conclusion_key = parse_key | |
| break | |
| text = '' | |
| conclusion_text = '' | |
| summary_text = '' | |
| summary_text += "<summary>" + chat_summary_text + "\n <Method summary>:\n" + chat_method_text | |
| if conclusion_key != '': | |
| # conclusion | |
| conclusion_text += paper.section_text_dict[conclusion_key] | |
| max_token = 2500 * 4 | |
| text = summary_text + "\n <Conclusion>:\n" + conclusion_text | |
| else: | |
| text = summary_text | |
| text = text[:max_token] | |
| chat_conclusion_text = self.chat_conclusion(text=text, key=str(key)) | |
| htmls.append(chat_conclusion_text) | |
| htmls.append("\n") | |
| md_text = "\n".join(htmls) | |
| return markdown.markdown(md_text) | |
| def chat_conclusion(self, text, key): | |
| openai.api_key = key | |
| response = openai.ChatCompletion.create( | |
| model="gpt-3.5-turbo", | |
| # prompt需要用英语替换,少占用token。 | |
| messages=[ | |
| {"role": "system", "content": "你是一个["+self.key_word+"]领域的审稿人,你需要严格评审这篇文章"}, # chatgpt 角色 | |
| {"role": "assistant", "content": "这是一篇英文文献的<summary>和<conclusion>部分内容,其中<summary>你已经总结好了,但是<conclusion>部分,我需要你帮忙归纳下面问题:"+text}, # 背景知识,可以参考OpenReview的审稿流程 | |
| {"role": "user", "content": """ | |
| 8. 做出如下总结: | |
| - (1):这篇工作的意义如何? | |
| - (2):从创新点、性能、工作量这三个维度,总结这篇文章的优点和缺点。 | |
| ....... | |
| 按照后面的格式输出: | |
| 8. Conclusion: | |
| - (1):xxx; | |
| - (2):创新点: xxx; 性能: xxx; 工作量: xxx; | |
| 务必使用中文回答(专有名词需要用英文标注),语句尽量简洁且学术,不要和之前的<summary>内容重复,数值使用原文数字, 务必严格按照格式,将对应内容输出到xxx中,.......代表按照实际需求填写,如果没有可以不用写. | |
| """}, | |
| ] | |
| ) | |
| result = '' | |
| for choice in response.choices: | |
| result += choice.message.content | |
| print("conclusion_result:\n", result) | |
| return result | |
| def chat_method(self, text, key): | |
| openai.api_key = key | |
| response = openai.ChatCompletion.create( | |
| model="gpt-3.5-turbo", | |
| messages=[ | |
| {"role": "system", "content": "你是一个["+self.key_word+"]领域的科研人员,善于使用精炼的语句总结论文"}, # chatgpt 角色 | |
| {"role": "assistant", "content": "这是一篇英文文献的<summary>和<Method>部分内容,其中<summary>你已经总结好了,但是<Methods>部分,我需要你帮忙阅读并归纳下面问题:"+text}, # 背景知识 | |
| {"role": "user", "content": """ | |
| 7. 详细描述这篇文章的方法思路。比如说它的步骤是: | |
| - (1):... | |
| - (2):... | |
| - (3):... | |
| - ....... | |
| 按照后面的格式输出: | |
| 7. Methods: | |
| - (1):xxx; | |
| - (2):xxx; | |
| - (3):xxx; | |
| ....... | |
| 务必使用中文回答(专有名词需要用英文标注),语句尽量简洁且学术,不要和之前的<summary>内容重复,数值使用原文数字, 务必严格按照格式,将对应内容输出到xxx中,按照\n换行,.......代表按照实际需求填写,如果没有可以不用写. | |
| """}, | |
| ] | |
| ) | |
| result = '' | |
| for choice in response.choices: | |
| result += choice.message.content | |
| print("method_result:\n", result) | |
| return result | |
| def chat_summary(self, text, key): | |
| openai.api_key = key | |
| response = openai.ChatCompletion.create( | |
| model="gpt-3.5-turbo", | |
| messages=[ | |
| {"role": "system", "content": "你是一个["+self.key_word+"]领域的科研人员,善于使用精炼的语句总结论文"}, # chatgpt 角色 | |
| {"role": "assistant", "content": "这是一篇英文文献的标题,作者,链接,Abstract和Introduction部分内容,我需要你帮忙阅读并归纳下面问题:"+text}, # 背景知识 | |
| {"role": "user", "content": """ | |
| 1. 标记出这篇文献的标题(加上中文翻译) | |
| 2. 列举所有的作者姓名 (使用英文) | |
| 3. 标记第一作者的单位(只输出中文翻译) | |
| 4. 标记出这篇文章的关键词(使用英文) | |
| 5. 论文链接,Github代码链接(如果有的话,没有的话请填写Github:None) | |
| 6. 按照下面四个点进行总结: | |
| - (1):这篇文章的研究背景是什么? | |
| - (2):过去的方法有哪些?它们存在什么问题?本文和过去的研究有哪些本质的区别?Is the approach well motivated? | |
| - (3):本文提出的研究方法是什么? | |
| - (4):本文方法在什么任务上,取得了什么性能?性能能否支持他们的目标? | |
| 按照后面的格式输出: | |
| 1. Title: xxx | |
| 2. Authors: xxx | |
| 3. Affiliation: xxx | |
| 4. Keywords: xxx | |
| 5. Urls: xxx or xxx , xxx | |
| 6. Summary: | |
| - (1):xxx; | |
| - (2):xxx; | |
| - (3):xxx; | |
| - (4):xxx. | |
| 务必使用中文回答(专有名词需要用英文标注),语句尽量简洁且学术,不要有太多重复的信息,数值使用原文数字, 务必严格按照格式,将对应内容输出到xxx中,按照\n换行. | |
| """}, | |
| ] | |
| ) | |
| result = '' | |
| for choice in response.choices: | |
| result += choice.message.content | |
| print("summary_result:\n", result) | |
| return result | |
| def export_to_markdown(self, text, file_name, mode='w'): | |
| # 使用markdown模块的convert方法,将文本转换为html格式 | |
| # html = markdown.markdown(text) | |
| # 打开一个文件,以写入模式 | |
| with open(file_name, mode, encoding="utf-8") as f: | |
| # 将html格式的内容写入文件 | |
| f.write(text) | |
| # 定义一个方法,打印出读者信息 | |
| def show_info(self): | |
| print(f"Key word: {self.key_word}") | |
| print(f"Query: {self.query}") | |
| print(f"Sort: {self.sort}") | |
| def upload_pdf(key, text, file): | |
| # 检查两个输入都不为空 | |
| if not key or not text or not file: | |
| return "两个输入都不能为空,请输入字符并上传 PDF 文件!" | |
| # 判断PDF文件 | |
| if file and file.name.split(".")[-1].lower() != "pdf": | |
| return '请勿上传非 PDF 文件!' | |
| else: | |
| section_list = text.split(',') | |
| paper_list = [Paper(path=file, sl=section_list)] | |
| # 创建一个Reader对象 | |
| reader = Reader() | |
| sum_info = reader.summary_with_chat(paper_list=paper_list, key=key) | |
| return sum_info | |
| # 标题 | |
| title = "ChatPaper" | |
| # 描述 | |
| description = '''<div align='center'> | |
| Use ChatGPT to summary the papers. | |
| Star our Github [ChatPaper](https://github.com/kaixindelele/ChatPaper) | |
| </div> | |
| ''' | |
| # 创建Gradio界面 | |
| ip = [ | |
| gradio.inputs.Textbox(label="请输入你的API-key(必填)", default="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"), | |
| gradio.inputs.Textbox(label="请输入论文大标题索引(用英文逗号隔开,必填)", default="'Abstract,Introduction,Related Work,Background,Preliminary,Problem Formulation,Methods,Methodology,Method,Approach,Approaches,Materials and Methods,Experiment Settings,Experiment,Experimental Results,Evaluation,Experiments,Results,Findings,Data Analysis,Discussion,Results and Discussion,Conclusion,References'"), | |
| gradio.inputs.File(label="请上传论文PDF(必填)") | |
| ] | |
| interface = gradio.Interface(fn=upload_pdf, inputs=ip, outputs="html", title=title, description=description) | |
| # 运行Gradio应用程序 | |
| interface.launch() |