Spaces:
Runtime error
Runtime error
File size: 3,336 Bytes
683d0f6 1c79925 683d0f6 1c79925 683d0f6 1c79925 646e829 1c79925 683d0f6 242ee3d 683d0f6 dad70d6 683d0f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import streamlit as st
from parse import retrieve
from transfer import retrieve_transfer
def main():
st.sidebar.title("Choose Function")
function_choice = st.sidebar.radio("", ["PromptBench", "Retrieve Transferability Information"])
if function_choice == "PromptBench":
promptbench()
elif function_choice == "Retrieve Transferability Information":
retrieve_transferability_information()
def promptbench():
st.title("PromptBench")
model_name = st.selectbox(
"Select Model",
options=["T5", "Vicuna", "UL2", "ChatGPT"],
index=0,
)
dataset_name = st.selectbox(
"Select Dataset",
options=[
"SST-2", "CoLA", "QQP", "MRPC", "MNLI", "QNLI",
"RTE", "WNLI", "MMLU", "SQuAD V2", "IWSLT 2017", "UN Multi", "Math"
],
index=0,
)
attack_name = st.selectbox(
"Select Attack",
options=[
"BertAttack", "CheckList", "DeepWordBug", "StressTest", "TextFooler", "TextBugger", "Semantic"
],
index=0,
)
prompt_type = st.selectbox(
"Select Prompt Type",
options=["zeroshot-task", "zeroshot-role", "fewshot-task", "fewshot-role"],
index=0,
)
st.write(f"Model: {model_name}")
st.write(f"Dataset: {dataset_name}")
st.write(f"Prompt Type: {prompt_type}")
if st.button("Retrieve"):
results = retrieve(model_name, dataset_name, attack_name, prompt_type)
for result in results:
st.write("Original prompt: {}".format(result["origin prompt"]))
st.write("Original acc: {}".format(result["origin acc"]))
st.write("Attack prompt: {}".format(result["attack prompt"]))
st.write("Attack acc: {}".format(result["attack acc"]))
def retrieve_transferability_information():
st.title("Retrieve Transferability Information")
source_model_name = st.selectbox(
"Select Source Model",
options=["T5", "Vicuna", "UL2", "ChatGPT"],
index=0,
)
target_model_name = st.selectbox(
"Select Target Model",
options=["T5", "Vicuna", "UL2", "ChatGPT"],
index=0,
)
if source_model_name == target_model_name:
st.write("Source model and target model cannot be the same.")
return
attack_name = st.selectbox(
"Select Attack",
options=[
"BertAttack", "CheckList", "DeepWordBug", "StressTest", "TextFooler", "TextBugger", "Semantic"
],
index=0,
)
if attack_name == "Semantic":
attack_name = "translation"
shot = st.selectbox(
"Select Shot",
options=[0, 3],
index=0,
)
data = retrieve_transfer(source_model_name, target_model_name, attack_name, shot)
for d in data:
with st.expander(f"Dataset: {d['dataset']} Prompt Type: {d['type']}-oriented"):
st.write(f"Origin prompt: {d['origin_prompt']}")
st.write(f"Attack prompt: {d['atk_prompt']}")
st.write(f"Source model: origin acc: {d['origin_acc']}, attack acc: {d['atk_acc']}")
st.write(f"Target model: origin acc: {d['transfer_ori_acc']}, attack acc: {d['transfer_atk_acc']}")
if __name__ == "__main__":
main()
|