JerryLiJinyi's picture
Upload 127 files
10b912d verified
raw
history blame
2.77 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoModel
from pathlib import Path
class LinearTokenSelector(nn.Module):
def __init__(self, encoder, embedding_size=768):
super(LinearTokenSelector, self).__init__()
self.encoder = encoder
self.classifier = nn.Linear(embedding_size, 2, bias=False)
def forward(self, x):
output = self.encoder(x, output_hidden_states=True)
x = output["hidden_states"][-1] # B * S * H
x = self.classifier(x)
x = F.log_softmax(x, dim=2)
return x
def save(self, classifier_path, encoder_path):
state = self.state_dict()
state = dict((k, v) for k, v in state.items() if k.startswith("classifier"))
torch.save(state, classifier_path)
self.encoder.save_pretrained(encoder_path)
def predict(self, texts, tokenizer, device):
input_ids = tokenizer(texts)["input_ids"]
input_ids = pad_sequence(
[torch.tensor(ids) for ids in input_ids], batch_first=True
).to(device)
logits = self.forward(input_ids)
argmax_labels = torch.argmax(logits, dim=2)
return labels_to_summary(input_ids, argmax_labels, tokenizer)
def load_model(model_dir, device="cuda", prefix="best"):
if isinstance(model_dir, str):
model_dir = Path(model_dir)
for p in (model_dir / "checkpoints").iterdir():
if p.name.startswith(f"{prefix}"):
checkpoint_dir = p
return load_checkpoint(checkpoint_dir, device=device)
def load_checkpoint(checkpoint_dir, device="cuda"):
if isinstance(checkpoint_dir, str):
checkpoint_dir = Path(checkpoint_dir)
encoder_path = checkpoint_dir / "encoder.bin"
classifier_path = checkpoint_dir / "classifier.bin"
encoder = AutoModel.from_pretrained(encoder_path).to(device)
embedding_size = encoder.state_dict()["embeddings.word_embeddings.weight"].shape[1]
classifier = LinearTokenSelector(None, embedding_size).to(device)
classifier_state = torch.load(classifier_path, map_location=device)
classifier_state = dict(
(k, v) for k, v in classifier_state.items()
if k.startswith("classifier")
)
classifier.load_state_dict(classifier_state)
classifier.encoder = encoder
return classifier.to(device)
def labels_to_summary(input_batch, label_batch, tokenizer):
summaries = []
for input_ids, labels in zip(input_batch, label_batch):
selected = [int(input_ids[i]) for i in range(len(input_ids))
if labels[i] == 1]
summary = tokenizer.decode(selected, skip_special_tokens=True)
summaries.append(summary)
return summaries