KSSDS: Korean Sentence Splitter for Dialogue Systems

KSSDS๋Š” lcw99/t5-base-korean-text-summary์˜ ์ธ์ฝ”๋”๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ํ•˜์—ฌ,
AI Hub ๋ฐ์ดํ„ฐ๋ฅผ ์ด์šฉํ•ด fine-tuningํ•œ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.

์ด ๋ชจ๋ธ์€ ํ•œ๊ตญ์–ด ๋Œ€ํ™” ์‹œ์Šคํ…œ ์šฉ ๋ฌธ์žฅ ๋ถ„๋ฆฌ๊ธฐ๋กœ,
Whisper์™€ ๊ฐ™์€ STT ๋ชจ๋ธ์ด ์ƒ์„ฑํ•œ ํ•œ๊ตญ์–ด ํ…์ŠคํŠธ๋ฅผ ๋ฌธ์žฅ ๋‹จ์œ„๋กœ ๋ถ„๋ฆฌํ•˜๋Š” ๊ฒƒ์„ ๋ชฉํ‘œ๋กœ ๋งŒ๋“ค์–ด์กŒ์Šต๋‹ˆ๋‹ค.

์ž์„ธํ•œ ์„ค๋ช…์€ KSSDS GitHub repository๋ฅผ ์ฐธ๊ณ ํ•ด์ฃผ์„ธ์š”.


์‚ฌ์šฉ ๋ฐฉ๋ฒ•

1. PyPI ๋˜๋Š” GitHub ์„ค์น˜๋ฅผ ํ†ตํ•œ ์‚ฌ์šฉ (๊ถŒ์žฅ)

  • PyPI ๋˜๋Š” GitHub๋ฅผ ํ†ตํ•ด ์„ค์น˜ํ•˜๋ฉด ๋”์šฑ ๋” ํŽธ๋ฆฌํ•˜๊ฒŒ ์‚ฌ์šฉํ•˜์‹ค ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • KSSDS GitHub repository์—์„œ ์„ค์ • ๋ฐ ์‚ฌ์šฉ ๋ฐฉ๋ฒ•์„ ํ™•์ธํ•˜์„ธ์š”.

2. Hugging Face Hub ๋„ค์ดํ‹ฐ๋ธŒ ๋ฐฉ์‹ ์‚ฌ์šฉ

Hugging Face Hub์—์„œ ๋ชจ๋ธ๊ณผ T5 ์ธ์ฝ”๋”๋ฅผ ์ง์ ‘ ๋‹ค์šด๋กœ๋“œํ•˜์—ฌ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
์•„๋ž˜๋Š” ๋ฌธ์žฅ ๋ถ„๋ฆฌ๋ฅผ ์ˆ˜ํ–‰ํ•˜๋Š” ์ „์ฒด ํŒŒ์ดํ”„๋ผ์ธ์˜ ์˜ˆ์ œ ์ฝ”๋“œ์ž…๋‹ˆ๋‹ค.

# ํ•„์š”ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ๋ฐ ๋ชจ๋“ˆ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
from huggingface_hub import hf_hub_download  # Hugging Face Hub์—์„œ ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ
import sys
import os
from transformers import AutoTokenizer  # Hugging Face Tokenizer
import torch
from typing import List

# Hugging Face Hub์—์„œ T5_encoder.py์™€ ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ
model_name = "ggomarobot/KSSDS"
file_path = hf_hub_download(repo_id=model_name, filename="T5_encoder.py")  # T5_encoder.py ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ
module_dir = os.path.dirname(file_path)  # T5_encoder.py ํŒŒ์ผ์ด ์žˆ๋Š” ๋””๋ ‰ํ„ฐ๋ฆฌ ๊ฒฝ๋กœ
sys.path.append(module_dir)  # Python ๊ฒฝ๋กœ์— T5_encoder ๋””๋ ‰ํ„ฐ๋ฆฌ ์ถ”๊ฐ€

from T5_encoder import T5ForTokenClassification  # ์ปค์Šคํ…€ T5 ์ธ์ฝ”๋” ๋ชจ๋ธ

tokenizer = AutoTokenizer.from_pretrained(model_name)  # ํ† ํฌ๋‚˜์ด์ € ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
model = T5ForTokenClassification.from_pretrained(model_name)  # ์ปค์Šคํ…€ T5 ์ธ์ฝ”๋” ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move the model to GPU
model = model.to(device)
model.eval()  # ๋ชจ๋ธ์„ ํ‰๊ฐ€ ๋ชจ๋“œ๋กœ ์„ค์ •

# ์ตœ๋Œ€ ์ž…๋ ฅ ๊ธธ์ด ์„ค์ •
max_length = 512

