nickil's picture
update app
38c2a22
raw
history blame
No virus
5.84 kB
import os
import torch
import datasets
import numpy as np
import pandas as pd
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from transformers import AutoTokenizer, logging
from onnxruntime import InferenceSession
from scipy.special import softmax
from weakly_supervised_parser.model.data_module_loader import DataModule
from weakly_supervised_parser.model.span_classifier import LightningModel
# Disable model checkpoint warnings
logging.set_verbosity_error()
class InsideOutsideStringClassifier:
def __init__(self, model_name_or_path: str, num_labels: int = 2, max_seq_length: int = 256):
self.model_name_or_path = model_name_or_path
self.num_labels = num_labels
self.max_seq_length = max_seq_length
def fit(
self,
train_df: pd.DataFrame,
eval_df: pd.DataFrame,
outputdir: str,
filename: str,
devices: int = 1,
enable_progress_bar: bool = True,
enable_model_summary: bool = False,
enable_checkpointing: bool = True,
logger: bool = False,
accelerator: str = "auto",
train_batch_size: int = 32,
eval_batch_size: int = 32,
learning_rate: float = 5e-6,
max_epochs: int = 10,
dataloader_num_workers: int = 16,
seed: int = 42,
):
data_module = DataModule(
model_name_or_path=self.model_name_or_path,
train_df=train_df,
eval_df=eval_df,
test_df=None,
max_seq_length=self.max_seq_length,
train_batch_size=train_batch_size,
eval_batch_size=eval_batch_size,
num_workers=dataloader_num_workers,
)
model = LightningModel(
model_name_or_path=self.model_name_or_path,
lr=learning_rate,
num_labels=self.num_labels,
train_batch_size=train_batch_size,
eval_batch_size=eval_batch_size,
)
seed_everything(seed, workers=True)
callbacks = []
callbacks.append(EarlyStopping(monitor="val_loss", patience=2, mode="min", check_finite=True))
callbacks.append(ModelCheckpoint(monitor="val_loss", dirpath=outputdir, filename=filename, save_top_k=1, save_weights_only=True, mode="min"))
trainer = Trainer(
accelerator=accelerator,
devices=devices,
max_epochs=max_epochs,
callbacks=callbacks,
enable_progress_bar=enable_progress_bar,
enable_model_summary=enable_model_summary,
enable_checkpointing=enable_checkpointing,
logger=logger,
)
trainer.fit(model, data_module)
trainer.validate(model, data_module.val_dataloader())
train_batch = next(iter(data_module.train_dataloader()))
model.to_onnx(
file_path=f"{outputdir}/{filename}.onnx",
input_sample=(train_batch["input_ids"].cuda(), train_batch["attention_mask"].cuda()),
export_params=True,
opset_version=11,
input_names=["input", "attention_mask"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, "attention_mask": {0: "batch_size"}, "output": {0: "batch_size"}},
)
def load_model(self, pre_trained_model_path):
self.model = InferenceSession(pre_trained_model_path, providers=["CPUExecutionProvider"]) #providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
def preprocess_function(self, data):
features = self.tokenizer(
data["sentence"], max_length=self.max_seq_length, padding="max_length", add_special_tokens=True, truncation=True, return_tensors="np"
)
return features
def process_spans(self, spans, scale_axis):
spans_dataset = datasets.Dataset.from_pandas(spans)
processed = spans_dataset.map(self.preprocess_function, batched=True, batch_size=None)
inputs = {"input": processed["input_ids"], "attention_mask": processed["attention_mask"]}
with torch.no_grad():
return softmax(self.model.run(None, inputs)[0], axis=scale_axis)
def predict_proba(self, spans, scale_axis, predict_batch_size):
if spans.shape[0] > predict_batch_size:
output = []
span_batches = np.array_split(spans, spans.shape[0] // predict_batch_size)
for span_batch in span_batches:
output.extend(self.process_spans(span_batch, scale_axis))
return np.vstack(output)
else:
return self.process_spans(spans, scale_axis)
def predict(self, spans):
return self.predict_proba(spans).argmax(axis=1)
class InsideOutsideStringPredictor:
def __init__(self, model_name_or_path, max_seq_length, pre_trained_model_path, num_workers=32):
self.model_name_or_path = model_name_or_path
self.pre_trained_model_path = pre_trained_model_path
self.max_seq_length = max_seq_length
self.num_workers = num_workers
def predict_proba(self, test_df):
test_dataloader = data_module = DataModule(
model_name_or_path=self.model_name_or_path,
train_df=None,
eval_df=None,
test_df=test_df,
max_seq_length=self.max_seq_length,
num_workers=self.num_workers,
)
return trainer.predict(model, dataloaders=test_dataloader)