File size: 5,292 Bytes
07423df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import hashlib
import os
from typing import Any, Dict, List

import pandas as pd

from llm_studio.src.datasets.conversation_chain_handler import get_conversation_chains
from llm_studio.src.datasets.text_utils import get_tokenizer
from llm_studio.src.plots.text_causal_language_modeling_plots import (
    create_batch_prediction_df,
    plot_validation_predictions,
)
from llm_studio.src.utils.data_utils import read_dataframe_drop_missing_labels
from llm_studio.src.utils.plot_utils import PlotData, format_for_markdown_visualization
from llm_studio.src.utils.utils import PatchedAttribute


class Plots:
    @classmethod
    def plot_batch(cls, batch, cfg) -> PlotData:
        tokenizer = get_tokenizer(cfg)
        df = create_batch_prediction_df(
            batch,
            tokenizer,
            ids_for_tokenized_text="chosen_input_ids",
            labels_column="chosen_labels",
        )
        path = os.path.join(cfg.output_directory, "batch_viz.parquet")
        df.to_parquet(path)
        return PlotData(path, encoding="df")

    @classmethod
    def plot_data(cls, cfg) -> PlotData:
        """
        Plots the data in a scrollable table.
        We limit the number of rows to max 600 to avoid rendering issues in Wave.
        As the data visualization is instantiated on every page load, we cache the
        data visualization in a parquet file.
        """
        config_id = (
            str(cfg.dataset.train_dataframe)
            + str(cfg.dataset.system_column)
            + str(cfg.dataset.prompt_column)
            + str(cfg.dataset.answer_column)
            + str(cfg.dataset.rejected_answer_column)
            + str(cfg.dataset.parent_id_column)
        )
        config_hash = hashlib.md5(config_id.encode()).hexdigest()
        path = os.path.join(
            os.path.dirname(cfg.dataset.train_dataframe),
            f"__meta_info__{config_hash}_data_viz.parquet",
        )
        if os.path.exists(path):
            return PlotData(path, encoding="df")

        df = read_dataframe_drop_missing_labels(cfg.dataset.train_dataframe, cfg)

        conversations_chosen = get_conversation_chains(
            df, cfg, limit_chained_samples=True
        )
        with PatchedAttribute(
            cfg.dataset, "answer_column", cfg.dataset.rejected_answer_column
        ):
            conversations_rejected = get_conversation_chains(
                df, cfg, limit_chained_samples=True
            )

        # Limit to max 15 prompt-conversation-answer rounds
        max_conversation_length = min(
            max(
                [len(conversation["prompts"]) for conversation in conversations_chosen]
            ),
            15,
        )

        conversations_to_display: List = []
        for conversation_length in range(1, max_conversation_length + 1):
            conversations_to_display += [
                (conversation_chosen, conversations_rejected)
                for conversation_chosen, conversations_rejected in zip(
                    conversations_chosen, conversations_rejected
                )
                if len(conversation_chosen["prompts"]) == conversation_length
            ][:5]

        # Convert into a scrollable table by transposing the dataframe
        df_transposed = pd.DataFrame(columns=["Sample Number", "Field", "Content"])

        i = 0
        for sample_number, (conversation_chosen, conversations_rejected) in enumerate(
            conversations_to_display
        ):
            if conversation_chosen["systems"][0] != "":
                df_transposed.loc[i] = [
                    sample_number,
                    "System",
                    conversation_chosen["systems"][0],
                ]
                i += 1
            for prompt, answer_chosen, answer_rejected in zip(
                conversation_chosen["prompts"],
                conversation_chosen["answers"],
                conversations_rejected["answers"],  # type: ignore
            ):
                df_transposed.loc[i] = [
                    sample_number,
                    "Prompt",
                    prompt,
                ]
                i += 1
                if answer_chosen == answer_rejected:
                    df_transposed.loc[i] = [
                        sample_number,
                        "Answer",
                        answer_chosen,
                    ]
                    i += 1
                else:
                    df_transposed.loc[i] = [
                        sample_number,
                        "Answer Chosen",
                        answer_chosen,
                    ]
                    i += 1
                    df_transposed.loc[i] = [
                        sample_number,
                        "Answer Rejected",
                        answer_rejected,
                    ]
                    i += 1

        df_transposed["Content"] = df_transposed["Content"].apply(
            format_for_markdown_visualization
        )
        df_transposed.to_parquet(path)
        return PlotData(path, encoding="df")

    @classmethod
    def plot_validation_predictions(
        cls, val_outputs: Dict, cfg: Any, val_df: pd.DataFrame, mode: str
    ) -> PlotData:
        return plot_validation_predictions(val_outputs, cfg, val_df, mode)