File size: 3,229 Bytes
d5b7cf9
 
 
 
 
 
 
 
 
 
a3397bd
 
 
 
 
 
 
d5b7cf9
 
48ebd8a
a3397bd
 
 
 
 
48ebd8a
 
 
 
 
 
d5b7cf9
48ebd8a
a3397bd
 
 
 
 
 
 
 
 
48ebd8a
 
 
d5b7cf9
48ebd8a
 
d5b7cf9
48ebd8a
d5b7cf9
 
48ebd8a
a3397bd
 
 
 
 
 
 
 
 
48ebd8a
 
 
 
 
 
 
 
 
 
 
 
d5b7cf9
48ebd8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5b7cf9
48ebd8a
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import polars as pl
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field

from planning_ai.common.utils import Paths
from planning_ai.llms.llm import GPT4o


class SummaryEvaluator(BaseModel):
    """Model for evaluating summaries.

    Attributes:
        score (int): The number of the best summary.
    """

    score: int = Field(...)


def load_templates():
    """Loads the comparison and summary templates from files.

    Returns:
        tuple: A tuple containing the compare template and summary template as strings.
    """
    with open("./planning_ai/eval/eval.txt", "r") as f:
        compare_template = f.read()
    with open("./planning_ai/eval/summary.txt", "r") as f:
        summary_template = f.read()
    return compare_template, summary_template


def initialize_chains(compare_template, summary_template):
    """Initializes the comparison and summary chains.

    Args:
        compare_template (str): The template for comparison.
        summary_template (str): The template for summary.

    Returns:
        tuple: A tuple containing the compare chain and summary chain.
    """
    SLLM = GPT4o.with_structured_output(SummaryEvaluator, strict=True)
    compare_prompt = ChatPromptTemplate([("system", compare_template)])
    compare_chain = compare_prompt | SLLM

    summary_prompt = ChatPromptTemplate([("system", summary_template)])
    summary_chain = summary_prompt | GPT4o | StrOutputParser()

    return compare_chain, summary_chain


def process_summaries(compare_chain, summary_chain):
    """Processes summaries by comparing and scoring them.

    Args:
        compare_chain: The chain used for comparing summaries.
        summary_chain: The chain used for generating summaries.

    Returns:
        polars.DataFrame: A DataFrame containing the original text, summaries, and scores.
    """
    original = pl.read_parquet(Paths.STAGING / "gcpt3.parquet").filter(
        pl.col("attachments_id").is_null()
    )
    summaries1 = original[["text", "representations_summary"]].unique()

    summaries2 = summaries1[["text"]]
    summaries2 = summaries2.with_columns(
        pl.col("text")
        .map_elements(
            lambda x: summary_chain.invoke({"content": x}), return_dtype=pl.String
        )
        .alias("summary")
    )

    summaries = summaries1.join(summaries2, on="text")
    summaries = summaries.with_columns(
        pl.struct(["text", "representations_summary", "summary"])
        .map_elements(
            lambda x: compare_chain.invoke(
                {
                    "document": x["text"],
                    "summary_1": x["representations_summary"],
                    "summary_2": x["summary"],
                }
            ).score,
            return_dtype=pl.Int8,
        )
        .alias("score")
    )
    return summaries


def main():
    compare_template, summary_template = load_templates()
    compare_chain, summary_chain = initialize_chains(compare_template, summary_template)
    summaries = process_summaries(compare_chain, summary_chain)
    summaries.write_parquet(Paths.OUT / "eval.parquet")


if __name__ == "__main__":
    main()