Spaces:
Runtime error
Runtime error
organize code
Browse files- app.py +8 -45
- metaanalyser/chains/__init__.py +2 -0
- metaanalyser/chains/outline/__init__.py +2 -1
- metaanalyser/chains/outline/outline.py +1 -1
- metaanalyser/chains/outline/prompt.py +8 -3
- metaanalyser/chains/overview/overview.py +1 -1
- metaanalyser/chains/section/prompt.py +1 -1
- metaanalyser/chains/section/section.py +61 -37
- metaanalyser/chains/sr.py +137 -0
app.py
CHANGED
@@ -1,11 +1,9 @@
|
|
1 |
import logging
|
2 |
import os
|
3 |
-
# from typing import Optional, Tuple
|
4 |
import gradio as gr
|
5 |
from langchain.chat_models import ChatOpenAI
|
6 |
|
7 |
-
from metaanalyser.chains import
|
8 |
-
from metaanalyser.paper import search_on_google_scholar, create_papers_vectorstor
|
9 |
|
10 |
|
11 |
logger = logging.getLogger(__name__)
|
@@ -13,48 +11,13 @@ logging.basicConfig()
|
|
13 |
logging.getLogger("metaanalyser").setLevel(level=logging.DEBUG)
|
14 |
|
15 |
|
16 |
-
def run(query: str):
|
17 |
-
|
18 |
-
|
19 |
-
db = create_papers_vectorstor(papers)
|
20 |
-
overview_chain = SROverviewChain(llm=llm, verbose=True)
|
21 |
-
outline_chain = SROutlintChain(llm=llm, verbose=True)
|
22 |
-
section_chain = SRSectionChain(
|
23 |
-
llm=llm,
|
24 |
-
paper_store=db,
|
25 |
-
verbose=True
|
26 |
-
)
|
27 |
-
|
28 |
-
overview = overview_chain.run({"query": query, "papers": papers})
|
29 |
-
outline = outline_chain.run({"query": query, "papers": papers, "overview": overview})
|
30 |
-
|
31 |
-
sections_as_md = []
|
32 |
-
|
33 |
-
for section_idx in range(len(outline.sections)):
|
34 |
-
# TODO: 入れ子のセクションに対応する
|
35 |
-
sections_as_md.append(section_chain.run({
|
36 |
-
"section_idx": section_idx,
|
37 |
-
"section_level": 2,
|
38 |
-
"query": query,
|
39 |
-
"papers": papers,
|
40 |
-
"overview": overview,
|
41 |
-
"outline": outline
|
42 |
-
}))
|
43 |
-
|
44 |
-
sr = f"# {overview.title}\n\n{overview.overview}\n\n## Table of contents\n\n{outline}\n\n"
|
45 |
-
sr += "\n\n".join(sections_as_md)
|
46 |
-
sr += "\n\n## References\n"
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
for citation_id in outline.citations_ids:
|
52 |
-
citation = papers_citation_id_map[int(citation_id)]
|
53 |
-
citations.append(f"[^{citation_id}]: [{citation.mla_citiation.snippet}]({citation.link})")
|
54 |
-
|
55 |
-
sr += "\n\n".join(citations)
|
56 |
-
|
57 |
-
return sr
|
58 |
|
59 |
|
60 |
def set_openai_api_key(api_key: str):
|
@@ -65,7 +28,6 @@ def set_serpapi_api_key(api_key: str):
|
|
65 |
os.environ["SERPAPI_API_KEY"] = api_key
|
66 |
|
67 |
|
68 |
-
# block = gr.Blocks(css=".gradio-container {background-color: lightgray}")
|
69 |
block = gr.Blocks()
|
70 |
|
71 |
with block:
|
@@ -94,6 +56,7 @@ with block:
|
|
94 |
placeholder="the query for Google Scholar",
|
95 |
lines=1,
|
96 |
)
|
|
|
97 |
submit = gr.Button(value="Send", variant="secondary").style(full_width=False)
|
98 |
|
99 |
gr.Examples(
|
|
|
1 |
import logging
|
2 |
import os
|
|
|
3 |
import gradio as gr
|
4 |
from langchain.chat_models import ChatOpenAI
|
5 |
|
6 |
+
from metaanalyser.chains import SRChain
|
|
|
7 |
|
8 |
|
9 |
logger = logging.getLogger(__name__)
|
|
|
11 |
logging.getLogger("metaanalyser").setLevel(level=logging.DEBUG)
|
12 |
|
13 |
|
14 |
+
def run(query: str, chain: SRChain):
|
15 |
+
if "OPENAI_API_KEY" in os.environ or "SERPAPI_API_KEY" not in os.environ:
|
16 |
+
raise gr.Error(f"Please paste your OpenAI (https://platform.openai.com/) key and SerpAPI (https://serpapi.com/) key to use.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
llm = ChatOpenAI(temperature=0)
|
19 |
+
chain = SRChain(llm=llm, verbose=True)
|
20 |
+
return chain.run({"query": query})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
|
23 |
def set_openai_api_key(api_key: str):
|
|
|
28 |
os.environ["SERPAPI_API_KEY"] = api_key
|
29 |
|
30 |
|
|
|
31 |
block = gr.Blocks()
|
32 |
|
33 |
with block:
|
|
|
56 |
placeholder="the query for Google Scholar",
|
57 |
lines=1,
|
58 |
)
|
59 |
+
|
60 |
submit = gr.Button(value="Send", variant="secondary").style(full_width=False)
|
61 |
|
62 |
gr.Examples(
|
metaanalyser/chains/__init__.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
from .overview import SROverviewChain
|
2 |
from .outline import SROutlintChain
|
3 |
from .section import SRSectionChain
|
|
|
4 |
|
5 |
|
6 |
__all__ = [
|
|
|
7 |
"SROutlintChain",
|
8 |
"SROverviewChain",
|
9 |
"SRSectionChain",
|
|
|
1 |
from .overview import SROverviewChain
|
2 |
from .outline import SROutlintChain
|
3 |
from .section import SRSectionChain
|
4 |
+
from .sr import SRChain
|
5 |
|
6 |
|
7 |
__all__ = [
|
8 |
+
"SRChain",
|
9 |
"SROutlintChain",
|
10 |
"SROverviewChain",
|
11 |
"SRSectionChain",
|
metaanalyser/chains/outline/__init__.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
from .outline import SROutlintChain
|
2 |
-
from .prompt import Outlint
|
3 |
|
4 |
|
5 |
__all__ = [
|
6 |
"Outlint",
|
|
|
7 |
"SROutlintChain",
|
8 |
]
|
|
|
1 |
from .outline import SROutlintChain
|
2 |
+
from .prompt import Outlint, Section
|
3 |
|
4 |
|
5 |
__all__ = [
|
6 |
"Outlint",
|
7 |
+
"Section",
|
8 |
"SROutlintChain",
|
9 |
]
|
metaanalyser/chains/outline/outline.py
CHANGED
@@ -20,7 +20,7 @@ class SROutlintChain(SRBaseChain):
|
|
20 |
|
21 |
prompt: BasePromptTemplate = OUTLINE_PROMPT
|
22 |
nb_categories: int = 3
|
23 |
-
nb_token_limit: int =
|
24 |
|
25 |
@property
|
26 |
def input_keys(self) -> List[str]:
|
|
|
20 |
|
21 |
prompt: BasePromptTemplate = OUTLINE_PROMPT
|
22 |
nb_categories: int = 3
|
23 |
+
nb_token_limit: int = 1_500
|
24 |
|
25 |
@property
|
26 |
def input_keys(self) -> List[str]:
|
metaanalyser/chains/outline/prompt.py
CHANGED
@@ -23,20 +23,24 @@ class Outlint(BaseModel):
|
|
23 |
citations_ids: List[int] = Field(description="citation ids to all paper abstracts cited in this paper")
|
24 |
|
25 |
def __str__(self):
|
26 |
-
def section_string(idx: int, section: Section):
|
27 |
result = [f"{idx}. {section.title}: {section.description}"]
|
28 |
|
29 |
if not section.children:
|
30 |
return result[0]
|
31 |
|
32 |
result += [
|
33 |
-
section_string(
|
|
|
|
|
|
|
|
|
34 |
for child_idx, child in enumerate(section.children, start=1)
|
35 |
]
|
36 |
return "\n".join(result)
|
37 |
|
38 |
return "\n".join([
|
39 |
-
section_string(idx, s)
|
40 |
for idx, s in enumerate(self.sections, start=1)
|
41 |
])
|
42 |
|
@@ -60,6 +64,7 @@ The following is an overview of this systematic review. Build the outline of the
|
|
60 |
|
61 |
Device each section of this outline by citing abstracts from the papers.
|
62 |
The beginning of element of the sections should by titled "Introduction" and last element of the sections should be titled "Conclusion".
|
|
|
63 |
|
64 |
{format_instructions}"""
|
65 |
human_prompt = HumanMessagePromptTemplate(
|
|
|
23 |
citations_ids: List[int] = Field(description="citation ids to all paper abstracts cited in this paper")
|
24 |
|
25 |
def __str__(self):
|
26 |
+
def section_string(idx: int, section: Section, indent_level: int):
|
27 |
result = [f"{idx}. {section.title}: {section.description}"]
|
28 |
|
29 |
if not section.children:
|
30 |
return result[0]
|
31 |
|
32 |
result += [
|
33 |
+
section_string(
|
34 |
+
(" " * (indent_level + 1)) + f"{child_idx}",
|
35 |
+
child,
|
36 |
+
indent_level + 1
|
37 |
+
)
|
38 |
for child_idx, child in enumerate(section.children, start=1)
|
39 |
]
|
40 |
return "\n".join(result)
|
41 |
|
42 |
return "\n".join([
|
43 |
+
section_string(idx, s, 0)
|
44 |
for idx, s in enumerate(self.sections, start=1)
|
45 |
])
|
46 |
|
|
|
64 |
|
65 |
Device each section of this outline by citing abstracts from the papers.
|
66 |
The beginning of element of the sections should by titled "Introduction" and last element of the sections should be titled "Conclusion".
|
67 |
+
It is preferred that sections be divided into more child sections. Each section can have up to two child sections.
|
68 |
|
69 |
{format_instructions}"""
|
70 |
human_prompt = HumanMessagePromptTemplate(
|
metaanalyser/chains/overview/overview.py
CHANGED
@@ -19,7 +19,7 @@ class SROverviewChain(SRBaseChain):
|
|
19 |
|
20 |
prompt: BasePromptTemplate = OVERVIEW_PROMPT
|
21 |
nb_categories: int = 3
|
22 |
-
nb_token_limit: int =
|
23 |
nb_max_retry: int = 3
|
24 |
|
25 |
@property
|
|
|
19 |
|
20 |
prompt: BasePromptTemplate = OVERVIEW_PROMPT
|
21 |
nb_categories: int = 3
|
22 |
+
nb_token_limit: int = 1_500
|
23 |
nb_max_retry: int = 3
|
24 |
|
25 |
@property
|
metaanalyser/chains/section/prompt.py
CHANGED
@@ -24,7 +24,7 @@ This systematic review should adhere to the following overview:
|
|
24 |
|
25 |
{overview}
|
26 |
|
27 |
-
Write the "{section_title}" section with respect to this overview. Write the text in markdown format. The title of this section should bu suffixed with {section_level} level markdown title (`{md_title_suffix}`). The text of the section should be based on a snippet or abstact and should be clearly cited. The citation should be written at the end of the sentence in the form `[
|
28 |
human_prompt = HumanMessagePromptTemplate.from_template(human_template)
|
29 |
|
30 |
SECTION_PROMPT = ChatPromptTemplate.from_messages([
|
|
|
24 |
|
25 |
{overview}
|
26 |
|
27 |
+
Write the "{section_title}: {section_description}" section with respect to this overview. Write the text in markdown format. The title of this section should bu suffixed with {section_level} level markdown title (`{md_title_suffix}`). The text of the section should be based on a snippet or abstact and should be clearly cited. The citation should be written at the end of the sentence in the form `[^<ID>]` where `ID` refers to the citation_id."""
|
28 |
human_prompt = HumanMessagePromptTemplate.from_template(human_template)
|
29 |
|
30 |
SECTION_PROMPT = ChatPromptTemplate.from_messages([
|
metaanalyser/chains/section/section.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
from langchain.base_language import BaseLanguageModel
|
|
|
2 |
from langchain.callbacks.manager import CallbackManagerForChainRun
|
3 |
from langchain.prompts.base import BasePromptTemplate
|
4 |
from langchain.vectorstores.base import VectorStore
|
|
|
5 |
from typing import Any, Dict, List, Optional
|
6 |
|
7 |
from ...paper import (
|
@@ -23,7 +25,7 @@ class SRSectionChain(SRBaseChain):
|
|
23 |
paper_store: VectorStore
|
24 |
prompt: BasePromptTemplate = SECTION_PROMPT
|
25 |
nb_categories: int = 3
|
26 |
-
nb_token_limit: int =
|
27 |
nb_max_retry: int = 3
|
28 |
|
29 |
@property
|
@@ -31,11 +33,11 @@ class SRSectionChain(SRBaseChain):
|
|
31 |
# TODO: 入れ子に対応する
|
32 |
return [
|
33 |
"section_idx",
|
34 |
-
"section_level",
|
35 |
"query",
|
36 |
"papers",
|
37 |
"overview",
|
38 |
-
"outline"
|
|
|
39 |
]
|
40 |
|
41 |
def _call(
|
@@ -47,11 +49,11 @@ class SRSectionChain(SRBaseChain):
|
|
47 |
self.llm,
|
48 |
self.paper_store,
|
49 |
inputs["section_idx"],
|
50 |
-
inputs["section_level"],
|
51 |
inputs["query"],
|
52 |
inputs["papers"],
|
53 |
inputs["overview"],
|
54 |
inputs["outline"],
|
|
|
55 |
self.nb_categories,
|
56 |
self.nb_token_limit,
|
57 |
)
|
@@ -66,69 +68,90 @@ class SRSectionChain(SRBaseChain):
|
|
66 |
self.llm,
|
67 |
self.paper_store,
|
68 |
inputs["section_idx"],
|
69 |
-
inputs["section_level"],
|
70 |
inputs["query"],
|
71 |
inputs["papers"],
|
72 |
inputs["overview"],
|
73 |
inputs["outline"],
|
|
|
74 |
self.nb_categories,
|
75 |
self.nb_token_limit,
|
76 |
)
|
77 |
return super()._acall(input_list, run_manager=run_manager)
|
78 |
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
def get_input_list(
|
81 |
llm: BaseLanguageModel,
|
82 |
paper_store: VectorStore,
|
83 |
section_idx: int,
|
84 |
-
section_level: int,
|
85 |
query: str,
|
86 |
papers: List[Paper],
|
87 |
overview: Overview,
|
88 |
outline: Outlint,
|
|
|
89 |
nb_categories: int,
|
90 |
nb_token_limit: int,
|
91 |
max_paper_store_search_size: int = 100,
|
92 |
):
|
93 |
-
section =
|
94 |
papers_citation_id_map = {p.citation_id: p for p in papers}
|
95 |
-
related_papers = [
|
96 |
-
papers_citation_id_map[int(citation_id)]
|
97 |
-
for citation_id in section.citation_ids
|
98 |
-
]
|
99 |
|
100 |
-
if
|
|
|
|
|
|
|
|
|
|
|
101 |
# citation_ids が空なら全部を対象とする
|
102 |
-
|
103 |
-
related_papers = papers
|
104 |
-
|
105 |
-
related_snippets = paper_store.similarity_search(
|
106 |
-
f"{section.title} {section.description}",
|
107 |
-
k=max_paper_store_search_size,
|
108 |
-
)
|
109 |
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
-
def get_snippet(
|
114 |
-
text = text.replace("\n", " ")
|
115 |
return f"""
|
116 |
-
Title: {title}
|
117 |
-
citation_id: {citation_id}
|
118 |
Text: {text}
|
119 |
"""
|
120 |
|
121 |
-
snippets = [
|
122 |
-
total_num_tokens =
|
123 |
idx = 0
|
124 |
|
125 |
-
while idx < len(
|
126 |
-
|
127 |
-
snippet_text = get_snippet(
|
128 |
-
snippet.metadata["title"],
|
129 |
-
snippet.metadata["citation_id"],
|
130 |
-
snippet.page_content,
|
131 |
-
)
|
132 |
num_tokens = llm.get_num_tokens(snippet_text)
|
133 |
|
134 |
if total_num_tokens + num_tokens > nb_token_limit:
|
@@ -142,9 +165,10 @@ Text: {text}
|
|
142 |
"query": query,
|
143 |
"title": overview.title,
|
144 |
"overview": overview,
|
145 |
-
"section_title": section.title,
|
146 |
-
"
|
147 |
-
"
|
|
|
148 |
"outline": outline,
|
149 |
"categories": get_categories_string(papers, nb_categories),
|
150 |
"snippets": "\n".join(snippets).strip(),
|
|
|
1 |
from langchain.base_language import BaseLanguageModel
|
2 |
+
from langchain.docstore.document import Document
|
3 |
from langchain.callbacks.manager import CallbackManagerForChainRun
|
4 |
from langchain.prompts.base import BasePromptTemplate
|
5 |
from langchain.vectorstores.base import VectorStore
|
6 |
+
from pydantic import BaseModel
|
7 |
from typing import Any, Dict, List, Optional
|
8 |
|
9 |
from ...paper import (
|
|
|
25 |
paper_store: VectorStore
|
26 |
prompt: BasePromptTemplate = SECTION_PROMPT
|
27 |
nb_categories: int = 3
|
28 |
+
nb_token_limit: int = 1_500
|
29 |
nb_max_retry: int = 3
|
30 |
|
31 |
@property
|
|
|
33 |
# TODO: 入れ子に対応する
|
34 |
return [
|
35 |
"section_idx",
|
|
|
36 |
"query",
|
37 |
"papers",
|
38 |
"overview",
|
39 |
+
"outline",
|
40 |
+
"flatten_sections",
|
41 |
]
|
42 |
|
43 |
def _call(
|
|
|
49 |
self.llm,
|
50 |
self.paper_store,
|
51 |
inputs["section_idx"],
|
|
|
52 |
inputs["query"],
|
53 |
inputs["papers"],
|
54 |
inputs["overview"],
|
55 |
inputs["outline"],
|
56 |
+
inputs["flatten_sections"],
|
57 |
self.nb_categories,
|
58 |
self.nb_token_limit,
|
59 |
)
|
|
|
68 |
self.llm,
|
69 |
self.paper_store,
|
70 |
inputs["section_idx"],
|
|
|
71 |
inputs["query"],
|
72 |
inputs["papers"],
|
73 |
inputs["overview"],
|
74 |
inputs["outline"],
|
75 |
+
inputs["flatten_sections"],
|
76 |
self.nb_categories,
|
77 |
self.nb_token_limit,
|
78 |
)
|
79 |
return super()._acall(input_list, run_manager=run_manager)
|
80 |
|
81 |
|
82 |
+
class TextSplit(BaseModel):
|
83 |
+
"""get_input_list 向けのヘルパークラス
|
84 |
+
"""
|
85 |
+
|
86 |
+
title: str
|
87 |
+
citation_id: int
|
88 |
+
text: str
|
89 |
+
|
90 |
+
@classmethod
|
91 |
+
def from_paper(cls, paper: Paper) -> "TextSplit":
|
92 |
+
return cls(
|
93 |
+
title=paper.title,
|
94 |
+
citation_id=paper.citation_id,
|
95 |
+
text=paper.summary,
|
96 |
+
)
|
97 |
+
|
98 |
+
@classmethod
|
99 |
+
def from_snippet(cls, snippet: Document) -> "TextSplit":
|
100 |
+
return cls(
|
101 |
+
title=snippet.metadata["title"],
|
102 |
+
citation_id=snippet.metadata["citation_id"],
|
103 |
+
text=snippet.page_content,
|
104 |
+
)
|
105 |
+
|
106 |
+
|
107 |
def get_input_list(
|
108 |
llm: BaseLanguageModel,
|
109 |
paper_store: VectorStore,
|
110 |
section_idx: int,
|
|
|
111 |
query: str,
|
112 |
papers: List[Paper],
|
113 |
overview: Overview,
|
114 |
outline: Outlint,
|
115 |
+
flatten_sections,
|
116 |
nb_categories: int,
|
117 |
nb_token_limit: int,
|
118 |
max_paper_store_search_size: int = 100,
|
119 |
):
|
120 |
+
section = flatten_sections[section_idx]
|
121 |
papers_citation_id_map = {p.citation_id: p for p in papers}
|
|
|
|
|
|
|
|
|
122 |
|
123 |
+
if section.section.citation_ids:
|
124 |
+
related_splits = [
|
125 |
+
TextSplit.from_paper(papers_citation_id_map[int(citation_id)])
|
126 |
+
for citation_id in section.section.citation_ids
|
127 |
+
]
|
128 |
+
else:
|
129 |
# citation_ids が空なら全部を対象とする
|
130 |
+
related_splits = [TextSplit.from_paper(p) for p in papers]
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
+
related_splits += [
|
133 |
+
TextSplit.from_snippet(snippet) for snippet in
|
134 |
+
paper_store.similarity_search(
|
135 |
+
f"{section.section.title} {section.section.description}",
|
136 |
+
k=max_paper_store_search_size,
|
137 |
+
)
|
138 |
+
]
|
139 |
|
140 |
+
def get_snippet(split: TextSplit):
|
141 |
+
text = split.text.replace("\n", " ")
|
142 |
return f"""
|
143 |
+
Title: {split.title}
|
144 |
+
citation_id: {split.citation_id}
|
145 |
Text: {text}
|
146 |
"""
|
147 |
|
148 |
+
snippets = []
|
149 |
+
total_num_tokens = 0
|
150 |
idx = 0
|
151 |
|
152 |
+
while idx < len(related_splits):
|
153 |
+
split = related_splits[idx]
|
154 |
+
snippet_text = get_snippet(split)
|
|
|
|
|
|
|
|
|
155 |
num_tokens = llm.get_num_tokens(snippet_text)
|
156 |
|
157 |
if total_num_tokens + num_tokens > nb_token_limit:
|
|
|
165 |
"query": query,
|
166 |
"title": overview.title,
|
167 |
"overview": overview,
|
168 |
+
"section_title": section.section.title,
|
169 |
+
"section_description": section.section.description,
|
170 |
+
"section_level": section.level,
|
171 |
+
"md_title_suffix": "#" * section.level,
|
172 |
"outline": outline,
|
173 |
"categories": get_categories_string(papers, nb_categories),
|
174 |
"snippets": "\n".join(snippets).strip(),
|
metaanalyser/chains/sr.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from langchain.base_language import BaseLanguageModel
|
3 |
+
from langchain.chains.base import Chain
|
4 |
+
from langchain.callbacks.manager import CallbackManagerForChainRun
|
5 |
+
from pydantic import BaseModel
|
6 |
+
from typing import Any, Dict, List, Optional
|
7 |
+
|
8 |
+
from ..paper import Paper, search_on_google_scholar, create_papers_vectorstor
|
9 |
+
from .outline import SROutlintChain, Outlint, Section
|
10 |
+
from .overview import SROverviewChain, Overview
|
11 |
+
from .section import SRSectionChain
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
class SRChain(Chain):
|
17 |
+
|
18 |
+
llm: BaseLanguageModel
|
19 |
+
output_key: str = "text"
|
20 |
+
|
21 |
+
@property
|
22 |
+
def input_keys(self) -> List[str]:
|
23 |
+
return ["query"]
|
24 |
+
|
25 |
+
@property
|
26 |
+
def output_keys(self) -> List[str]:
|
27 |
+
return [self.output_key]
|
28 |
+
|
29 |
+
def _call(
|
30 |
+
self,
|
31 |
+
inputs: Dict[str, Any],
|
32 |
+
run_manager: Optional[CallbackManagerForChainRun] = None,
|
33 |
+
) -> Dict[str, str]:
|
34 |
+
query = inputs["query"]
|
35 |
+
logger.info(f"Searching `{query}` on Google Scholar.")
|
36 |
+
papers = search_on_google_scholar(query)
|
37 |
+
|
38 |
+
logger.info(f"Writing an overview of the paper.")
|
39 |
+
overview_chain = SROverviewChain(llm=self.llm, verbose=self.verbose)
|
40 |
+
overview: Overview = overview_chain.run({"query": query, "papers": papers})
|
41 |
+
|
42 |
+
logger.info(f"Building the outline of the paper.")
|
43 |
+
outline_chain = SROutlintChain(llm=self.llm, verbose=self.verbose)
|
44 |
+
outline: Outlint = outline_chain.run({
|
45 |
+
"query": query,
|
46 |
+
"papers": papers,
|
47 |
+
"overview": overview
|
48 |
+
})
|
49 |
+
|
50 |
+
logger.info(f"Creating vector store.")
|
51 |
+
db = create_papers_vectorstor(papers)
|
52 |
+
|
53 |
+
section_chain = SRSectionChain(llm=self.llm, paper_store=db, verbose=self.verbose)
|
54 |
+
flatten_sections = get_flatten_sections(outline)
|
55 |
+
sections_as_md = []
|
56 |
+
|
57 |
+
for section_idx in range(len(flatten_sections)):
|
58 |
+
logger.info(f"Writing sections: [{section_idx + 1} / {len(flatten_sections)}]")
|
59 |
+
|
60 |
+
sections_as_md.append(
|
61 |
+
section_chain.run({
|
62 |
+
"section_idx": section_idx,
|
63 |
+
"query": query,
|
64 |
+
"papers": papers,
|
65 |
+
"overview": overview,
|
66 |
+
"outline": outline,
|
67 |
+
"flatten_sections": flatten_sections,
|
68 |
+
})
|
69 |
+
)
|
70 |
+
|
71 |
+
return {
|
72 |
+
self.output_key: create_output(outline, overview, papers, flatten_sections, sections_as_md)
|
73 |
+
}
|
74 |
+
|
75 |
+
|
76 |
+
class FlattenSection(BaseModel):
|
77 |
+
|
78 |
+
"""SRChain 向けのセクションを表すヘルパークラス
|
79 |
+
"""
|
80 |
+
|
81 |
+
level: int
|
82 |
+
section: Section
|
83 |
+
|
84 |
+
|
85 |
+
def get_flatten_sections(
|
86 |
+
outline: Outlint,
|
87 |
+
start_section_level: int = 2,
|
88 |
+
) -> List[FlattenSection]:
|
89 |
+
def inner(section_level, section: Section) -> List[FlattenSection]:
|
90 |
+
result = FlattenSection(level=section_level, section=section)
|
91 |
+
|
92 |
+
if not section.children:
|
93 |
+
return [result]
|
94 |
+
|
95 |
+
return (
|
96 |
+
[result] + sum([
|
97 |
+
inner(section_level + 1, child)
|
98 |
+
for child in section.children
|
99 |
+
], [])
|
100 |
+
)
|
101 |
+
|
102 |
+
return sum([
|
103 |
+
inner(start_section_level, section)
|
104 |
+
for section in outline.sections
|
105 |
+
], [])
|
106 |
+
|
107 |
+
|
108 |
+
def create_output(
|
109 |
+
outline: Outlint,
|
110 |
+
overview: Overview,
|
111 |
+
papers: List[Paper],
|
112 |
+
flatten_sections: List[FlattenSection],
|
113 |
+
sections_as_md: List[str],
|
114 |
+
) -> str:
|
115 |
+
papers_citation_id_map = {p.citation_id: p for p in papers}
|
116 |
+
all_citation_ids = list(set(
|
117 |
+
outline.citations_ids + sum([
|
118 |
+
s.section.citation_ids for s in flatten_sections
|
119 |
+
], [])
|
120 |
+
))
|
121 |
+
|
122 |
+
citations = []
|
123 |
+
|
124 |
+
for citation_id in all_citation_ids:
|
125 |
+
citation = papers_citation_id_map[int(citation_id)]
|
126 |
+
citations.append(
|
127 |
+
f"[^{citation_id}]: "
|
128 |
+
f"[{citation.mla_citiation.snippet}]({citation.link})"
|
129 |
+
)
|
130 |
+
|
131 |
+
return (
|
132 |
+
f"# {overview.title}\n\n{overview.overview}\n\n"
|
133 |
+
+ f"## Table of contents\n\n{outline}\n\n"
|
134 |
+
+ "\n\n".join(sections_as_md)
|
135 |
+
+ "\n\n## References\n"
|
136 |
+
+ "\n\n".join(citations)
|
137 |
+
)
|