File size: 3,336 Bytes
683d0f6
1c79925
683d0f6
1c79925
683d0f6
 
 
 
 
 
 
 
 
 
 
 
1c79925
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
646e829
 
 
 
 
 
 
1c79925
683d0f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242ee3d
 
 
683d0f6
 
 
 
 
 
 
 
dad70d6
 
 
 
 
683d0f6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import streamlit as st
from parse import retrieve
from transfer import retrieve_transfer

def main():
    st.sidebar.title("Choose Function")
    function_choice = st.sidebar.radio("", ["PromptBench", "Retrieve Transferability Information"])
    
    if function_choice == "PromptBench":
        promptbench()

    elif function_choice == "Retrieve Transferability Information":
        retrieve_transferability_information()

def promptbench():
    st.title("PromptBench")
  
    model_name = st.selectbox(  
        "Select Model",  
        options=["T5", "Vicuna", "UL2", "ChatGPT"],  
        index=0,  
    )  
  
    dataset_name = st.selectbox(  
        "Select Dataset",  
        options=[  
            "SST-2", "CoLA", "QQP", "MRPC", "MNLI", "QNLI",  
            "RTE", "WNLI", "MMLU", "SQuAD V2", "IWSLT 2017", "UN Multi", "Math"  
        ],  
        index=0,  
    )

    attack_name = st.selectbox(
        "Select Attack",
        options=[
            "BertAttack", "CheckList", "DeepWordBug", "StressTest", "TextFooler", "TextBugger", "Semantic"
        ],
        index=0,
    )

    prompt_type = st.selectbox(  
        "Select Prompt Type",  
        options=["zeroshot-task", "zeroshot-role", "fewshot-task", "fewshot-role"],  
        index=0,  
    )  
  
    st.write(f"Model: {model_name}")  
    st.write(f"Dataset: {dataset_name}")  
    st.write(f"Prompt Type: {prompt_type}")  
  
    if st.button("Retrieve"):  
        results = retrieve(model_name, dataset_name, attack_name, prompt_type)
        
        for result in results:
            st.write("Original prompt: {}".format(result["origin prompt"]))
            st.write("Original acc: {}".format(result["origin acc"]))
            st.write("Attack prompt: {}".format(result["attack prompt"]))
            st.write("Attack acc: {}".format(result["attack acc"]))
  

def retrieve_transferability_information():
    st.title("Retrieve Transferability Information")
    source_model_name = st.selectbox(  
        "Select Source Model",  
        options=["T5", "Vicuna", "UL2", "ChatGPT"],  
        index=0,  
    ) 

    target_model_name = st.selectbox(
        "Select Target Model",
        options=["T5", "Vicuna", "UL2", "ChatGPT"],
        index=0,
    )

    if source_model_name == target_model_name:
        st.write("Source model and target model cannot be the same.")
        return
    
    attack_name = st.selectbox(
        "Select Attack",
        options=[
            "BertAttack", "CheckList", "DeepWordBug", "StressTest", "TextFooler", "TextBugger", "Semantic"
        ],
        index=0,
    )

    if attack_name == "Semantic":
        attack_name = "translation"

    shot = st.selectbox(
        "Select Shot",
        options=[0, 3],
        index=0,
    )

    data = retrieve_transfer(source_model_name, target_model_name, attack_name, shot)
    for d in data:
        with st.expander(f"Dataset: {d['dataset']} Prompt Type: {d['type']}-oriented"):
            st.write(f"Origin prompt: {d['origin_prompt']}")
            st.write(f"Attack prompt: {d['atk_prompt']}")
            st.write(f"Source model: origin acc: {d['origin_acc']}, attack acc: {d['atk_acc']}")
            st.write(f"Target model: origin acc: {d['transfer_ori_acc']}, attack acc: {d['transfer_atk_acc']}")

if __name__ == "__main__":
    main()