File size: 1,768 Bytes
98847a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b087e88
98847a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b087e88
 
 
 
 
 
 
 
 
98847a8
b087e88
98847a8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import pandas as pd

from pathlib import Path


def plot_data(metric, selected_attack, all_attacks_df):
    attack_df = all_attacks_df[all_attacks_df.attack == selected_attack]

    # if metric == "None":
    #     return gr.LinePlot(x_bin=None)

    # return gr.LinePlot(
    #     attack_df,
    #     x="strength",
    #     y=metric,
    #     color="model",
    # )


def mk_variations(
    all_attacks_df,
    attacks_with_variations: list[str],
):
    # all_attacks_df = pd.read_csv(csv_file)
    # print(all_attacks_df)
    # print(csv_file)

    # with gr.Row():
    #     group_by = gr.Radio(metrics, value=metrics[0], label="Choose metric")
    #     attacks_dropdown = gr.Dropdown(
    #         attacks_with_variations,
    #         label=attacks_with_variations[0],
    #         info="Select attack",
    #     )

    # attacks_by_strength = plot_data(
    #     group_by.value, attacks_dropdown.value, all_attacks_df
    # )

    # all_graphs = [
    #     attacks_by_strength,
    # ]

    # group_by.change(
    #     lambda group: plot_data(group, attacks_dropdown.value, all_attacks_df),
    #     group_by,
    #     all_graphs,
    # )

    # attacks_dropdown.change(
    #     lambda attack: plot_data(group_by.value, attack, all_attacks_df),
    #     attacks_dropdown,
    #     all_graphs,
    # )

    # Replace NaN values with None for JSON serialization
    all_attacks_df = all_attacks_df.fillna(value="NaN")
    attacks_plot_metrics = [
        "bit_acc",
        "log10_p_value",
        "TPR",
        "FPR",
        "watermark_det_score",
    ]
    return {
        "metrics": attacks_plot_metrics,
        "attacks_with_variations": attacks_with_variations,
        "all_attacks_df": all_attacks_df.to_dict(orient="records"),
    }