DCWIR-Offcial-Demo / textattack /models /helpers /word_cnn_for_classification.py
PFEemp2024's picture
solving GPU error for previous version
4a1df2e
raw
history blame contribute delete
No virus
4.92 kB
"""
Word CNN for Classification
---------------------------------------------------------------------
"""
import json
import os
import torch
from torch import nn as nn
from torch.nn import functional as F
import textattack
from textattack.model_args import TEXTATTACK_MODELS
from textattack.models.helpers import GloveEmbeddingLayer
from textattack.models.helpers.utils import load_cached_state_dict
from textattack.shared import utils
class WordCNNForClassification(nn.Module):
"""A convolutional neural network for text classification.
We use different versions of this network to pretrain models for
text classification.
"""
def __init__(
self,
hidden_size=150,
dropout=0.3,
num_labels=2,
max_seq_length=128,
model_path=None,
emb_layer_trainable=True,
):
super().__init__()
self._config = {
"architectures": "WordCNNForClassification",
"hidden_size": hidden_size,
"dropout": dropout,
"num_labels": num_labels,
"max_seq_length": max_seq_length,
"model_path": model_path,
"emb_layer_trainable": emb_layer_trainable,
}
self.drop = nn.Dropout(dropout)
self.emb_layer = GloveEmbeddingLayer(emb_layer_trainable=emb_layer_trainable)
self.word2id = self.emb_layer.word2id
self.encoder = CNNTextLayer(
self.emb_layer.n_d, widths=[3, 4, 5], filters=hidden_size
)
d_out = 3 * hidden_size
self.out = nn.Linear(d_out, num_labels)
self.tokenizer = textattack.models.tokenizers.GloveTokenizer(
word_id_map=self.word2id,
unk_token_id=self.emb_layer.oovid,
pad_token_id=self.emb_layer.padid,
max_length=max_seq_length,
)
if model_path is not None:
self.load_from_disk(model_path)
self.eval()
def load_from_disk(self, model_path):
# TODO: Consider removing this in the future as well as loading via `model_path` in `__init__`.
import warnings
warnings.warn(
"`load_from_disk` method is deprecated. Please save and load using `save_pretrained` and `from_pretrained` methods.",
DeprecationWarning,
stacklevel=2,
)
self.load_state_dict(load_cached_state_dict(model_path))
self.eval()
def save_pretrained(self, output_path):
if not os.path.exists(output_path):
os.makedirs(output_path)
state_dict = {k: v.cpu() for k, v in self.state_dict().items()}
torch.save(state_dict, os.path.join(output_path, "pytorch_model.bin"))
with open(os.path.join(output_path, "config.json"), "w") as f:
json.dump(self._config, f)
@classmethod
def from_pretrained(cls, name_or_path):
"""Load trained Word CNN model by name or from path.
Args:
name_or_path (:obj:`str`): Name of the model (e.g. "cnn-imdb") or model saved via :meth:`save_pretrained`.
Returns:
:class:`~textattack.models.helpers.WordCNNForClassification` model
"""
if name_or_path in TEXTATTACK_MODELS:
path = utils.download_from_s3(TEXTATTACK_MODELS[name_or_path])
else:
path = name_or_path
config_path = os.path.join(path, "config.json")
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)
else:
# Default config
config = {
"architectures": "WordCNNForClassification",
"hidden_size": 150,
"dropout": 0.3,
"num_labels": 2,
"max_seq_length": 128,
"model_path": None,
"emb_layer_trainable": True,
}
del config["architectures"]
model = cls(**config)
state_dict = load_cached_state_dict(path)
model.load_state_dict(state_dict)
return model
def forward(self, _input):
emb = self.emb_layer(_input)
emb = self.drop(emb)
output = self.encoder(emb)
output = self.drop(output)
pred = self.out(output)
return pred
def get_input_embeddings(self):
return self.emb_layer.embedding
class CNNTextLayer(nn.Module):
def __init__(self, n_in, widths=[3, 4, 5], filters=100):
super().__init__()
Ci = 1
Co = filters
h = n_in
self.convs1 = nn.ModuleList([nn.Conv2d(Ci, Co, (w, h)) for w in widths])
def forward(self, x):
x = x.unsqueeze(1) # (batch, Ci, len, d)
x = [
F.relu(conv(x)).squeeze(3) for conv in self.convs1
] # [(batch, Co, len), ...]
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] # [(N,Co), ...]
x = torch.cat(x, 1)
return x