File size: 898 Bytes
fbeb50a
 
 
 
e3b702e
 
fbeb50a
 
3ae37e3
e3b702e
fbeb50a
e3b702e
fbeb50a
 
 
1423671
fbeb50a
e3b702e
fbeb50a
 
 
 
5c397d4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from langchain_community.callbacks import get_openai_callback
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI

from prompts.accountability_agent import template


class AccountabilityAgent:
    def __init__(self, model: str = "gpt-4o", temperature: float = 0.7) -> None:
        self._prompt = PromptTemplate(input_variables=["query", "synthesis"], template=template)

        self._llm = ChatOpenAI(model=model, temperature=temperature)

        self._chain = self._prompt | self._llm

    def run(self, query: str, synthesis: str) -> tuple[str, dict[str, int | float]]:
        with get_openai_callback() as cb:
            accountability = self._chain.invoke({"query": query, "synthesis": synthesis}).content.strip()

            tokens = cb.total_tokens
            cost = cb.total_cost

        return accountability, {"tokens": tokens, "cost": cost}