from langchain import ConversationChain, PromptTemplate from langchain.chains import ConversationalRetrievalChain from langchain.chains.base import Chain from langchain.memory import ConversationBufferMemory from langchain.schema import BaseRetriever from pydantic import BaseModel, Field from pydantic_redis import Model, Store from edu_assistant.learning_tasks.base import BaseTask from edu_assistant.utils.langchain_utils import ( escape_for_prompt, load_gpt4_llm, load_llm, ) from edu_assistant.utils.redis_utils import get_redis_config TEMPLATE = """The following is a friendly conversation between a human and an ai. The ai is talkative and provides lots of specific details from its context. If the ai does not know the answer to a question, it truthfully says it does not know. The ai act following below instructions: --- {instruction} --- The coding problem: --- {problem} --- Student's code: --- {answer} --- Current conversation: {{chat_history}} Human: {{input}} AI:""" KNOWLEDGE_TEMPLATE = """The following is a friendly conversation between a human and an ai. The ai is talkative and provides lots of specific details from its context. If the ai does not know the answer to a question, it truthfully says it does not know. The ai act following below instructions: --- {instruction} --- The coding problem: --- {problem} --- Student's code: ``` {answer} ``` Extra Information might be helpful for you: --- {{context}} --- Current conversation: {{chat_history}} Human: {{question}} AI: """ DEFAULT_INSTRUCTION = """Act as a c++ professional to check student's code. The code is written by a student aged 5-10 and mostly like to buggy or bad performanced. """ DEFAULT_FIRST_QUESTION = "请问这段代码中有什么问题吗?" class CodingProblem(Model): _primary_key_field: str = "title" title: str = Field() question: str = Field() standard_answer: str = Field(default="") analysis: str = Field(default="") language: str = Field(default="") extra: str = Field(default_factory=lambda: list()) # TODO: Add cache to expr function with pydantic 2 computed_field decorator. # Wait for langchain to support pydantic2. @staticmethod def enable_redis_orm(): store = Store(name="coding_problems", redis_config=get_redis_config(), life_span_in_seconds=3600 * 24 * 30) store.register_model(CodingProblem) def expr(self, lang=""): expr = f"## Question\n\n---\n{escape_for_prompt(self.question)}\n---\n\n" expr += ( f"""## Standard Answer (There might be others)\n\n```{lang if lang else self.language} {escape_for_prompt(self.standard_answer)}\n``` """ if self.standard_answer else "" ) expr += f"## Analysis\n\n---\n{escape_for_prompt(self.analysis)}\n---\n\n" if self.analysis else "" expr += "## Extra\n\n" + escape_for_prompt("".join(self.extra)) + "\n" return expr def __str__(self): return self.expr() class CodingAnswer(BaseModel): answer: str = Field() extra: list[str] = Field(default="") def expr(self, lang=""): expr = f"Answer:\n```{lang}\n{escape_for_prompt(self.answer)}\n```\n" expr += escape_for_prompt("".join(self.extra)) + "\n" return expr def __str__(self): return self.expr() class CodingProblemAnalysis(BaseTask): HISTORY_KEY = "chat_history" def __init__( self, instruction: str = DEFAULT_INSTRUCTION, first_question: str = DEFAULT_FIRST_QUESTION, lang: str = "", knowledge: BaseRetriever = None, enable_gpt4: bool = False, ): self.instruction = instruction self.first_question = first_question self.lang = lang self.enable_gpt4 = enable_gpt4 # TODO: load threshold key from implement. value from config self.vectordbkwargs = {"score_threshold": 0.9} # Qdrant cosine. higher is better. if knowledge: self._input_key = "question" self._output_key = "answer" else: self._input_key = "input" self._output_key = "response" self._session_store = {} self._knowledge = knowledge self._init_llm() @staticmethod def build_coding_problem(question: str, standard_answer: str = "", analysis: str = "", extra: list[str] = None): extra = [] if extra is None else extra return CodingProblem(question=question, standard_answer=standard_answer, analysis=analysis, extra=extra) @staticmethod def build_coding_answer(answer: str, extra: list[str] = None): extra = [] if extra is None else extra return CodingAnswer(answer=answer, extra=extra) def start_analysis(self, problem: CodingProblem, answer: CodingAnswer, first_question: str = None) -> dict: """start analysis of a coding problem and incorrect answer. Args: problem (CodingProblem): a coding problem answer (CodingAnswer): a coding problem answer Returns: dict: question answer and metadata """ chain = self._build_chain(problem, answer) session_id = self._create_session_id() self._session_store[session_id] = chain args = {self._input_key: first_question if first_question else self.first_question, self.HISTORY_KEY: ""} # TODO: ConversationalRetrievalChain should support vectordbkwargs # if self._knowledge: # args["vectordbkwargs"] = self.vectordbkwargs result = chain(args) result["session_id"] = session_id return result def ask(self, question: str, session_id: str) -> dict: """further ask question on a coding problem. Args: question (str): question to llm. session_id (str): specify a problem and answer session. Returns: dict: question answer and metadata """ assert question if session_id not in self._session_store: return {} chain = self._session_store[session_id] args = {self._input_key: question} # if self._knowledge: # args["vectordbkwargs"] = self.vectordbkwargs result = chain(args) result["session_id"] = session_id return result def _init_llm(self): self._main_llm = load_gpt4_llm() if self.enable_gpt4 else load_llm() self._secondary_llm = load_llm() def _build_chain(self, problem: CodingProblem, answer: CodingAnswer) -> Chain: memory = ConversationBufferMemory( memory_key=self.HISTORY_KEY, output_key=self._output_key, return_messages=True ) if not self._knowledge: prompt = PromptTemplate.from_template( TEMPLATE.format( instruction=self.instruction, problem=problem.expr(lang=problem.language or self.lang), answer=answer.expr(lang=problem.language or self.lang), ) ) return ConversationChain( llm=self._main_llm, memory=memory, prompt=prompt, ) else: prompt = PromptTemplate.from_template( KNOWLEDGE_TEMPLATE.format( instruction=self.instruction, problem=problem.expr(lang=problem.language or self.lang), answer=answer.expr(lang=problem.language or self.lang), ) ) return ConversationalRetrievalChain.from_llm( llm=self._main_llm, memory=memory, retriever=self._knowledge, condense_question_llm=self._secondary_llm, return_source_documents=True, combine_docs_chain_kwargs={"prompt": prompt}, )