p-baleine commited on
Commit
5a5f604
1 Parent(s): ec36bd2

organize code

Browse files
app.py CHANGED
@@ -1,11 +1,9 @@
1
  import logging
2
  import os
3
- # from typing import Optional, Tuple
4
  import gradio as gr
5
  from langchain.chat_models import ChatOpenAI
6
 
7
- from metaanalyser.chains import SRSectionChain, SROverviewChain, SROutlintChain
8
- from metaanalyser.paper import search_on_google_scholar, create_papers_vectorstor
9
 
10
 
11
  logger = logging.getLogger(__name__)
@@ -13,48 +11,13 @@ logging.basicConfig()
13
  logging.getLogger("metaanalyser").setLevel(level=logging.DEBUG)
14
 
15
 
16
- def run(query: str):
17
- llm = ChatOpenAI(temperature=0)
18
- papers = search_on_google_scholar(query)
19
- db = create_papers_vectorstor(papers)
20
- overview_chain = SROverviewChain(llm=llm, verbose=True)
21
- outline_chain = SROutlintChain(llm=llm, verbose=True)
22
- section_chain = SRSectionChain(
23
- llm=llm,
24
- paper_store=db,
25
- verbose=True
26
- )
27
-
28
- overview = overview_chain.run({"query": query, "papers": papers})
29
- outline = outline_chain.run({"query": query, "papers": papers, "overview": overview})
30
-
31
- sections_as_md = []
32
-
33
- for section_idx in range(len(outline.sections)):
34
- # TODO: 入れ子のセクションに対応する
35
- sections_as_md.append(section_chain.run({
36
- "section_idx": section_idx,
37
- "section_level": 2,
38
- "query": query,
39
- "papers": papers,
40
- "overview": overview,
41
- "outline": outline
42
- }))
43
-
44
- sr = f"# {overview.title}\n\n{overview.overview}\n\n## Table of contents\n\n{outline}\n\n"
45
- sr += "\n\n".join(sections_as_md)
46
- sr += "\n\n## References\n"
47
 
48
- papers_citation_id_map = {p.citation_id: p for p in papers}
49
- citations = []
50
-
51
- for citation_id in outline.citations_ids:
52
- citation = papers_citation_id_map[int(citation_id)]
53
- citations.append(f"[^{citation_id}]: [{citation.mla_citiation.snippet}]({citation.link})")
54
-
55
- sr += "\n\n".join(citations)
56
-
57
- return sr
58
 
59
 
60
  def set_openai_api_key(api_key: str):
@@ -65,7 +28,6 @@ def set_serpapi_api_key(api_key: str):
65
  os.environ["SERPAPI_API_KEY"] = api_key
66
 
67
 
68
- # block = gr.Blocks(css=".gradio-container {background-color: lightgray}")
69
  block = gr.Blocks()
70
 
71
  with block:
@@ -94,6 +56,7 @@ with block:
94
  placeholder="the query for Google Scholar",
95
  lines=1,
96
  )
 
97
  submit = gr.Button(value="Send", variant="secondary").style(full_width=False)
98
 
