Spaces:
Sleeping
Sleeping
| from abc import abstractmethod | |
| from functools import cached_property | |
| from pathlib import Path | |
| from typing import Any, Callable | |
| from langchain_core.language_models.chat_models import BaseChatModel | |
| from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate | |
| from langchain_core.runnables import Runnable | |
| from pydantic import BaseModel | |
| from talentum_score.flows.base import FlowState | |
| class Node(Runnable[FlowState, FlowState]): | |
| name: str | |
| description: str | |
| class BaseLLMNode(Node): | |
| def __init__( | |
| self, | |
| system_prompt_path: Path | str, | |
| query: str, | |
| output_key: str, | |
| model: str, | |
| system_prompt_kwargs: dict[str, str] = {}, | |
| structured_output: BaseModel | None = None, | |
| parse_output_fn: Callable[[Any], Any] = lambda x: x["parsed"] | |
| if isinstance(x, dict) and x.get("parsed") is not None | |
| else x, | |
| ): | |
| self.system_prompt_path = system_prompt_path | |
| self.system_prompt_kwargs = system_prompt_kwargs | |
| self.structured_output = structured_output | |
| self.query = query | |
| self.model = model | |
| self.output_key = output_key | |
| self.parse_output_fn = parse_output_fn | |
| def system_prompt(self) -> str: | |
| return ( | |
| SystemMessagePromptTemplate.from_template_file( | |
| self.system_prompt_path, | |
| input_variables=list(self.system_prompt_kwargs.keys()), | |
| ) | |
| .format(**self.system_prompt_kwargs) | |
| .content | |
| ) | |
| def prompt(self) -> ChatPromptTemplate: | |
| return ChatPromptTemplate.from_messages( | |
| [("system", self.system_prompt), ("human", "{input}")] | |
| ) | |
| def get_llm(self, **kwargs) -> BaseChatModel: ... | |
| def invoke(self, state: FlowState, context: dict[str, Any], **kwargs) -> FlowState: | |
| llm = self.get_llm(**kwargs) | |
| if self.structured_output is not None and hasattr( | |
| llm, "with_structured_output" | |
| ): | |
| llm = llm.with_structured_output(self.structured_output, include_raw=True) | |
| elif not hasattr(llm, "with_structured_output"): | |
| raise ValueError(f"LLM {llm.model_name} does not support structured output") | |
| chain = self.prompt | llm | |
| return { | |
| self.output_key: self.parse_output_fn( | |
| chain.invoke({"input": self.query.format(state=state)}) | |
| ) | |
| } | |