Spaces:
Running
Running
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,778 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
from typing import List
|
4 |
+
from typing_extensions import TypedDict
|
5 |
+
from pydantic import BaseModel, Field
|
6 |
+
from langgraph.graph import START, END, StateGraph
|
7 |
+
from langgraph.checkpoint.memory import MemorySaver
|
8 |
+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
9 |
+
import operator
|
10 |
+
from typing import Annotated
|
11 |
+
from langgraph.graph import MessagesState
|
12 |
+
from langchain_openai import ChatOpenAI
|
13 |
+
import gradio as gr
|
14 |
+
import uuid
|
15 |
+
from gradio.themes.utils import colors
|
16 |
+
from gradio.themes import Base
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
load_dotenv()
|
21 |
+
|
22 |
+
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
|
23 |
+
os.environ["TAVILY_API_KEY"] = os.getenv('TAVILY_API_KEY')
|
24 |
+
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
25 |
+
os.environ["LANGCHAIN_API_KEY"] = os.getenv('LANGCHAIN_API_KEY')
|
26 |
+
os.environ["LANGCHAIN_PROJECT"] = "Research Assistant v1"
|
27 |
+
|
28 |
+
|
29 |
+
llm = ChatOpenAI(model="gpt-4o", temperature=0)
|
30 |
+
|
31 |
+
|
32 |
+
from typing import List
|
33 |
+
from typing_extensions import TypedDict
|
34 |
+
from pydantic import BaseModel, Field
|
35 |
+
|
36 |
+
class Analyst(BaseModel):
|
37 |
+
affiliation: str = Field(
|
38 |
+
description="Primary affiliation of the analyst.",
|
39 |
+
)
|
40 |
+
name: str = Field(
|
41 |
+
description="Name of the analyst."
|
42 |
+
)
|
43 |
+
role: str = Field(
|
44 |
+
description="Role of the analyst in the context of the topic.",
|
45 |
+
)
|
46 |
+
description: str = Field(
|
47 |
+
description="Description of the analyst focus, concerns, and motives.",
|
48 |
+
)
|
49 |
+
@property
|
50 |
+
def persona(self) -> str:
|
51 |
+
return f"Name: {self.name}\nRole: {self.role}\nAffiliation: {self.affiliation}\nDescription: {self.description}\n"
|
52 |
+
|
53 |
+
class Perspectives(BaseModel):
|
54 |
+
analysts: List[Analyst] = Field(
|
55 |
+
description="Comprehensive list of analysts with their roles and affiliations.",
|
56 |
+
)
|
57 |
+
|
58 |
+
class GenerateAnalystsState(TypedDict):
|
59 |
+
topic: str # Research topic
|
60 |
+
max_analysts: int # Number of analysts
|
61 |
+
human_analyst_feedback: str # Human feedback
|
62 |
+
analysts: List[Analyst] # Analyst asking questions
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
from langgraph.graph import START, END, StateGraph
|
67 |
+
from langgraph.checkpoint.memory import MemorySaver
|
68 |
+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
69 |
+
|
70 |
+
analyst_instructions="""You are tasked with creating a set of AI analyst personas. Follow these instructions carefully:
|
71 |
+
|
72 |
+
1. First, review the research topic:
|
73 |
+
{topic}
|
74 |
+
|
75 |
+
2. Examine any editorial feedback that has been optionally provided to guide creation of the analysts:
|
76 |
+
|
77 |
+
{human_analyst_feedback}
|
78 |
+
|
79 |
+
3. Determine the most interesting themes based upon documents and / or feedback above.
|
80 |
+
|
81 |
+
4. Pick the top {max_analysts} themes.
|
82 |
+
|
83 |
+
5. Assign one analyst to each theme."""
|
84 |
+
|
85 |
+
def create_analysts(state: GenerateAnalystsState):
|
86 |
+
|
87 |
+
""" Create analysts """
|
88 |
+
|
89 |
+
topic=state['topic']
|
90 |
+
max_analysts=state['max_analysts']
|
91 |
+
human_analyst_feedback=state.get('human_analyst_feedback', '')
|
92 |
+
|
93 |
+
# Enforce structured output
|
94 |
+
structured_llm = llm.with_structured_output(Perspectives)
|
95 |
+
|
96 |
+
# System message
|
97 |
+
system_message = analyst_instructions.format(topic=topic,
|
98 |
+
human_analyst_feedback=human_analyst_feedback,
|
99 |
+
max_analysts=max_analysts)
|
100 |
+
|
101 |
+
# Generate question
|
102 |
+
analysts = structured_llm.invoke([SystemMessage(content=system_message)]+[HumanMessage(content="Generate the set of analysts.")])
|
103 |
+
|
104 |
+
# Write the list of analysis to state
|
105 |
+
return {"analysts": analysts.analysts}
|
106 |
+
|
107 |
+
def human_feedback(state: GenerateAnalystsState):
|
108 |
+
""" No-op node that should be interrupted on """
|
109 |
+
pass
|
110 |
+
|
111 |
+
def should_continue(state: GenerateAnalystsState):
|
112 |
+
""" Return the next node to execute """
|
113 |
+
|
114 |
+
# Check if human feedback
|
115 |
+
human_analyst_feedback=state.get('human_analyst_feedback', None)
|
116 |
+
if human_analyst_feedback:
|
117 |
+
return "create_analysts"
|
118 |
+
|
119 |
+
# Otherwise end
|
120 |
+
return END
|
121 |
+
|
122 |
+
import operator
|
123 |
+
from typing import Annotated
|
124 |
+
from langgraph.graph import MessagesState
|
125 |
+
|
126 |
+
class InterviewState(MessagesState):
|
127 |
+
max_num_turns: int # Number turns of conversation
|
128 |
+
context: Annotated[list, operator.add] # Source docs
|
129 |
+
analyst: Analyst # Analyst asking questions
|
130 |
+
interview: str # Interview transcript
|
131 |
+
sections: list # Final key we duplicate in outer state for Send() API
|
132 |
+
|
133 |
+
class SearchQuery(BaseModel):
|
134 |
+
search_query: str = Field(None, description="Search query for retrieval.")
|
135 |
+
|
136 |
+
question_instructions = """You are an analyst tasked with interviewing an expert to learn about a specific topic.
|
137 |
+
|
138 |
+
Your goal is boil down to interesting and specific insights related to your topic.
|
139 |
+
|
140 |
+
1. Interesting: Insights that people will find surprising or non-obvious.
|
141 |
+
|
142 |
+
2. Specific: Insights that avoid generalities and include specific examples from the expert.
|
143 |
+
|
144 |
+
Here is your topic of focus and set of goals: {goals}
|
145 |
+
|
146 |
+
Begin by introducing yourself using a name that fits your persona, and then ask your question.
|
147 |
+
|
148 |
+
Continue to ask questions to drill down and refine your understanding of the topic.
|
149 |
+
|
150 |
+
When you are satisfied with your understanding, complete the interview with: "Thank you so much for your help!"
|
151 |
+
|
152 |
+
Remember to stay in character throughout your response, reflecting the persona and goals provided to you."""
|
153 |
+
|
154 |
+
def generate_question(state: InterviewState):
|
155 |
+
""" Node to generate a question """
|
156 |
+
|
157 |
+
# Get state
|
158 |
+
analyst = state["analyst"]
|
159 |
+
messages = state["messages"]
|
160 |
+
|
161 |
+
# Generate question
|
162 |
+
system_message = question_instructions.format(goals=analyst.persona)
|
163 |
+
question = llm.invoke([SystemMessage(content=system_message)]+messages)
|
164 |
+
|
165 |
+
# Write messages to state
|
166 |
+
return {"messages": [question]}
|
167 |
+
|
168 |
+
|
169 |
+
from langchain_community.document_loaders import WikipediaLoader
|
170 |
+
from langchain_community.tools.tavily_search import TavilySearchResults
|
171 |
+
from langchain_core.messages import get_buffer_string
|
172 |
+
from langchain_core.documents import Document
|
173 |
+
import arxiv
|
174 |
+
|
175 |
+
# Search query writing
|
176 |
+
search_instructions = SystemMessage(content=f"""You will be given a conversation between an analyst and an expert.
|
177 |
+
|
178 |
+
Your goal is to generate a well-structured query for use in retrieval and / or web-search related to the conversation.
|
179 |
+
|
180 |
+
First, analyze the full conversation.
|
181 |
+
|
182 |
+
Pay particular attention to the final question posed by the analyst.
|
183 |
+
|
184 |
+
Convert this final question into a well-structured web search query""")
|
185 |
+
|
186 |
+
def search_web(state: InterviewState):
|
187 |
+
""" Retrieve docs from web search """
|
188 |
+
|
189 |
+
# Search query
|
190 |
+
structured_llm = llm.with_structured_output(SearchQuery)
|
191 |
+
search_query = structured_llm.invoke([search_instructions]+state['messages'])
|
192 |
+
|
193 |
+
# Search
|
194 |
+
tavily_search = TavilySearchResults(max_results=5)
|
195 |
+
search_docs = tavily_search.invoke(search_query.search_query)
|
196 |
+
|
197 |
+
# Debug: Print the type and content of search_docs
|
198 |
+
print(f"Type of search_docs: {type(search_docs)}")
|
199 |
+
print(f"Content of search_docs: {search_docs}")
|
200 |
+
|
201 |
+
# Format
|
202 |
+
try:
|
203 |
+
formatted_search_docs = "\n\n---\n\n".join(
|
204 |
+
[
|
205 |
+
f'<Document href="{doc["url"]}"/>\n{doc["content"]}\n</Document>'
|
206 |
+
for doc in search_docs
|
207 |
+
]
|
208 |
+
)
|
209 |
+
except TypeError as e:
|
210 |
+
print(f"Error in formatting search_docs: {e}")
|
211 |
+
# Fallback: treat search_docs as a single string if it's not iterable
|
212 |
+
formatted_search_docs = f"<Document>\n{search_docs}\n</Document>"
|
213 |
+
|
214 |
+
return {"context": [formatted_search_docs]}
|
215 |
+
|
216 |
+
def search_wikipedia(state: InterviewState):
|
217 |
+
|
218 |
+
""" Retrieve docs from wikipedia """
|
219 |
+
|
220 |
+
# Search query
|
221 |
+
structured_llm = llm.with_structured_output(SearchQuery)
|
222 |
+
search_query = structured_llm.invoke([search_instructions]+state['messages'])
|
223 |
+
|
224 |
+
# Search
|
225 |
+
search_docs = WikipediaLoader(query=search_query.search_query,
|
226 |
+
load_max_docs=2).load()
|
227 |
+
|
228 |
+
# Format
|
229 |
+
formatted_search_docs = "\n\n---\n\n".join(
|
230 |
+
[
|
231 |
+
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
|
232 |
+
for doc in search_docs
|
233 |
+
]
|
234 |
+
)
|
235 |
+
|
236 |
+
return {"context": [formatted_search_docs]}
|
237 |
+
|
238 |
+
def search_arxiv(state: InterviewState):
|
239 |
+
""" Retrieve docs from arxiv """
|
240 |
+
|
241 |
+
# Search query
|
242 |
+
structured_llm = llm.with_structured_output(SearchQuery)
|
243 |
+
search_query = structured_llm.invoke([search_instructions]+state['messages'])
|
244 |
+
|
245 |
+
# Search arXiv
|
246 |
+
search = arxiv.Search(
|
247 |
+
query=search_query.search_query,
|
248 |
+
max_results=10,
|
249 |
+
sort_by=arxiv.SortCriterion.Relevance
|
250 |
+
)
|
251 |
+
|
252 |
+
# Retrieve results
|
253 |
+
search_docs = []
|
254 |
+
for result in search.results():
|
255 |
+
doc = Document(
|
256 |
+
page_content=f"{result.title}\n\n{result.summary}",
|
257 |
+
metadata={
|
258 |
+
"title": result.title,
|
259 |
+
"authors": ", ".join(author.name for author in result.authors),
|
260 |
+
"published": result.published.strftime("%Y-%m-%d"),
|
261 |
+
"url": result.entry_id,
|
262 |
+
}
|
263 |
+
)
|
264 |
+
search_docs.append(doc)
|
265 |
+
|
266 |
+
# Format
|
267 |
+
formatted_search_docs = "\n\n---\n\n".join(
|
268 |
+
[
|
269 |
+
f'<Document title="{doc.metadata["title"]}" authors="{doc.metadata["authors"]}" published="{doc.metadata["published"]}" url="{doc.metadata["url"]}"/>\n{doc.page_content}\n</Document>'
|
270 |
+
for doc in search_docs
|
271 |
+
]
|
272 |
+
)
|
273 |
+
|
274 |
+
return {"context": [formatted_search_docs]}
|
275 |
+
|
276 |
+
|
277 |
+
|
278 |
+
answer_instructions = """You are an expert being interviewed by an analyst.
|
279 |
+
|
280 |
+
Here is analyst area of focus: {goals}.
|
281 |
+
|
282 |
+
You goal is to answer a question posed by the interviewer.
|
283 |
+
|
284 |
+
To answer question, use this context:
|
285 |
+
|
286 |
+
{context}
|
287 |
+
|
288 |
+
When answering questions, follow these guidelines:
|
289 |
+
|
290 |
+
1. Use only the information provided in the context.
|
291 |
+
|
292 |
+
2. Do not introduce external information or make assumptions beyond what is explicitly stated in the context.
|
293 |
+
|
294 |
+
3. The context contain sources at the topic of each individual document.
|
295 |
+
|
296 |
+
4. Include these sources your answer next to any relevant statements. For example, for source # 1 use [1].
|
297 |
+
|
298 |
+
5. List your sources in order at the bottom of your answer. [1] Source 1, [2] Source 2, etc
|
299 |
+
|
300 |
+
6. If the source is: <Document source="assistant/docs/llama3_1.pdf" page="7"/>' then just list:
|
301 |
+
|
302 |
+
[1] assistant/docs/llama3_1.pdf, page 7
|
303 |
+
|
304 |
+
And skip the addition of the brackets as well as the Document source preamble in your citation."""
|
305 |
+
|
306 |
+
def generate_answer(state: InterviewState):
|
307 |
+
|
308 |
+
""" Node to answer a question """
|
309 |
+
|
310 |
+
# Get state
|
311 |
+
analyst = state["analyst"]
|
312 |
+
messages = state["messages"]
|
313 |
+
context = state["context"]
|
314 |
+
|
315 |
+
# Answer question
|
316 |
+
system_message = answer_instructions.format(goals=analyst.persona, context=context)
|
317 |
+
answer = llm.invoke([SystemMessage(content=system_message)]+messages)
|
318 |
+
|
319 |
+
# Name the message as coming from the expert
|
320 |
+
answer.name = "expert"
|
321 |
+
|
322 |
+
# Append it to state
|
323 |
+
return {"messages": [answer]}
|
324 |
+
|
325 |
+
def save_interview(state: InterviewState):
|
326 |
+
|
327 |
+
""" Save interviews """
|
328 |
+
|
329 |
+
# Get messages
|
330 |
+
messages = state["messages"]
|
331 |
+
|
332 |
+
# Convert interview to a string
|
333 |
+
interview = get_buffer_string(messages)
|
334 |
+
|
335 |
+
# Save to interviews key
|
336 |
+
return {"interview": interview}
|
337 |
+
|
338 |
+
def route_messages(state: InterviewState,
|
339 |
+
name: str = "expert"):
|
340 |
+
|
341 |
+
""" Route between question and answer """
|
342 |
+
|
343 |
+
# Get messages
|
344 |
+
messages = state["messages"]
|
345 |
+
max_num_turns = state.get('max_num_turns',2)
|
346 |
+
|
347 |
+
# Check the number of expert answers
|
348 |
+
num_responses = len(
|
349 |
+
[m for m in messages if isinstance(m, AIMessage) and m.name == name]
|
350 |
+
)
|
351 |
+
|
352 |
+
# End if expert has answered more than the max turns
|
353 |
+
if num_responses >= max_num_turns:
|
354 |
+
return 'save_interview'
|
355 |
+
|
356 |
+
# This router is run after each question - answer pair
|
357 |
+
# Get the last question asked to check if it signals the end of discussion
|
358 |
+
last_question = messages[-2]
|
359 |
+
|
360 |
+
if "Thank you so much for your help" in last_question.content:
|
361 |
+
return 'save_interview'
|
362 |
+
return "ask_question"
|
363 |
+
|
364 |
+
section_writer_instructions = """You are an expert technical writer.
|
365 |
+
|
366 |
+
Your task is to create a short, easily digestible section of a report based on a set of source documents.
|
367 |
+
|
368 |
+
1. Analyze the content of the source documents:
|
369 |
+
- The name of each source document is at the start of the document, with the <Document tag.
|
370 |
+
|
371 |
+
2. Create a report structure using markdown formatting:
|
372 |
+
- Use ## for the section title
|
373 |
+
- Use ### for sub-section headers
|
374 |
+
|
375 |
+
3. Write the report following this structure:
|
376 |
+
a. Title (## header)
|
377 |
+
b. Summary (### header)
|
378 |
+
c. Sources (### header)
|
379 |
+
|
380 |
+
4. Make your title engaging based upon the focus area of the analyst:
|
381 |
+
{focus}
|
382 |
+
|
383 |
+
5. For the summary section:
|
384 |
+
- Set up summary with general background / context related to the focus area of the analyst
|
385 |
+
- Emphasize what is novel, interesting, or surprising about insights gathered from the interview
|
386 |
+
- Create a numbered list of source documents, as you use them
|
387 |
+
- Do not mention the names of interviewers or experts
|
388 |
+
- Aim for approximately 400 words maximum
|
389 |
+
- Use numbered sources in your report (e.g., [1], [2]) based on information from source documents
|
390 |
+
|
391 |
+
6. In the Sources section:
|
392 |
+
- Include all sources used in your report
|
393 |
+
- Provide full links to relevant websites or specific document paths
|
394 |
+
- Separate each source by a newline. Use two spaces at the end of each line to create a newline in Markdown.
|
395 |
+
- It will look like:
|
396 |
+
|
397 |
+
### Sources
|
398 |
+
[1] Link or Document name
|
399 |
+
[2] Link or Document name
|
400 |
+
|
401 |
+
7. Be sure to combine sources. For example this is not correct:
|
402 |
+
|
403 |
+
[3] https://ai.meta.com/blog/meta-llama-3-1/
|
404 |
+
[4] https://ai.meta.com/blog/meta-llama-3-1/
|
405 |
+
|
406 |
+
There should be no redundant sources. It should simply be:
|
407 |
+
|
408 |
+
[3] https://ai.meta.com/blog/meta-llama-3-1/
|
409 |
+
|
410 |
+
8. Final review:
|
411 |
+
- Ensure the report follows the required structure
|
412 |
+
- Include no preamble before the title of the report
|
413 |
+
- Check that all guidelines have been followed"""
|
414 |
+
|
415 |
+
def write_section(state: InterviewState):
|
416 |
+
|
417 |
+
""" Node to answer a question """
|
418 |
+
|
419 |
+
# Get state
|
420 |
+
interview = state["interview"]
|
421 |
+
context = state["context"]
|
422 |
+
analyst = state["analyst"]
|
423 |
+
|
424 |
+
# Write section using either the gathered source docs from interview (context) or the interview itself (interview)
|
425 |
+
system_message = section_writer_instructions.format(focus=analyst.description)
|
426 |
+
section = llm.invoke([SystemMessage(content=system_message)]+[HumanMessage(content=f"Use this source to write your section: {context}")])
|
427 |
+
|
428 |
+
# Append it to state
|
429 |
+
return {"sections": [section.content]}
|
430 |
+
|
431 |
+
# Add nodes and edges
|
432 |
+
interview_builder = StateGraph(InterviewState)
|
433 |
+
interview_builder.add_node("ask_question", generate_question)
|
434 |
+
interview_builder.add_node("search_web", search_web)
|
435 |
+
interview_builder.add_node("search_wikipedia", search_wikipedia)
|
436 |
+
interview_builder.add_node("search_arxiv", search_arxiv)
|
437 |
+
interview_builder.add_node("answer_question", generate_answer)
|
438 |
+
interview_builder.add_node("save_interview", save_interview)
|
439 |
+
interview_builder.add_node("write_section", write_section)
|
440 |
+
|
441 |
+
# Flow
|
442 |
+
interview_builder.add_edge(START, "ask_question")
|
443 |
+
interview_builder.add_edge("ask_question", "search_web")
|
444 |
+
interview_builder.add_edge("ask_question", "search_wikipedia")
|
445 |
+
interview_builder.add_edge("ask_question", "search_arxiv")
|
446 |
+
interview_builder.add_edge("search_web", "answer_question")
|
447 |
+
interview_builder.add_edge("search_wikipedia", "answer_question")
|
448 |
+
interview_builder.add_edge("search_arxiv", "answer_question")
|
449 |
+
interview_builder.add_conditional_edges("answer_question", route_messages,['ask_question','save_interview'])
|
450 |
+
interview_builder.add_edge("save_interview", "write_section")
|
451 |
+
interview_builder.add_edge("write_section", END)
|
452 |
+
|
453 |
+
|
454 |
+
import uuid
|
455 |
+
pre_defined_run_id = uuid.uuid4()
|
456 |
+
print("pre_defined_run_id", pre_defined_run_id)
|
457 |
+
|
458 |
+
|
459 |
+
# Interview
|
460 |
+
memory = MemorySaver()
|
461 |
+
interview_graph = interview_builder.compile(checkpointer=memory).with_config(run_name="Conduct Interviews")
|
462 |
+
|
463 |
+
|
464 |
+
import operator
|
465 |
+
from typing import List, Annotated
|
466 |
+
from typing_extensions import TypedDict
|
467 |
+
|
468 |
+
class ResearchGraphState(TypedDict):
|
469 |
+
topic: str # Research topic
|
470 |
+
max_analysts: int # Number of analysts
|
471 |
+
human_analyst_feedback: str # Human feedback
|
472 |
+
analysts: List[Analyst] # Analyst asking questions
|
473 |
+
sections: Annotated[list, operator.add] # Send() API key
|
474 |
+
introduction: str # Introduction for the final report
|
475 |
+
content: str # Content for the final report
|
476 |
+
conclusion: str # Conclusion for the final report
|
477 |
+
final_report: str # Final report
|
478 |
+
|
479 |
+
from langgraph.constants import Send
|
480 |
+
|
481 |
+
def initiate_all_interviews(state: ResearchGraphState):
|
482 |
+
""" This is the "map" step where we run each interview sub-graph using Send API """
|
483 |
+
|
484 |
+
# Check if human feedback
|
485 |
+
human_analyst_feedback=state.get('human_analyst_feedback')
|
486 |
+
if human_analyst_feedback:
|
487 |
+
# Return to create_analysts
|
488 |
+
return "create_analysts"
|
489 |
+
|
490 |
+
# Otherwise kick off interviews in parallel via Send() API
|
491 |
+
else:
|
492 |
+
topic = state["topic"]
|
493 |
+
return [Send("conduct_interview", {"analyst": analyst,
|
494 |
+
"messages": [HumanMessage(
|
495 |
+
content=f"So you said you were writing an article on {topic}?"
|
496 |
+
)
|
497 |
+
]}) for analyst in state["analysts"]]
|
498 |
+
|
499 |
+
report_writer_instructions = """You are a technical writer creating a report on this overall topic:
|
500 |
+
|
501 |
+
{topic}
|
502 |
+
|
503 |
+
You have a team of analysts. Each analyst has done two things:
|
504 |
+
|
505 |
+
1. They conducted an interview with an expert on a specific sub-topic.
|
506 |
+
2. They write up their finding into a memo.
|
507 |
+
|
508 |
+
Your task:
|
509 |
+
|
510 |
+
1. You will be given a collection of memos from your analysts.
|
511 |
+
2. Think carefully about the insights from each memo.
|
512 |
+
3. Consolidate these into a crisp overall summary that ties together the central ideas from all of the memos.
|
513 |
+
4. Summarize the central points in each memo into a cohesive single narrative.
|
514 |
+
|
515 |
+
To format your report:
|
516 |
+
|
517 |
+
1. Use markdown formatting.
|
518 |
+
2. Include no pre-amble for the report.
|
519 |
+
3. Use no sub-heading.
|
520 |
+
4. Start your report with a single title header: ## Insights
|
521 |
+
5. Do not mention any analyst names in your report.
|
522 |
+
6. Preserve any citations in the memos, which will be annotated in brackets, for example [1] or [2].
|
523 |
+
7. Create a final, consolidated list of sources and add to a Sources section with the `## Sources` header.
|
524 |
+
8. List your sources in order and do not repeat.
|
525 |
+
|
526 |
+
[1] Source 1
|
527 |
+
[2] Source 2
|
528 |
+
|
529 |
+
Here are the memos from your analysts to build your report from:
|
530 |
+
|
531 |
+
{context}"""
|
532 |
+
|
533 |
+
def write_report(state: ResearchGraphState):
|
534 |
+
# Full set of sections
|
535 |
+
sections = state["sections"]
|
536 |
+
topic = state["topic"]
|
537 |
+
|
538 |
+
# Concat all sections together
|
539 |
+
formatted_str_sections = "\n\n".join([f"{section}" for section in sections])
|
540 |
+
|
541 |
+
# Summarize the sections into a final report
|
542 |
+
system_message = report_writer_instructions.format(topic=topic, context=formatted_str_sections)
|
543 |
+
report = llm.invoke([SystemMessage(content=system_message)]+[HumanMessage(content=f"Write a report based upon these memos.")])
|
544 |
+
return {"content": report.content}
|
545 |
+
|
546 |
+
intro_conclusion_instructions = """You are a technical writer finishing a report on {topic}
|
547 |
+
|
548 |
+
You will be given all of the sections of the report.
|
549 |
+
|
550 |
+
You job is to write a crisp and compelling introduction or conclusion section.
|
551 |
+
|
552 |
+
The user will instruct you whether to write the introduction or conclusion.
|
553 |
+
|
554 |
+
Include no pre-amble for either section.
|
555 |
+
|
556 |
+
Target around 100 words, crisply previewing (for introduction) or recapping (for conclusion) all of the sections of the report.
|
557 |
+
|
558 |
+
Use markdown formatting.
|
559 |
+
|
560 |
+
For your introduction, create a compelling title and use the # header for the title.
|
561 |
+
|
562 |
+
For your introduction, use ## Introduction as the section header.
|
563 |
+
|
564 |
+
For your conclusion, use ## Conclusion as the section header.
|
565 |
+
|
566 |
+
Here are the sections to reflect on for writing: {formatted_str_sections}"""
|
567 |
+
|
568 |
+
def write_introduction(state: ResearchGraphState):
|
569 |
+
# Full set of sections
|
570 |
+
sections = state["sections"]
|
571 |
+
topic = state["topic"]
|
572 |
+
|
573 |
+
# Concat all sections together
|
574 |
+
formatted_str_sections = "\n\n".join([f"{section}" for section in sections])
|
575 |
+
|
576 |
+
# Summarize the sections into a final report
|
577 |
+
|
578 |
+
instructions = intro_conclusion_instructions.format(topic=topic, formatted_str_sections=formatted_str_sections)
|
579 |
+
intro = llm.invoke([instructions]+[HumanMessage(content=f"Write the report introduction")])
|
580 |
+
return {"introduction": intro.content}
|
581 |
+
|
582 |
+
def write_conclusion(state: ResearchGraphState):
|
583 |
+
# Full set of sections
|
584 |
+
sections = state["sections"]
|
585 |
+
topic = state["topic"]
|
586 |
+
|
587 |
+
# Concat all sections together
|
588 |
+
formatted_str_sections = "\n\n".join([f"{section}" for section in sections])
|
589 |
+
|
590 |
+
# Summarize the sections into a final report
|
591 |
+
|
592 |
+
instructions = intro_conclusion_instructions.format(topic=topic, formatted_str_sections=formatted_str_sections)
|
593 |
+
conclusion = llm.invoke([instructions]+[HumanMessage(content=f"Write the report conclusion")])
|
594 |
+
return {"conclusion": conclusion.content}
|
595 |
+
|
596 |
+
def finalize_report(state: ResearchGraphState):
|
597 |
+
""" The is the "reduce" step where we gather all the sections, combine them, and reflect on them to write the intro/conclusion """
|
598 |
+
# Save full final report
|
599 |
+
content = state["content"]
|
600 |
+
if content.startswith("## Insights"):
|
601 |
+
content = content.strip("## Insights")
|
602 |
+
if "## Sources" in content:
|
603 |
+
try:
|
604 |
+
content, sources = content.split("\n## Sources\n")
|
605 |
+
except:
|
606 |
+
sources = None
|
607 |
+
else:
|
608 |
+
sources = None
|
609 |
+
|
610 |
+
final_report = state["introduction"] + "\n\n---\n\n" + content + "\n\n---\n\n" + state["conclusion"]
|
611 |
+
if sources is not None:
|
612 |
+
final_report += "\n\n## Sources\n" + sources
|
613 |
+
return {"final_report": final_report}
|
614 |
+
|
615 |
+
# Add nodes and edges
|
616 |
+
builder = StateGraph(ResearchGraphState)
|
617 |
+
builder.add_node("create_analysts", create_analysts)
|
618 |
+
builder.add_node("human_feedback", human_feedback)
|
619 |
+
builder.add_node("conduct_interview", interview_builder.compile())
|
620 |
+
builder.add_node("write_report",write_report)
|
621 |
+
builder.add_node("write_introduction",write_introduction)
|
622 |
+
builder.add_node("write_conclusion",write_conclusion)
|
623 |
+
builder.add_node("finalize_report",finalize_report)
|
624 |
+
|
625 |
+
# Logic
|
626 |
+
builder.add_edge(START, "create_analysts")
|
627 |
+
builder.add_edge("create_analysts", "human_feedback")
|
628 |
+
builder.add_conditional_edges("human_feedback", initiate_all_interviews, ["create_analysts", "conduct_interview"])
|
629 |
+
builder.add_edge("conduct_interview", "write_report")
|
630 |
+
builder.add_edge("conduct_interview", "write_introduction")
|
631 |
+
builder.add_edge("conduct_interview", "write_conclusion")
|
632 |
+
builder.add_edge(["write_conclusion", "write_report", "write_introduction"], "finalize_report")
|
633 |
+
builder.add_edge("finalize_report", END)
|
634 |
+
|
635 |
+
# Compile
|
636 |
+
memory = MemorySaver()
|
637 |
+
graph = builder.compile(interrupt_before=['human_feedback'], checkpointer=memory)
|
638 |
+
|
639 |
+
|
640 |
+
def run_research(topic, max_analysts):
|
641 |
+
thread_id = str(uuid.uuid4())
|
642 |
+
thread = {"configurable": {"thread_id": thread_id}}
|
643 |
+
|
644 |
+
try:
|
645 |
+
events = list(graph.stream({"topic": topic, "max_analysts": max_analysts},
|
646 |
+
thread,
|
647 |
+
stream_mode="values"))
|
648 |
+
|
649 |
+
for event in events:
|
650 |
+
analysts = event.get('analysts', '')
|
651 |
+
if analysts:
|
652 |
+
analyst_info = []
|
653 |
+
for analyst in analysts:
|
654 |
+
analyst_info.append(f"## {analyst.name}\n"
|
655 |
+
f"Affiliation: {analyst.affiliation}\n"
|
656 |
+
f"Role: {analyst.role}\n"
|
657 |
+
f"Description: {analyst.description}\n"
|
658 |
+
f"---")
|
659 |
+
return "\n\n".join(analyst_info), thread_id, gr.update(visible=True)
|
660 |
+
|
661 |
+
return "No analysts generated. Please try again.", None, gr.update(visible=False)
|
662 |
+
|
663 |
+
except GeneratorExit:
|
664 |
+
return "Research process was interrupted. Please try again.", None, gr.update(visible=False)
|
665 |
+
except Exception as e:
|
666 |
+
return f"An error occurred: {str(e)}", None, gr.update(visible=False)
|
667 |
+
|
668 |
+
def process_feedback(topic, max_analysts, feedback, thread_id):
|
669 |
+
if not thread_id:
|
670 |
+
return "Error: No active research session. Please start a new research.", None, gr.update(visible=True)
|
671 |
+
|
672 |
+
thread = {"configurable": {"thread_id": thread_id}}
|
673 |
+
|
674 |
+
try:
|
675 |
+
if feedback:
|
676 |
+
graph.update_state(thread, {"human_analyst_feedback": feedback}, as_node="human_feedback")
|
677 |
+
|
678 |
+
all_analysts = []
|
679 |
+
for event in graph.stream(None, thread, stream_mode="values"):
|
680 |
+
analysts = event.get('analysts', [])
|
681 |
+
if analysts:
|
682 |
+
all_analysts = analysts # 更新為最新的分析師列表
|
683 |
+
|
684 |
+
if all_analysts:
|
685 |
+
# 只取最後 max_analysts 個分析師
|
686 |
+
latest_analysts = all_analysts[-max_analysts:]
|
687 |
+
analyst_info = []
|
688 |
+
for analyst in latest_analysts:
|
689 |
+
analyst_info.append(f"## {analyst.name}\n"
|
690 |
+
f"Affiliation: {analyst.affiliation}\n"
|
691 |
+
f"Role: {analyst.role}\n"
|
692 |
+
f"Description: {analyst.description}\n"
|
693 |
+
f"---")
|
694 |
+
return "\n\n".join(analyst_info), thread_id, gr.update(visible=True)
|
695 |
+
return "No new analysts generated. Please try again with different feedback.", thread_id, gr.update(visible=True)
|
696 |
+
else:
|
697 |
+
# 繼續研究過程
|
698 |
+
graph.update_state(thread, {"human_analyst_feedback": None}, as_node="human_feedback")
|
699 |
+
for event in graph.stream(None, thread, stream_mode="updates"):
|
700 |
+
pass # 我們可以在這裡添加進度更新
|
701 |
+
|
702 |
+
final_state = graph.get_state(thread)
|
703 |
+
report = final_state.values.get('final_report')
|
704 |
+
return report, None, gr.update(visible=False)
|
705 |
+
|
706 |
+
except GeneratorExit:
|
707 |
+
return "Feedback process was interrupted. Please try again.", thread_id, gr.update(visible=True)
|
708 |
+
except Exception as e:
|
709 |
+
return f"An error occurred: {str(e)}", thread_id, gr.update(visible=True)
|
710 |
+
|
711 |
+
# 在文件頂部��加以下導入
|
712 |
+
from gradio.themes.utils import colors
|
713 |
+
from gradio.themes import Base
|
714 |
+
|
715 |
+
# 在 demo 定義之前添加以下主題定義
|
716 |
+
theme = Base(
|
717 |
+
primary_hue=colors.blue,
|
718 |
+
secondary_hue=colors.slate,
|
719 |
+
neutral_hue=colors.gray,
|
720 |
+
font=("Helvetica", "sans-serif"),
|
721 |
+
).set(
|
722 |
+
body_background_fill="*neutral_50",
|
723 |
+
body_background_fill_dark="*neutral_900",
|
724 |
+
button_primary_background_fill="*primary_600",
|
725 |
+
button_primary_background_fill_hover="*primary_700",
|
726 |
+
button_primary_text_color="white",
|
727 |
+
block_title_text_weight="600",
|
728 |
+
block_border_width="2px",
|
729 |
+
block_shadow="0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1)",
|
730 |
+
)
|
731 |
+
|
732 |
+
# 修改 Gradio 介面定義
|
733 |
+
with gr.Blocks(theme=theme, css="footer {visibility: hidden}") as demo:
|
734 |
+
gr.Markdown(
|
735 |
+
"""
|
736 |
+
# Multi-agent Research Team 👨👨👦👦🔗
|
737 |
+
|
738 |
+
Generate a team of AI experts to conduct in-depth research on your chosen topic.
|
739 |
+
"""
|
740 |
+
)
|
741 |
+
|
742 |
+
with gr.Row():
|
743 |
+
with gr.Column(scale=3):
|
744 |
+
topic = gr.Textbox(label="Research Topic", placeholder="Enter your research topic here...")
|
745 |
+
with gr.Column(scale=1):
|
746 |
+
max_analysts = gr.Slider(minimum=1, maximum=5, step=1, label="Number of Experts", value=3)
|
747 |
+
|
748 |
+
start_btn = gr.Button("Generate Experts", variant="primary")
|
749 |
+
|
750 |
+
output = gr.Markdown(label="Output")
|
751 |
+
|
752 |
+
with gr.Row(visible=False) as feedback_row:
|
753 |
+
with gr.Column(scale=3):
|
754 |
+
feedback = gr.Textbox(label="Human-in-the-Loop Feedback (optional)", placeholder="Provide feedback to refine the experts...")
|
755 |
+
with gr.Column(scale=2):
|
756 |
+
with gr.Row():
|
757 |
+
continue_btn = gr.Button("Regenerate Experts", variant="secondary", scale=1)
|
758 |
+
finish_btn = gr.Button("Start Research", variant="primary", scale=1)
|
759 |
+
|
760 |
+
thread_id = gr.State(value=None)
|
761 |
+
|
762 |
+
def update_visibility(visible):
|
763 |
+
return gr.update(visible=visible)
|
764 |
+
|
765 |
+
start_btn.click(run_research,
|
766 |
+
inputs=[topic, max_analysts],
|
767 |
+
outputs=[output, thread_id, feedback_row])
|
768 |
+
|
769 |
+
continue_btn.click(process_feedback,
|
770 |
+
inputs=[topic, max_analysts, feedback, thread_id],
|
771 |
+
outputs=[output, thread_id, feedback_row])
|
772 |
+
|
773 |
+
finish_btn.click(lambda t, m, f, tid: process_feedback(t, m, "", tid),
|
774 |
+
inputs=[topic, max_analysts, feedback, thread_id],
|
775 |
+
outputs=[output, thread_id, feedback_row])
|
776 |
+
|
777 |
+
if __name__ == "__main__":
|
778 |
+
demo.launch()
|