PFEemp2024's picture
solving GPU error for previous version
4a1df2e
raw
history blame contribute delete
No virus
3.72 kB
import json
import os
import random
import numpy as np
import torch
import textattack
device = os.environ.get(
"TA_DEVICE", torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
def html_style_from_dict(style_dict):
"""Turns.
{ 'color': 'red', 'height': '100px'}
into
style: "color: red; height: 100px"
"""
style_str = ""
for key in style_dict:
style_str += key + ": " + style_dict[key] + ";"
return 'style="{}"'.format(style_str)
def html_table_from_rows(rows, title=None, header=None, style_dict=None):
# Stylize the container div.
if style_dict:
table_html = "<div {}>".format(html_style_from_dict(style_dict))
else:
table_html = "<div>"
# Print the title string.
if title:
table_html += "<h1>{}</h1>".format(title)
# Construct each row as HTML.
table_html = '<table class="table">'
if header:
table_html += "<tr>"
for element in header:
table_html += "<th>"
table_html += str(element)
table_html += "</th>"
table_html += "</tr>"
for row in rows:
table_html += "<tr>"
for element in row:
table_html += "<td>"
table_html += str(element)
table_html += "</td>"
table_html += "</tr>"
# Close the table and print to screen.
table_html += "</table></div>"
return table_html
def get_textattack_model_num_labels(model_name, model_path):
"""Reads `train_args.json` and gets the number of labels for a trained
model, if present."""
model_cache_path = textattack.shared.utils.download_from_s3(model_path)
train_args_path = os.path.join(model_cache_path, "train_args.json")
if not os.path.exists(train_args_path):
textattack.shared.logger.warn(
f"train_args.json not found in model path {model_path}. Defaulting to 2 labels."
)
return 2
else:
args = json.loads(open(train_args_path).read())
return args.get("num_labels", 2)
def load_textattack_model_from_path(model_name, model_path):
"""Loads a pre-trained TextAttack model from its name and path.
For example, model_name "lstm-yelp" and model path
"models/classification/lstm/yelp".
"""
colored_model_name = textattack.shared.utils.color_text(
model_name, color="blue", method="ansi"
)
if model_name.startswith("lstm"):
num_labels = get_textattack_model_num_labels(model_name, model_path)
textattack.shared.logger.info(
f"Loading pre-trained TextAttack LSTM: {colored_model_name}"
)
model = textattack.models.helpers.LSTMForClassification(
model_path=model_path, num_labels=num_labels
)
elif model_name.startswith("cnn"):
num_labels = get_textattack_model_num_labels(model_name, model_path)
textattack.shared.logger.info(
f"Loading pre-trained TextAttack CNN: {colored_model_name}"
)
model = textattack.models.helpers.WordCNNForClassification(
model_path=model_path, num_labels=num_labels
)
elif model_name.startswith("t5"):
model = textattack.models.helpers.T5ForTextToText(model_path)
else:
raise ValueError(f"Unknown textattack model {model_path}")
return model
def set_seed(random_seed):
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
def hashable(key):
try:
hash(key)
return True
except TypeError:
return False
def sigmoid(n):
return 1 / (1 + np.exp(-n))
GLOBAL_OBJECTS = {}
ARGS_SPLIT_TOKEN = "^"