Ngit's picture
Update README.md
e0009c4 verified
|
raw
history blame
No virus
4.07 kB
metadata
language:
  - en

Text Classification Toxicity

This model is a fined-tuned version of nreimers/MiniLMv2-L6-H384-distilled-from-BERT-Large on the on the Jigsaw 1st Kaggle competition dataset using unitary/toxic-bert as teacher model. The original unquantized model can be found here here.

The model contains two labels only (toxicity and severe toxicity). For the model with all labels refer to this page

Usage

Installation

pip install tokenizers
pip install onnxruntime
git clone https://huggingface.co/minuva/MiniLMv2-toxic-jigsaw-lite-onnx

Load the Model

import os
import numpy as np
import json

from tokenizers import Tokenizer
from onnxruntime import InferenceSession


model_name = "minuva/MiniLMv2-toxic-jigsaw-lite-onnx"
tokenizer = Tokenizer.from_pretrained(model_name)
tokenizer.enable_padding()
tokenizer.enable_truncation(max_length=256)
batch_size = 16

texts = ["This is pure trash",]
outputs = []
model = InferenceSession("MiniLMv2-toxic-jigsaw-lite-onnx/model_optimized_quantized.onnx", providers=['CPUExecutionProvider'])

with open(os.path.join("MiniLMv2-toxic-jigsaw-lite-onnx", "config.json"), "r") as f:
            config = json.load(f)

output_names = [output.name for output in model.get_outputs()]
input_names = [input.name for input in model.get_inputs()]

for subtexts in np.array_split(np.array(texts), len(texts) // batch_size + 1):
            encodings = tokenizer.encode_batch(list(subtexts))
            inputs = {
                "input_ids": np.vstack(
                    [encoding.ids for encoding in encodings],
                ),
                "attention_mask": np.vstack(
                    [encoding.attention_mask for encoding in encodings],
                ),
                "token_type_ids": np.vstack(
                    [encoding.type_ids for encoding in encodings],
                ),
            }

            for input_name in input_names:
                if input_name not in inputs:
                    raise ValueError(f"Input name {input_name} not found in inputs")

            inputs = {input_name: inputs[input_name] for input_name in input_names}
            output = np.squeeze(
                np.stack(
                    model.run(output_names=output_names, input_feed=inputs)
                ),
                axis=0,
            )
            outputs.append(output)

outputs = np.concatenate(outputs, axis=0)
scores = 1 / (1 + np.exp(-outputs))
results = []
for item in scores:
    labels = []
    scores = []
    for idx, s in enumerate(item):
        labels.append(config["id2label"][str(idx)])
        scores.append(float(s))
    results.append({"labels": labels, "scores": scores})

res = []

for result in results:
    joined = list(zip(result['labels'], result['scores']))
    max_score = max(joined, key=lambda x: x[1])    
    res.append(max_score)

res
# [('toxic', 0.6553249955177307)]

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 6e-05
  • train_batch_size: 48
  • eval_batch_size: 48
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: linear
  • num_epochs: 10
  • warmup_ratio: 0.1

Metrics (comparison with teacher model)

Teacher (params) Student (params) Set (metric) Score (teacher) Score (student)
unitary/toxic-bert (110M) MiniLMv2-toxic-jigsaw-lite (23M) Test (ROC_AUC) 0.982677 0.9806

Deployment

Check this repository to see how to easily deploy this model in a serverless environment with fast CPU inference and light resource utilization.