yiyixin commited on
Commit
28c2a3d
1 Parent(s): 060c872
Files changed (17) hide show
  1. Dockerfile +8 -3
  2. LICENSE +25 -0
  3. README.md +52 -11
  4. app.py +0 -10
  5. backend.py +81 -0
  6. base_class.py +106 -0
  7. chat_pdf.py +83 -0
  8. chatbot.py +320 -0
  9. config.py +15 -0
  10. embedding_model.py +47 -0
  11. frontend.py +87 -0
  12. pdf_parser.py +148 -0
  13. requirements.txt +7 -1
  14. scipdf_utils.py +424 -0
  15. serve_grobid.sh +1 -1
  16. similarity_metric.py +13 -0
  17. utils.py +24 -0
Dockerfile CHANGED
@@ -14,13 +14,18 @@ RUN pip install --no-cache-dir -r requirements.txt
14
  # Copy the current directory contents into the container at /app
15
  COPY . /app
16
 
17
- # Expose port 7860 for the Flask app to listen on
18
  EXPOSE 7860
19
 
 
20
  RUN apt-get update && \
21
  apt-get install -y wget unzip
22
  RUN chmod +x serve_grobid.sh
23
  RUN ./serve_grobid.sh
24
 
25
- # Run the command to start the Flask app
26
- CMD ["python", "app.py"]
 
 
 
 
 
14
  # Copy the current directory contents into the container at /app
15
  COPY . /app
16
 
17
+ # Expose port 7860 for the streamlit app to listen on
18
  EXPOSE 7860
19
 
20
+ # open the grobid server
21
  RUN apt-get update && \
22
  apt-get install -y wget unzip
23
  RUN chmod +x serve_grobid.sh
24
  RUN ./serve_grobid.sh
25
 
