import streamlit as st from parse import retrieve from transfer import retrieve_transfer def main(): st.sidebar.title("Choose Function") function_choice = st.sidebar.radio("", ["PromptBench", "Retrieve Transferability Information"]) if function_choice == "PromptBench": promptbench() elif function_choice == "Retrieve Transferability Information": retrieve_transferability_information() def promptbench(): st.title("PromptBench") model_name = st.selectbox( "Select Model", options=["T5", "Vicuna", "UL2", "ChatGPT"], index=0, ) dataset_name = st.selectbox( "Select Dataset", options=[ "SST-2", "CoLA", "QQP", "MRPC", "MNLI", "QNLI", "RTE", "WNLI", "MMLU", "SQuAD V2", "IWSLT 2017", "UN Multi", "Math" ], index=0, ) attack_name = st.selectbox( "Select Attack", options=[ "BertAttack", "CheckList", "DeepWordBug", "StressTest", "TextFooler", "TextBugger", "Semantic" ], index=0, ) prompt_type = st.selectbox( "Select Prompt Type", options=["zeroshot-task", "zeroshot-role", "fewshot-task", "fewshot-role"], index=0, ) st.write(f"Model: {model_name}") st.write(f"Dataset: {dataset_name}") st.write(f"Prompt Type: {prompt_type}") if st.button("Retrieve"): results = retrieve(model_name, dataset_name, attack_name, prompt_type) for result in results: st.write("Original prompt: {}".format(result["origin prompt"])) st.write("Original acc: {}".format(result["origin acc"])) st.write("Attack prompt: {}".format(result["attack prompt"])) st.write("Attack acc: {}".format(result["attack acc"])) def retrieve_transferability_information(): st.title("Retrieve Transferability Information") source_model_name = st.selectbox( "Select Source Model", options=["T5", "Vicuna", "UL2", "ChatGPT"], index=0, ) target_model_name = st.selectbox( "Select Target Model", options=["T5", "Vicuna", "UL2", "ChatGPT"], index=0, ) if source_model_name == target_model_name: st.write("Source model and target model cannot be the same.") return attack_name = st.selectbox( "Select Attack", options=[ "BertAttack", "CheckList", "DeepWordBug", "StressTest", "TextFooler", "TextBugger", "Semantic" ], index=0, ) if attack_name == "Semantic": attack_name = "translation" shot = st.selectbox( "Select Shot", options=[0, 3], index=0, ) data = retrieve_transfer(source_model_name, target_model_name, attack_name, shot) for d in data: with st.expander(f"Dataset: {d['dataset']} Prompt Type: {d['type']}-oriented"): st.write(f"Origin prompt: {d['origin_prompt']}") st.write(f"Attack prompt: {d['atk_prompt']}") st.write(f"Source model: origin acc: {d['origin_acc']}, attack acc: {d['atk_acc']}") st.write(f"Target model: origin acc: {d['transfer_ori_acc']}, attack acc: {d['transfer_atk_acc']}") if __name__ == "__main__": main()