File size: 6,783 Bytes
f3141ae
 
 
 
 
8886757
f3141ae
 
 
 
b470e8d
 
 
 
 
f3141ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b894e4
 
 
 
f3141ae
 
2b894e4
 
 
f3141ae
2b894e4
 
 
 
f3141ae
 
 
 
 
 
 
 
 
 
8886757
f3141ae
 
 
b470e8d
8886757
f3141ae
b470e8d
 
 
 
 
 
 
 
 
 
 
 
2b894e4
 
 
 
 
b470e8d
 
 
f3141ae
 
 
b470e8d
f3141ae
7b8dacd
f3141ae
 
 
 
 
 
 
 
56bc33d
f3141ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b894e4
f3141ae
 
2b894e4
f3141ae
2b894e4
f3141ae
 
 
 
 
 
2b894e4
f3141ae
2b894e4
f3141ae
 
 
 
 
 
2b894e4
f3141ae
2b894e4
f3141ae
 
 
 
 
 
 
 
2b894e4
f3141ae
2b894e4
f3141ae
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
from typing import Callable
import hashlib

import mesop as me

import components as mex
from state import Prompt


@me.component
def prompt_eval_table(
  prompts: list[Prompt],
  on_select_rating: Callable | None = None,
  on_click_run: Callable | None = None,
):
  data = _make_table_meta(prompts)
  response_map = _make_response_map(prompts)

  with me.box(
    style=me.Style(
      display="grid",
      grid_template_columns=" ".join([row.get("size", "1fr") for row in data]),
      overflow_x="scroll",
    )
  ):
    # Render first header row. This row only displays the Prompt version.
    for row in data:
      with me.box(style=_HEADER_STYLE):
        me.text(row.get("header_1", ""))

    # Render second header row.
    for row in data:
      with me.box(style=_HEADER_STYLE):
        if row["type"] == "model_rating":
          avg_rating = _calculate_avg_rating_from_prompt(row["prompt"])
          if avg_rating is not None:
            with me.tooltip(message="Average rating"):
              me.text(
                f"{avg_rating:.2f}",
                style=me.Style(
                  font_size=13, color=me.theme_var("on-surface-variant"), text_align="center"
                ),
              )
        elif row["type"] == "variable":
          me.text(
            row.get("header_2", ""), style=me.Style(font_size=13, color=me.theme_var("primary"))
          )
        else:
          me.text(
            row.get("header_2", ""),
            style=me.Style(font_size=13, color=me.theme_var("on-surface-variant")),
          )

    # Render examples
    for row_index, example in enumerate(prompts[0].responses):
      response_key = _make_variables_key(example["variables"])
      for row in data:
        if row["type"] == "index":
          with me.box(style=_INDEX_STYLE):
            me.text(str(row_index))
        elif row["type"] == "variable":
          with me.box(style=_MARKDOWN_BOX_STYLE):
            mex.markdown(example["variables"][row["variable_name"]], has_copy_to_clipboard=True)
        elif row["type"] == "model_response":
          with me.box(style=_MARKDOWN_BOX_STYLE):
            prompt_response = response_map[row["prompt"].version].get(response_key)
            if prompt_response and prompt_response[0]["output"]:
              mex.markdown(prompt_response[0]["output"], has_copy_to_clipboard=True)
            else:
              with me.box(
                style=me.Style(
                  display="flex", height="100%", justify_content="center", align_items="center"
                )
              ):
                response_index = prompt_response[1] if prompt_response else "-1"
                _, selected_version_response_index = response_map[prompts[0].version].get(
                  response_key
                )
                with me.content_button(
                  key=f"run_{row['prompt'].version}_{response_index}_{selected_version_response_index}",
                  on_click=on_click_run,
                  style=me.Style(
                    background=me.theme_var("secondary-container"),
                    color=me.theme_var("on-secondary-container"),
                    border_radius="10",
                  ),
                ):
                  with me.tooltip(message="Run prompt"):
                    me.icon("play_arrow")
        elif row["type"] == "model_rating":
          with me.box(style=_RATING_STYLE):
            prompt_response = response_map[row["prompt"].version].get(response_key)
            if prompt_response and prompt_response[0]["output"]:
              me.select(
                value=prompt_response[0].get("rating", "0"),
                options=[
                  me.SelectOption(label="1", value="1"),
                  me.SelectOption(label="2", value="2"),
                  me.SelectOption(label="3", value="3"),
                  me.SelectOption(label="4", value="4"),
                  me.SelectOption(label="5", value="5"),
                ],
                on_selection_change=on_select_rating,
                key=f"rating_{row['prompt'].version}_{prompt_response[1]}",
                style=me.Style(width=60),
              )


def _hash_string(string: str) -> str:
  encoded_string = string.encode("utf-8")
  result = hashlib.md5(encoded_string)
  return result.hexdigest()


def _make_variables_key(variables: dict[str, str]) -> str:
  return _hash_string("--".join(variables.values()))


def _make_response_map(prompts: list[Prompt]) -> dict:
  prompt_map = {}
  for prompt in prompts:
    response_map = {}
    for response_index, response in enumerate(prompt.responses):
      key = _make_variables_key(response["variables"])
      response_map[key] = (response, response_index)
    prompt_map[prompt.version] = response_map
  return prompt_map


def _make_table_meta(prompts: list[Prompt]) -> dict:
  data = [
    {
      "type": "index",
    }
  ]
  for variable in prompts[0].variables:
    data.append(
      {
        "type": "variable",
        "header_2": "{{" + variable + "}}",
        "size": "20fr",
        "variable_name": variable,
      }
    )

  for i, prompt in enumerate(prompts):
    data.append(
      {
        "type": "model_response",
        "header_1": "Version " + str(prompt.version),
        "header_2": "Model response",
        "index": i,
        "size": "20fr",
        "prompt": prompt,
      }
    )
    data.append(
      {
        "type": "model_rating",
        "header_1": "",
        "header_2": "Rating",
        "prompt": prompt,
      }
    )
  return data


def _calculate_avg_rating_from_prompt(prompt: Prompt) -> float | None:
  ratings = [int(response["rating"]) for response in prompt.responses if response.get("rating")]
  if ratings:
    return sum(ratings) / float(len(ratings))
  return None


_BORDER_SIDE = me.BorderSide(width=1, style="solid", color=me.theme_var("outline-variant"))

_HEADER_STYLE = me.Style(
  background=me.theme_var("surface-container-lowest"),
  border=me.Border.all(_BORDER_SIDE),
  color=me.theme_var("on-surface"),
  font_size=15,
  font_weight="bold",
  padding=me.Padding.all(10),
)

_INDEX_STYLE = me.Style(
  background=me.theme_var("surface-container-lowest"),
  border=me.Border.all(_BORDER_SIDE),
  color=me.theme_var("on-surface"),
  font_size=14,
  padding=me.Padding.all(10),
  text_align="center",
)

_MARKDOWN_BOX_STYLE = me.Style(
  background=me.theme_var("surface-container-lowest"),
  border=me.Border.all(_BORDER_SIDE),
  color=me.theme_var("on-surface"),
  font_size=14,
  padding=me.Padding.all(10),
  max_height=300,
  min_width=300,
  overflow_y="scroll",
)

_RATING_STYLE = me.Style(
  background=me.theme_var("surface-container-lowest"),
  border=me.Border.all(_BORDER_SIDE),
  color=me.theme_var("on-surface"),
  padding=me.Padding.all(10),
)