26
+ # open the backend server
27
+ RUN python backend.py &
28
+
29
+ # Run the command to start the streamlit app
30
+ CMD ["streamlit", "run", "frontend.py", "--server.address", "0.0.0.0", "--server.port", "7860"]
31
+
LICENSE ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 2-Clause License
2
+
3
+ Copyright (c) 2023, Yixin Liu
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ * Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ * Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
README.md CHANGED
@@ -1,11 +1,52 @@
1
- ---
2
- title: ChatPaper
3
- emoji: 📊
4
- colorFrom: purple
5
- colorTo: blue
6
- sdk: docker
7
- pinned: false
8
- license: bsd-2-clause
9
- ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ChatPaper
2
+ Yet another paper reading assistant, similar as [ChatPDF](https://www.chatpdf.com/).
3
+
4
+ ## Setup
5
+
6
+ 1. Install dependencies (tested on Python 3.9)
7
+
8
+ ```bash
9
+ pip install -r requirements.txt
10
+ ```
11
+
12
+ 2. Setup GROBID local server
13
+
14
+ ```bash
15
+ bash serve_grobid.sh
16
+ ```
17
+
18
+ 3. Setup backend
19
+
20
+ ```bash
21
+ python backend.py --port 5000 --host localhost
22
+ ```
23
+
24
+ 4. Frontend
25
+
26
+ ```bash
27
+ streamlit run frontend.py --server.port 8502 --server.host localhost
28
+ ```
29
+
30
+ ## Demo Example
31
+
32
+ - Prepare an [OpenAI API key](https://platform.openai.com/account/api-keys) and then upload a PDF to start chatting with the paper.
33
+
34
+ ![image-20230318232056584](https://s2.loli.net/2023/03/19/SbsuLQJpdqePoZV.png)
35
+
36
+ ## Implementation Details
37
+
38
+ - Greedy Dynamic Context: Since the max token limit, we select the most relevant paragraphs in the pdf for each user query. Our model split the text input and output by the chatbot into four part: system_prompt (S), dynamic_source (D), user_query (Q), and model_answer(A). So upon each query, we first rank all the paragraphs by using a sentence_embedding model to calculate the similarity distance between the query embedding and all source embeddings. Then we compose the dynamic_source using a greedy method by to gradually push all relevant paragraphs (maintaing D <= MAX_TOKEN_LIMIT - Q - S - A - SOME_OVERHEAD).
39
+
40
+ - Context Truncating: When context is too long, we now we simply pop out the first QA-pair.
41
+
42
+ ## TODO
43
+
44
+ - [ ] **Context Condense**: how to deal with long context? maybe we can tune a soft prompt to condense the context
45
+ - [ ] **Poping context out based on similarity**
46
+
47
+ ## References
48
+
49
+ 1. SciPDF Parser: https://github.com/titipata/scipdf_parser
50
+ 2. St-chat: https://github.com/AI-Yash/st-chat
51
+ 3. Sentence-transformers: https://github.com/UKPLab/sentence-transformers
52
+ 4. ChatGPT Chatbot Wrapper: https://github.com/acheong08/ChatGPT
app.py DELETED
@@ -1,10 +0,0 @@
1
- from flask import Flask
2
-
3
- app = Flask(__name__)
4
-
5
- @app.route('/')
6
- def hello_world():
7
- return 'Hello, World!'
8
-
9
- if __name__ == '__main__':
10
- app.run(debug=True, host='0.0.0.0', port=7860)
 
 
 
 
 
 
 
 
 
 
 
backend.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import jsonify, Flask, request
2
+ from embedding_model import HuggingfaceSentenceTransformerModel
3
+ from similarity_metric import CosineSimilarity
4
+ from pdf_parser import GrobidSciPDFPaser
5
+ from chatbot import OpenAIChatbot
6
+ from chat_pdf import ChatPDF
7
+ from config import DEFAULT_ENGINE, MAX_TOKEN_MODEL_MAP, DEFAULT_TEMPERATURE, DEFAULT_TOP_P, DEFAULT_PRESENCE_PENALTY, DEFAULT_FREQUENCY_PENALTY, DEFAULT_REPLY_COUNT
8
+ app = Flask(__name__)
9
+ chatpdf_pool = {}
10
+
11
+ embedding_model = HuggingfaceSentenceTransformerModel()
12
+ simi_metric = CosineSimilarity()
13
+
14
+
15
+ @app.route("/query/", methods=['POST', 'GET'])
16
+ def query():
17
+ api_key = request.headers.get('Api-Key')
18
+ pdf_link = request.json['pdf_link']
19
+ user_stamp = request.json['user_stamp']
20
+ user_query = request.json['user_query']
21
+ print(
22
+ "api_key", api_key,
23
+ "pdf_link", pdf_link,
24
+ "user_stamp", user_stamp,
25
+ "user_query", user_query
26
+ )
27
+
28
+ chat_pdf = None
29
+ if user_stamp not in chatpdf_pool:
30
+ print(f"User {user_stamp} not in pool, creating new chatpdf")
31
+ # Initialize the ChatPDF
32
+ bot = OpenAIChatbot(
33
+ api_key=api_key,
34
+ engine=DEFAULT_ENGINE,
35
+ proxy=None,
36
+ max_tokens=4000,
37
+ temperature=DEFAULT_TEMPERATURE,
38
+ top_p=DEFAULT_TOP_P,
39
+ presence_penalty=DEFAULT_PRESENCE_PENALTY,
40
+ frequency_penalty=DEFAULT_FREQUENCY_PENALTY,
41
+ reply_count=DEFAULT_REPLY_COUNT
42
+ )
43
+
44
+ pdf = GrobidSciPDFPaser(
45
+ pdf_link=pdf_link
46
+ )
47
+ chat_pdf = ChatPDF(
48
+ pdf=pdf,
49
+ bot=bot,
50
+ embedding_model=embedding_model,
51
+ similarity_metric=simi_metric,
52
+ user_stamp=user_stamp
53
+ )
54
+ chatpdf_pool[user_stamp] = chat_pdf
55
+ else:
56
+ print("user_stamp", user_stamp, "already exists")
57
+ chat_pdf = chatpdf_pool[user_stamp]
58
+
59
+ try:
60
+ response = chat_pdf.chat(user_query)
61
+ code = 200
62
+ json_dict = {
63
+ "code": code,
64
+ "response": response
65
+ }
66
+ except Exception as e:
67
+ code = 500
68
+ json_dict = {
69
+ "code": code,
70
+ "response": str(e)
71
+ }
72
+ return jsonify(json_dict)
73
+
74
+
75
+ # @app.route("/", methods=['GET'])
76
+ # def index():
77
+ # return "Hello World!"
78
+
79
+
80
+ if __name__ == '__main__':
81
+ app.run(host='localhost', port=5000, debug=False)
base_class.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import pandas as pd
3
+ import pickle
4
+
5
+
6
+ class SimilarityAlg(metaclass=abc.ABCMeta):
7
+ """Similarity Algorithm to compute similarity between query_embedding and embeddings"""
8
+
9
+ def __init__(self) -> None:
10
+ pass
11
+
12
+ @abc.abstractmethod
13
+ def __call__(self, query_embedding, embeddings) -> None:
14
+ pass
15
+
16
+
17
+ class Embedding_Model(metaclass=abc.ABCMeta):
18
+ """Embedding Model to compute embedding of a text"""
19
+
20
+ def __init__(self, model_name) -> None:
21
+ """Initialize the embedding model"""
22
+ embedding_cache_path = f"/data/ckpt/embedding_cache_{model_name}.pkl"
23
+ self.embedding_cache_path = embedding_cache_path
24
+
25
+ # load the cache if it exists, and save a copy to disk
26
+ try:
27
+ embedding_cache = pd.read_pickle(embedding_cache_path)
28
+ except FileNotFoundError:
29
+ embedding_cache = {}
30
+ with open(embedding_cache_path, "wb") as embedding_cache_file:
31
+ pickle.dump(embedding_cache, embedding_cache_file)
32
+ self.embedding_cache = embedding_cache
33
+ self.model_name = model_name
34
+
35
+ @abc.abstractmethod
36
+ def __call__(self, text) -> None:
37
+ """Compute the embedding of the text"""
38
+ pass
39
+
40
+
41
+ class AbstractPDFParser(metaclass=abc.ABCMeta):
42
+ """ PDF parser to parse a PDF file"""
43
+
44
+ def __init__(self, db_name) -> None:
45
+ """Initialize the pdf database"""
46
+ db_cache_path = f"/data/ckpt/pdf_parser_{db_name}.pkl"
47
+ self.db_cache_path = db_cache_path
48
+
49
+ # load the cache if it exists, and save a copy to disk
50
+ try:
51
+ db_cache = pd.read_pickle(db_cache_path)
52
+ except FileNotFoundError:
53
+ db_cache = {}
54
+ with open(db_cache_path, "wb") as cache_file:
55
+ pickle.dump(db_cache, cache_file)
56
+ self.db_cache = db_cache
57
+ self.db_name = db_name
58
+
59
+ @abc.abstractmethod
60
+ def parse_pdf(self,) -> None:
61
+ """Parse the PDF file"""
62
+ pass
63
+
64
+ @abc.abstractmethod
65
+ def _get_metadata(self, ) -> None:
66
+ """Get the metadata of the PDF file"""
67
+ pass
68
+
69
+ def get_paragraphs(self, ) -> None:
70
+ """Get the paragraphs of the PDF file"""
71
+ pass
72
+
73
+ @abc.abstractmethod
74
+ def get_split_paragraphs(self, ) -> None:
75
+ """
76
+ Get the split paragraphs of the PDF file
77
+ Return:
78
+ split_paragraphs: dict of metadata and corresponding list of split paragraphs
79
+ """
80
+ pass
81
+
82
+ def _determine_metadata_of_paragraph(self, paragraph) -> None:
83
+ """
84
+ Determine the metadata of a paragraph
85
+ Return:
86
+ metadata: metadata of the paragraph
87
+ """
88
+ pass
89
+
90
+ @abc.abstractmethod
91
+ def _determine_optimal_split_of_pargraphs(self, ) -> None:
92
+ """
93
+ Determine the optimal split of paragraphs
94
+ Return:
95
+ split_paragraphs: dict of metadata and corresponding list of split paragraphs
96
+ """
97
+ pass
98
+
99
+
100
+ class ChatbotEngine(metaclass=abc.ABCMeta):
101
+ def __init__(self,) -> None:
102
+ pass
103
+
104
+ @abc.abstractmethod
105
+ def query(self, user_query):
106
+ pass
chat_pdf.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class ChatPDF():
3
+ """ChatPDF enables us to chat with a PDF file
4
+ """
5
+
6
+ def __init__(self, pdf, bot, embedding_model, similarity_metric, expect_answer_token_length_max=100, expect_q_token_length_max=100, user_stamp=None) -> None:
7
+ self.pdf = pdf
8
+ self.bot = bot
9
+ self.embedding_model = embedding_model
10
+ self.similarity_metric = similarity_metric
11
+ self.user_stamp = user_stamp
12
+
13
+ self.system_task_prompt = f"You are a helpful PDF file. Your task is to provide information and answer any questions related to the topic of {self.pdf.metadata['title']}. You should use the sections of the PDF as your source of information and try to provide concise and accurate answers to any questions asked by the user. If you are unable to find relevant information in the given sections, you will need to let the user know that the source does not contain relevant information but still try to provide an answer based on your general knowledge. You must refer to the corresponding section name and page that you refer to when answering. The following is the related information about the PDF file that will help you answer users' questions:\n\n"
14
+ self.system_information_prompt = "Title:\n" + self.pdf.metadata['title'] + "\n\nAbstract:\n" + self.pdf.metadata["abstract"] + \
15
+ "\n\nFiltered paragraphs from each sections (the section titles are enclosed in asterisks**):\n\n"
16
+
17
+ self.system_token_length = self.bot.encode_length(
18
+ self.system_task_prompt) + self.bot.encode_length(self.system_information_prompt)
19
+ self.expect_answer_token_length_max = expect_answer_token_length_max
20
+ self.expect_q_token_length_max = expect_q_token_length_max
21
+ self.context_max_length = self.bot.max_tokens - self.system_token_length - \
22
+ self.bot.overhead_token - self.expect_answer_token_length_max - \
23
+ self.expect_q_token_length_max
24
+
25
+ def _get_related_context(self, user_query):
26
+ all_contextes = [user_query]+self.pdf.flattn_paragraphs
27
+ rank_indices = self.rank_indices(all_contextes, 0,)
28
+ rank_indices = list(rank_indices)
29
+ rank_indices.remove(0)
30
+ inital_context = ":\n\n".join(self.pdf.section_names_with_page_index)
31
+ context_dict = {section_name: []
32
+ for section_name in self.pdf.section_names}
33
+ inital_context_token_length = self.bot.encode_length(inital_context)
34
+ running_length = inital_context_token_length
35
+ for idx in rank_indices:
36
+ text_to_insert = all_contextes[idx]
37
+ text_to_insert_token_length = self.bot.encode_length(
38
+ text_to_insert)
39
+ if running_length + text_to_insert_token_length < self.context_max_length:
40
+ running_length += text_to_insert_token_length
41
+ section = self.pdf.content2section[text_to_insert]
42
+ context_dict[section].append(text_to_insert)
43
+ else:
44
+ break
45
+ composed_context = ""
46
+ for i, section_name in enumerate(self.pdf.section_names):
47
+ if len(context_dict[section_name]) > 0:
48
+ section_name_with_page_index = self.pdf.section_names_with_page_index[i]
49
+ composed_context += "**"+section_name_with_page_index + "**" + \
50
+ ":\n" + "\n".join(context_dict[section_name]) + "\n\n"
51
+ return composed_context
52
+
53
+ def chat(self, user_query):
54
+ """Chat with the PDF file
55
+ """
56
+ context_data = self._get_related_context(user_query)
57
+ dynamic_system_context = self.system_task_prompt + \
58
+ self.system_information_prompt + context_data
59
+ print(
60
+ "************ Start of Composed Context ************\n",
61
+ dynamic_system_context,
62
+ "\n************ End of Composed Context ************\n"
63
+ )
64
+ response = self.bot.query(
65
+ context=dynamic_system_context, questions=user_query, convo_id=self.user_stamp)
66
+ return response
67
+
68
+ def rank_indices(
69
+ self,
70
+ strings: list[str],
71
+ index_of_source_string: int,
72
+ ) -> list[int]:
73
+ """Rank the indices of the strings in the list based on their similarity to the source string."""
74
+ # get embeddings for all strings
75
+ embeddings = [self.embedding_model(string) for string in strings]
76
+ # get the embedding of the source string
77
+ query_embedding = embeddings[index_of_source_string]
78
+ # get distances between the source embedding and other embeddings (function from embeddings_utils.py)
79
+ distances = self.similarity_metric(query_embedding, embeddings)
80
+ # get rank of indices based on distances
81
+ import numpy as np
82
+ indices_of_nearest_neighbors = np.argsort(distances, )
83
+ return indices_of_nearest_neighbors
chatbot.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from base_class import ChatbotEngine
2
+ import os
3
+ import openai
4
+ import json
5
+ import os
6
+ import requests
7
+ import tiktoken
8
+ from config import MAX_TOKEN_MODEL_MAP
9
+ from utils import get_filtered_keys_from_object
10
+
11
+
12
+ class ChatbotWrapper:
13
+ """
14
+ Wrapper of Official ChatGPT API,
15
+ # base on https://github.com/ChatGPT-Hackers/revChatGPT
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ api_key: str,
21
+ engine: str = os.environ.get("GPT_ENGINE") or "gpt-3.5-turbo",
22
+ proxy: str = None,
23
+ max_tokens: int = 3000,
24
+ temperature: float = 0.5,
25
+ top_p: float = 1.0,
26
+ presence_penalty: float = 0.0,
27
+ frequency_penalty: float = 0.0,
28
+ reply_count: int = 1,
29
+ system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
30
+ overhead_token=96,
31
+ ) -> None:
32
+ """
33
+ Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys)
34
+ """
35
+ self.engine = engine
36
+ self.session = requests.Session()
37
+ self.api_key = api_key
38
+ self.system_prompt = system_prompt
39
+ self.max_tokens = max_tokens
40
+ self.temperature = temperature
41
+ self.top_p = top_p
42
+ self.presence_penalty = presence_penalty
43
+ self.frequency_penalty = frequency_penalty
44
+ self.reply_count = reply_count
45
+ self.max_limit = MAX_TOKEN_MODEL_MAP[self.engine]
46
+ self.overhead_token = overhead_token
47
+
48
+ if proxy:
49
+ self.session.proxies = {
50
+ "http": proxy,
51
+ "https": proxy,
52
+ }
53
+
54
+ self.conversation: dict = {
55
+ "default": [
56
+ {
57
+ "role": "system",
58
+ "content": system_prompt,
59
+ },
60
+ ],
61
+ }
62
+
63
+ if max_tokens > self.max_limit - self.overhead_token:
64
+ raise Exception(
65
+ f"Max tokens cannot be greater than {self.max_limit- self.overhead_token}")
66
+
67
+ if self.get_token_count("default") > self.max_tokens:
68
+ raise Exception("System prompt is too long")
69
+
70
+ def add_to_conversation(
71
+ self,
72
+ message: str,
73
+ role: str,
74
+ convo_id: str = "default",
75
+ ) -> None:
76
+ """
77
+ Add a message to the conversation
78
+ """
79
+ self.conversation[convo_id].append({"role": role, "content": message})
80
+
81
+ def __truncate_conversation(self, convo_id: str = "default") -> None:
82
+ """
83
+ Truncate the conversation
84
+ """
85
+ # TODO: context condense with soft prompt tuning
86
+ while True:
87
+ if (
88
+ self.get_token_count(convo_id) > self.max_tokens
89
+ and len(self.conversation[convo_id]) > 1
90
+ ):
91
+ # Don't remove the first message and remove the first QA pair
92
+ self.conversation[convo_id].pop(1)
93
+ self.conversation[convo_id].pop(1)
94
+ # TODO: optimal pop out based on similarity distance
95
+ else:
96
+ break
97
+
98
+ # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
99
+ def get_token_count(self, convo_id: str = "default") -> int:
100
+ """
101
+ Get token count
102
+ """
103
+ if self.engine not in ["gpt-3.5-turbo", "gpt-3.5-turbo-0301"]:
104
+ raise NotImplementedError("Unsupported engine {self.engine}")
105
+
106
+ encoding = tiktoken.encoding_for_model(self.engine)
107
+
108
+ num_tokens = 0
109
+ for message in self.conversation[convo_id]:
110
+ # every message follows <im_start>{role/name}\n{content}<im_end>\n
111
+ num_tokens += 4
112
+ for key, value in message.items():
113
+ num_tokens += len(encoding.encode(value))
114
+ if key == "name": # if there's a name, the role is omitted
115
+ num_tokens += 1 # role is always required and always 1 token
116
+ num_tokens += 2 # every reply is primed with <im_start>assistant
117
+ return num_tokens
118
+
119
+ def get_max_tokens(self, convo_id: str) -> int:
120
+ """
121
+ Get max tokens
122
+ """
123
+ return self.max_tokens - self.get_token_count(convo_id)
124
+
125
+ def ask_stream(
126
+ self,
127
+ prompt: str,
128
+ role: str = "user",
129
+ convo_id: str = "default",
130
+ dynamic_system_prompt=None,
131
+ **kwargs,
132
+ ) -> str:
133
+ """
134
+ Ask a question
135
+ """
136
+ # Make conversation if it doesn't exist
137
+ if convo_id not in self.conversation:
138
+ self.reset(convo_id=convo_id, system_prompt=dynamic_system_prompt)
139
+
140
+ # adjust system prompt
141
+ assert dynamic_system_prompt is not None
142
+ self.conversation[convo_id][0]["content"] = dynamic_system_prompt
143
+
144
+ self.add_to_conversation(prompt, "user", convo_id=convo_id)
145
+ print(" total tokens:")
146
+ print(self.get_token_count(convo_id))
147
+ self.__truncate_conversation(convo_id=convo_id)
148
+ # Get response
149
+ response = self.session.post(
150
+ os.environ.get(
151
+ "API_URL") or "https://api.openai.com/v1/chat/completions",
152
+ headers={
153
+ "Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
154
+ json={
155
+ "model": self.engine,
156
+ "messages": self.conversation[convo_id],
157
+ "stream": True,
158
+ # kwargs
159
+ "temperature": kwargs.get("temperature", self.temperature),
160
+ "top_p": kwargs.get("top_p", self.top_p),
161
+ "presence_penalty": kwargs.get(
162
+ "presence_penalty",
163
+ self.presence_penalty,
164
+ ),
165
+ "frequency_penalty": kwargs.get(
166
+ "frequency_penalty",
167
+ self.frequency_penalty,
168
+ ),
169
+ "n": kwargs.get("n", self.reply_count),
170
+ "user": role,
171
+ "max_tokens": self. get_max_tokens(convo_id=convo_id),
172
+ },
173
+ stream=True,
174
+ )
175
+ if response.status_code != 200:
176
+ raise Exception(
177
+ f"Error: {response.status_code} {response.reason} {response.text}",
178
+ )
179
+ response_role: str = None
180
+ full_response: str = ""
181
+ for line in response.iter_lines():
182
+ if not line:
183
+ continue
184
+ # Remove "data: "
185
+ line = line.decode("utf-8")[6:]
186
+ if line == "[DONE]":
187
+ break
188
+ resp: dict = json.loads(line)
189
+ choices = resp.get("choices")
190
+ if not choices:
191
+ continue
192
+ delta = choices[0].get("delta")
193
+ if not delta:
194
+ continue
195
+ if "role" in delta:
196
+ response_role = delta["role"]
197
+ if "content" in delta:
198
+ content = delta["content"]
199
+ full_response += content
200
+ yield content
201
+ self.add_to_conversation(
202
+ full_response, response_role, convo_id=convo_id)
203
+
204
+ def ask(
205
+ self,
206
+ prompt: str,
207
+ role: str = "user",
208
+ convo_id: str = "default",
209
+ dynamic_system_prompt: str = None,
210
+ **kwargs,
211
+ ) -> str:
212
+ """
213
+ Non-streaming ask
214
+ """
215
+ response = self.ask_stream(
216
+ prompt=prompt,
217
+ role=role,
218
+ convo_id=convo_id,
219
+ dynamic_system_prompt=dynamic_system_prompt,
220
+ **kwargs,
221
+ )
222
+ full_response: str = "".join(response)
223
+ return full_response
224
+
225
+ def rollback(self, n: int = 1, convo_id: str = "default") -> None:
226
+ """
227
+ Rollback the conversation
228
+ """
229
+ for _ in range(n):
230
+ self.conversation[convo_id].pop()
231
+
232
+ def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
233
+ """
234
+ Reset the conversation
235
+ """
236
+ self.conversation[convo_id] = [
237
+ {"role": "system", "content": system_prompt or self.system_prompt},
238
+ ]
239
+
240
+ def save(self, file: str, *keys: str) -> None:
241
+ """
242
+ Save the Chatbot configuration to a JSON file
243
+ """
244
+ with open(file, "w", encoding="utf-8") as f:
245
+ json.dump(
246
+ {
247
+ key: self.__dict__[key]
248
+ for key in get_filtered_keys_from_object(self, *keys)
249
+ },
250
+ f,
251
+ indent=2,
252
+ # saves session.proxies dict as session
253
+ default=lambda o: o.__dict__["proxies"],
254
+ )
255
+
256
+ def load(self, file: str, *keys: str) -> None:
257
+ """
258
+ Load the Chatbot configuration from a JSON file
259
+ """
260
+ with open(file, encoding="utf-8") as f:
261
+ # load json, if session is in keys, load proxies
262
+ loaded_config = json.load(f)
263
+ keys = get_filtered_keys_from_object(self, *keys)
264
+
265
+ if "session" in keys and loaded_config["session"]:
266
+ self.session.proxies = loaded_config["session"]
267
+ keys = keys - {"session"}
268
+ self.__dict__.update({key: loaded_config[key] for key in keys})
269
+
270
+
271
+ class OpenAIChatbot(ChatbotEngine):
272
+ def __init__(self, api_key: str,
273
+ engine: str = os.environ.get("GPT_ENGINE") or "gpt-3.5-turbo",
274
+ proxy: str = None,
275
+ max_tokens: int = 3000,
276
+ temperature: float = 0.5,
277
+ top_p: float = 1.0,
278
+ presence_penalty: float = 0.0,
279
+ frequency_penalty: float = 0.0,
280
+ reply_count: int = 1,
281
+ system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
282
+ overhead_token=96) -> None:
283
+ openai.api_key = api_key
284
+ self.api_key = api_key
285
+ self.engine = engine
286
+ self.proxy = proxy
287
+ self.max_tokens = max_tokens
288
+ self.temperature = temperature
289
+ self.top_p = top_p
290
+ self.presence_penalty = presence_penalty
291
+ self.frequency_penalty = frequency_penalty
292
+ self.reply_count = reply_count
293
+ self.system_prompt = system_prompt
294
+
295
+ self.bot = ChatbotWrapper(
296
+ api_key=self.api_key,
297
+ engine=self.engine,
298
+ proxy=self.proxy,
299
+ max_tokens=self.max_tokens,
300
+ temperature=self.temperature,
301
+ top_p=self.top_p,
302
+ presence_penalty=self.presence_penalty,
303
+ frequency_penalty=self.frequency_penalty,
304
+ reply_count=self.reply_count,
305
+ system_prompt=self.system_prompt,
306
+ overhead_token=overhead_token
307
+ )
308
+ self.overhead_token = overhead_token
309
+ import tiktoken
310
+ self.encoding = tiktoken.encoding_for_model(self.engine)
311
+
312
+ def encode_length(self, text: str) -> int:
313
+ return len(self.encoding.encode(text))
314
+
315
+ def query(self, questions: str,
316
+ role: str = "user",
317
+ convo_id: str = "default",
318
+ context: str = None,
319
+ **kwargs,):
320
+ return self.bot.ask(prompt=questions, role=role, convo_id=convo_id, dynamic_system_prompt=context, **kwargs)
config.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ MAX_TOKEN_MODEL_MAP = {
4
+ "gpt-3.5-turbo": 4096,
5
+ }
6
+
7
+ PDF_SAVE_DIR = "./files/"
8
+
9
+
10
+ DEFAULT_ENGINE = "gpt-3.5-turbo"
11
+ DEFAULT_TEMPERATURE = 0.5
12
+ DEFAULT_TOP_P = 1
13
+ DEFAULT_PRESENCE_PENALTY = 0
14
+ DEFAULT_FREQUENCY_PENALTY = 0
15
+ DEFAULT_REPLY_COUNT = 1
embedding_model.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from base_class import Embedding_Model
3
+ import pickle
4
+ from sentence_transformers import SentenceTransformer
5
+
6
+ from openai.embeddings_utils import (
7
+ get_embedding,
8
+ )
9
+
10
+
11
+ class HuggingfaceSentenceTransformerModel(Embedding_Model):
12
+ EMBEDDING_MODEL = "distiluse-base-multilingual-cased-v2"
13
+
14
+ def __init__(self, model_name=EMBEDDING_MODEL) -> None:
15
+ super().__init__(model_name)
16
+
17
+ self.model = SentenceTransformer(model_name)
18
+
19
+ def __call__(self, text) -> None:
20
+ return self.model.encode(text)
21
+
22
+
23
+ class OpenAIEmbeddingModel(Embedding_Model):
24
+ # constants
25
+ EMBEDDING_MODEL = "text-embedding-ada-002"
26
+ # establish a cache of embeddings to avoid recomputing
27
+ # cache is a dict of tuples (text, model) -> embedding, saved as a pickle file
28
+
29
+ def __init__(self, model_name=EMBEDDING_MODEL) -> None:
30
+ super().__init__(model_name)
31
+ self.model_name = model_name
32
+
33
+ # define a function to retrieve embeddings from the cache if present, and otherwise request via the API
34
+ def embedding_from_string(self,
35
+ string: str,
36
+ ) -> list:
37
+ """Return embedding of given string, using a cache to avoid recomputing."""
38
+ model = self.model_name
39
+ if (string, model) not in self.embedding_cache.keys():
40
+ self.embedding_cache[(string, model)] = get_embedding(
41
+ string, model)
42
+ with open(self.embedding_cache_path, "wb") as embedding_cache_file:
43
+ pickle.dump(self.embedding_cache, embedding_cache_file)
44
+ return self.embedding_cache[(string, model)]
45
+
46
+ def __call__(self, text) -> None:
47
+ return self.embedding_from_string(text)
frontend.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+ import streamlit as st
4
+ from streamlit_chat import message
5
+ import requests
6
+ from config import PDF_SAVE_DIR
7
+
8
+ st.set_page_config(
9
+ page_title="ChatPaper - Demo",
10
+ page_icon=":robot:"
11
+ )
12
+
13
+ pdf_uploaded = False
14
+
15
+ if pdf_uploaded is False:
16
+ st.sidebar.markdown("## Upload a PDF")
17
+ pdf_uploader = st.sidebar.file_uploader("Upload a PDF", type="pdf", )
18
+
19
+ st.sidebar.markdown("## API Key")
20
+ api_key = st.sidebar.text_input(
21
+ "OpenAI API Key", value="", label_visibility="hidden", help="Please enter your API key.")
22
+
23
+
24
+ def get_text():
25
+ input_text = st.text_input(
26
+ "User: ", "", help="Please ask any questions about the paper.")
27
+ return input_text
28
+
29
+
30
+ st.header("ChatPaper - Demo")
31
+
32
+ API_URL = "http://localhost:5000/query/"
33
+ header = {"api_key": ""}
34
+
35
+ if 'generated' not in st.session_state:
36
+ st.session_state['generated'] = []
37
+
38
+ if 'past' not in st.session_state:
39
+ st.session_state['past'] = []
40
+
41
+ if "user_stamp" not in st.session_state:
42
+ import datetime
43
+ user_stamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
44
+ st.session_state['user_stamp'] = user_stamp
45
+
46
+
47
+ if pdf_uploader is not None:
48
+ if api_key:
49
+ header['api_key'] = api_key
50
+ pdf_name = pdf_uploader.name.replace(' ', '_')
51
+
52
+ file_name = f"{st.session_state.user_stamp}_{pdf_name}"
53
+
54
+ # check PDF_SAVE_DIR
55
+ if not os.path.exists(PDF_SAVE_DIR):
56
+ os.makedirs(PDF_SAVE_DIR)
57
+
58
+ filepath = os.path.join(PDF_SAVE_DIR, file_name)
59
+ with open(filepath, "wb") as f:
60
+ f.write(pdf_uploader.getbuffer())
61
+ user_query = get_text()
62
+
63
+ if user_query:
64
+ st.session_state.past.append(user_query)
65
+ query_data = {"pdf_link": filepath,
66
+ "user_stamp": st.session_state.user_stamp, "user_query": user_query}
67
+ print(query_data)
68
+ response = requests.post(
69
+ API_URL, headers=header, json=query_data, timeout=300)
70
+ output = response.json()
71
+ code = output['code']
72
+ response = output['response']
73
+ if code == 200:
74
+ st.session_state.generated.append(response)
75
+
76
+ if st.session_state['generated']:
77
+ for i in range(len(st.session_state['generated'])-1, -1, -1):
78
+ message(st.session_state["generated"][i],
79
+ key=str(i), avatar_style="fun-emoji")
80
+ message(st.session_state['past'][i], is_user=True, key=str(
81
+ i) + '_user', avatar_style="personas")
82
+ else:
83
+ st.markdown(
84
+ "<span style='color:red'>Please enter your API key.</span>", unsafe_allow_html=True)
85
+ else:
86
+ st.markdown("<span style='color:red'>Please upload a PDF file.</span>",
87
+ unsafe_allow_html=True)
pdf_parser.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from base_class import AbstractPDFParser
2
+ import pickle
3
+ from scipdf_utils import parse_pdf_to_dict
4
+
5
+
6
+ class GrobidSciPDFPaser(AbstractPDFParser):
7
+ import pysbd
8
+ seg_en = pysbd.Segmenter(language="en", clean=False)
9
+ seg_chinese = pysbd.Segmenter(language="zh", clean=False)
10
+
11
+ def __init__(self, pdf_link, db_name="grobid_scipdf", short_thereshold=30) -> None:
12
+ """Initialize the PDF parser
13
+
14
+ Args:
15
+ pdf_link: link to the PDF file, the pdf link can be a web link or local file path
16
+ metadata: metadata of the PDF file, like authors, title, abstract, etc.
17
+ paragraphs: list of paragraphs of the PDF file, all paragraphs are concatenated together
18
+ split_paragraphs: dict of section name and corresponding list of split paragraphs
19
+ """
20
+ super().__init__(db_name=db_name)
21
+ self.db_name = db_name
22
+ self.pdf_link = pdf_link
23
+ self.pdf = None
24
+ self.metadata = {}
25
+ self.flattn_paragraphs = None
26
+ self.split_paragraphs = None
27
+ self.short_thereshold = short_thereshold
28
+ self.parse_pdf()
29
+
30
+ def _contact_too_short_paragraphs(self, ):
31
+ """Contact too short paragraphs or discard them"""
32
+ for i, section in enumerate(self.split_paragraphs):
33
+ # section_name = section['heading']
34
+ paragraphs = section['texts']
35
+ new_paragraphs = []
36
+ for paragraph in paragraphs:
37
+ if len(paragraph) <= self.short_thereshold and len(paragraph.strip()) != 0:
38
+ if len(new_paragraphs) != 0:
39
+ new_paragraphs[-1] += paragraph
40
+ else:
41
+ new_paragraphs.append(paragraph)
42
+ else:
43
+ new_paragraphs.append(paragraph)
44
+ self.split_paragraphs[i]['texts'] = new_paragraphs
45
+
46
+ @staticmethod
47
+ def _find_largest_font_string(file_name, search_string):
48
+ search_string = search_string.strip()
49
+ max_font_size = -1
50
+ page_number = -1
51
+ import PyPDF2
52
+ from pdfminer.high_level import extract_pages
53
+ from pdfminer.layout import LTTextContainer, LTChar
54
+ try:
55
+ with open(file_name, 'rb') as file:
56
+ pdf_reader = PyPDF2.PdfReader(file)
57
+
58
+ for index, page_layout in enumerate(extract_pages(file_name)):
59
+ for element in page_layout:
60
+ if isinstance(element, LTTextContainer):
61
+ for text_line in element:
62
+ if search_string in text_line.get_text():
63
+ for character in text_line:
64
+ if isinstance(character, LTChar):
65
+ if character.size > max_font_size:
66
+ max_font_size = character.size
67
+ page_number = index
68
+ return page_number + 1 if page_number != -1 else -1
69
+ except Exception as e:
70
+ return -1
71
+
72
+
73
+ def _find_section_page(self, section_name) -> None:
74
+ return GrobidSciPDFPaser._find_largest_font_string(self.pdf_link, section_name)
75
+
76
+ def _retrive_or_parse(self, ):
77
+ """Return pdf dict from cache if present, otherwise parse the pdf"""
78
+ db_name = self.db_name
79
+ if (self.pdf_link, db_name) not in self.db_cache.keys():
80
+ self.db_cache[(self.pdf_link, db_name)
81
+ ] = parse_pdf_to_dict(self.pdf_link)
82
+ with open(self.db_cache_path, "wb") as db_cache_file:
83
+ pickle.dump(self.db_cache, db_cache_file)
84
+ return self.db_cache[(self.pdf_link, db_name)]
85
+
86
+ @staticmethod
87
+ def _check_chinese(text) -> None:
88
+ return any(u'\u4e00' <= char <= u'\u9fff' for char in text)
89
+
90
+ def parse_pdf(self) -> None:
91
+ """Parse the PDF file
92
+ """
93
+ article_dict = self._retrive_or_parse()
94
+ self.article_dict = article_dict
95
+ self._get_metadata()
96
+ self.split_paragraphs = self.get_split_paragraphs()
97
+ self._contact_too_short_paragraphs()
98
+
99
+ self.flattn_paragraphs = self.get_paragraphs()
100
+
101
+ def get_paragraphs(self) -> None:
102
+ """Get the paragraphs of the PDF file
103
+ """
104
+ paragraphs = []
105
+ self.content2section = {}
106
+ for section in self.split_paragraphs:
107
+ # paragraphs+=[section["heading"]]
108
+ paragraphs += section["texts"]
109
+ for para in section["texts"]:
110
+ self.content2section[para] = section["heading"]
111
+ return paragraphs
112
+
113
+ def _get_metadata(self) -> None:
114
+ for meta in ['authors', "pub_date", "abstract", "references", "doi", 'title',]:
115
+ self.metadata[meta] = self.article_dict[meta]
116
+ self.section_names = [section["heading"]
117
+ for section in self.article_dict['sections']]
118
+ self.section_names2page = {}
119
+ for section_name in self.section_names:
120
+ section_page_index = self._find_section_page(section_name)
121
+ self.section_names2page.update({section_name: section_page_index})
122
+ self.section_names_with_page_index = [section_name + " (Page {})".format(
123
+ self.section_names2page[section_name]) for section_name in self.section_names]
124
+
125
+ def get_split_paragraphs(self, ) -> None:
126
+ section_pair_list = []
127
+ for section in self.article_dict['sections']:
128
+ section_pair_list.append({
129
+ "heading": section["heading"],
130
+ "texts": section["all_paragraphs"],
131
+ })
132
+ return section_pair_list
133
+
134
+ @staticmethod
135
+ def _determine_optimal_split_of_pargraphs(section_pair_list) -> None:
136
+ """
137
+ split based on the some magic rules
138
+ """
139
+ import pysbd
140
+ for section_pair in section_pair_list:
141
+ if GrobidSciPDFPaser._check_chinese(section_pair["text"]):
142
+ seg = GrobidSciPDFPaser.seg_chinese
143
+ else:
144
+ seg = GrobidSciPDFPaser.seg_en
145
+ section_pair["texts"] = seg.segment(section_pair["texts"])
146
+ section_pair["texts"] = [
147
+ para for para in section_pair["text"] if len(para) > 2]
148
+ return section_pair_list
requirements.txt CHANGED
@@ -1 +1,7 @@
1
- flask
 
 
 
 
 
 
 
1
+ Flask==2.2.3
2
+ streamlit==1.20.0
3
+ streamlit-chat==0.0.2.2
4
+ torch
5
+ sentence_transformers
6
+ bs4
7
+ openai
scipdf_utils.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import os.path as op
4
+ from glob import glob
5
+ import urllib
6
+ import subprocess
7
+ import requests
8
+ from bs4 import BeautifulSoup, NavigableString
9
+
10
+
11
+ # or https://cloud.science-miner.com/grobid/ for cloud service
12
+ GROBID_URL = "http://localhost:8070"
13
+ DIR_PATH = op.dirname(op.abspath(__file__))
14
+ PDF_FIGURES_JAR_PATH = op.join(
15
+ DIR_PATH, "pdffigures2", "pdffigures2-assembly-0.0.12-SNAPSHOT.jar"
16
+ )
17
+
18
+
19
+ def list_pdf_paths(pdf_folder: str):
20
+ """
21
+ list of pdf paths in pdf folder
22
+ """
23
+ return glob(op.join(pdf_folder, "*", "*", "*.pdf"))
24
+
25
+
26
+ def validate_url(path: str):
27
+ """
28
+ Validate a given ``path`` if it is URL or not
29
+ """
30
+ regex = re.compile(
31
+ r"^(?:http|ftp)s?://" # http:// or https://
32
+ # domain...
33
+ r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|"
34
+ r"localhost|" # localhost...
35
+ r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # ...or ip
36
+ r"(?::\d+)?" # optional port
37
+ r"(?:/?|[/?]\S+)$",
38
+ re.IGNORECASE,
39
+ )
40
+ return re.match(regex, path) is not None
41
+
42
+
43
+ def parse_pdf(
44
+ pdf_path: str,
45
+ fulltext: bool = True,
46
+ soup: bool = False,
47
+ grobid_url: str = GROBID_URL,
48
+ ):
49
+ """
50
+ Function to parse PDF to XML or BeautifulSoup using GROBID tool
51
+
52
+ You can see http://grobid.readthedocs.io/en/latest/Install-Grobid/ on how to run GROBID locally
53
+ After loading GROBID zip file, you can run GROBID by using the following
54
+ >> ./gradlew run
55
+
56
+ Parameters
57
+ ==========
58
+ pdf_path: str or bytes, path or URL to publication or article or bytes string of PDF
59
+ fulltext: bool, option for parsing, if True, parse full text of the article
60
+ if False, parse only header
61
+ grobid_url: str, url to GROBID parser, default at 'http://localhost:8070'
62
+ This could be changed to "https://cloud.science-miner.com/grobid/" for the cloud service
63
+ soup: bool, if True, return BeautifulSoup of the article
64
+
65
+ Output
66
+ ======
67
+ parsed_article: if soup is False, return parsed XML in text format,
68
+ else return BeautifulSoup of the XML
69
+ Example
70
+ =======
71
+ >> parsed_article = parse_pdf(pdf_path, fulltext=True, soup=True)
72
+ """
73
+ # GROBID URL
74
+ if fulltext:
75
+ url = "%s/api/processFulltextDocument" % grobid_url
76
+ else:
77
+ url = "%s/api/processHeaderDocument" % grobid_url
78
+
79
+ if isinstance(pdf_path, str):
80
+ if validate_url(pdf_path) and op.splitext(pdf_path)[-1].lower() != ".pdf":
81
+ print("The input URL has to end with ``.pdf``")
82
+ parsed_article = None
83
+ elif validate_url(pdf_path) and op.splitext(pdf_path)[-1] == ".pdf":
84
+ page = urllib.request.urlopen(pdf_path).read()
85
+ parsed_article = requests.post(url, files={"input": page}).text
86
+ elif op.exists(pdf_path):
87
+ parsed_article = requests.post(
88
+ url, files={"input": open(pdf_path, "rb")}
89
+ ).text
90
+ else:
91
+ parsed_article = None
92
+ elif isinstance(pdf_path, bytes):
93
+ # assume that incoming is byte string
94
+ parsed_article = requests.post(url, files={"input": pdf_path}).text
95
+ else:
96
+ parsed_article = None
97
+
98
+ if soup and parsed_article is not None:
99
+ parsed_article = BeautifulSoup(parsed_article, "lxml")
100
+ return parsed_article
101
+
102
+
103
+ def parse_authors(article):
104
+ """
105
+ Parse authors from a given BeautifulSoup of an article
106
+ """
107
+ author_names = article.find("sourcedesc").findAll("persname")
108
+ authors = []
109
+ for author in author_names:
110
+ firstname = author.find("forename", {"type": "first"})
111
+ firstname = firstname.text.strip() if firstname is not None else ""
112
+ middlename = author.find("forename", {"type": "middle"})
113
+ middlename = middlename.text.strip() if middlename is not None else ""
114
+ lastname = author.find("surname")
115
+ lastname = lastname.text.strip() if lastname is not None else ""
116
+ if middlename != "":
117
+ authors.append(firstname + " " + middlename + " " + lastname)
118
+ else:
119
+ authors.append(firstname + " " + lastname)
120
+ authors = "; ".join(authors)
121
+ return authors
122
+
123
+
124
+ def parse_date(article):
125
+ """
126
+ Parse date from a given BeautifulSoup of an article
127
+ """
128
+ pub_date = article.find("publicationstmt")
129
+ year = pub_date.find("date")
130
+ year = year.attrs.get("when") if year is not None else ""
131
+ return year
132
+
133
+
134
+ def parse_abstract(article):
135
+ """
136
+ Parse abstract from a given BeautifulSoup of an article
137
+ """
138
+ div = article.find("abstract")
139
+ abstract = ""
140
+ for p in list(div.children):
141
+ if not isinstance(p, NavigableString) and len(list(p)) > 0:
142
+ abstract += " ".join(
143
+ [elem.text for elem in p if not isinstance(
144
+ elem, NavigableString)]
145
+ )
146
+ return abstract
147
+
148
+
149
+ def calculate_number_of_references(div):
150
+ """
151
+ For a given section, calculate number of references made in the section
152
+ """
153
+ n_publication_ref = len(
154
+ [ref for ref in div.find_all("ref") if ref.attrs.get("type") == "bibr"]
155
+ )
156
+ n_figure_ref = len(
157
+ [ref for ref in div.find_all(
158
+ "ref") if ref.attrs.get("type") == "figure"]
159
+ )
160
+ return {"n_publication_ref": n_publication_ref, "n_figure_ref": n_figure_ref}
161
+
162
+
163
+ def parse_sections(article, as_list: bool = False):
164
+ """
165
+ Parse list of sections from a given BeautifulSoup of an article
166
+
167
+ Parameters
168
+ ==========
169
+ as_list: bool, if True, output text as a list of paragraph instead
170
+ of joining it together as one single text
171
+ """
172
+ article_text = article.find("text")
173
+ divs = article_text.find_all(
174
+ "div", attrs={"xmlns": "http://www.tei-c.org/ns/1.0"})
175
+ sections = []
176
+ for div in divs:
177
+ div_list = list(div.children)
178
+ if len(div_list) == 0:
179
+ heading = ""
180
+ text = ""
181
+ all_paragraphs = []
182
+ elif len(div_list) == 1:
183
+ if isinstance(div_list[0], NavigableString):
184
+ heading = str(div_list[0])
185
+ text = ""
186
+ all_paragraphs = []
187
+ else:
188
+ heading = ""
189
+ text = div_list[0].text
190
+ all_paragraphs = [text]
191
+ else:
192
+ text = []
193
+ heading = div_list[0]
194
+ all_paragraphs = []
195
+ if isinstance(heading, NavigableString):
196
+ heading = str(heading)
197
+ p_all = list(div.children)[1:]
198
+ else:
199
+ heading = ""
200
+ p_all = list(div.children)
201
+ for p in p_all:
202
+ if p is not None:
203
+ try:
204
+ text.append(p.text)
205
+ all_paragraphs.append(p.text)
206
+ except:
207
+ pass
208
+ if not as_list:
209
+ text = "\n".join(text)
210
+ if heading != "" or text != "":
211
+ ref_dict = calculate_number_of_references(div)
212
+ sections.append(
213
+ {
214
+ "heading": heading,
215
+ "text": text,
216
+ "all_paragraphs": all_paragraphs,
217
+ "n_publication_ref": ref_dict["n_publication_ref"],
218
+ "n_figure_ref": ref_dict["n_figure_ref"],
219
+ }
220
+ )
221
+ return sections
222
+
223
+
224
+ def parse_references(article):
225
+ """
226
+ Parse list of references from a given BeautifulSoup of an article
227
+ """
228
+ reference_list = []
229
+ references = article.find("text").find("div", attrs={"type": "references"})
230
+ references = references.find_all(
231
+ "biblstruct") if references is not None else []
232
+ reference_list = []
233
+ for reference in references:
234
+ title = reference.find("title", attrs={"level": "a"})
235
+ if title is None:
236
+ title = reference.find("title", attrs={"level": "m"})
237
+ title = title.text if title is not None else ""
238
+ journal = reference.find("title", attrs={"level": "j"})
239
+ journal = journal.text if journal is not None else ""
240
+ if journal == "":
241
+ journal = reference.find("publisher")
242
+ journal = journal.text if journal is not None else ""
243
+ year = reference.find("date")
244
+ year = year.attrs.get("when") if year is not None else ""
245
+ authors = []
246
+ for author in reference.find_all("author"):
247
+ firstname = author.find("forename", {"type": "first"})
248
+ firstname = firstname.text.strip() if firstname is not None else ""
249
+ middlename = author.find("forename", {"type": "middle"})
250
+ middlename = middlename.text.strip() if middlename is not None else ""
251
+ lastname = author.find("surname")
252
+ lastname = lastname.text.strip() if lastname is not None else ""
253
+ if middlename != "":
254
+ authors.append(firstname + " " + middlename + " " + lastname)
255
+ else:
256
+ authors.append(firstname + " " + lastname)
257
+ authors = "; ".join(authors)
258
+ reference_list.append(
259
+ {"title": title, "journal": journal, "year": year, "authors": authors}
260
+ )
261
+ return reference_list
262
+
263
+
264
+ def parse_figure_caption(article):
265
+ """
266
+ Parse list of figures/tables from a given BeautifulSoup of an article
267
+ """
268
+ figures_list = []
269
+ figures = article.find_all("figure")
270
+ for figure in figures:
271
+ figure_type = figure.attrs.get("type") or ""
272
+ figure_id = figure.attrs["xml:id"] or ""
273
+ label = figure.find("label").text
274
+ if figure_type == "table":
275
+ caption = figure.find("figdesc").text
276
+ data = figure.table.text
277
+ else:
278
+ caption = figure.text
279
+ data = ""
280
+ figures_list.append(
281
+ {
282
+ "figure_label": label,
283
+ "figure_type": figure_type,
284
+ "figure_id": figure_id,
285
+ "figure_caption": caption,
286
+ "figure_data": data,
287
+ }
288
+ )
289
+ return figures_list
290
+
291
+
292
+ def convert_article_soup_to_dict(article, as_list: bool = False):
293
+ """
294
+ Function to convert BeautifulSoup to JSON format
295
+ similar to the output from https://github.com/allenai/science-parse/
296
+
297
+ Parameters
298
+ ==========
299
+ article: BeautifulSoup
300
+
301
+ Output
302
+ ======
303
+ article_json: dict, parsed dictionary of a given article in the following format
304
+ {
305
+ 'title': ...,
306
+ 'abstract': ...,
307
+ 'sections': [
308
+ {'heading': ..., 'text': ...},
309
+ {'heading': ..., 'text': ...},
310
+ ...
311
+ ],
312
+ 'references': [
313
+ {'title': ..., 'journal': ..., 'year': ..., 'authors': ...},
314
+ {'title': ..., 'journal': ..., 'year': ..., 'authors': ...},
315
+ ...
316
+ ],
317
+ 'figures': [
318
+ {'figure_label': ..., 'figure_type': ..., 'figure_id': ..., 'figure_caption': ..., 'figure_data': ...},
319
+ ...
320
+ ]
321
+ }
322
+ """
323
+ article_dict = {}
324
+ if article is not None:
325
+ title = article.find("title", attrs={"type": "main"})
326
+ title = title.text.strip() if title is not None else ""
327
+ article_dict["authors"] = parse_authors(article)
328
+ article_dict["pub_date"] = parse_date(article)
329
+ article_dict["title"] = title
330
+ article_dict["abstract"] = parse_abstract(article)
331
+ article_dict["sections"] = parse_sections(article, as_list=as_list)
332
+ article_dict["references"] = parse_references(article)
333
+ article_dict["figures"] = parse_figure_caption(article)
334
+
335
+ doi = article.find("idno", attrs={"type": "DOI"})
336
+ doi = doi.text if doi is not None else ""
337
+ article_dict["doi"] = doi
338
+
339
+ return article_dict
340
+ else:
341
+ return None
342
+
343
+
344
+ def parse_pdf_to_dict(
345
+ pdf_path: str,
346
+ fulltext: bool = True,
347
+ soup: bool = True,
348
+ as_list: bool = False,
349
+ grobid_url: str = GROBID_URL,
350
+ ):
351
+ """
352
+ Parse the given PDF and return dictionary of the parsed article
353
+
354
+ Parameters
355
+ ==========
356
+ pdf_path: str, path to publication or article
357
+ fulltext: bool, whether to extract fulltext or not
358
+ soup: bool, whether to return BeautifulSoup or not
359
+ as_list: bool, whether to return list of sections or not
360
+ grobid_url: str, url to grobid server, default is `GROBID_URL`
361
+ This could be changed to "https://cloud.science-miner.com/grobid/" for the cloud service
362
+
363
+ Ouput
364
+ =====
365
+ article_dict: dict, dictionary of an article
366
+ """
367
+ parsed_article = parse_pdf(
368
+ pdf_path, fulltext=fulltext, soup=soup, grobid_url=grobid_url
369
+ )
370
+ article_dict = convert_article_soup_to_dict(
371
+ parsed_article, as_list=as_list)
372
+ return article_dict
373
+
374
+
375
+ def parse_figures(
376
+ pdf_folder: str,
377
+ jar_path: str = PDF_FIGURES_JAR_PATH,
378
+ resolution: int = 300,
379
+ output_folder: str = "figures",
380
+ ):
381
+ """
382
+ Parse figures from the given scientific PDF using pdffigures2
383
+
384
+ Parameters
385
+ ==========
386
+ pdf_folder: str, path to a folder that contains PDF files. A folder must contains only PDF files
387
+ jar_path: str, default path to pdffigures2-assembly-0.0.12-SNAPSHOT.jar file
388
+ resolution: int, resolution of the output figures
389
+ output_folder: str, path to folder that we want to save parsed data (related to figures) and figures
390
+
391
+ Output
392
+ ======
393
+ folder: making a folder of output_folder/data and output_folder/figures of parsed data and figures relatively
394
+ """
395
+ if not op.isdir(output_folder):
396
+ os.makedirs(output_folder)
397
+
398
+ # create ``data`` and ``figures`` subfolder within ``output_folder``
399
+ data_path = op.join(output_folder, "data")
400
+ figure_path = op.join(output_folder, "figures")
401
+ if not op.exists(data_path):
402
+ os.makedirs(data_path)
403
+ if not op.exists(figure_path):
404
+ os.makedirs(figure_path)
405
+
406
+ if op.isdir(data_path) and op.isdir(figure_path):
407
+ args = [
408
+ "java",
409
+ "-jar",
410
+ jar_path,
411
+ pdf_folder,
412
+ "-i",
413
+ str(resolution),
414
+ "-d",
415
+ os.path.join(os.path.abspath(data_path), ""),
416
+ "-m",
417
+ op.join(os.path.abspath(figure_path), ""), # end path with "/"
418
+ ]
419
+ _ = subprocess.run(
420
+ args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=20
421
+ )
422
+ print("Done parsing figures from PDFs!")
423
+ else:
424
+ print("You may have to check of ``data`` and ``figures`` in the the output folder path.")
serve_grobid.sh CHANGED
@@ -11,4 +11,4 @@ fi
11
 
12
  # run GROBID
13
  cd grobid-${GROBID_VERSION} || exit
14
- ./gradlew run &
 
11
 
12
  # run GROBID
13
  cd grobid-${GROBID_VERSION} || exit
14
+ ./gradlew run
similarity_metric.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from base_class import SimilarityAlg
2
+ from openai.embeddings_utils import (
3
+ distances_from_embeddings,
4
+ )
5
+
6
+
7
+ class CosineSimilarity(SimilarityAlg):
8
+ def __init__(self) -> None:
9
+ pass
10
+
11
+ @staticmethod
12
+ def __call__(query_embedding, embeddings) -> None:
13
+ return distances_from_embeddings(query_embedding, embeddings, distance_metric="cosine")
utils.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Set
3
+
4
+
5
+ def get_filtered_keys_from_object(obj: object, *keys: str) -> Set[str]:
6
+ """
7
+ Get filtered list of object variable names.
8
+ :param keys: List of keys to include. If the first key is "not", the remaining keys will be removed from the class keys.
9
+ :return: List of class keys.
10
+ """
11
+ class_keys = obj.__dict__.keys()
12
+ if not keys:
13
+ return class_keys
14
+
15
+ # Remove the passed keys from the class keys.
16
+ if keys[0] == "not":
17
+ return {key for key in class_keys if key not in keys[1:]}
18
+ # Check if all passed keys are valid
19
+ if invalid_keys := set(keys) - class_keys:
20
+ raise ValueError(
21
+ f"Invalid keys: {invalid_keys}",
22
+ )
23
+ # Only return specified keys that are in class_keys
24
+ return {key for key in keys if key in class_keys}