def preprocess_text(input_text: str) -> List[dict]:
    """
    ๊ธด ์ž…๋ ฅ ํ…์ŠคํŠธ๋ฅผ ์ฒ˜๋ฆฌ ๊ฐ€๋Šฅํ•œ ์ฒญํฌ๋กœ ๋‚˜๋ˆ„๋Š” ํ•จ์ˆ˜.

    Args:
        input_text (str): ์ฒ˜๋ฆฌํ•  ๊ธด ์ž…๋ ฅ ํ…์ŠคํŠธ.
    Returns:
        List[dict]: ์ฒญํฌ ๋‹จ์œ„์˜ input_ids์™€ attention_mask๋ฅผ ํฌํ•จํ•œ ๋ฆฌ์ŠคํŠธ.
    """
    tokenized_sequence = tokenizer.encode(input_text, add_special_tokens=False)
    chunks = [tokenized_sequence[i:i + max_length] for i in range(0, len(tokenized_sequence), max_length)]
    processed_data = [{"input_ids": chunk, "attention_mask": [1] * len(chunk)} for chunk in chunks]
    return processed_data

def handle_repetitions(sentences: List[str], max_repeats: int = 60, detection_threshold: int = 70, max_phrase_length: int = 2) -> List[str]:
    """
    Handles single-word and phrase repetitions in a list of sentences, ensuring proper order and separation.

    Args:
        sentences (List[str]): List of input sentences to process.
        max_repeats (int): Maximum number of phrase repetitions to allow before splitting.
        detection_threshold (int): Minimum length of text for repetition detection.
        max_phrase_length (int): Maximum length of a phrase to consider for repetition detection.

    Returns:
        List[str]: The processed text split into sentences.
    """
    processed_sentences = []

    for sentence in sentences:
        words = sentence.split()
        if len(words) <= detection_threshold:
            processed_sentences.append(sentence)
            continue

        result_sentences = []
        current_sentence = []
        current_repetition = []

        def flush_sentence():
            """Flush the current sentence into result_sentences."""
            if current_sentence:
                result_sentences.append(" ".join(current_sentence))
                current_sentence.clear()

        def flush_repetition(phrase_length):
            """Flush the current repetition into result_sentences."""
            for i in range(0, len(current_repetition), phrase_length * max_repeats):
                chunk = current_repetition[i:i + phrase_length * max_repeats]
                result_sentences.append(" ".join(chunk))
            current_repetition.clear()

        def find_repeating_phrase(start_idx):
            """Find the smallest repeating phrase starting at the given index."""
            for phrase_length in range(1, max_phrase_length + 1):
                phrase = words[start_idx:start_idx + phrase_length]
                next_idx = start_idx + phrase_length
                if next_idx + phrase_length <= len(words) and words[next_idx:next_idx + phrase_length] == phrase:
                    return phrase
            return None

        i = 0
        while i < len(words):
            repeating_phrase = find_repeating_phrase(i)
            if repeating_phrase:
                # Flush any ongoing sentence before handling repetition
                flush_sentence()

                # Accumulate repeating phrases
                phrase_length = len(repeating_phrase)
                while i + phrase_length <= len(words) and words[i:i + phrase_length] == repeating_phrase:
                    current_repetition.extend(repeating_phrase)
                    i += phrase_length

                # Flush accumulated repetition if it reaches the threshold
                if len(current_repetition) >= phrase_length * max_repeats:
                    flush_repetition(phrase_length)
            else:
                # Add non-repeating words to the current sentence
                if current_repetition:
                    # Flush repetition before starting a new sentence
                    flush_repetition(1)  # Default to single-word repetition
                current_sentence.append(words[i])
                i += 1

        # Flush any remaining tokens
        flush_sentence()
        flush_repetition(1)  # Default to single-word repetition for the last chunk

        processed_sentences.extend(result_sentences)

    return processed_sentences

def segment_predictions(input_ids: List[int], predictions: List[int]) -> List[List[int]]:
    """
    Segment predictions into sentences based on label 1 (sentence-ending).
    Args:
        input_ids (List[int]): List of input token IDs.
        predictions (List[int]): Corresponding prediction labels.
    Returns:
        List[List[int]]: Segmented sentences as lists of token IDs.
    """
    segments = []
    current_segment = []

    for token, label in zip(input_ids, predictions):
        if label == 1:  # Sentence-ending label
            if current_segment:
                current_segment.append(token)
                segments.append(current_segment)
                current_segment = []
            else:
                segments.append([token])
        else:
            current_segment.append(token)

    if current_segment:  # Append any remaining tokens
        segments.append(current_segment)

    return segments


