DCWIR-Offcial-Demo / textattack /models /helpers /lstm_for_classification.py
PFEemp2024's picture
solving GPU error for previous version
4a1df2e
raw
history blame contribute delete
No virus
5.09 kB
"""
LSTM 4 Classification
---------------------------------------------------------------------
"""
import json
import os
import torch
from torch import nn as nn
import textattack
from textattack.model_args import TEXTATTACK_MODELS
from textattack.models.helpers import GloveEmbeddingLayer
from textattack.models.helpers.utils import load_cached_state_dict
from textattack.shared import utils
class LSTMForClassification(nn.Module):
"""A long short-term memory neural network for text classification.
We use different versions of this network to pretrain models for
text classification.
"""
def __init__(
self,
hidden_size=150,
depth=1,
dropout=0.3,
num_labels=2,
max_seq_length=128,
model_path=None,
emb_layer_trainable=True,
):
super().__init__()
self._config = {
"architectures": "LSTMForClassification",
"hidden_size": hidden_size,
"depth": depth,
"dropout": dropout,
"num_labels": num_labels,
"max_seq_length": max_seq_length,
"model_path": model_path,
"emb_layer_trainable": emb_layer_trainable,
}
if depth <= 1:
# Fix error where we ask for non-zero dropout with only 1 layer.
# nn.module.RNN won't add dropout for the last recurrent layer,
# so if that's all we have, this will display a warning.
dropout = 0
self.drop = nn.Dropout(dropout)
self.emb_layer_trainable = emb_layer_trainable
self.emb_layer = GloveEmbeddingLayer(emb_layer_trainable=emb_layer_trainable)
self.word2id = self.emb_layer.word2id
self.encoder = nn.LSTM(
input_size=self.emb_layer.n_d,
hidden_size=hidden_size // 2,
num_layers=depth,
dropout=dropout,
bidirectional=True,
)
d_out = hidden_size
self.out = nn.Linear(d_out, num_labels)
self.tokenizer = textattack.models.tokenizers.GloveTokenizer(
word_id_map=self.word2id,
unk_token_id=self.emb_layer.oovid,
pad_token_id=self.emb_layer.padid,
max_length=max_seq_length,
)
if model_path is not None:
self.load_from_disk(model_path)
self.eval()
def load_from_disk(self, model_path):
# TODO: Consider removing this in the future as well as loading via `model_path` in `__init__`.
import warnings
warnings.warn(
"`load_from_disk` method is deprecated. Please save and load using `save_pretrained` and `from_pretrained` methods.",
DeprecationWarning,
stacklevel=2,
)
self.load_state_dict(load_cached_state_dict(model_path))
self.eval()
def save_pretrained(self, output_path):
if not os.path.exists(output_path):
os.makedirs(output_path)
state_dict = {k: v.cpu() for k, v in self.state_dict().items()}
torch.save(
state_dict,
os.path.join(output_path, "pytorch_model.bin"),
)
with open(os.path.join(output_path, "config.json"), "w") as f:
json.dump(self._config, f)
@classmethod
def from_pretrained(cls, name_or_path):
"""Load trained LSTM model by name or from path.
Args:
name_or_path (:obj:`str`): Name of the model (e.g. "lstm-imdb") or model saved via :meth:`save_pretrained`.
Returns:
:class:`~textattack.models.helpers.LSTMForClassification` model
"""
if name_or_path in TEXTATTACK_MODELS:
# path = utils.download_if_needed(TEXTATTACK_MODELS[name_or_path])
path = utils.download_from_s3(TEXTATTACK_MODELS[name_or_path])
else:
path = name_or_path
config_path = os.path.join(path, "config.json")
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)
else:
# Default config
config = {
"architectures": "LSTMForClassification",
"hidden_size": 150,
"depth": 1,
"dropout": 0.3,
"num_labels": 2,
"max_seq_length": 128,
"model_path": None,
"emb_layer_trainable": True,
}
del config["architectures"]
model = cls(**config)
state_dict = load_cached_state_dict(path)
model.load_state_dict(state_dict)
return model
def forward(self, _input):
# ensure RNN module weights are part of single contiguous chunk of memory
self.encoder.flatten_parameters()
emb = self.emb_layer(_input.t())
emb = self.drop(emb)
output, hidden = self.encoder(emb)
output = torch.max(output, dim=0)[0]
output = self.drop(output)
pred = self.out(output)
return pred
def get_input_embeddings(self):
return self.emb_layer.embedding