language: en
tags:
- text segmentation
- document chunking
license: apache-2.0
datasets:
- wikipedia
pipeline_tag: text-classification
base_model: distilbert/distilbert-base-uncased
widget:
- text: Left context. [SEP] Right context.
- text: >-
They have 6 grandchildren. [SEP] Ane is currently coaching Crestwood High
School's Boys Varsity Soccer.
DistilBERT Cross Segment Document Chunking
This model is a fine-tuned version of distilbert-base-uncased for classifying if two subsequent sentences are from the same Wikipedia article section. Intended usage is text segmantation/document chunking. It is based on the article Text Segmentation by Cross Segment Attention by Michal Lukasik, Boris Dadachev, Gonc¸alo Simoes and Kishore Papineni.
How to use it
One way to use this model is via the HuggingFace transformers TextClassificationPipeline class.
from transformers import (
AutoModelForSequenceClassification,
DistilBertTokenizer,
TextClassificationPipeline
)
model_name = "BlueOrangeDigital/distilbert-cross-segment-document-chunking"
id2label = {0: "SAME", 1: "DIFFERENT"}
label2id = {"SAME": 0, "DIFFERENT": 1}
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
num_labels=2,
id2label=id2label,
label2id=label2id
)
pairs = [
"Left context. [SEP] Right context.",
"They have 6 grandchildren. [SEP] Ane is currently coaching Crestwood High School's Boys Varsity Soccer.",
]
pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True)
pipe(pairs)
[[{'label': 'SAME', 'score': 0.986},
{'label': 'DIFFERENT', 'score': 0.015}],
[{'label': 'SAME', 'score': 0.212},
{'label': 'DIFFERENT', 'score': 0.788}]]
Training Data
Sentences pairs from 40,000 (train) + 4,000 (validation) Wikipedia articles. Label 1: Two subsequent sentences that are not from the same article section; Label 0: Every other pair of subsequent sentences.
Label 0 pairs were undersampled, resulting in a total of 408,753 and 45,417 training and validation pairs, respectively.
The input of the model are of the form
[CLS] Right context [SEP] Left context [SEP]
Given DistilBERT 512 token limit, both right and left context are limited to 255 token length. When exceeding this limit, the sentence was truncated (either the beggining or the end of the sentence, for right and left context, respectively).
Trainig Procedure
The model was trained for 2 epochs with a learning rate of 1e-5 and cross-entropy loss on a P100 GPU for 8 hours.
Validation Metrics
Loss | Accuracy | Recall | Precision | F1 |
---|---|---|---|---|
0.34 | 0.85 | 0.85 | 0.86 | 0.85 |