File size: 1,256 Bytes
8b6a021
 
 
 
 
 
 
 
3ae37e3
8b6a021
 
 
 
 
 
529f708
8b6a021
f83a2a2
8b6a021
 
 
 
529f708
16981d2
 
529f708
f83a2a2
 
 
529f708
f83a2a2
 
 
529f708
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from langchain_community.callbacks import get_openai_callback
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI

from prompts.outline_agent import template


class OutlineAgent:
    def __init__(self, model: str = "gpt-4o", temperature: float = 0.3) -> None:
        self._prompt = PromptTemplate(input_variables=["query"], template=template)

        self._llm = ChatOpenAI(model=model, temperature=temperature)

        self._chain = self._prompt | self._llm

    def run(self, query: str) -> tuple[str, list[dict[str, str]], dict[str, int | float]]:
        with get_openai_callback() as cb:
            outline = self._chain.invoke({"query": query}).content.strip()

            tokens = cb.total_tokens
            cost = cb.total_cost

        elements = outline.split("\n\n")

        main_title = elements[0][elements[0].index(":") + 1 :].strip().replace('"', "")

        themes = []
        keys = ["title", "query", "description"]

        for theme in elements[1:]:
            items = theme.strip().split("\n")
            themes.append({keys[i]: item[item.index(":") + 1 :].strip().replace('"', "") for i, item in enumerate(items)})

        return main_title, themes, {"tokens": tokens, "cost": cost}