Spaces:
Running
Running
import random | |
from difflib import Differ | |
from textattack.attack_recipes import BAEGarg2019 | |
from textattack.datasets import Dataset | |
from textattack.models.wrappers import HuggingFaceModelWrapper | |
from findfile import find_files | |
from flask import Flask | |
from textattack import Attacker | |
class ModelWrapper(HuggingFaceModelWrapper): | |
def __init__(self, model): | |
self.model = model # pipeline = pipeline | |
def __call__(self, text_inputs, **kwargs): | |
outputs = [] | |
for text_input in text_inputs: | |
raw_outputs = self.model.infer(text_input, print_result=False, **kwargs) | |
outputs.append(raw_outputs["probs"]) | |
return outputs | |
class SentAttacker: | |
def __init__(self, model, recipe_class=BAEGarg2019): | |
model = model | |
model_wrapper = ModelWrapper(model) | |
recipe = recipe_class.build(model_wrapper) | |
# WordNet defaults to english. Set the default language to French ('fra') | |
# recipe.transformation.language = "en" | |
_dataset = [("", 0)] | |
_dataset = Dataset(_dataset) | |
self.attacker = Attacker(recipe, _dataset) | |
def diff_texts(text1, text2): | |
d = Differ() | |
text1_words = text1.split() | |
text2_words = text2.split() | |
return [ | |
(token[2:], token[0] if token[0] != " " else None) | |
for token in d.compare(text1_words, text2_words) | |
] | |
def get_ensembled_tad_results(results): | |
target_dict = {} | |
for r in results: | |
target_dict[r["label"]] = ( | |
target_dict.get(r["label"]) + 1 if r["label"] in target_dict else 1 | |
) | |
return dict(zip(target_dict.values(), target_dict.keys()))[ | |
max(target_dict.values()) | |
] | |
def get_sst2_example(): | |
filter_key_words = [ | |
".py", | |
".md", | |
"readme", | |
"log", | |
"result", | |
"zip", | |
".state_dict", | |
".model", | |
".png", | |
"acc_", | |
"f1_", | |
".origin", | |
".adv", | |
".csv", | |
] | |
dataset_file = {"train": [], "test": [], "valid": []} | |
dataset = "sst2" | |
search_path = "./" | |
task = "text_defense" | |
dataset_file["test"] += find_files( | |
search_path, | |
[dataset, "test", task], | |
exclude_key=[".adv", ".org", ".defense", ".inference", "train."] | |
+ filter_key_words, | |
) | |
for dat_type in ["test"]: | |
data = [] | |
label_set = set() | |
for data_file in dataset_file[dat_type]: | |
with open(data_file, mode="r", encoding="utf8") as fin: | |
lines = fin.readlines() | |
for line in lines: | |
text, label = line.split("$LABEL$") | |
text = text.strip() | |
label = int(label.strip()) | |
data.append((text, label)) | |
label_set.add(label) | |
return random.choice(data) | |
def get_agnews_example(): | |
filter_key_words = [ | |
".py", | |
".md", | |
"readme", | |
"log", | |
"result", | |
"zip", | |
".state_dict", | |
".model", | |
".png", | |
"acc_", | |
"f1_", | |
".origin", | |
".adv", | |
".csv", | |
] | |
dataset_file = {"train": [], "test": [], "valid": []} | |
dataset = "agnews" | |
search_path = "./" | |
task = "text_defense" | |
dataset_file["test"] += find_files( | |
search_path, | |
[dataset, "test", task], | |
exclude_key=[".adv", ".org", ".defense", ".inference", "train."] | |
+ filter_key_words, | |
) | |
for dat_type in ["test"]: | |
data = [] | |
label_set = set() | |
for data_file in dataset_file[dat_type]: | |
with open(data_file, mode="r", encoding="utf8") as fin: | |
lines = fin.readlines() | |
for line in lines: | |
text, label = line.split("$LABEL$") | |
text = text.strip() | |
label = int(label.strip()) | |
data.append((text, label)) | |
label_set.add(label) | |
return random.choice(data) | |
def get_amazon_example(): | |
filter_key_words = [ | |
".py", | |
".md", | |
"readme", | |
"log", | |
"result", | |
"zip", | |
".state_dict", | |
".model", | |
".png", | |
"acc_", | |
"f1_", | |
".origin", | |
".adv", | |
".csv", | |
] | |
dataset_file = {"train": [], "test": [], "valid": []} | |
dataset = "amazon" | |
search_path = "./" | |
task = "text_defense" | |
dataset_file["test"] += find_files( | |
search_path, | |
[dataset, "test", task], | |
exclude_key=[".adv", ".org", ".defense", ".inference", "train."] | |
+ filter_key_words, | |
) | |
for dat_type in ["test"]: | |
data = [] | |
label_set = set() | |
for data_file in dataset_file[dat_type]: | |
with open(data_file, mode="r", encoding="utf8") as fin: | |
lines = fin.readlines() | |
for line in lines: | |
text, label = line.split("$LABEL$") | |
text = text.strip() | |
label = int(label.strip()) | |
data.append((text, label)) | |
label_set.add(label) | |
return random.choice(data) | |
def get_imdb_example(): | |
filter_key_words = [ | |
".py", | |
".md", | |
"readme", | |
"log", | |
"result", | |
"zip", | |
".state_dict", | |
".model", | |
".png", | |
"acc_", | |
"f1_", | |
".origin", | |
".adv", | |
".csv", | |
] | |
dataset_file = {"train": [], "test": [], "valid": []} | |
dataset = "imdb" | |
search_path = "./" | |
task = "text_defense" | |
dataset_file["test"] += find_files( | |
search_path, | |
[dataset, "test", task], | |
exclude_key=[".adv", ".org", ".defense", ".inference", "train."] | |
+ filter_key_words, | |
) | |
for dat_type in ["test"]: | |
data = [] | |
label_set = set() | |
for data_file in dataset_file[dat_type]: | |
with open(data_file, mode="r", encoding="utf8") as fin: | |
lines = fin.readlines() | |
for line in lines: | |
text, label = line.split("$LABEL$") | |
text = text.strip() | |
label = int(label.strip()) | |
data.append((text, label)) | |
label_set.add(label) | |
return random.choice(data) | |