|
""" |
|
Word CNN for Classification |
|
--------------------------------------------------------------------- |
|
|
|
""" |
|
import json |
|
import os |
|
|
|
import torch |
|
from torch import nn as nn |
|
from torch.nn import functional as F |
|
|
|
import textattack |
|
from textattack.model_args import TEXTATTACK_MODELS |
|
from textattack.models.helpers import GloveEmbeddingLayer |
|
from textattack.models.helpers.utils import load_cached_state_dict |
|
from textattack.shared import utils |
|
|
|
|
|
class WordCNNForClassification(nn.Module): |
|
"""A convolutional neural network for text classification. |
|
|
|
We use different versions of this network to pretrain models for |
|
text classification. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
hidden_size=150, |
|
dropout=0.3, |
|
num_labels=2, |
|
max_seq_length=128, |
|
model_path=None, |
|
emb_layer_trainable=True, |
|
): |
|
super().__init__() |
|
self._config = { |
|
"architectures": "WordCNNForClassification", |
|
"hidden_size": hidden_size, |
|
"dropout": dropout, |
|
"num_labels": num_labels, |
|
"max_seq_length": max_seq_length, |
|
"model_path": model_path, |
|
"emb_layer_trainable": emb_layer_trainable, |
|
} |
|
self.drop = nn.Dropout(dropout) |
|
self.emb_layer = GloveEmbeddingLayer(emb_layer_trainable=emb_layer_trainable) |
|
self.word2id = self.emb_layer.word2id |
|
self.encoder = CNNTextLayer( |
|
self.emb_layer.n_d, widths=[3, 4, 5], filters=hidden_size |
|
) |
|
d_out = 3 * hidden_size |
|
self.out = nn.Linear(d_out, num_labels) |
|
self.tokenizer = textattack.models.tokenizers.GloveTokenizer( |
|
word_id_map=self.word2id, |
|
unk_token_id=self.emb_layer.oovid, |
|
pad_token_id=self.emb_layer.padid, |
|
max_length=max_seq_length, |
|
) |
|
|
|
if model_path is not None: |
|
self.load_from_disk(model_path) |
|
self.eval() |
|
|
|
def load_from_disk(self, model_path): |
|
|
|
import warnings |
|
|
|
warnings.warn( |
|
"`load_from_disk` method is deprecated. Please save and load using `save_pretrained` and `from_pretrained` methods.", |
|
DeprecationWarning, |
|
stacklevel=2, |
|
) |
|
self.load_state_dict(load_cached_state_dict(model_path)) |
|
self.eval() |
|
|
|
def save_pretrained(self, output_path): |
|
if not os.path.exists(output_path): |
|
os.makedirs(output_path) |
|
state_dict = {k: v.cpu() for k, v in self.state_dict().items()} |
|
torch.save(state_dict, os.path.join(output_path, "pytorch_model.bin")) |
|
with open(os.path.join(output_path, "config.json"), "w") as f: |
|
json.dump(self._config, f) |
|
|
|
@classmethod |
|
def from_pretrained(cls, name_or_path): |
|
"""Load trained Word CNN model by name or from path. |
|
|
|
Args: |
|
name_or_path (:obj:`str`): Name of the model (e.g. "cnn-imdb") or model saved via :meth:`save_pretrained`. |
|
Returns: |
|
:class:`~textattack.models.helpers.WordCNNForClassification` model |
|
""" |
|
if name_or_path in TEXTATTACK_MODELS: |
|
path = utils.download_from_s3(TEXTATTACK_MODELS[name_or_path]) |
|
else: |
|
path = name_or_path |
|
|
|
config_path = os.path.join(path, "config.json") |
|
|
|
if os.path.exists(config_path): |
|
with open(config_path, "r") as f: |
|
config = json.load(f) |
|
else: |
|
|
|
config = { |
|
"architectures": "WordCNNForClassification", |
|
"hidden_size": 150, |
|
"dropout": 0.3, |
|
"num_labels": 2, |
|
"max_seq_length": 128, |
|
"model_path": None, |
|
"emb_layer_trainable": True, |
|
} |
|
del config["architectures"] |
|
model = cls(**config) |
|
state_dict = load_cached_state_dict(path) |
|
model.load_state_dict(state_dict) |
|
return model |
|
|
|
def forward(self, _input): |
|
emb = self.emb_layer(_input) |
|
emb = self.drop(emb) |
|
|
|
output = self.encoder(emb) |
|
|
|
output = self.drop(output) |
|
pred = self.out(output) |
|
return pred |
|
|
|
def get_input_embeddings(self): |
|
return self.emb_layer.embedding |
|
|
|
|
|
class CNNTextLayer(nn.Module): |
|
def __init__(self, n_in, widths=[3, 4, 5], filters=100): |
|
super().__init__() |
|
Ci = 1 |
|
Co = filters |
|
h = n_in |
|
self.convs1 = nn.ModuleList([nn.Conv2d(Ci, Co, (w, h)) for w in widths]) |
|
|
|
def forward(self, x): |
|
x = x.unsqueeze(1) |
|
x = [ |
|
F.relu(conv(x)).squeeze(3) for conv in self.convs1 |
|
] |
|
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] |
|
x = torch.cat(x, 1) |
|
return x |
|
|