File size: 707 Bytes
ef9a72e
 
5dbce8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef9a72e
 
 
5dbce8b
ef9a72e
 
 
 
5dbce8b
ef9a72e
 
 
 
5dbce8b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from pandas import DataFrame

from zeno import (
    DistillReturn,
    MetricReturn,
    ModelReturn,
    ZenoOptions,
    distill,
    metric,
    model,
)


@model
def model_ret(name):
    def model(df: DataFrame, ops: ZenoOptions):
        return ModelReturn(model_output=df[ops.data_column])

    return model


@distill
def length(df: DataFrame, ops: ZenoOptions):
    return DistillReturn(distill_output=df["prompt"].str.len())


@metric
def avg_image_nswf(df: DataFrame, ops: ZenoOptions):
    return MetricReturn(metric=float(df["image_nsfw"].dropna().mean()))


@metric
def avg_prompt_nsfw(df: DataFrame, ops: ZenoOptions):
    return MetricReturn(metric=float(df["prompt_nsfw"].dropna().mean()))