p-baleine commited on
Commit
cfdc527
1 Parent(s): 38b78a6
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ .ipynb_checkpoints
3
+
4
+ .envrc
5
+
6
+ !.gitkeep
7
+
8
+ .cache/*
README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # metaanalyser
metaanalyser/__init__.py ADDED
File without changes
metaanalyser/chains/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .overview import SROverviewChain
2
+ from .outline import SROutlintChain
3
+ from .section import SRSectionChain
4
+
5
+
6
+ __all__ = [
7
+ "SROutlintChain",
8
+ "SROverviewChain",
9
+ "SRSectionChain",
10
+ ]
metaanalyser/chains/base.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from langchain.base_language import BaseLanguageModel
3
+ from langchain.chains.llm import LLMChain
4
+ from langchain.callbacks.manager import CallbackManagerForChainRun
5
+ from langchain.output_parsers import RetryWithErrorOutputParser
6
+ from langchain.prompts.base import BasePromptTemplate
7
+ from langchain.schema import BaseOutputParser, OutputParserException
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class SRBaseChain(LLMChain):
14
+
15
+ def _call(
16
+ self,
17
+ inputs: Dict[str, Any],
18
+ run_manager: Optional[CallbackManagerForChainRun] = None,
19
+ ) -> Dict[str, str]:
20
+ response = self.generate(inputs, run_manager=run_manager)
21
+ # トークンの利用状況を確認したい
22
+ logger.info(f"LLM utilization: {response.llm_output}")
23
+ return self.create_outputs(response)[0]
24
+
25
+ def _acall(
26
+ self,
27
+ inputs: Dict[str, Any],
28
+ run_manager: Optional[CallbackManagerForChainRun] = None,
29
+ ) -> Dict[str, str]:
30
+ response = self.agenerate(inputs, run_manager=run_manager)
31
+ logger.info(f"LLM utilization: {response.llm_output}")
32
+ return self.create_outputs(response)[0]
33
+
34
+
35
+ def maybe_retry_with_error_output_parser(
36
+ llm: BaseLanguageModel,
37
+ input_list: List[Dict[str, str]],
38
+ output: Dict[str, str],
39
+ output_parser: BaseOutputParser,
40
+ output_key: str,
41
+ prompt: BasePromptTemplate,
42
+ ):
43
+ retry_parser = RetryWithErrorOutputParser.from_llm(
44
+ parser=output_parser,
45
+ llm=llm,
46
+ )
47
+
48
+ try:
49
+ output_text = output_parser.parse(output[output_key])
50
+ except OutputParserException as e:
51
+ logger.warning(f"An error occurred on parsing output, retrying parse, {e}")
52
+
53
+ output_text = retry_parser.parse_with_prompt(
54
+ output[output_key],
55
+ prompt.format_prompt(**(input_list[0]))
56
+ )
57
+
58
+ return {output_key: output_text}
metaanalyser/chains/outline/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .outline import SROutlintChain
2
+ from .prompt import Outlint
3
+
4
+
5
+ __all__ = [
6
+ "Outlint",
7
+ "SROutlintChain",
8
+ ]
metaanalyser/chains/outline/outline.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.base_language import BaseLanguageModel
2
+ from langchain.prompts.base import BasePromptTemplate
3
+ from langchain.callbacks.manager import CallbackManagerForChainRun
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ from ...paper import (
7
+ Paper,
8
+ get_abstract_with_token_limit,
9
+ get_categories_string,
10
+ )
11
+ from ..base import (
12
+ SRBaseChain,
13
+ maybe_retry_with_error_output_parser,
14
+ )
15
+ from ..overview import Overview
16
+ from .prompt import OUTLINE_PROMPT, output_parser
17
+
18
+
19
+ class SROutlintChain(SRBaseChain):
20
+
21
+ prompt: BasePromptTemplate = OUTLINE_PROMPT
22
+ nb_categories: int = 3
23
+ nb_token_limit: int = 2_000
24
+
25
+ @property
26
+ def input_keys(self) -> List[str]:
27
+ return ["query", "papers", "overview"]
28
+
29
+ def _call(
30
+ self,
31
+ inputs: Dict[str, Any],
32
+ run_manager: Optional[CallbackManagerForChainRun] = None,
33
+ ) -> Dict[str, str]:
34
+ input_list = get_input_list(
35
+ self.llm,
36
+ inputs["query"],
37
+ inputs["papers"],
38
+ inputs["overview"],
39
+ self.nb_categories,
40
+ self.nb_token_limit,
41
+ )
42
+ output = super()._call(input_list, run_manager=run_manager)
43
+ return maybe_retry_with_error_output_parser(
44
+ llm=self.llm,
45
+ input_list=input_list,
46
+ output=output,
47
+ output_parser=output_parser,
48
+ output_key=self.output_key,
49
+ prompt=self.prompt,
50
+ )
51
+
52
+ def _acall(
53
+ self,
54
+ inputs: Dict[str, Any],
55
+ run_manager: Optional[CallbackManagerForChainRun] = None,
56
+ ) -> Dict[str, str]:
57
+ input_list = get_input_list(
58
+ self.llm,
59
+ inputs["query"],
60
+ inputs["papers"],
61
+ inputs["overview"],
62
+ self.nb_categories,
63
+ self.nb_token_limit,
64
+ )
65
+ output = super()._acall(input_list, run_manager=run_manager)
66
+ return maybe_retry_with_error_output_parser(
67
+ llm=self.llm,
68
+ input_list=input_list,
69
+ output=output,
70
+ output_parser=output_parser,
71
+ output_key=self.output_key,
72
+ prompt=self.prompt,
73
+ )
74
+
75
+
76
+ def get_input_list(
77
+ llm: BaseLanguageModel,
78
+ query: str,
79
+ papers: List[Paper],
80
+ overview: Overview,
81
+ nb_categories: int,
82
+ nb_token_limit: int,
83
+ ):
84
+ return [{
85
+ "query": query,
86
+ "overview": overview,
87
+ "categories": get_categories_string(papers, nb_categories),
88
+ "abstracts": get_abstract_with_token_limit(llm, papers, nb_token_limit)
89
+ }]
metaanalyser/chains/outline/prompt.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.output_parsers import PydanticOutputParser
2
+ from langchain.prompts import (
3
+ ChatPromptTemplate,
4
+ PromptTemplate,
5
+ SystemMessagePromptTemplate,
6
+ HumanMessagePromptTemplate,
7
+ )
8
+ from pydantic import BaseModel, Field
9
+ from typing import List, Optional
10
+
11
+
12
+ class Section(BaseModel):
13
+
14
+ title: str = Field(description="title of this section")
15
+ children: Optional[List["Section"]] = Field(description="subsections of this section")
16
+ description: str = Field(description="brief description of this section (approximately 30 words maximum)")
17
+ citation_ids: List[int] = Field(description="citation ids to a paper abstract that this section cites")
18
+
19
+
20
+ class Outlint(BaseModel):
21
+
22
+ sections: List[Section] = Field(description="sections that make up this systematic review")
23
+ citations_ids: List[int] = Field(description="citation ids to all paper abstracts cited in this paper")
24
+
25
+ def __str__(self):
26
+ def section_string(idx: int, section: Section):
27
+ result = [f"{idx}. {section.title}: {section.description}"]
28
+
29
+ if not section.children:
30
+ return result[0]
31
+
32
+ result += [
33
+ section_string(f"{idx}.{child_idx}", child)
34
+ for child_idx, child in enumerate(section.children, start=1)
35
+ ]
36
+ return "\n".join(result)
37
+
38
+ return "\n".join([
39
+ section_string(idx, s)
40
+ for idx, s in enumerate(self.sections, start=1)
41
+ ])
42
+
43
+
44
+ output_parser = PydanticOutputParser(pydantic_object=Outlint)
45
+
46
+ system_template = "You are a research scientist and intereseted in {categories}. You are working on writing a systematic review regarding \"{query}\"."
47
+ system_prompt = SystemMessagePromptTemplate.from_template(system_template)
48
+
49
+ human_template = """Build an outline of the systematic review regarding \"{query}\" based on the following list of paper abstracts.
50
+
51
+ -----
52
+ {abstracts}
53
+ -----
54
+
55
+ The following is an overview of this systematic review. Build the outline of the systematic review according to this overview.
56
+
57
+ -----
58
+ {overview}
59
+ -----
60
+
61
+ Device each section of this outline by citing abstracts from the papers.
62
+ The beginning of element of the sections should by titled "Introduction" and last element of the sections should be titled "Conclusion".
63
+
64
+ {format_instructions}"""
65
+ human_prompt = HumanMessagePromptTemplate(
66
+ prompt=PromptTemplate(
67
+ template=human_template,
68
+ input_variables=["query", "abstracts", "overview"],
69
+ partial_variables={
70
+ "format_instructions": output_parser.get_format_instructions()
71
+ }
72
+ )
73
+ )
74
+
75
+ OUTLINE_PROMPT = ChatPromptTemplate.from_messages([system_prompt, human_prompt])
metaanalyser/chains/overview/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .overview import SROverviewChain
2
+ from .prompt import Overview
3
+
4
+
5
+ __all__ = [
6
+ "Overview",
7
+ "SROverviewChain",
8
+ ]
metaanalyser/chains/overview/overview.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.base_language import BaseLanguageModel
2
+ from langchain.callbacks.manager import CallbackManagerForChainRun
3
+ from langchain.prompts.base import BasePromptTemplate
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ from ...paper import (
7
+ Paper,
8
+ get_abstract_with_token_limit,
9
+ get_categories_string,
10
+ )
11
+ from ..base import (
12
+ SRBaseChain,
13
+ maybe_retry_with_error_output_parser,
14
+ )
15
+ from .prompt import OVERVIEW_PROMPT, output_parser
16
+
17
+
18
+ class SROverviewChain(SRBaseChain):
19
+
20
+ prompt: BasePromptTemplate = OVERVIEW_PROMPT
21
+ nb_categories: int = 3
22
+ nb_token_limit: int = 2_000
23
+ nb_max_retry: int = 3
24
+
25
+ @property
26
+ def input_keys(self) -> List[str]:
27
+ return ["query", "papers"]
28
+
29
+ def _call(
30
+ self,
31
+ inputs: Dict[str, Any],
32
+ run_manager: Optional[CallbackManagerForChainRun] = None,
33
+ ) -> Dict[str, str]:
34
+ input_list = get_input_list(
35
+ self.llm,
36
+ inputs["query"],
37
+ inputs["papers"],
38
+ self.nb_categories,
39
+ self.nb_token_limit,
40
+ )
41
+ output = super()._call(input_list, run_manager=run_manager)
42
+ return maybe_retry_with_error_output_parser(
43
+ llm=self.llm,
44
+ input_list=input_list,
45
+ output=output,
46
+ output_parser=output_parser,
47
+ output_key=self.output_key,
48
+ prompt=self.prompt,
49
+ )
50
+
51
+ def _acall(
52
+ self,
53
+ inputs: Dict[str, Any],
54
+ run_manager: Optional[CallbackManagerForChainRun] = None,
55
+ ) -> Dict[str, str]:
56
+ input_list = get_input_list(
57
+ self.llm,
58
+ inputs["query"],
59
+ inputs["papers"],
60
+ self.nb_categories,
61
+ self.nb_token_limit,
62
+ )
63
+ output = super()._acall(input_list, run_manager=run_manager)
64
+ return maybe_retry_with_error_output_parser(
65
+ llm=self.llm,
66
+ input_list=input_list,
67
+ output=output,
68
+ output_parser=output_parser,
69
+ output_key=self.output_key,
70
+ prompt=self.prompt,
71
+ )
72
+
73
+
74
+ def get_input_list(
75
+ llm: BaseLanguageModel,
76
+ query: str,
77
+ papers: List[Paper],
78
+ nb_categories: int,
79
+ nb_token_limit: int,
80
+ ):
81
+ return [{
82
+ "query": query,
83
+ "categories": get_categories_string(papers, nb_categories),
84
+ "abstracts": get_abstract_with_token_limit(llm, papers, nb_token_limit)
85
+ }]
metaanalyser/chains/overview/prompt.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.output_parsers import PydanticOutputParser
2
+ from langchain.prompts import (
3
+ ChatPromptTemplate,
4
+ PromptTemplate,
5
+ SystemMessagePromptTemplate,
6
+ HumanMessagePromptTemplate,
7
+ )
8
+ from pydantic import BaseModel, Field
9
+ from typing import List
10
+
11
+
12
+ class Overview(BaseModel):
13
+
14
+ title: str = Field(description="title of the systematic review")
15
+ main_points: List[str] = Field(description="main points that make up the systematic review")
16
+ overview: str = Field(description="overview of the systematic review")
17
+
18
+ def __str__(self):
19
+ points = "\n - ".join(self.main_points)
20
+ return f"""
21
+ Title: {self.title}
22
+ Points:
23
+ - {points}
24
+ Overview: {self.overview}
25
+ """.strip()
26
+
27
+ def _repr_html_(self):
28
+ main_points = "".join([f"<li>{p}</li>" for p in self.main_points])
29
+
30
+ return (
31
+ "<div>"
32
+ f" <div><span style=\"font-weight: bold\">Title:</span>"
33
+ f" <span style=\"margin-left: 5px\">{self.title}</span>"
34
+ f" </div>"
35
+ f" <div><span style=\"font-weight: bold\">Main points:</span>"
36
+ f" <ul style=\"margin: 0 10px\">{main_points}</ul>"
37
+ f" </div>"
38
+ f" <div><span style=\"font-weight: bold\">Overview:</span>"
39
+ f" <span style=\"margin-left: 5px\">{self.overview}</span>"
40
+ f" </div>"
41
+ "</div>"
42
+ )
43
+
44
+
45
+ output_parser = PydanticOutputParser(pydantic_object=Overview)
46
+
47
+ system_template = "You are a research scientist and intereseted in {categories}. You are working on writing a systematic review regarding \"{query}\"."
48
+ system_prompt = SystemMessagePromptTemplate.from_template(system_template)
49
+
50
+ human_template = """Write an overview of the systematic review based on the summary of the following list of paper abstracts.
51
+
52
+ -----
53
+ {abstracts}
54
+ -----
55
+
56
+ This overview should serve as a compass for you as you construct the outline of the systematic review and write down its details.
57
+
58
+ Assuming that the readers of this systematic review will not be familiar with the field. In order to make it easy for readers who are not familiar with this field to understand, list the main points briefly (approximately 30 words maximum) based on the following points.
59
+
60
+ - Motivation for this field and the problem this field are trying to solve
61
+ - Historical background of this field
62
+ - Future development of this field
63
+
64
+ Based on these main points, provide an overview of the systematic review regarding {query} you will write.
65
+
66
+ Finally, write the title of the systematic review you are going to write based on this overview.
67
+
68
+ {format_instructions}"""
69
+ human_prompt = HumanMessagePromptTemplate(
70
+ prompt=PromptTemplate(
71
+ template=human_template,
72
+ input_variables=["abstracts", "query"],
73
+ partial_variables={
74
+ "format_instructions": output_parser.get_format_instructions()
75
+ }
76
+ )
77
+ )
78
+
79
+ OVERVIEW_PROMPT = ChatPromptTemplate.from_messages([system_prompt, human_prompt])
metaanalyser/chains/section/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .section import SRSectionChain
2
+
3
+
4
+ __all__ = [
5
+ "SRSectionChain"
6
+ ]
metaanalyser/chains/section/prompt.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import (
2
+ ChatPromptTemplate,
3
+ SystemMessagePromptTemplate,
4
+ HumanMessagePromptTemplate,
5
+ )
6
+
7
+
8
+ system_template = """You are a research scientist and intereseted in {categories}. You are working on writing a systematic review regarding \"{query}\".
9
+
10
+ The outline of the systematic review is as follows:
11
+
12
+ -----
13
+ Title: {title}
14
+ {outline}"""
15
+ system_prompt = SystemMessagePromptTemplate.from_template(system_template)
16
+
17
+ human_template = """Write the "{section_title}" section of this systematic review based on the following list of snippets or abstracts of relative papers.
18
+
19
+ -----
20
+ {snippets}
21
+ -----
22
+
23
+ This systematic review should adhere to the following overview:
24
+
25
+ {overview}
26
+
27
+ Write the "{section_title}" section with respect to this overview. Write the text in markdown format. The title of this section should bu suffixed with {section_level} level markdown title (`{md_title_suffix}`). The text of the section should be based on a snippet or abstact and should be clearly cited. The citation should be written at the end of the sentence in the form `[^cite_id]`."""
28
+ human_prompt = HumanMessagePromptTemplate.from_template(human_template)
29
+
30
+ SECTION_PROMPT = ChatPromptTemplate.from_messages([
31
+ system_prompt,
32
+ human_prompt,
33
+ ])
metaanalyser/chains/section/section.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.base_language import BaseLanguageModel
2
+ from langchain.callbacks.manager import CallbackManagerForChainRun
3
+ from langchain.prompts.base import BasePromptTemplate
4
+ from langchain.vectorstores.base import VectorStore
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ from ...paper import (
8
+ Paper,
9
+ get_abstract_with_token_limit,
10
+ get_categories_string,
11
+ )
12
+ from ..base import (
13
+ SRBaseChain,
14
+ maybe_retry_with_error_output_parser,
15
+ )
16
+ from ..outline import Outlint
17
+ from ..overview import Overview
18
+ from .prompt import SECTION_PROMPT
19
+
20
+
21
+ class SRSectionChain(SRBaseChain):
22
+
23
+ paper_store: VectorStore
24
+ prompt: BasePromptTemplate = SECTION_PROMPT
25
+ nb_categories: int = 3
26
+ nb_token_limit: int = 2_000
27
+ nb_max_retry: int = 3
28
+
29
+ @property
30
+ def input_keys(self) -> List[str]:
31
+ # TODO: 入れ子に対応する
32
+ return [
33
+ "section_idx",
34
+ "section_level",
35
+ "query",
36
+ "papers",
37
+ "overview",
38
+ "outline"
39
+ ]
40
+
41
+ def _call(
42
+ self,
43
+ inputs: Dict[str, Any],
44
+ run_manager: Optional[CallbackManagerForChainRun] = None,
45
+ ) -> Dict[str, str]:
46
+ input_list = get_input_list(
47
+ self.llm,
48
+ self.paper_store,
49
+ inputs["section_idx"],
50
+ inputs["section_level"],
51
+ inputs["query"],
52
+ inputs["papers"],
53
+ inputs["overview"],
54
+ inputs["outline"],
55
+ self.nb_categories,
56
+ self.nb_token_limit,
57
+ )
58
+ return super()._call(input_list, run_manager=run_manager)
59
+
60
+ def _acall(
61
+ self,
62
+ inputs: Dict[str, Any],
63
+ run_manager: Optional[CallbackManagerForChainRun] = None,
64
+ ) -> Dict[str, str]:
65
+ input_list = get_input_list(
66
+ self.llm,
67
+ self.paper_store,
68
+ inputs["section_idx"],
69
+ inputs["section_level"],
70
+ inputs["query"],
71
+ inputs["papers"],
72
+ inputs["overview"],
73
+ inputs["outline"],
74
+ self.nb_categories,
75
+ self.nb_token_limit,
76
+ )
77
+ return super()._acall(input_list, run_manager=run_manager)
78
+
79
+
80
+ def get_input_list(
81
+ llm: BaseLanguageModel,
82
+ paper_store: VectorStore,
83
+ section_idx: int,
84
+ section_level: int,
85
+ query: str,
86
+ papers: List[Paper],
87
+ overview: Overview,
88
+ outline: Outlint,
89
+ nb_categories: int,
90
+ nb_token_limit: int,
91
+ max_paper_store_search_size: int = 100,
92
+ ):
93
+ section = outline.sections[section_idx]
94
+ papers_citation_id_map = {p.citation_id: p for p in papers}
95
+ related_papers = [
96
+ papers_citation_id_map[int(citation_id)]
97
+ for citation_id in section.citation_ids
98
+ ]
99
+
100
+ if not related_papers:
101
+ # citation_ids が空なら全部を対象とする
102
+ # FIXME: 全部にしちゃうと溢れちゃうのでは??
103
+ related_papers = papers
104
+
105
+ related_snippets = paper_store.similarity_search(
106
+ f"{section.title} {section.description}",
107
+ k=max_paper_store_search_size,
108
+ )
109
+
110
+ # overview が引用している論文の abst は全部 snippet に含める
111
+ # 加えて nb_token_limit に到達するまで vectorstore から関連文章を集める
112
+
113
+ def get_snippet(title, citation_id, text):
114
+ text = text.replace("\n", " ")
115
+ return f"""
116
+ Title: {title}
117
+ citation_id: {citation_id}
118
+ Text: {text}
119
+ """
120
+
121
+ snippets = [get_snippet(p.title, p.citation_id, p.summary) for p in related_papers]
122
+ total_num_tokens = llm.get_num_tokens("\n".join(snippets).strip())
123
+ idx = 0
124
+
125
+ while idx < len(related_snippets):
126
+ snippet = related_snippets[idx]
127
+ snippet_text = get_snippet(
128
+ snippet.metadata["title"],
129
+ snippet.metadata["citation_id"],
130
+ snippet.page_content,
131
+ )
132
+ num_tokens = llm.get_num_tokens(snippet_text)
133
+
134
+ if total_num_tokens + num_tokens > nb_token_limit:
135
+ break
136
+
137
+ snippets.append(snippet_text)
138
+ total_num_tokens += num_tokens
139
+ idx += 1
140
+
141
+ return [{
142
+ "query": query,
143
+ "title": overview.title,
144
+ "overview": overview,
145
+ "section_title": section.title,
146
+ "section_level": section_level,
147
+ "md_title_suffix": "#" * section_level,
148
+ "outline": outline,
149
+ "categories": get_categories_string(papers, nb_categories),
150
+ "snippets": "\n".join(snippets).strip(),
151
+ }]
metaanalyser/memory.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from joblib import Memory
3
+
4
+
5
+ CACHE_DIR = os.environ.get(
6
+ "METAANALYSER_CACHE_DIR",
7
+ os.path.join(os.path.relpath(os.path.dirname(__file__)), "..", ".cache")
8
+ )
9
+
10
+ memory = Memory(CACHE_DIR, verbose=0)
metaanalyser/paper/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .paper import (
2
+ Paper,
3
+ get_abstract_with_token_limit,
4
+ get_categories_string,
5
+ search_on_google_scholar,
6
+ )
7
+ from .vectorstore import create_papers_vectorstor
8
+
9
+
10
+ __all__ = [
11
+ "Paper",
12
+ "create_papers_vectorstor",
13
+ "get_abstract_with_token_limit",
14
+ "get_categories_string",
15
+ "search_on_google_scholar",
16
+ ]
metaanalyser/paper/arxiv_categories.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://arxiv.org/category_taxonomy をスクレイピングして取得した
2
+ # TODO: 新規に追加されるものに対応する
3
+ CATEGORY_NAME_ID_MAP = {
4
+ 'cs.AI': 'Artificial Intelligence',
5
+ 'cs.AR': 'Hardware Architecture',
6
+ 'cs.CC': 'Computational Complexity',
7
+ 'cs.CE': 'Computational Engineering, Finance, and Science',
8
+ 'cs.CG': 'Computational Geometry',
9
+ 'cs.CL': 'Computation and Language',
10
+ 'cs.CR': 'Cryptography and Security',
11
+ 'cs.CV': 'Computer Vision and Pattern Recognition',
12
+ 'cs.CY': 'Computers and Society',
13
+ 'cs.DB': 'Databases',
14
+ 'cs.DC': 'Distributed, Parallel, and Cluster Computing',
15
+ 'cs.DL': 'Digital Libraries',
16
+ 'cs.DM': 'Discrete Mathematics',
17
+ 'cs.DS': 'Data Structures and Algorithms',
18
+ 'cs.ET': 'Emerging Technologies',
19
+ 'cs.FL': 'Formal Languages and Automata Theory',
20
+ 'cs.GL': 'General Literature',
21
+ 'cs.GR': 'Graphics',
22
+ 'cs.GT': 'Computer Science and Game Theory',
23
+ 'cs.HC': 'Human-Computer Interaction',
24
+ 'cs.IR': 'Information Retrieval',
25
+ 'cs.IT': 'Information Theory',
26
+ 'cs.LG': 'Machine Learning',
27
+ 'cs.LO': 'Logic in Computer Science',
28
+ 'cs.MA': 'Multiagent Systems',
29
+ 'cs.MM': 'Multimedia',
30
+ 'cs.MS': 'Mathematical Software',
31
+ 'cs.NA': 'Numerical Analysis',
32
+ 'cs.NE': 'Neural and Evolutionary Computing',
33
+ 'cs.NI': 'Networking and Internet Architecture',
34
+ 'cs.OH': 'Other Computer Science',
35
+ 'cs.OS': 'Operating Systems',
36
+ 'cs.PF': 'Performance',
37
+ 'cs.PL': 'Programming Languages',
38
+ 'cs.RO': 'Robotics',
39
+ 'cs.SC': 'Symbolic Computation',
40
+ 'cs.SD': 'Sound',
41
+ 'cs.SE': 'Software Engineering',
42
+ 'cs.SI': 'Social and Information Networks',
43
+ 'cs.SY': 'Systems and Control',
44
+ 'econ.EM': 'Econometrics',
45
+ 'econ.GN': 'General Economics',
46
+ 'econ.TH': 'Theoretical Economics',
47
+ 'eess.AS': 'Audio and Speech Processing',
48
+ 'eess.IV': 'Image and Video Processing',
49
+ 'eess.SP': 'Signal Processing',
50
+ 'eess.SY': 'Systems and Control',
51
+ 'math.AC': 'Commutative Algebra',
52
+ 'math.AG': 'Algebraic Geometry',
53
+ 'math.AP': 'Analysis of PDEs',
54
+ 'math.AT': 'Algebraic Topology',
55
+ 'math.CA': 'Classical Analysis and ODEs',
56
+ 'math.CO': 'Combinatorics',
57
+ 'math.CT': 'Category Theory',
58
+ 'math.CV': 'Complex Variables',
59
+ 'math.DG': 'Differential Geometry',
60
+ 'math.DS': 'Dynamical Systems',
61
+ 'math.FA': 'Functional Analysis',
62
+ 'math.GM': 'General Mathematics',
63
+ 'math.GN': 'General Topology',
64
+ 'math.GR': 'Group Theory',
65
+ 'math.GT': 'Geometric Topology',
66
+ 'math.HO': 'History and Overview',
67
+ 'math.IT': 'Information Theory',
68
+ 'math.KT': 'K-Theory and Homology',
69
+ 'math.LO': 'Logic',
70
+ 'math.MG': 'Metric Geometry',
71
+ 'math.MP': 'Mathematical Physics',
72
+ 'math.NA': 'Numerical Analysis',
73
+ 'math.NT': 'Number Theory',
74
+ 'math.OA': 'Operator Algebras',
75
+ 'math.OC': 'Optimization and Control',
76
+ 'math.PR': 'Probability',
77
+ 'math.QA': 'Quantum Algebra',
78
+ 'math.RA': 'Rings and Algebras',
79
+ 'math.RT': 'Representation Theory',
80
+ 'math.SG': 'Symplectic Geometry',
81
+ 'math.SP': 'Spectral Theory',
82
+ 'math.ST': 'Statistics Theory',
83
+ 'Astrophysics': 'astro-ph',
84
+ 'astro-ph.CO': 'Cosmology and Nongalactic Astrophysics',
85
+ 'astro-ph.EP': 'Earth and Planetary Astrophysics',
86
+ 'astro-ph.GA': 'Astrophysics of Galaxies',
87
+ 'astro-ph.HE': 'High Energy Astrophysical Phenomena',
88
+ 'astro-ph.IM': 'Instrumentation and Methods for Astrophysics',
89
+ 'astro-ph.SR': 'Solar and Stellar Astrophysics',
90
+ 'Condensed Matter': 'cond-mat',
91
+ 'cond-mat.dis-nn': 'Disordered Systems and Neural Networks',
92
+ 'cond-mat.mes-hall': 'Mesoscale and Nanoscale Physics',
93
+ 'cond-mat.mtrl-sci': 'Materials Science',
94
+ 'cond-mat.other': 'Other Condensed Matter',
95
+ 'cond-mat.quant-gas': 'Quantum Gases',
96
+ 'cond-mat.soft': 'Soft Condensed Matter',
97
+ 'cond-mat.stat-mech': 'Statistical Mechanics',
98
+ 'cond-mat.str-el': 'Strongly Correlated Electrons',
99
+ 'cond-mat.supr-con': 'Superconductivity',
100
+ 'General Relativity and Quantum Cosmology': 'gr-qc',
101
+ 'gr-qc': 'General Relativity and Quantum Cosmology',
102
+ 'High Energy Physics - Experiment': 'hep-ex',
103
+ 'hep-ex': 'High Energy Physics - Experiment',
104
+ 'High Energy Physics - Lattice': 'hep-lat',
105
+ 'hep-lat': 'High Energy Physics - Lattice',
106
+ 'High Energy Physics - Phenomenology': 'hep-ph',
107
+ 'hep-ph': 'High Energy Physics - Phenomenology',
108
+ 'High Energy Physics - Theory': 'hep-th',
109
+ 'hep-th': 'High Energy Physics - Theory',
110
+ 'Mathematical Physics': 'math-ph',
111
+ 'math-ph': 'Mathematical Physics',
112
+ 'Nonlinear Sciences': 'nlin',
113
+ 'nlin.AO': 'Adaptation and Self-Organizing Systems',
114
+ 'nlin.CD': 'Chaotic Dynamics',
115
+ 'nlin.CG': 'Cellular Automata and Lattice Gases',
116
+ 'nlin.PS': 'Pattern Formation and Solitons',
117
+ 'nlin.SI': 'Exactly Solvable and Integrable Systems',
118
+ 'Nuclear Experiment': 'nucl-ex',
119
+ 'nucl-ex': 'Nuclear Experiment',
120
+ 'Nuclear Theory': 'nucl-th',
121
+ 'nucl-th': 'Nuclear Theory',
122
+ 'Physics': 'physics',
123
+ 'physics.acc-ph': 'Accelerator Physics',
124
+ 'physics.ao-ph': 'Atmospheric and Oceanic Physics',
125
+ 'physics.app-ph': 'Applied Physics',
126
+ 'physics.atm-clus': 'Atomic and Molecular Clusters',
127
+ 'physics.atom-ph': 'Atomic Physics',
128
+ 'physics.bio-ph': 'Biological Physics',
129
+ 'physics.chem-ph': 'Chemical Physics',
130
+ 'physics.class-ph': 'Classical Physics',
131
+ 'physics.comp-ph': 'Computational Physics',
132
+ 'physics.data-an': 'Data Analysis, Statistics and Probability',
133
+ 'physics.ed-ph': 'Physics Education',
134
+ 'physics.flu-dyn': 'Fluid Dynamics',
135
+ 'physics.gen-ph': 'General Physics',
136
+ 'physics.geo-ph': 'Geophysics',
137
+ 'physics.hist-ph': 'History and Philosophy of Physics',
138
+ 'physics.ins-det': 'Instrumentation and Detectors',
139
+ 'physics.med-ph': 'Medical Physics',
140
+ 'physics.optics': 'Optics',
141
+ 'physics.plasm-ph': 'Plasma Physics',
142
+ 'physics.pop-ph': 'Popular Physics',
143
+ 'physics.soc-ph': 'Physics and Society',
144
+ 'physics.space-ph': 'Space Physics',
145
+ 'Quantum Physics': 'quant-ph',
146
+ 'quant-ph': 'Quantum Physics',
147
+ 'q-bio.BM': 'Biomolecules',
148
+ 'q-bio.CB': 'Cell Behavior',
149
+ 'q-bio.GN': 'Genomics',
150
+ 'q-bio.MN': 'Molecular Networks',
151
+ 'q-bio.NC': 'Neurons and Cognition',
152
+ 'q-bio.OT': 'Other Quantitative Biology',
153
+ 'q-bio.PE': 'Populations and Evolution',
154
+ 'q-bio.QM': 'Quantitative Methods',
155
+ 'q-bio.SC': 'Subcellular Processes',
156
+ 'q-bio.TO': 'Tissues and Organs',
157
+ 'q-fin.CP': 'Computational Finance',
158
+ 'q-fin.EC': 'Economics',
159
+ 'q-fin.GN': 'General Finance',
160
+ 'q-fin.MF': 'Mathematical Finance',
161
+ 'q-fin.PM': 'Portfolio Management',
162
+ 'q-fin.PR': 'Pricing of Securities',
163
+ 'q-fin.RM': 'Risk Management',
164
+ 'q-fin.ST': 'Statistical Finance',
165
+ 'q-fin.TR': 'Trading and Market Microstructure',
166
+ 'stat.AP': 'Applications',
167
+ 'stat.CO': 'Computation',
168
+ 'stat.ME': 'Methodology',
169
+ 'stat.ML': 'Machine Learning',
170
+ 'stat.OT': 'Other Statistics',
171
+ 'stat.TH': 'Statistics Theory'
172
+ }
metaanalyser/paper/paper.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import arxiv
2
+ import datetime
3
+ import logging
4
+ import re
5
+ import tempfile
6
+ from collections import Counter
7
+ from langchain.base_language import BaseLanguageModel
8
+ from langchain.utilities import SerpAPIWrapper
9
+ from pdfminer.high_level import extract_text
10
+ from pydantic import BaseModel
11
+ from tqdm.auto import tqdm
12
+ from typing import List, Optional
13
+
14
+ from ..memory import memory
15
+ from .arxiv_categories import CATEGORY_NAME_ID_MAP
16
+
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class Citation(BaseModel):
22
+
23
+ title: str
24
+ snippet: str
25
+
26
+
27
+ class GoogleScholarItem(BaseModel):
28
+
29
+ result_id: str
30
+ title: str
31
+ link: str
32
+ nb_cited: int
33
+ citations: List[Citation]
34
+
35
+ @property
36
+ def mla_citiation(self) -> str:
37
+ mla = [c for c in self.citations if c.title == 'MLA']
38
+
39
+ if mla:
40
+ return mla[0]
41
+
42
+ @classmethod
43
+ def from_google_scholar_result(cls, result):
44
+ result_id = result["result_id"]
45
+ link = result["link"] if "link" in result else ""
46
+ nb_cited = (
47
+ result["inline_links"]["cited_by"]["total"]
48
+ if "cited_by" in result["inline_links"] else 0
49
+ )
50
+ citations = [
51
+ Citation(title=c["title"], snippet=c["snippet"]) for c in
52
+ fetch_google_scholar_cite(result_id)["citations"]
53
+ ]
54
+
55
+ return cls(
56
+ result_id=result_id,
57
+ title=result["title"],
58
+ link=link,
59
+ nb_cited=nb_cited,
60
+ citations=citations,
61
+ )
62
+
63
+
64
+ class Paper(BaseModel):
65
+ """論文を表す、Google Scholar で得られる情報に追加して doi や要約などのフィールドを持つ
66
+
67
+ NOTE: serpapi 以外をソースにすることも考えられるが、今は Paper の出自は serpapi の検索結果に限定する
68
+ """
69
+
70
+ citation_id: int
71
+ google_scholar_item: GoogleScholarItem
72
+ entry_id: str
73
+ summary: str
74
+ published: datetime.datetime
75
+ primary_category: str
76
+ categories: List[str]
77
+ text: str
78
+ doi: Optional[str]
79
+
80
+ @property
81
+ def google_scholar_result_id(self):
82
+ return self.google_scholar_item.result_id
83
+
84
+ @property
85
+ def title(self) -> str:
86
+ return self.google_scholar_item.title
87
+
88
+ @property
89
+ def link(self) -> str:
90
+ return self.google_scholar_item.link
91
+
92
+ @property
93
+ def nb_cited(self) -> int:
94
+ return self.google_scholar_item.nb_cited
95
+
96
+ @property
97
+ def citations(self) -> str:
98
+ return self.google_scholar_item.citations
99
+
100
+ @property
101
+ def mla_citiation(self) -> str:
102
+ return self.google_scholar_item.mla_citiation
103
+
104
+ @classmethod
105
+ def from_google_scholar_result(cls, citation_id, result):
106
+ google_scholar_item = GoogleScholarItem.from_google_scholar_result(result)
107
+ arxiv_result = fetch_arxiv_result(google_scholar_item.link)
108
+
109
+ def get_category(c):
110
+ if c not in CATEGORY_NAME_ID_MAP:
111
+ logger.warning(f'Category {c} is not found in CATEGORY_NAME_ID_MAP.')
112
+ return None
113
+ return CATEGORY_NAME_ID_MAP[c]
114
+
115
+ primary_category = get_category(arxiv_result.primary_category)
116
+ categories = [
117
+ c for c in [get_category(c) for c in arxiv_result.categories]
118
+ if c
119
+ ]
120
+
121
+ return cls(
122
+ citation_id=citation_id,
123
+ google_scholar_item=google_scholar_item,
124
+ entry_id=arxiv_result.entry_id,
125
+ summary=arxiv_result.summary,
126
+ published=arxiv_result.published,
127
+ primary_category=primary_category,
128
+ categories=categories,
129
+ doi=arxiv_result.doi,
130
+ text=get_text_from_arxiv_search_result(arxiv_result),
131
+ )
132
+
133
+ def _repr_html_(self):
134
+ def get_category_string():
135
+ # 基本的に categories の先頭が primary_category らしい
136
+ if not self.categories:
137
+ return ""
138
+
139
+ result = f"<span style=\"font-weight: bold\">{self.categories[0]}</span>"
140
+
141
+ if len(self.categories) == 1:
142
+ return result
143
+
144
+ return f"{result}; " + "; ".join([c for c in self.categories[1:]])
145
+
146
+ return (
147
+ "<div>"
148
+ f" Title:&nbsp;<a href=\"{self.link}\" target=\"_blank\">{self.title}</a><br/>"
149
+ f" 引用:&nbsp;[{self.citation_id}] {self.mla_citiation.snippet}<br/>"
150
+ f" 被引用数:&nbsp;{self.nb_cited}<br/>"
151
+ f" 発行日:&nbsp;{self.published}<br/>"
152
+ f" カテゴリ:&nbsp;{get_category_string()}<br/>"
153
+ f" 要約:&nbsp;{self.summary}<br/>"
154
+ "</div>"
155
+ )
156
+
157
+
158
+ def search_on_google_scholar(
159
+ query: str,
160
+ approved_domains: List[str] = ["arxiv.org"],
161
+ n: int = 10,
162
+ ) -> List[Paper]:
163
+ """query で SerpApi の Google Scholar API に問合せた結果を返す。
164
+ approved_domains に指定されたドメインの論文のみを対象とする。
165
+ 最大 n に指定された件数��返却する。
166
+ """
167
+
168
+ def fetch(start=0):
169
+ def valid_item(i):
170
+ if "link" not in i:
171
+ return False
172
+
173
+ domain = re.match(r"https?://([^/]+)", i["link"])
174
+
175
+ if not domain or domain.group(1) not in approved_domains:
176
+ return False
177
+
178
+ return True
179
+
180
+ search_result = fetch_google_scholar(query, start)
181
+
182
+ return [i for i in search_result if valid_item(i)]
183
+
184
+ result = []
185
+ start = 0
186
+
187
+ while len(result) < n:
188
+ # FIXME: 今のままだとそもそも検索結果が全体で n 件以下の場合に無限ループになってしまう
189
+ logger.info(f"Looking for `{query}` on Google Scholar, offset: {start}...")
190
+ result += fetch(start)
191
+ start += 10
192
+
193
+ logger.info("Collecting details...")
194
+
195
+ return [
196
+ Paper.from_google_scholar_result(id, i)
197
+ for id, i in tqdm(enumerate(result[:n], start=1))
198
+ ]
199
+
200
+
201
+ def get_categories_string(papers: List[Paper], n: int = 3) -> str:
202
+ categories = Counter(sum([p.categories for p in papers], []))
203
+ common = categories.most_common(n)
204
+
205
+ if not common:
206
+ return "Artifical Intelligence"
207
+
208
+ if len(common) == 1:
209
+ return common[0][0]
210
+
211
+ if len(common) == 2:
212
+ return " and ".join([c[0] for c in common])
213
+
214
+ *lst, last = common
215
+
216
+ return ", ".join([c[0] for c in lst]) + f" and {last[0]}"
217
+
218
+
219
+ def get_abstract_with_token_limit(
220
+ model: BaseLanguageModel,
221
+ papers: List[Paper],
222
+ limit: int,
223
+ separator: str = "\n",
224
+ ) -> str:
225
+ def get_summary(paper: Paper):
226
+ summary = paper.summary.replace("\n", " ")
227
+ return f"""
228
+ Title: {paper.title}
229
+ citation_id: {paper.citation_id}
230
+ Summry: {summary}
231
+ """
232
+
233
+ summaries = []
234
+ total_num_tokens = 0
235
+ idx = 0
236
+
237
+ while idx < len(papers):
238
+ summary = get_summary(papers[idx])
239
+ num_tokens = model.get_num_tokens(summary)
240
+
241
+ if total_num_tokens + num_tokens > limit:
242
+ break
243
+
244
+ summaries.append(summary)
245
+ total_num_tokens += num_tokens
246
+ idx += 1
247
+
248
+ result = separator.join(summaries).strip()
249
+
250
+ logger.info(
251
+ f'Number of papers: {len(summaries)}, '
252
+ f'number of tokens: {total_num_tokens}, text: {result[:100]}...'
253
+ )
254
+
255
+ return result
256
+
257
+
258
+ @memory.cache
259
+ def fetch_google_scholar(query: str, start: int) -> dict:
260
+ serpapi = SerpAPIWrapper(params={
261
+ "engine": "google_scholar",
262
+ "gl": "us",
263
+ "hl": "en",
264
+ "start": start,
265
+ })
266
+ return serpapi.results(query)["organic_results"]
267
+
268
+
269
+ @memory.cache
270
+ def fetch_google_scholar_cite(google_scholar_id: str) -> dict:
271
+ serpapi = SerpAPIWrapper(params={"engine": "google_scholar_cite"})
272
+ return serpapi.results(google_scholar_id)
273
+
274
+
275
+ @memory.cache
276
+ def fetch_arxiv_result(arxiv_abs_link: str) -> arxiv.Result:
277
+ m = re.match(r"https://arxiv\.org/abs/(.+)", arxiv_abs_link)
278
+ assert m is not None, f"{arxiv_abs_link} should be a arxiv link"
279
+ arxiv_id = m.group(1)
280
+ return next(arxiv.Search(id_list=[arxiv_id]).results())
281
+
282
+
283
+ @memory.cache
284
+ def get_text_from_arxiv_search_result(
285
+ arxiv_search_result: arxiv.Result
286
+ ) -> str:
287
+ with tempfile.TemporaryDirectory() as d:
288
+ file_path = arxiv_search_result.download_pdf(dirpath=d)
289
+ return extract_text(file_path)
metaanalyser/paper/vectorstore.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from langchain.embeddings import OpenAIEmbeddings
3
+ from langchain.text_splitter import SpacyTextSplitter
4
+ from langchain.vectorstores import FAISS
5
+ from tqdm.auto import tqdm
6
+ from typing import List
7
+
8
+ from .paper import Paper
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def create_papers_vectorstor(
14
+ papers: List[Paper],
15
+ tiktoken_encoder_model_name: str = "gpt-3.5-turbo",
16
+ chunk_size: int = 150,
17
+ chunk_overlap: int = 10,
18
+ ) -> FAISS:
19
+ splitter = SpacyTextSplitter.from_tiktoken_encoder(
20
+ model_name=tiktoken_encoder_model_name,
21
+ chunk_size=chunk_size,
22
+ chunk_overlap=chunk_overlap,
23
+ )
24
+
25
+ logger.info(
26
+ f"Creating vector store,"
27
+ f" {tiktoken_encoder_model_name=}"
28
+ f", {chunk_size=}, {chunk_overlap=}"
29
+ )
30
+
31
+ docs = splitter.create_documents(
32
+ [p.text.replace("\n", " ") for p in tqdm(papers)],
33
+ metadatas=[
34
+ {
35
+ 'google_scholar_result_id': p.google_scholar_result_id,
36
+ 'title': p.title,
37
+ 'link': p.link,
38
+ 'nb_cited': p.nb_cited,
39
+ 'citation_id': p.citation_id,
40
+ 'entry_id': p.entry_id,
41
+ 'published': str(p.published),
42
+ 'primary_category': p.primary_category,
43
+ 'categories': ", ".join(p.categories),
44
+ 'doi': p.doi,
45
+ 'citiation': p.mla_citiation.snippet,
46
+ } for p in papers
47
+ ]
48
+ )
49
+
50
+ embeddings = OpenAIEmbeddings()
51
+ db = FAISS.from_documents(docs, embeddings)
52
+
53
+ logger.info(
54
+ f"Vector store is created from {len(papers)} papers,"
55
+ f" document size={len(docs)}"
56
+ )
57
+
58
+ return db