99
  gr.Examples(
 
1
  import logging
2
  import os
 
3
  import gradio as gr
4
  from langchain.chat_models import ChatOpenAI
5
 
6
+ from metaanalyser.chains import SRChain
 
7
 
8
 
9
  logger = logging.getLogger(__name__)
 
11
  logging.getLogger("metaanalyser").setLevel(level=logging.DEBUG)
12
 
13
 
14
+ def run(query: str, chain: SRChain):
15
+ if "OPENAI_API_KEY" in os.environ or "SERPAPI_API_KEY" not in os.environ:
16
+ raise gr.Error(f"Please paste your OpenAI (https://platform.openai.com/) key and SerpAPI (https://serpapi.com/) key to use.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ llm = ChatOpenAI(temperature=0)
19
+ chain = SRChain(llm=llm, verbose=True)
20
+ return chain.run({"query": query})
 
 
 
 
 
 
 
21
 
22
 
23
  def set_openai_api_key(api_key: str):
 
28
  os.environ["SERPAPI_API_KEY"] = api_key
29
 
30
 
 
31
  block = gr.Blocks()
32
 
33
  with block:
 
56
  placeholder="the query for Google Scholar",
57
  lines=1,
58
  )
59
+
60
  submit = gr.Button(value="Send", variant="secondary").style(full_width=False)
61
 
62
  gr.Examples(
metaanalyser/chains/__init__.py CHANGED
@@ -1,9 +1,11 @@
1
  from .overview import SROverviewChain
2
  from .outline import SROutlintChain
3
  from .section import SRSectionChain
 
4
 
5
 
6
  __all__ = [
 
7
  "SROutlintChain",
8
  "SROverviewChain",
9
  "SRSectionChain",
 
1
  from .overview import SROverviewChain
2
  from .outline import SROutlintChain
3
  from .section import SRSectionChain
4
+ from .sr import SRChain
5
 
6
 
7
  __all__ = [
8
+ "SRChain",
9
  "SROutlintChain",
10
  "SROverviewChain",
11
  "SRSectionChain",
metaanalyser/chains/outline/__init__.py CHANGED
@@ -1,8 +1,9 @@
1
  from .outline import SROutlintChain
2
- from .prompt import Outlint
3
 
4
 
5
  __all__ = [
6
  "Outlint",
 
7
  "SROutlintChain",
8
  ]
 
1
  from .outline import SROutlintChain
2
+ from .prompt import Outlint, Section
3
 
4
 
5
  __all__ = [
6
  "Outlint",
7
+ "Section",
8
  "SROutlintChain",
9
  ]
metaanalyser/chains/outline/outline.py CHANGED
@@ -20,7 +20,7 @@ class SROutlintChain(SRBaseChain):
20
 
21
  prompt: BasePromptTemplate = OUTLINE_PROMPT
22
  nb_categories: int = 3
23
- nb_token_limit: int = 2_000
24
 
25
  @property
26
  def input_keys(self) -> List[str]:
 
20
 
21
  prompt: BasePromptTemplate = OUTLINE_PROMPT
22
  nb_categories: int = 3
23
+ nb_token_limit: int = 1_500
24
 
25
  @property
26
  def input_keys(self) -> List[str]:
metaanalyser/chains/outline/prompt.py CHANGED
@@ -23,20 +23,24 @@ class Outlint(BaseModel):
23
  citations_ids: List[int] = Field(description="citation ids to all paper abstracts cited in this paper")
24
 
25
  def __str__(self):
26
- def section_string(idx: int, section: Section):
27
  result = [f"{idx}. {section.title}: {section.description}"]
28
 
29
  if not section.children:
30
  return result[0]
31
 
32
  result += [
33
- section_string(f"{idx}.{child_idx}", child)
 
 
 
 
34
  for child_idx, child in enumerate(section.children, start=1)
35
  ]
36
  return "\n".join(result)
37
 
38
  return "\n".join([
39
- section_string(idx, s)
40
  for idx, s in enumerate(self.sections, start=1)
41
  ])
42
 
@@ -60,6 +64,7 @@ The following is an overview of this systematic review. Build the outline of the
60
 
61
  Device each section of this outline by citing abstracts from the papers.
62
  The beginning of element of the sections should by titled "Introduction" and last element of the sections should be titled "Conclusion".
 
63
 
64
  {format_instructions}"""
65
  human_prompt = HumanMessagePromptTemplate(
 
23
  citations_ids: List[int] = Field(description="citation ids to all paper abstracts cited in this paper")
24
 
25
  def __str__(self):
26
+ def section_string(idx: int, section: Section, indent_level: int):
27
  result = [f"{idx}. {section.title}: {section.description}"]
28
 
29
  if not section.children:
30
  return result[0]
31
 
32
  result += [
33
+ section_string(
34
+ (" " * (indent_level + 1)) + f"{child_idx}",
35
+ child,
36
+ indent_level + 1
37
+ )
38
  for child_idx, child in enumerate(section.children, start=1)
39
  ]
40
  return "\n".join(result)
41
 
42
  return "\n".join([
43
+ section_string(idx, s, 0)
44
  for idx, s in enumerate(self.sections, start=1)
45
  ])
46
 
 
64
 
65
  Device each section of this outline by citing abstracts from the papers.
66
  The beginning of element of the sections should by titled "Introduction" and last element of the sections should be titled "Conclusion".
67
+ It is preferred that sections be divided into more child sections. Each section can have up to two child sections.
68
 
69
  {format_instructions}"""
70
  human_prompt = HumanMessagePromptTemplate(
metaanalyser/chains/overview/overview.py CHANGED
@@ -19,7 +19,7 @@ class SROverviewChain(SRBaseChain):
19
 
20
  prompt: BasePromptTemplate = OVERVIEW_PROMPT
21
  nb_categories: int = 3
22
- nb_token_limit: int = 2_000
23
  nb_max_retry: int = 3
24
 
25
  @property
 
19
 
20
  prompt: BasePromptTemplate = OVERVIEW_PROMPT
21
  nb_categories: int = 3
22
+ nb_token_limit: int = 1_500
23
  nb_max_retry: int = 3
24
 
25
  @property
metaanalyser/chains/section/prompt.py CHANGED
@@ -24,7 +24,7 @@ This systematic review should adhere to the following overview:
24
 
25
  {overview}
26
 
27
- Write the "{section_title}" section with respect to this overview. Write the text in markdown format. The title of this section should bu suffixed with {section_level} level markdown title (`{md_title_suffix}`). The text of the section should be based on a snippet or abstact and should be clearly cited. The citation should be written at the end of the sentence in the form `[^cite_id]`."""
28
  human_prompt = HumanMessagePromptTemplate.from_template(human_template)
29
 
30
  SECTION_PROMPT = ChatPromptTemplate.from_messages([
 
24
 
25
  {overview}
26
 
27
+ Write the "{section_title}: {section_description}" section with respect to this overview. Write the text in markdown format. The title of this section should bu suffixed with {section_level} level markdown title (`{md_title_suffix}`). The text of the section should be based on a snippet or abstact and should be clearly cited. The citation should be written at the end of the sentence in the form `[^<ID>]` where `ID` refers to the citation_id."""
28
  human_prompt = HumanMessagePromptTemplate.from_template(human_template)
29
 
30
  SECTION_PROMPT = ChatPromptTemplate.from_messages([
metaanalyser/chains/section/section.py CHANGED
@@ -1,7 +1,9 @@
1
  from langchain.base_language import BaseLanguageModel
 
2
  from langchain.callbacks.manager import CallbackManagerForChainRun
3
  from langchain.prompts.base import BasePromptTemplate
4
  from langchain.vectorstores.base import VectorStore
 
5
  from typing import Any, Dict, List, Optional
6
 
7
  from ...paper import (
@@ -23,7 +25,7 @@ class SRSectionChain(SRBaseChain):
23
  paper_store: VectorStore
24
  prompt: BasePromptTemplate = SECTION_PROMPT
25
  nb_categories: int = 3
26
- nb_token_limit: int = 2_000
27
  nb_max_retry: int = 3
28
 
29
  @property
@@ -31,11 +33,11 @@ class SRSectionChain(SRBaseChain):
31
  # TODO: 入れ子に対応する
32
  return [
33
  "section_idx",
34
- "section_level",
35
  "query",
36
  "papers",
37
  "overview",
38
- "outline"
 
39
  ]
40
 
41
  def _call(
@@ -47,11 +49,11 @@ class SRSectionChain(SRBaseChain):
47
  self.llm,
48
  self.paper_store,
49
  inputs["section_idx"],
50
- inputs["section_level"],
51
  inputs["query"],
52
  inputs["papers"],
53
  inputs["overview"],
54
  inputs["outline"],
 
55
  self.nb_categories,
56
  self.nb_token_limit,
57
  )
@@ -66,69 +68,90 @@ class SRSectionChain(SRBaseChain):
66
  self.llm,
67
  self.paper_store,
68
  inputs["section_idx"],
69
- inputs["section_level"],
70
  inputs["query"],
71
  inputs["papers"],
72
  inputs["overview"],
73
  inputs["outline"],
 
74
  self.nb_categories,
75
  self.nb_token_limit,
76
  )
77
  return super()._acall(input_list, run_manager=run_manager)
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  def get_input_list(
81
  llm: BaseLanguageModel,
82
  paper_store: VectorStore,
83
  section_idx: int,
84
- section_level: int,
85
  query: str,
86
  papers: List[Paper],
87
  overview: Overview,
88
  outline: Outlint,
 
89
  nb_categories: int,
90
  nb_token_limit: int,
91
  max_paper_store_search_size: int = 100,
92
  ):
93
- section = outline.sections[section_idx]
94
  papers_citation_id_map = {p.citation_id: p for p in papers}
95
- related_papers = [
96
- papers_citation_id_map[int(citation_id)]
97
- for citation_id in section.citation_ids
98
- ]
99
 
100
- if not related_papers:
 
 
 
 
 
101
  # citation_ids が空なら全部を対象とする
102
- # FIXME: 全部にしちゃうと溢れちゃうのでは??
103
- related_papers = papers
104
-
105
- related_snippets = paper_store.similarity_search(
106
- f"{section.title} {section.description}",
107
- k=max_paper_store_search_size,
108
- )
109
 
110
- # overview が引用している論文の abst は全部 snippet に含める
111
- # 加えて nb_token_limit に到達するまで vectorstore から関連文章を集める
 
 
 
 
 
112
 
113
- def get_snippet(title, citation_id, text):
114
- text = text.replace("\n", " ")
115
  return f"""
116
- Title: {title}
117
- citation_id: {citation_id}
118
  Text: {text}
119
  """
120
 
121
- snippets = [get_snippet(p.title, p.citation_id, p.summary) for p in related_papers]
122
- total_num_tokens = llm.get_num_tokens("\n".join(snippets).strip())
123
  idx = 0
124
 
125
- while idx < len(related_snippets):
126
- snippet = related_snippets[idx]
127
- snippet_text = get_snippet(
128
- snippet.metadata["title"],
129
- snippet.metadata["citation_id"],
130
- snippet.page_content,
131
- )
132
  num_tokens = llm.get_num_tokens(snippet_text)
133
 
134
  if total_num_tokens + num_tokens > nb_token_limit:
@@ -142,9 +165,10 @@ Text: {text}
142
  "query": query,
143
  "title": overview.title,
144
  "overview": overview,
145
- "section_title": section.title,
146
- "section_level": section_level,
147
- "md_title_suffix": "#" * section_level,
 
148
  "outline": outline,
149
  "categories": get_categories_string(papers, nb_categories),
150
  "snippets": "\n".join(snippets).strip(),
 
1
  from langchain.base_language import BaseLanguageModel
2
+ from langchain.docstore.document import Document
3
  from langchain.callbacks.manager import CallbackManagerForChainRun
4
  from langchain.prompts.base import BasePromptTemplate
5
  from langchain.vectorstores.base import VectorStore
6
+ from pydantic import BaseModel
7
  from typing import Any, Dict, List, Optional
8
 
9
  from ...paper import (
 
25
  paper_store: VectorStore
26
  prompt: BasePromptTemplate = SECTION_PROMPT
27
  nb_categories: int = 3
28
+ nb_token_limit: int = 1_500
29
  nb_max_retry: int = 3
30
 
31
  @property
 
33
  # TODO: 入れ子に対応する
34
  return [
35
  "section_idx",
 
36
  "query",
37
  "papers",
38
  "overview",
39
+ "outline",
40
+ "flatten_sections",
41
  ]
42
 
43
  def _call(
 
49
  self.llm,
50
  self.paper_store,
51
  inputs["section_idx"],
 
52
  inputs["query"],
53
  inputs["papers"],
54
  inputs["overview"],
55
  inputs["outline"],
56
+ inputs["flatten_sections"],
57
  self.nb_categories,
58
  self.nb_token_limit,
59
  )
 
68
  self.llm,
69
  self.paper_store,
70
  inputs["section_idx"],
 
71
  inputs["query"],
72
  inputs["papers"],
73
  inputs["overview"],
74
  inputs["outline"],
75
+ inputs["flatten_sections"],
76
  self.nb_categories,
77
  self.nb_token_limit,
78
  )
79
  return super()._acall(input_list, run_manager=run_manager)
80
 
81
 
82
+ class TextSplit(BaseModel):
83
+ """get_input_list 向けのヘルパークラス
84
+ """
85
+
86
+ title: str
87
+ citation_id: int
88
+ text: str
89
+
90
+ @classmethod
91
+ def from_paper(cls, paper: Paper) -> "TextSplit":
92
+ return cls(
93
+ title=paper.title,
94
+ citation_id=paper.citation_id,
95
+ text=paper.summary,
96
+ )
97
+
98
+ @classmethod
99
+ def from_snippet(cls, snippet: Document) -> "TextSplit":
100
+ return cls(
101
+ title=snippet.metadata["title"],
102
+ citation_id=snippet.metadata["citation_id"],
103
+ text=snippet.page_content,
104
+ )
105
+
106
+
107
  def get_input_list(
108
  llm: BaseLanguageModel,
109
  paper_store: VectorStore,
110
  section_idx: int,
 
111
  query: str,
112
  papers: List[Paper],
113
  overview: Overview,
114
  outline: Outlint,
115
+ flatten_sections,
116
  nb_categories: int,
117
  nb_token_limit: int,
118
  max_paper_store_search_size: int = 100,
119
  ):
120
+ section = flatten_sections[section_idx]
121
  papers_citation_id_map = {p.citation_id: p for p in papers}
 
 
 
 
122
 
123
+ if section.section.citation_ids:
124
+ related_splits = [
125
+ TextSplit.from_paper(papers_citation_id_map[int(citation_id)])
126
+ for citation_id in section.section.citation_ids
127
+ ]
128
+ else:
129
  # citation_ids が空なら全部を対象とする
130
+ related_splits = [TextSplit.from_paper(p) for p in papers]
 
 
 
 
 
 
131
 
132
+ related_splits += [
133
+ TextSplit.from_snippet(snippet) for snippet in
134
+ paper_store.similarity_search(
135
+ f"{section.section.title} {section.section.description}",
136
+ k=max_paper_store_search_size,
137
+ )
138
+ ]
139
 
140
+ def get_snippet(split: TextSplit):
141
+ text = split.text.replace("\n", " ")
142
  return f"""
143
+ Title: {split.title}
144
+ citation_id: {split.citation_id}
145
  Text: {text}
146
  """
147
 
148
+ snippets = []
149
+ total_num_tokens = 0
150
  idx = 0
151
 
152
+ while idx < len(related_splits):
153
+ split = related_splits[idx]
154
+ snippet_text = get_snippet(split)
 
 
 
 
155
  num_tokens = llm.get_num_tokens(snippet_text)
156
 
157
  if total_num_tokens + num_tokens > nb_token_limit:
 
165
  "query": query,
166
  "title": overview.title,
167
  "overview": overview,
168
+ "section_title": section.section.title,
169
+ "section_description": section.section.description,
170
+ "section_level": section.level,
171
+ "md_title_suffix": "#" * section.level,
172
  "outline": outline,
173
  "categories": get_categories_string(papers, nb_categories),
174
  "snippets": "\n".join(snippets).strip(),
metaanalyser/chains/sr.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from langchain.base_language import BaseLanguageModel
3
+ from langchain.chains.base import Chain
4
+ from langchain.callbacks.manager import CallbackManagerForChainRun
5
+ from pydantic import BaseModel
6
+ from typing import Any, Dict, List, Optional
7
+
8
+ from ..paper import Paper, search_on_google_scholar, create_papers_vectorstor
9
+ from .outline import SROutlintChain, Outlint, Section
10
+ from .overview import SROverviewChain, Overview
11
+ from .section import SRSectionChain
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class SRChain(Chain):
17
+
18
+ llm: BaseLanguageModel
19
+ output_key: str = "text"
20
+
21
+ @property
22
+ def input_keys(self) -> List[str]:
23
+ return ["query"]
24
+
25
+ @property
26
+ def output_keys(self) -> List[str]:
27
+ return [self.output_key]
28
+
29
+ def _call(
30
+ self,
31
+ inputs: Dict[str, Any],
32
+ run_manager: Optional[CallbackManagerForChainRun] = None,
33
+ ) -> Dict[str, str]:
34
+ query = inputs["query"]
35
+ logger.info(f"Searching `{query}` on Google Scholar.")
36
+ papers = search_on_google_scholar(query)
37
+
38
+ logger.info(f"Writing an overview of the paper.")
39
+ overview_chain = SROverviewChain(llm=self.llm, verbose=self.verbose)
40
+ overview: Overview = overview_chain.run({"query": query, "papers": papers})
41
+
42
+ logger.info(f"Building the outline of the paper.")
43
+ outline_chain = SROutlintChain(llm=self.llm, verbose=self.verbose)
44
+ outline: Outlint = outline_chain.run({
45
+ "query": query,
46
+ "papers": papers,
47
+ "overview": overview
48
+ })
49
+
50
+ logger.info(f"Creating vector store.")
51
+ db = create_papers_vectorstor(papers)
52
+
53
+ section_chain = SRSectionChain(llm=self.llm, paper_store=db, verbose=self.verbose)
54
+ flatten_sections = get_flatten_sections(outline)
55
+ sections_as_md = []
56
+
57
+ for section_idx in range(len(flatten_sections)):
58
+ logger.info(f"Writing sections: [{section_idx + 1} / {len(flatten_sections)}]")
59
+
60
+ sections_as_md.append(
61
+ section_chain.run({
62
+ "section_idx": section_idx,
63
+ "query": query,
64
+ "papers": papers,
65
+ "overview": overview,
66
+ "outline": outline,
67
+ "flatten_sections": flatten_sections,
68
+ })
69
+ )
70
+
71
+ return {
72
+ self.output_key: create_output(outline, overview, papers, flatten_sections, sections_as_md)
73
+ }
74
+
75
+
76
+ class FlattenSection(BaseModel):
77
+
78
+ """SRChain 向けのセクションを表すヘルパークラス
79
+ """
80
+
81
+ level: int
82
+ section: Section
83
+
84
+
85
+ def get_flatten_sections(
86
+ outline: Outlint,
87
+ start_section_level: int = 2,
88
+ ) -> List[FlattenSection]:
89
+ def inner(section_level, section: Section) -> List[FlattenSection]:
90
+ result = FlattenSection(level=section_level, section=section)
91
+
92
+ if not section.children:
93
+ return [result]
94
+
95
+ return (
96
+ [result] + sum([
97
+ inner(section_level + 1, child)
98
+ for child in section.children
99
+ ], [])
100
+ )
101
+
102
+ return sum([
103
+ inner(start_section_level, section)
104
+ for section in outline.sections
105
+ ], [])
106
+
107
+
108
+ def create_output(
109
+ outline: Outlint,
110
+ overview: Overview,
111
+ papers: List[Paper],
112
+ flatten_sections: List[FlattenSection],
113
+ sections_as_md: List[str],
114
+ ) -> str:
115
+ papers_citation_id_map = {p.citation_id: p for p in papers}
116
+ all_citation_ids = list(set(
117
+ outline.citations_ids + sum([
118
+ s.section.citation_ids for s in flatten_sections
119
+ ], [])
120
+ ))
121
+
122
+ citations = []
123
+
124
+ for citation_id in all_citation_ids:
125
+ citation = papers_citation_id_map[int(citation_id)]
126
+ citations.append(
127
+ f"[^{citation_id}]: "
128
+ f"[{citation.mla_citiation.snippet}]({citation.link})"
129
+ )
130
+
131
+ return (
132
+ f"# {overview.title}\n\n{overview.overview}\n\n"
133
+ + f"## Table of contents\n\n{outline}\n\n"
134
+ + "\n\n".join(sections_as_md)
135
+ + "\n\n## References\n"
136
+ + "\n\n".join(citations)
137
+ )