def decode_predictions(input_ids: List[int], predictions: List[int], tokenizer, carry_over=None):
    """
    Decode model predictions into sentences, handling carry-over tokens across chunks.
    Args:
        input_ids (List[int]): Input token IDs.
        predictions (List[int]): Prediction labels.
        tokenizer: Hugging Face tokenizer instance.
        carry_over (List[int], optional): Tokens carried over from the previous chunk.
    Returns:
        Tuple[List[str], List[int]]: Decoded sentences and remaining carry-over tokens.
    """
    if carry_over is None:
        carry_over = []

    sentences = []
    tokens = carry_over + input_ids  # Include carry-over tokens
    labels = [0] * len(carry_over) + predictions  # Carry-over tokens have label 0

    segmented = segment_predictions(tokens, labels)

    for segment in segmented[:-1]:  # Decode all segments except the last one
        sentence = tokenizer.decode(segment, skip_special_tokens=False, clean_up_tokenization_spaces=False).strip()
        if sentence:
            sentences.append(sentence)

    # Handle carry-over for the last segment
    carry_over = segmented[-1] if segmented[-1] and labels[len(tokens) - len(segmented[-1])] != 1 else []

    return sentences, carry_over


def inference(input_text: str) -> List[str]:
    """
    Perform sentence splitting on input text using the HF Hub model.
    Args:
        input_text (str): Input text to split into sentences.
    Returns:
        List[str]: List of split sentences.
    """
    chunks = preprocess_text(input_text)
    carry_over_tokens = []
    sentences = []

    for chunk in chunks:
        # Prepare inputs for the model
        input_ids_tensor = torch.tensor([chunk["input_ids"]], dtype=torch.long, device=device)
        attention_mask_tensor = torch.tensor([chunk["attention_mask"]], dtype=torch.long, device=device)

        with torch.no_grad():
            outputs = model(input_ids=input_ids_tensor, attention_mask=attention_mask_tensor)
            predictions = outputs.logits.argmax(dim=-1).squeeze().tolist()

            # Ensure predictions is a list
            if isinstance(predictions, int):
                predictions = [predictions]

        input_ids = chunk["input_ids"]
        decoded_sentences, carry_over_tokens = decode_predictions(input_ids, predictions, tokenizer, carry_over_tokens)
        sentences.extend(decoded_sentences)

    # Process any remaining carry-over tokens
    if carry_over_tokens:
        remaining_sentence = tokenizer.decode(carry_over_tokens, skip_special_tokens=False, clean_up_tokenization_spaces=False).strip()
        if remaining_sentence:
            sentences.append(remaining_sentence)

    return handle_repetitions(sentences)

# ์‚ฌ์šฉ ์˜ˆ์ œ
input_text = "์•ˆ๋…•ํ•˜์„ธ์š”. ์˜ค๋Š˜ ๋‚ ์”จ๊ฐ€ ์ฐธ ์ข‹๋„ค์š”. ์ €๋Š” ์‚ฐ์ฑ…์„ ๋‚˜๊ฐˆ ์˜ˆ์ •์ž…๋‹ˆ๋‹ค."
split_sentences = inference(input_text)

# ๊ฒฐ๊ณผ ์ถœ๋ ฅ
for idx, sentence in enumerate(split_sentences):
    print(f"{idx + 1}: {sentence}")

KSSDS_NO_LF ๋ชจ๋ธ ์†Œ๊ฐœ

KSSDS_NO_LF๋Š” ablation study๋ฅผ ์œ„ํ•ด Length Filter๋ฅผ ์ ์šฉํ•˜์ง€ ์•Š๊ณ  ํ•™์Šต๋œ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.
์ด ๋ชจ๋ธ์€ ๊ธด ํ…์ŠคํŠธ๋ฅผ ์ฒ˜๋ฆฌํ•  ๋•Œ ์ƒ๋Œ€์ ์œผ๋กœ ๋œ ์ •๋ฐ€ํ•œ ๋ฌธ์žฅ ๋ถ„๋ฆฌ๋ฅผ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.

์‚ฌ์šฉ ๋ฐฉ์‹์€ ๋ณธ ๋ชจ๋ธ๊ณผ ๋™์ผํ•˜์ง€๋งŒ
์—ฐ๊ตฌ ๋ฐ ๋น„๊ต ๋ชฉ์  ์™ธ์—๋Š” ์‹ค์ œ ์‚ฌ์šฉ์— ์ ํ•ฉํ•˜์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์‚ฌ์šฉ ๋ฐฉ๋ฒ•

์œ„์˜ Hugging Face Hub ๋„ค์ดํ‹ฐ๋ธŒ ๋ฐฉ์‹ ์‚ฌ์šฉ ์ฝ”๋“œ ์Šค๋‹ˆํŽซ์—์„œ model_name์„ "ggomarobot/KSSDS_NO_LF"๋กœ ๊ต์ฒดํ•˜์—ฌ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

# Load tokenizer and model
model_name = "ggomarobot/KSSDS_NO_LF"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = T5ForTokenClassification.from_pretrained(model_name)
Downloads last month
14
Safetensors
Model size
124M params
Tensor type
F32
ยท
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.

Model tree for ggomarobot/KSSDS

Finetuned
(3)
this model