from typing import List, TypedDict from llm_config import get_llm_instructor, call_llm from pydantic import BaseModel, Field import ui import prompts from search import fetch_search_results, format_search_results import random import time from dotenv import load_dotenv import re load_dotenv() class RoundtableMessage(BaseModel): response: str = Field(..., title="Your response") follow_up: str = Field(..., title="Your follow-up question") next_persona: str = Field(..., title="Who you are asking the question to") class ContentState(TypedDict): previous_messages: List[dict] content: str expert_question: str iteration: int full_messages: List[str] refernces : str class Queries(BaseModel): queries : List[str] = Field(..., title="List of queries to search for") class PersonaQuestion(BaseModel): question: str = Field(..., title="Your question for the expert") class StrucutredAnswer(BaseModel): answer_response: str = Field(..., title="The response to the question with citations") references_used: List[int] = Field(..., title="The references used to answer the question") class ImproveContent: def __init__(self, section_topic, section_description, section_key_questions, personas): self.section_topic = section_topic self.section_description = section_description self.section_key_questions = section_key_questions self.client = get_llm_instructor() self.num_search_result = 1 self.num_interview_rounds = 3 self.personas = personas self.warm_start_rounds = 10 # Define the initial state def create_initial_state(self) -> ContentState: return { "expert_question": "", "iteration": 0, 'previous_messages': [], 'full_messages': [], 'references' : '' } def expert_question_generator(self, persona, state: ContentState) -> ContentState: response = call_llm( instructions=prompts.QUALITY_CHECKER_INSTRUCTIONS, additional_messages= state['previous_messages'], context={ "title_description": self.section_description + ":" + self.section_topic, "key_questions": self.section_key_questions, 'persona': persona.persona }, response_model=PersonaQuestion, logging_fn="quality_checker" ) ui.system_sub_update("-------------------") ui.system_sub_update(f'{persona.name} ({persona.role},{persona.affiliation}):') ui.system_sub_update(response.question) ui.system_sub_update("-------------------") state["expert_question"] = response.question state['previous_messages'].append({'role' : 'assistant', 'content': response.question}) state['full_messages'].append(response.question) return state def replace_references(self, text: str, references_list: List[int]) -> str: """Helper method to replace bracketed references with unique numbering.""" for idx in references_list: text = text.replace(f"[{idx}]", f"[{self.num_search_result}]") self.num_search_result += 1 return text def answer_question(self, persona, state: ContentState): queries = call_llm( instructions=prompts.IMPROVE_CONTENT_CREATE_QUERY_INSTRUCTIONS, model_type='fast', context={ "section_topic": self.section_topic, "expert_question": state["expert_question"], 'persona': persona.persona }, response_model=Queries, logging_fn="improve_content_create_query" ) search_results, search_results_list = yield from fetch_search_results(queries.queries, self.task_status, self.section_topic, self.update_ui_fn) # Hit the search engine to fetch relevant documents if search_results_list == []: queries = call_llm( instructions=prompts.IMPROVE_CONTENT_CREATE_QUERY_INSTRUCTIONS, model_type='fast', context={ "section_topic": self.section_topic, "expert_question": state["expert_question"], 'persona': persona.persona }, response_model=Queries, logging_fn="improve_content_create_query_fallback" ) search_results, search_results_list = yield from fetch_search_results(queries.queries, self.task_status,self.section_topic, self.update_ui_fn) response = call_llm( instructions=prompts.IMPORVE_CONTENT_ANSWER_QUERY_INSTRUCTION, model_type='rag', context={ "section_topic": self.section_topic, "expert_question": state["expert_question"], "search_results": search_results, 'persona' : persona.persona }, response_model=StrucutredAnswer, logging_fn="improve_content_answer_query" ) state["content"] =response.answer_response references_used = format_search_results([search_results_list[i-1] for i in response.references_used]) # Find all unique bracketed references in the search results bracketed_refs = re.findall(r'\[(\d+)\](?=\s*Title:)', search_results) #Replace citations[2,3,4] with [2][3][4] cited_references_raw = re.findall(r'\[(\d+(?:,\s*\d+)*)\]', response.answer_response) for group in cited_references_raw: nums_list = group.split(',') new_string = ''.join(f'[{n.strip()}]' for n in nums_list) old_string = f'[{group}]' response.answer_response = response.answer_response.replace(old_string, new_string) # Replace each reference number with its a unique search number for ref in bracketed_refs: search_results = search_results.replace(f'[{ref}]', f"[{self.num_search_result}]") response.answer_response = response.answer_response.replace(f'[{ref}]', f"[{self.num_search_result}]") self.num_search_result += 1 ui.system_sub_update("-------------------") ui.system_sub_update('Content:') ui.system_sub_update(response.answer_response) ui.system_sub_update("-------------------") state['previous_messages'].append({'role' : 'user', 'content' : response.answer_response}) state['full_messages'].append(response.answer_response) state['references'] = state['references'] + '\n\n' + search_results state["iteration"] += 1 return state def create_and_run_interview(self, task_status, update_ui_fn): """Runs an iterative process of generating questions and answers until the iteration limit is reached.""" self.task_status = task_status self.update_ui_fn = update_ui_fn discussion_messages = [] for persona in self.personas: ui.system_update(f"Starting discussion with : {persona.name}: {persona.role}, {persona.affiliation}") state = self.create_initial_state() while state["iteration"] <= self.num_interview_rounds: state = self.expert_question_generator(persona, state) state = yield from self.answer_question(persona, state) discussion_messages.extend(state['previous_messages']) self.final_state = state return discussion_messages def generate_final_section(self, synopsis): return '\n\n'.join(self.final_state['full_messages']), self.final_state['references'] def warm_start_discussion(self): """Warm start the discussion with existing personas""" messages = [f"{self.personas[0].name}: Hi! Let's get started!"] selected_persona = random.choice(self.personas) for _ in range(self.warm_start_rounds): # Get the last 5 messages if there are more than 5 recent_messages = messages[-5:] if len(messages) > 5 else messages message = call_llm( instructions=prompts.ROUNDTABLE_DISCUSSION_INSTRUCTIONS, model_type='fast', context={ "persona_name" : selected_persona.name, "persona_role" : selected_persona.role, "persona_affiliation" : selected_persona.affiliation, "persona_focus" : selected_persona.focus, "personas" : "\n\n".join([p.name + '\n' + p.persona for p in self.personas if p != selected_persona]), "discussion" : "\n\n".join(recent_messages) }, response_model=RoundtableMessage, logging_fn="roundtable_discussion" ) ui.system_sub_update("\n\n" + selected_persona.name + ": " + message.response + '\n' + message.follow_up) messages.append(selected_persona.name + ": " + message.response + '\n' + message.follow_up) selected_persona = [p for p in self.personas if p.name == message.next_persona][0] time.sleep(3) return messages if __name__ == "__main__": section_name = 'Glean Search in the Enterprise Search Market' section_description = 'Positioning and Competition' section_key_questions = ['how is glean positioned in the enterprise search market?', "who are the main competitors in this space?"] personas = ['\nRole: Business Analyst\nAffiliation: Enterprise Software Consultant\nDescription: Specializes in helping organizations implement and optimize AI-powered tools for improved productivity and knowledge management. Will analyze Glean and Copilot from a business user perspective.\n'] improve_content = ImproveContent(section_name, section_description, section_key_questions, personas) improved_content = improve_content.create_and_run_interview() improve_content.generate_final_section() print(improved_content)