jc132 commited on
Commit
ddb4fb3
1 Parent(s): 666cba1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +778 -0
app.py ADDED
@@ -0,0 +1,778 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from typing import List
4
+ from typing_extensions import TypedDict
5
+ from pydantic import BaseModel, Field
6
+ from langgraph.graph import START, END, StateGraph
7
+ from langgraph.checkpoint.memory import MemorySaver
8
+ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
9
+ import operator
10
+ from typing import Annotated
11
+ from langgraph.graph import MessagesState
12
+ from langchain_openai import ChatOpenAI
13
+ import gradio as gr
14
+ import uuid
15
+ from gradio.themes.utils import colors
16
+ from gradio.themes import Base
17
+
18
+
19
+
20
+ load_dotenv()
21
+
22
+ os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
23
+ os.environ["TAVILY_API_KEY"] = os.getenv('TAVILY_API_KEY')
24
+ os.environ["LANGCHAIN_TRACING_V2"] = "true"
25
+ os.environ["LANGCHAIN_API_KEY"] = os.getenv('LANGCHAIN_API_KEY')
26
+ os.environ["LANGCHAIN_PROJECT"] = "Research Assistant v1"
27
+
28
+
29
+ llm = ChatOpenAI(model="gpt-4o", temperature=0)
30
+
31
+
32
+ from typing import List
33
+ from typing_extensions import TypedDict
34
+ from pydantic import BaseModel, Field
35
+
36
+ class Analyst(BaseModel):
37
+ affiliation: str = Field(
38
+ description="Primary affiliation of the analyst.",
39
+ )
40
+ name: str = Field(
41
+ description="Name of the analyst."
42
+ )
43
+ role: str = Field(
44
+ description="Role of the analyst in the context of the topic.",
45
+ )
46
+ description: str = Field(
47
+ description="Description of the analyst focus, concerns, and motives.",
48
+ )
49
+ @property
50
+ def persona(self) -> str:
51
+ return f"Name: {self.name}\nRole: {self.role}\nAffiliation: {self.affiliation}\nDescription: {self.description}\n"
52
+
53
+ class Perspectives(BaseModel):
54
+ analysts: List[Analyst] = Field(
55
+ description="Comprehensive list of analysts with their roles and affiliations.",
56
+ )
57
+
58
+ class GenerateAnalystsState(TypedDict):
59
+ topic: str # Research topic
60
+ max_analysts: int # Number of analysts
61
+ human_analyst_feedback: str # Human feedback
62
+ analysts: List[Analyst] # Analyst asking questions
63
+
64
+
65
+
66
+ from langgraph.graph import START, END, StateGraph
67
+ from langgraph.checkpoint.memory import MemorySaver
68
+ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
69
+
70
+ analyst_instructions="""You are tasked with creating a set of AI analyst personas. Follow these instructions carefully:
71
+
72
+ 1. First, review the research topic:
73
+ {topic}
74
+
75
+ 2. Examine any editorial feedback that has been optionally provided to guide creation of the analysts:
76
+
77
+ {human_analyst_feedback}
78
+
79
+ 3. Determine the most interesting themes based upon documents and / or feedback above.
80
+
81
+ 4. Pick the top {max_analysts} themes.
82
+
83
+ 5. Assign one analyst to each theme."""
84
+
85
+ def create_analysts(state: GenerateAnalystsState):
86
+
87
+ """ Create analysts """
88
+
89
+ topic=state['topic']
90
+ max_analysts=state['max_analysts']
91
+ human_analyst_feedback=state.get('human_analyst_feedback', '')
92
+
93
+ # Enforce structured output
94
+ structured_llm = llm.with_structured_output(Perspectives)
95
+
96
+ # System message
97
+ system_message = analyst_instructions.format(topic=topic,
98
+ human_analyst_feedback=human_analyst_feedback,
99
+ max_analysts=max_analysts)
100
+
101
+ # Generate question
102
+ analysts = structured_llm.invoke([SystemMessage(content=system_message)]+[HumanMessage(content="Generate the set of analysts.")])
103
+
104
+ # Write the list of analysis to state
105
+ return {"analysts": analysts.analysts}
106
+
107
+ def human_feedback(state: GenerateAnalystsState):
108
+ """ No-op node that should be interrupted on """
109
+ pass
110
+
111
+ def should_continue(state: GenerateAnalystsState):
112
+ """ Return the next node to execute """
113
+
114
+ # Check if human feedback
115
+ human_analyst_feedback=state.get('human_analyst_feedback', None)
116
+ if human_analyst_feedback:
117
+ return "create_analysts"
118
+
119
+ # Otherwise end
120
+ return END
121
+
122
+ import operator
123
+ from typing import Annotated
124
+ from langgraph.graph import MessagesState
125
+
126
+ class InterviewState(MessagesState):
127
+ max_num_turns: int # Number turns of conversation
128
+ context: Annotated[list, operator.add] # Source docs
129
+ analyst: Analyst # Analyst asking questions
130
+ interview: str # Interview transcript
131
+ sections: list # Final key we duplicate in outer state for Send() API
132
+
133
+ class SearchQuery(BaseModel):
134
+ search_query: str = Field(None, description="Search query for retrieval.")
135
+
136
+ question_instructions = """You are an analyst tasked with interviewing an expert to learn about a specific topic.
137
+
138
+ Your goal is boil down to interesting and specific insights related to your topic.
139
+
140
+ 1. Interesting: Insights that people will find surprising or non-obvious.
141
+
142
+ 2. Specific: Insights that avoid generalities and include specific examples from the expert.
143
+
144
+ Here is your topic of focus and set of goals: {goals}
145
+
146
+ Begin by introducing yourself using a name that fits your persona, and then ask your question.
147
+
148
+ Continue to ask questions to drill down and refine your understanding of the topic.
149
+
150
+ When you are satisfied with your understanding, complete the interview with: "Thank you so much for your help!"
151
+
152
+ Remember to stay in character throughout your response, reflecting the persona and goals provided to you."""
153
+
154
+ def generate_question(state: InterviewState):
155
+ """ Node to generate a question """
156
+
157
+ # Get state
158
+ analyst = state["analyst"]
159
+ messages = state["messages"]
160
+
161
+ # Generate question
162
+ system_message = question_instructions.format(goals=analyst.persona)
163
+ question = llm.invoke([SystemMessage(content=system_message)]+messages)
164
+
165
+ # Write messages to state
166
+ return {"messages": [question]}
167
+
168
+
169
+ from langchain_community.document_loaders import WikipediaLoader
170
+ from langchain_community.tools.tavily_search import TavilySearchResults
171
+ from langchain_core.messages import get_buffer_string
172
+ from langchain_core.documents import Document
173
+ import arxiv
174
+
175
+ # Search query writing
176
+ search_instructions = SystemMessage(content=f"""You will be given a conversation between an analyst and an expert.
177
+
178
+ Your goal is to generate a well-structured query for use in retrieval and / or web-search related to the conversation.
179
+
180
+ First, analyze the full conversation.
181
+
182
+ Pay particular attention to the final question posed by the analyst.
183
+
184
+ Convert this final question into a well-structured web search query""")
185
+
186
+ def search_web(state: InterviewState):
187
+ """ Retrieve docs from web search """
188
+
189
+ # Search query
190
+ structured_llm = llm.with_structured_output(SearchQuery)
191
+ search_query = structured_llm.invoke([search_instructions]+state['messages'])
192
+
193
+ # Search
194
+ tavily_search = TavilySearchResults(max_results=5)
195
+ search_docs = tavily_search.invoke(search_query.search_query)
196
+
197
+ # Debug: Print the type and content of search_docs
198
+ print(f"Type of search_docs: {type(search_docs)}")
199
+ print(f"Content of search_docs: {search_docs}")
200
+
201
+ # Format
202
+ try:
203
+ formatted_search_docs = "\n\n---\n\n".join(
204
+ [
205
+ f'<Document href="{doc["url"]}"/>\n{doc["content"]}\n</Document>'
206
+ for doc in search_docs
207
+ ]
208
+ )
209
+ except TypeError as e:
210
+ print(f"Error in formatting search_docs: {e}")
211
+ # Fallback: treat search_docs as a single string if it's not iterable
212
+ formatted_search_docs = f"<Document>\n{search_docs}\n</Document>"
213
+
214
+ return {"context": [formatted_search_docs]}
215
+
216
+ def search_wikipedia(state: InterviewState):
217
+
218
+ """ Retrieve docs from wikipedia """
219
+
220
+ # Search query
221
+ structured_llm = llm.with_structured_output(SearchQuery)
222
+ search_query = structured_llm.invoke([search_instructions]+state['messages'])
223
+
224
+ # Search
225
+ search_docs = WikipediaLoader(query=search_query.search_query,
226
+ load_max_docs=2).load()
227
+
228
+ # Format
229
+ formatted_search_docs = "\n\n---\n\n".join(
230
+ [
231
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
232
+ for doc in search_docs
233
+ ]
234
+ )
235
+
236
+ return {"context": [formatted_search_docs]}
237
+
238
+ def search_arxiv(state: InterviewState):
239
+ """ Retrieve docs from arxiv """
240
+
241
+ # Search query
242
+ structured_llm = llm.with_structured_output(SearchQuery)
243
+ search_query = structured_llm.invoke([search_instructions]+state['messages'])
244
+
245
+ # Search arXiv
246
+ search = arxiv.Search(
247
+ query=search_query.search_query,
248
+ max_results=10,
249
+ sort_by=arxiv.SortCriterion.Relevance
250
+ )
251
+
252
+ # Retrieve results
253
+ search_docs = []
254
+ for result in search.results():
255
+ doc = Document(
256
+ page_content=f"{result.title}\n\n{result.summary}",
257
+ metadata={
258
+ "title": result.title,
259
+ "authors": ", ".join(author.name for author in result.authors),
260
+ "published": result.published.strftime("%Y-%m-%d"),
261
+ "url": result.entry_id,
262
+ }
263
+ )
264
+ search_docs.append(doc)
265
+
266
+ # Format
267
+ formatted_search_docs = "\n\n---\n\n".join(
268
+ [
269
+ f'<Document title="{doc.metadata["title"]}" authors="{doc.metadata["authors"]}" published="{doc.metadata["published"]}" url="{doc.metadata["url"]}"/>\n{doc.page_content}\n</Document>'
270
+ for doc in search_docs
271
+ ]
272
+ )
273
+
274
+ return {"context": [formatted_search_docs]}
275
+
276
+
277
+
278
+ answer_instructions = """You are an expert being interviewed by an analyst.
279
+
280
+ Here is analyst area of focus: {goals}.
281
+
282
+ You goal is to answer a question posed by the interviewer.
283
+
284
+ To answer question, use this context:
285
+
286
+ {context}
287
+
288
+ When answering questions, follow these guidelines:
289
+
290
+ 1. Use only the information provided in the context.
291
+
292
+ 2. Do not introduce external information or make assumptions beyond what is explicitly stated in the context.
293
+
294
+ 3. The context contain sources at the topic of each individual document.
295
+
296
+ 4. Include these sources your answer next to any relevant statements. For example, for source # 1 use [1].
297
+
298
+ 5. List your sources in order at the bottom of your answer. [1] Source 1, [2] Source 2, etc
299
+
300
+ 6. If the source is: <Document source="assistant/docs/llama3_1.pdf" page="7"/>' then just list:
301
+
302
+ [1] assistant/docs/llama3_1.pdf, page 7
303
+
304
+ And skip the addition of the brackets as well as the Document source preamble in your citation."""
305
+
306
+ def generate_answer(state: InterviewState):
307
+
308
+ """ Node to answer a question """
309
+
310
+ # Get state
311
+ analyst = state["analyst"]
312
+ messages = state["messages"]
313
+ context = state["context"]
314
+
315
+ # Answer question
316
+ system_message = answer_instructions.format(goals=analyst.persona, context=context)
317
+ answer = llm.invoke([SystemMessage(content=system_message)]+messages)
318
+
319
+ # Name the message as coming from the expert
320
+ answer.name = "expert"
321
+
322
+ # Append it to state
323
+ return {"messages": [answer]}
324
+
325
+ def save_interview(state: InterviewState):
326
+
327
+ """ Save interviews """
328
+
329
+ # Get messages
330
+ messages = state["messages"]
331
+
332
+ # Convert interview to a string
333
+ interview = get_buffer_string(messages)
334
+
335
+ # Save to interviews key
336
+ return {"interview": interview}
337
+
338
+ def route_messages(state: InterviewState,
339
+ name: str = "expert"):
340
+
341
+ """ Route between question and answer """
342
+
343
+ # Get messages
344
+ messages = state["messages"]
345
+ max_num_turns = state.get('max_num_turns',2)
346
+
347
+ # Check the number of expert answers
348
+ num_responses = len(
349
+ [m for m in messages if isinstance(m, AIMessage) and m.name == name]
350
+ )
351
+
352
+ # End if expert has answered more than the max turns
353
+ if num_responses >= max_num_turns:
354
+ return 'save_interview'
355
+
356
+ # This router is run after each question - answer pair
357
+ # Get the last question asked to check if it signals the end of discussion
358
+ last_question = messages[-2]
359
+
360
+ if "Thank you so much for your help" in last_question.content:
361
+ return 'save_interview'
362
+ return "ask_question"
363
+
364
+ section_writer_instructions = """You are an expert technical writer.
365
+
366
+ Your task is to create a short, easily digestible section of a report based on a set of source documents.
367
+
368
+ 1. Analyze the content of the source documents:
369
+ - The name of each source document is at the start of the document, with the <Document tag.
370
+
371
+ 2. Create a report structure using markdown formatting:
372
+ - Use ## for the section title
373
+ - Use ### for sub-section headers
374
+
375
+ 3. Write the report following this structure:
376
+ a. Title (## header)
377
+ b. Summary (### header)
378
+ c. Sources (### header)
379
+
380
+ 4. Make your title engaging based upon the focus area of the analyst:
381
+ {focus}
382
+
383
+ 5. For the summary section:
384
+ - Set up summary with general background / context related to the focus area of the analyst
385
+ - Emphasize what is novel, interesting, or surprising about insights gathered from the interview
386
+ - Create a numbered list of source documents, as you use them
387
+ - Do not mention the names of interviewers or experts
388
+ - Aim for approximately 400 words maximum
389
+ - Use numbered sources in your report (e.g., [1], [2]) based on information from source documents
390
+
391
+ 6. In the Sources section:
392
+ - Include all sources used in your report
393
+ - Provide full links to relevant websites or specific document paths
394
+ - Separate each source by a newline. Use two spaces at the end of each line to create a newline in Markdown.
395
+ - It will look like:
396
+
397
+ ### Sources
398
+ [1] Link or Document name
399
+ [2] Link or Document name
400
+
401
+ 7. Be sure to combine sources. For example this is not correct:
402
+
403
+ [3] https://ai.meta.com/blog/meta-llama-3-1/
404
+ [4] https://ai.meta.com/blog/meta-llama-3-1/
405
+
406
+ There should be no redundant sources. It should simply be:
407
+
408
+ [3] https://ai.meta.com/blog/meta-llama-3-1/
409
+
410
+ 8. Final review:
411
+ - Ensure the report follows the required structure
412
+ - Include no preamble before the title of the report
413
+ - Check that all guidelines have been followed"""
414
+
415
+ def write_section(state: InterviewState):
416
+
417
+ """ Node to answer a question """
418
+
419
+ # Get state
420
+ interview = state["interview"]
421
+ context = state["context"]
422
+ analyst = state["analyst"]
423
+
424
+ # Write section using either the gathered source docs from interview (context) or the interview itself (interview)
425
+ system_message = section_writer_instructions.format(focus=analyst.description)
426
+ section = llm.invoke([SystemMessage(content=system_message)]+[HumanMessage(content=f"Use this source to write your section: {context}")])
427
+
428
+ # Append it to state
429
+ return {"sections": [section.content]}
430
+
431
+ # Add nodes and edges
432
+ interview_builder = StateGraph(InterviewState)
433
+ interview_builder.add_node("ask_question", generate_question)
434
+ interview_builder.add_node("search_web", search_web)
435
+ interview_builder.add_node("search_wikipedia", search_wikipedia)
436
+ interview_builder.add_node("search_arxiv", search_arxiv)
437
+ interview_builder.add_node("answer_question", generate_answer)
438
+ interview_builder.add_node("save_interview", save_interview)
439
+ interview_builder.add_node("write_section", write_section)
440
+
441
+ # Flow
442
+ interview_builder.add_edge(START, "ask_question")
443
+ interview_builder.add_edge("ask_question", "search_web")
444
+ interview_builder.add_edge("ask_question", "search_wikipedia")
445
+ interview_builder.add_edge("ask_question", "search_arxiv")
446
+ interview_builder.add_edge("search_web", "answer_question")
447
+ interview_builder.add_edge("search_wikipedia", "answer_question")
448
+ interview_builder.add_edge("search_arxiv", "answer_question")
449
+ interview_builder.add_conditional_edges("answer_question", route_messages,['ask_question','save_interview'])
450
+ interview_builder.add_edge("save_interview", "write_section")
451
+ interview_builder.add_edge("write_section", END)
452
+
453
+
454
+ import uuid
455
+ pre_defined_run_id = uuid.uuid4()
456
+ print("pre_defined_run_id", pre_defined_run_id)
457
+
458
+
459
+ # Interview
460
+ memory = MemorySaver()
461
+ interview_graph = interview_builder.compile(checkpointer=memory).with_config(run_name="Conduct Interviews")
462
+
463
+
464
+ import operator
465
+ from typing import List, Annotated
466
+ from typing_extensions import TypedDict
467
+
468
+ class ResearchGraphState(TypedDict):
469
+ topic: str # Research topic
470
+ max_analysts: int # Number of analysts
471
+ human_analyst_feedback: str # Human feedback
472
+ analysts: List[Analyst] # Analyst asking questions
473
+ sections: Annotated[list, operator.add] # Send() API key
474
+ introduction: str # Introduction for the final report
475
+ content: str # Content for the final report
476
+ conclusion: str # Conclusion for the final report
477
+ final_report: str # Final report
478
+
479
+ from langgraph.constants import Send
480
+
481
+ def initiate_all_interviews(state: ResearchGraphState):
482
+ """ This is the "map" step where we run each interview sub-graph using Send API """
483
+
484
+ # Check if human feedback
485
+ human_analyst_feedback=state.get('human_analyst_feedback')
486
+ if human_analyst_feedback:
487
+ # Return to create_analysts
488
+ return "create_analysts"
489
+
490
+ # Otherwise kick off interviews in parallel via Send() API
491
+ else:
492
+ topic = state["topic"]
493
+ return [Send("conduct_interview", {"analyst": analyst,
494
+ "messages": [HumanMessage(
495
+ content=f"So you said you were writing an article on {topic}?"
496
+ )
497
+ ]}) for analyst in state["analysts"]]
498
+
499
+ report_writer_instructions = """You are a technical writer creating a report on this overall topic:
500
+
501
+ {topic}
502
+
503
+ You have a team of analysts. Each analyst has done two things:
504
+
505
+ 1. They conducted an interview with an expert on a specific sub-topic.
506
+ 2. They write up their finding into a memo.
507
+
508
+ Your task:
509
+
510
+ 1. You will be given a collection of memos from your analysts.
511
+ 2. Think carefully about the insights from each memo.
512
+ 3. Consolidate these into a crisp overall summary that ties together the central ideas from all of the memos.
513
+ 4. Summarize the central points in each memo into a cohesive single narrative.
514
+
515
+ To format your report:
516
+
517
+ 1. Use markdown formatting.
518
+ 2. Include no pre-amble for the report.
519
+ 3. Use no sub-heading.
520
+ 4. Start your report with a single title header: ## Insights
521
+ 5. Do not mention any analyst names in your report.
522
+ 6. Preserve any citations in the memos, which will be annotated in brackets, for example [1] or [2].
523
+ 7. Create a final, consolidated list of sources and add to a Sources section with the `## Sources` header.
524
+ 8. List your sources in order and do not repeat.
525
+
526
+ [1] Source 1
527
+ [2] Source 2
528
+
529
+ Here are the memos from your analysts to build your report from:
530
+
531
+ {context}"""
532
+
533
+ def write_report(state: ResearchGraphState):
534
+ # Full set of sections
535
+ sections = state["sections"]
536
+ topic = state["topic"]
537
+
538
+ # Concat all sections together
539
+ formatted_str_sections = "\n\n".join([f"{section}" for section in sections])
540
+
541
+ # Summarize the sections into a final report
542
+ system_message = report_writer_instructions.format(topic=topic, context=formatted_str_sections)
543
+ report = llm.invoke([SystemMessage(content=system_message)]+[HumanMessage(content=f"Write a report based upon these memos.")])
544
+ return {"content": report.content}
545
+
546
+ intro_conclusion_instructions = """You are a technical writer finishing a report on {topic}
547
+
548
+ You will be given all of the sections of the report.
549
+
550
+ You job is to write a crisp and compelling introduction or conclusion section.
551
+
552
+ The user will instruct you whether to write the introduction or conclusion.
553
+
554
+ Include no pre-amble for either section.
555
+
556
+ Target around 100 words, crisply previewing (for introduction) or recapping (for conclusion) all of the sections of the report.
557
+
558
+ Use markdown formatting.
559
+
560
+ For your introduction, create a compelling title and use the # header for the title.
561
+
562
+ For your introduction, use ## Introduction as the section header.
563
+
564
+ For your conclusion, use ## Conclusion as the section header.
565
+
566
+ Here are the sections to reflect on for writing: {formatted_str_sections}"""
567
+
568
+ def write_introduction(state: ResearchGraphState):
569
+ # Full set of sections
570
+ sections = state["sections"]
571
+ topic = state["topic"]
572
+
573
+ # Concat all sections together
574
+ formatted_str_sections = "\n\n".join([f"{section}" for section in sections])
575
+
576
+ # Summarize the sections into a final report
577
+
578
+ instructions = intro_conclusion_instructions.format(topic=topic, formatted_str_sections=formatted_str_sections)
579
+ intro = llm.invoke([instructions]+[HumanMessage(content=f"Write the report introduction")])
580
+ return {"introduction": intro.content}
581
+
582
+ def write_conclusion(state: ResearchGraphState):
583
+ # Full set of sections
584
+ sections = state["sections"]
585
+ topic = state["topic"]
586
+
587
+ # Concat all sections together
588
+ formatted_str_sections = "\n\n".join([f"{section}" for section in sections])
589
+
590
+ # Summarize the sections into a final report
591
+
592
+ instructions = intro_conclusion_instructions.format(topic=topic, formatted_str_sections=formatted_str_sections)
593
+ conclusion = llm.invoke([instructions]+[HumanMessage(content=f"Write the report conclusion")])
594
+ return {"conclusion": conclusion.content}
595
+
596
+ def finalize_report(state: ResearchGraphState):
597
+ """ The is the "reduce" step where we gather all the sections, combine them, and reflect on them to write the intro/conclusion """
598
+ # Save full final report
599
+ content = state["content"]
600
+ if content.startswith("## Insights"):
601
+ content = content.strip("## Insights")
602
+ if "## Sources" in content:
603
+ try:
604
+ content, sources = content.split("\n## Sources\n")
605
+ except:
606
+ sources = None
607
+ else:
608
+ sources = None
609
+
610
+ final_report = state["introduction"] + "\n\n---\n\n" + content + "\n\n---\n\n" + state["conclusion"]
611
+ if sources is not None:
612
+ final_report += "\n\n## Sources\n" + sources
613
+ return {"final_report": final_report}
614
+
615
+ # Add nodes and edges
616
+ builder = StateGraph(ResearchGraphState)
617
+ builder.add_node("create_analysts", create_analysts)
618
+ builder.add_node("human_feedback", human_feedback)
619
+ builder.add_node("conduct_interview", interview_builder.compile())
620
+ builder.add_node("write_report",write_report)
621
+ builder.add_node("write_introduction",write_introduction)
622
+ builder.add_node("write_conclusion",write_conclusion)
623
+ builder.add_node("finalize_report",finalize_report)
624
+
625
+ # Logic
626
+ builder.add_edge(START, "create_analysts")
627
+ builder.add_edge("create_analysts", "human_feedback")
628
+ builder.add_conditional_edges("human_feedback", initiate_all_interviews, ["create_analysts", "conduct_interview"])
629
+ builder.add_edge("conduct_interview", "write_report")
630
+ builder.add_edge("conduct_interview", "write_introduction")
631
+ builder.add_edge("conduct_interview", "write_conclusion")
632
+ builder.add_edge(["write_conclusion", "write_report", "write_introduction"], "finalize_report")
633
+ builder.add_edge("finalize_report", END)
634
+
635
+ # Compile
636
+ memory = MemorySaver()
637
+ graph = builder.compile(interrupt_before=['human_feedback'], checkpointer=memory)
638
+
639
+
640
+ def run_research(topic, max_analysts):
641
+ thread_id = str(uuid.uuid4())
642
+ thread = {"configurable": {"thread_id": thread_id}}
643
+
644
+ try:
645
+ events = list(graph.stream({"topic": topic, "max_analysts": max_analysts},
646
+ thread,
647
+ stream_mode="values"))
648
+
649
+ for event in events:
650
+ analysts = event.get('analysts', '')
651
+ if analysts:
652
+ analyst_info = []
653
+ for analyst in analysts:
654
+ analyst_info.append(f"## {analyst.name}\n"
655
+ f"Affiliation: {analyst.affiliation}\n"
656
+ f"Role: {analyst.role}\n"
657
+ f"Description: {analyst.description}\n"
658
+ f"---")
659
+ return "\n\n".join(analyst_info), thread_id, gr.update(visible=True)
660
+
661
+ return "No analysts generated. Please try again.", None, gr.update(visible=False)
662
+
663
+ except GeneratorExit:
664
+ return "Research process was interrupted. Please try again.", None, gr.update(visible=False)
665
+ except Exception as e:
666
+ return f"An error occurred: {str(e)}", None, gr.update(visible=False)
667
+
668
+ def process_feedback(topic, max_analysts, feedback, thread_id):
669
+ if not thread_id:
670
+ return "Error: No active research session. Please start a new research.", None, gr.update(visible=True)
671
+
672
+ thread = {"configurable": {"thread_id": thread_id}}
673
+
674
+ try:
675
+ if feedback:
676
+ graph.update_state(thread, {"human_analyst_feedback": feedback}, as_node="human_feedback")
677
+
678
+ all_analysts = []
679
+ for event in graph.stream(None, thread, stream_mode="values"):
680
+ analysts = event.get('analysts', [])
681
+ if analysts:
682
+ all_analysts = analysts # 更新為最新的分析師列表
683
+
684
+ if all_analysts:
685
+ # 只取最後 max_analysts 個分析師
686
+ latest_analysts = all_analysts[-max_analysts:]
687
+ analyst_info = []
688
+ for analyst in latest_analysts:
689
+ analyst_info.append(f"## {analyst.name}\n"
690
+ f"Affiliation: {analyst.affiliation}\n"
691
+ f"Role: {analyst.role}\n"
692
+ f"Description: {analyst.description}\n"
693
+ f"---")
694
+ return "\n\n".join(analyst_info), thread_id, gr.update(visible=True)
695
+ return "No new analysts generated. Please try again with different feedback.", thread_id, gr.update(visible=True)
696
+ else:
697
+ # 繼續研究過程
698
+ graph.update_state(thread, {"human_analyst_feedback": None}, as_node="human_feedback")
699
+ for event in graph.stream(None, thread, stream_mode="updates"):
700
+ pass # 我們可以在這裡添加進度更新
701
+
702
+ final_state = graph.get_state(thread)
703
+ report = final_state.values.get('final_report')
704
+ return report, None, gr.update(visible=False)
705
+
706
+ except GeneratorExit:
707
+ return "Feedback process was interrupted. Please try again.", thread_id, gr.update(visible=True)
708
+ except Exception as e:
709
+ return f"An error occurred: {str(e)}", thread_id, gr.update(visible=True)
710
+
711
+ # 在文件頂部��加以下導入
712
+ from gradio.themes.utils import colors
713
+ from gradio.themes import Base
714
+
715
+ # 在 demo 定義之前添加以下主題定義
716
+ theme = Base(
717
+ primary_hue=colors.blue,
718
+ secondary_hue=colors.slate,
719
+ neutral_hue=colors.gray,
720
+ font=("Helvetica", "sans-serif"),
721
+ ).set(
722
+ body_background_fill="*neutral_50",
723
+ body_background_fill_dark="*neutral_900",
724
+ button_primary_background_fill="*primary_600",
725
+ button_primary_background_fill_hover="*primary_700",
726
+ button_primary_text_color="white",
727
+ block_title_text_weight="600",
728
+ block_border_width="2px",
729
+ block_shadow="0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1)",
730
+ )
731
+
732
+ # 修改 Gradio 介面定義
733
+ with gr.Blocks(theme=theme, css="footer {visibility: hidden}") as demo:
734
+ gr.Markdown(
735
+ """
736
+ # Multi-agent Research Team 👨‍👨‍👦‍👦🔗
737
+
738
+ Generate a team of AI experts to conduct in-depth research on your chosen topic.
739
+ """
740
+ )
741
+
742
+ with gr.Row():
743
+ with gr.Column(scale=3):
744
+ topic = gr.Textbox(label="Research Topic", placeholder="Enter your research topic here...")
745
+ with gr.Column(scale=1):
746
+ max_analysts = gr.Slider(minimum=1, maximum=5, step=1, label="Number of Experts", value=3)
747
+
748
+ start_btn = gr.Button("Generate Experts", variant="primary")
749
+
750
+ output = gr.Markdown(label="Output")
751
+
752
+ with gr.Row(visible=False) as feedback_row:
753
+ with gr.Column(scale=3):
754
+ feedback = gr.Textbox(label="Human-in-the-Loop Feedback (optional)", placeholder="Provide feedback to refine the experts...")
755
+ with gr.Column(scale=2):
756
+ with gr.Row():
757
+ continue_btn = gr.Button("Regenerate Experts", variant="secondary", scale=1)
758
+ finish_btn = gr.Button("Start Research", variant="primary", scale=1)
759
+
760
+ thread_id = gr.State(value=None)
761
+
762
+ def update_visibility(visible):
763
+ return gr.update(visible=visible)
764
+
765
+ start_btn.click(run_research,
766
+ inputs=[topic, max_analysts],
767
+ outputs=[output, thread_id, feedback_row])
768
+
769
+ continue_btn.click(process_feedback,
770
+ inputs=[topic, max_analysts, feedback, thread_id],
771
+ outputs=[output, thread_id, feedback_row])
772
+
773
+ finish_btn.click(lambda t, m, f, tid: process_feedback(t, m, "", tid),
774
+ inputs=[topic, max_analysts, feedback, thread_id],
775
+ outputs=[output, thread_id, feedback_row])
776
+
777
+ if __name__ == "__main__":
778
+ demo.launch()