arcadinis's picture
Update README.md
14bc4e3 verified
metadata
language: en
tags:
  - text segmentation
  - document chunking
license: apache-2.0
datasets:
  - wikipedia
pipeline_tag: text-classification
base_model: distilbert/distilbert-base-uncased
widget:
  - text: Left context. [SEP] Right context.
  - text: >-
      They have 6 grandchildren. [SEP] Ane is currently coaching Crestwood High
      School's Boys Varsity Soccer.

DistilBERT Cross Segment Document Chunking

This model is a fine-tuned version of distilbert-base-uncased for classifying if two subsequent sentences are from the same Wikipedia article section. Intended usage is text segmantation/document chunking. It is based on the article Text Segmentation by Cross Segment Attention by Michal Lukasik, Boris Dadachev, Gonc¸alo Simoes and Kishore Papineni.

How to use it

One way to use this model is via the HuggingFace transformers TextClassificationPipeline class.

from transformers import (
    AutoModelForSequenceClassification,
    DistilBertTokenizer,
    TextClassificationPipeline
)

model_name = "BlueOrangeDigital/distilbert-cross-segment-document-chunking"

id2label = {0: "SAME", 1: "DIFFERENT"}
label2id = {"SAME": 0, "DIFFERENT": 1}

tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=2,
    id2label=id2label,
    label2id=label2id
)

pairs = [
    "Left context. [SEP] Right context.",
    "They have 6 grandchildren. [SEP] Ane is currently coaching Crestwood High School's Boys Varsity Soccer.",
]
pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True)

pipe(pairs)

[[{'label': 'SAME', 'score': 0.986},
  {'label': 'DIFFERENT', 'score': 0.015}],
 [{'label': 'SAME', 'score': 0.212},
  {'label': 'DIFFERENT', 'score': 0.788}]]

Training Data

Sentences pairs from 40,000 (train) + 4,000 (validation) Wikipedia articles. Label 1: Two subsequent sentences that are not from the same article section; Label 0: Every other pair of subsequent sentences.

Label 0 pairs were undersampled, resulting in a total of 408,753 and 45,417 training and validation pairs, respectively.

The input of the model are of the form

[CLS] Right context [SEP] Left context [SEP]

Given DistilBERT 512 token limit, both right and left context are limited to 255 token length. When exceeding this limit, the sentence was truncated (either the beggining or the end of the sentence, for right and left context, respectively).

Trainig Procedure

The model was trained for 2 epochs with a learning rate of 1e-5 and cross-entropy loss on a P100 GPU for 8 hours.

Validation Metrics

Loss Accuracy Recall Precision F1
0.34 0.85 0.85 0.86 0.85